Skip to main content

trueno/matrix/ops/ml_ops/
mod.rs

1//! Machine learning operations for Matrix
2//!
3//! This module provides ML-specific operations:
4//! - `convolve2d()` - 2D convolution
5//! - `embedding_lookup()` - Embedding table lookup
6//! - `embedding_lookup_sparse()` - Embedding lookup with gradient tracking
7//! - `max_pool2d()` - Max pooling
8//! - `avg_pool2d()` - Average pooling
9//! - `topk()` - Top-K selection
10//! - `gather()` - Gather elements along axis
11//! - `pad()` - Pad matrix with constant value
12
13mod convolution;
14mod pooling;
15
16use crate::TruenoError;
17
18use super::super::Matrix;
19
20impl Matrix<f32> {
21    /// Lookup embeddings by indices
22    ///
23    /// Performs embedding lookup where self is the embedding table with shape
24    /// `[vocab_size, embed_dim]` and indices specify which rows to select.
25    ///
26    /// # Arguments
27    ///
28    /// * `indices` - Slice of indices into the embedding table
29    ///
30    /// # Returns
31    ///
32    /// A matrix with shape `[indices.len(), embed_dim]` containing the selected rows
33    ///
34    /// # Errors
35    ///
36    /// Returns `InvalidInput` if any index is out of bounds
37    ///
38    /// # Example
39    ///
40    /// ```
41    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
42    /// use trueno::Matrix;
43    ///
44    /// // Create embedding table: 4 words, 3-dimensional embeddings
45    /// let embeddings = Matrix::from_vec(4, 3, vec![
46    ///     1.0, 2.0, 3.0,   // word 0
47    ///     4.0, 5.0, 6.0,   // word 1
48    ///     7.0, 8.0, 9.0,   // word 2
49    ///     10.0, 11.0, 12.0 // word 3
50    /// ])?;
51    ///
52    /// // Lookup embeddings for indices [1, 3, 0]
53    /// let result = embeddings.embedding_lookup(&[1, 3, 0])?;
54    ///
55    /// assert_eq!(result.rows(), 3);
56    /// assert_eq!(result.cols(), 3);
57    /// assert_eq!(result.get(0, 0), Some(&4.0)); // word 1
58    /// assert_eq!(result.get(1, 0), Some(&10.0)); // word 3
59    /// assert_eq!(result.get(2, 0), Some(&1.0)); // word 0
60    /// # Ok(())
61    /// # }
62    /// ```
63    pub fn embedding_lookup(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
64        // Validate indices
65        contract_pre_embedding_lookup!(indices);
66        for (i, &idx) in indices.iter().enumerate() {
67            if idx >= self.rows {
68                return Err(TruenoError::InvalidInput(format!(
69                    "Index {} at position {} is out of bounds for embedding table with {} rows",
70                    idx, i, self.rows
71                )));
72            }
73        }
74
75        // Handle empty indices
76        if indices.is_empty() {
77            return Ok(Matrix::zeros_with_backend(0, self.cols, self.backend));
78        }
79
80        // Allocate output matrix: [seq_len, embed_dim]
81        // Uninit: every element gets written by copy_from_slice below.
82        let seq_len = indices.len();
83        let embed_dim = self.cols;
84        let total = seq_len * embed_dim;
85        let mut data: Vec<f32> = Vec::with_capacity(total);
86        // SAFETY: Loop below writes every element via copy_from_slice.
87        unsafe {
88            data.set_len(total);
89        }
90        let mut result = Matrix { rows: seq_len, cols: embed_dim, data, backend: self.backend };
91
92        // Copy rows from embedding table to result
93        for (out_row, &idx) in indices.iter().enumerate() {
94            let src_start = idx * embed_dim;
95            let dst_start = out_row * embed_dim;
96
97            result.data[dst_start..dst_start + embed_dim]
98                .copy_from_slice(&self.data[src_start..src_start + embed_dim]);
99        }
100
101        contract_post_embedding_lookup!(&result.data);
102        Ok(result)
103    }
104
105    /// Lookup embeddings with gradient tracking support (for training)
106    ///
107    /// Returns both the embeddings and a sparse gradient accumulator.
108    /// This is useful for sparse gradient updates in training.
109    ///
110    /// # Arguments
111    ///
112    /// * `indices` - Slice of indices into the embedding table
113    ///
114    /// # Returns
115    ///
116    /// Tuple of (embeddings, unique_indices) where unique_indices can be used
117    /// for sparse gradient updates
118    ///
119    /// # Errors
120    ///
121    /// Returns `InvalidInput` if any index is out of bounds
122    pub fn embedding_lookup_sparse(
123        &self,
124        indices: &[usize],
125    ) -> Result<(Matrix<f32>, Vec<usize>), TruenoError> {
126        let embeddings = self.embedding_lookup(indices)?;
127
128        // Get unique indices for sparse gradient updates
129        let mut unique: Vec<usize> = indices.to_vec();
130        unique.sort_unstable();
131        unique.dedup();
132
133        Ok((embeddings, unique))
134    }
135
136    /// Top-K selection: returns the k largest elements and their indices
137    ///
138    /// Useful for beam search, sampling, and ranking operations.
139    /// Searches row-major order and returns (values, indices) sorted descending.
140    ///
141    /// # Examples
142    /// ```
143    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
144    /// use trueno::matrix::Matrix;
145    /// let m = Matrix::from_vec(2, 3, vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0])?;
146    /// let (values, indices) = m.topk(2)?;
147    /// assert_eq!(values, vec![6.0, 5.0]);
148    /// assert_eq!(indices, vec![4, 1]);  // flat indices
149    /// # Ok(())
150    /// # }
151    /// ```
152    pub fn topk(&self, k: usize) -> Result<(Vec<f32>, Vec<usize>), TruenoError> {
153        if k == 0 {
154            return Ok((vec![], vec![]));
155        }
156
157        let k = k.min(self.data.len());
158        let mut indexed: Vec<(usize, f32)> = self.data.iter().copied().enumerate().collect();
159
160        // Partial sort - only sort k elements
161        indexed.select_nth_unstable_by(k.saturating_sub(1), |a, b| {
162            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
163        });
164
165        indexed.truncate(k);
166        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
167
168        let values: Vec<f32> = indexed.iter().map(|(_, v)| *v).collect();
169        let indices: Vec<usize> = indexed.iter().map(|(i, _)| *i).collect();
170
171        Ok((values, indices))
172    }
173
174    /// Gather elements along axis using indices
175    ///
176    /// For 2D matrix with axis=0: output[i] = self[indices[i], :]
177    /// For 2D matrix with axis=1: output[:, i] = self[:, indices[i]]
178    ///
179    /// # Examples
180    /// ```
181    /// use trueno::matrix::Matrix;
182    /// let m = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
183    /// let gathered = m.gather(&[2, 0], 0).unwrap();  // Select rows 2 and 0
184    /// assert_eq!(gathered.shape(), (2, 2));
185    /// assert_eq!(gathered.get(0, 0), Some(&5.0));  // Row 2
186    /// assert_eq!(gathered.get(1, 0), Some(&1.0));  // Row 0
187    /// ```
188    pub fn gather(&self, indices: &[usize], axis: usize) -> Result<Matrix<f32>, TruenoError> {
189        match axis {
190            0 => self.gather_rows(indices),
191            1 => self.gather_cols(indices),
192            _ => Err(TruenoError::InvalidInput(format!(
193                "Axis {} not supported for 2D matrix (use 0 or 1)",
194                axis
195            ))),
196        }
197    }
198
199    fn gather_rows(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
200        let mut result = Matrix::new(indices.len(), self.cols);
201        for (out_i, &idx) in indices.iter().enumerate() {
202            if idx >= self.rows {
203                return Err(TruenoError::InvalidInput(format!(
204                    "Index {} out of bounds for axis 0 with size {}",
205                    idx, self.rows
206                )));
207            }
208            result.data[out_i * self.cols..(out_i + 1) * self.cols]
209                .copy_from_slice(&self.data[idx * self.cols..(idx + 1) * self.cols]);
210        }
211        Ok(result)
212    }
213
214    fn gather_cols(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
215        let mut result = Matrix::new(self.rows, indices.len());
216        for i in 0..self.rows {
217            for (out_j, &idx) in indices.iter().enumerate() {
218                if idx >= self.cols {
219                    return Err(TruenoError::InvalidInput(format!(
220                        "Index {} out of bounds for axis 1 with size {}",
221                        idx, self.cols
222                    )));
223                }
224                result.data[i * indices.len() + out_j] = self.data[i * self.cols + idx];
225            }
226        }
227        Ok(result)
228    }
229
230    /// Pad matrix with a constant value
231    ///
232    /// # Arguments
233    /// * `padding` - ((top, bottom), (left, right)) padding amounts
234    /// * `value` - constant value to pad with (usually 0.0)
235    ///
236    /// # Examples
237    /// ```
238    /// use trueno::matrix::Matrix;
239    /// let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
240    /// let padded = m.pad(((1, 1), (1, 1)), 0.0).unwrap();
241    /// assert_eq!(padded.shape(), (4, 4));
242    /// assert_eq!(padded.get(0, 0), Some(&0.0));  // top-left padding
243    /// assert_eq!(padded.get(1, 1), Some(&1.0));  // original (0,0)
244    /// ```
245    pub fn pad(
246        &self,
247        padding: ((usize, usize), (usize, usize)),
248        value: f32,
249    ) -> Result<Matrix<f32>, TruenoError> {
250        let ((top, bottom), (left, right)) = padding;
251        let new_rows = self.rows + top + bottom;
252        let new_cols = self.cols + left + right;
253
254        let mut result = Matrix::from_vec(new_rows, new_cols, vec![value; new_rows * new_cols])?;
255
256        // Copy original data
257        for i in 0..self.rows {
258            for j in 0..self.cols {
259                result.data[(i + top) * new_cols + (j + left)] = self.data[i * self.cols + j];
260            }
261        }
262
263        Ok(result)
264    }
265}
266
267#[cfg(test)]
268mod tests;