trueno/matrix/ops/ml_ops/mod.rs
1//! Machine learning operations for Matrix
2//!
3//! This module provides ML-specific operations:
4//! - `convolve2d()` - 2D convolution
5//! - `embedding_lookup()` - Embedding table lookup
6//! - `embedding_lookup_sparse()` - Embedding lookup with gradient tracking
7//! - `max_pool2d()` - Max pooling
8//! - `avg_pool2d()` - Average pooling
9//! - `topk()` - Top-K selection
10//! - `gather()` - Gather elements along axis
11//! - `pad()` - Pad matrix with constant value
12
13mod convolution;
14mod pooling;
15
16use crate::TruenoError;
17
18use super::super::Matrix;
19
20impl Matrix<f32> {
21 /// Lookup embeddings by indices
22 ///
23 /// Performs embedding lookup where self is the embedding table with shape
24 /// `[vocab_size, embed_dim]` and indices specify which rows to select.
25 ///
26 /// # Arguments
27 ///
28 /// * `indices` - Slice of indices into the embedding table
29 ///
30 /// # Returns
31 ///
32 /// A matrix with shape `[indices.len(), embed_dim]` containing the selected rows
33 ///
34 /// # Errors
35 ///
36 /// Returns `InvalidInput` if any index is out of bounds
37 ///
38 /// # Example
39 ///
40 /// ```
41 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
42 /// use trueno::Matrix;
43 ///
44 /// // Create embedding table: 4 words, 3-dimensional embeddings
45 /// let embeddings = Matrix::from_vec(4, 3, vec![
46 /// 1.0, 2.0, 3.0, // word 0
47 /// 4.0, 5.0, 6.0, // word 1
48 /// 7.0, 8.0, 9.0, // word 2
49 /// 10.0, 11.0, 12.0 // word 3
50 /// ])?;
51 ///
52 /// // Lookup embeddings for indices [1, 3, 0]
53 /// let result = embeddings.embedding_lookup(&[1, 3, 0])?;
54 ///
55 /// assert_eq!(result.rows(), 3);
56 /// assert_eq!(result.cols(), 3);
57 /// assert_eq!(result.get(0, 0), Some(&4.0)); // word 1
58 /// assert_eq!(result.get(1, 0), Some(&10.0)); // word 3
59 /// assert_eq!(result.get(2, 0), Some(&1.0)); // word 0
60 /// # Ok(())
61 /// # }
62 /// ```
63 pub fn embedding_lookup(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
64 // Validate indices
65 contract_pre_embedding_lookup!(indices);
66 for (i, &idx) in indices.iter().enumerate() {
67 if idx >= self.rows {
68 return Err(TruenoError::InvalidInput(format!(
69 "Index {} at position {} is out of bounds for embedding table with {} rows",
70 idx, i, self.rows
71 )));
72 }
73 }
74
75 // Handle empty indices
76 if indices.is_empty() {
77 return Ok(Matrix::zeros_with_backend(0, self.cols, self.backend));
78 }
79
80 // Allocate output matrix: [seq_len, embed_dim]
81 // Uninit: every element gets written by copy_from_slice below.
82 let seq_len = indices.len();
83 let embed_dim = self.cols;
84 let total = seq_len * embed_dim;
85 let mut data: Vec<f32> = Vec::with_capacity(total);
86 // SAFETY: Loop below writes every element via copy_from_slice.
87 unsafe {
88 data.set_len(total);
89 }
90 let mut result = Matrix { rows: seq_len, cols: embed_dim, data, backend: self.backend };
91
92 // Copy rows from embedding table to result
93 for (out_row, &idx) in indices.iter().enumerate() {
94 let src_start = idx * embed_dim;
95 let dst_start = out_row * embed_dim;
96
97 result.data[dst_start..dst_start + embed_dim]
98 .copy_from_slice(&self.data[src_start..src_start + embed_dim]);
99 }
100
101 contract_post_embedding_lookup!(&result.data);
102 Ok(result)
103 }
104
105 /// Lookup embeddings with gradient tracking support (for training)
106 ///
107 /// Returns both the embeddings and a sparse gradient accumulator.
108 /// This is useful for sparse gradient updates in training.
109 ///
110 /// # Arguments
111 ///
112 /// * `indices` - Slice of indices into the embedding table
113 ///
114 /// # Returns
115 ///
116 /// Tuple of (embeddings, unique_indices) where unique_indices can be used
117 /// for sparse gradient updates
118 ///
119 /// # Errors
120 ///
121 /// Returns `InvalidInput` if any index is out of bounds
122 pub fn embedding_lookup_sparse(
123 &self,
124 indices: &[usize],
125 ) -> Result<(Matrix<f32>, Vec<usize>), TruenoError> {
126 let embeddings = self.embedding_lookup(indices)?;
127
128 // Get unique indices for sparse gradient updates
129 let mut unique: Vec<usize> = indices.to_vec();
130 unique.sort_unstable();
131 unique.dedup();
132
133 Ok((embeddings, unique))
134 }
135
136 /// Top-K selection: returns the k largest elements and their indices
137 ///
138 /// Useful for beam search, sampling, and ranking operations.
139 /// Searches row-major order and returns (values, indices) sorted descending.
140 ///
141 /// # Examples
142 /// ```
143 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
144 /// use trueno::matrix::Matrix;
145 /// let m = Matrix::from_vec(2, 3, vec![1.0, 5.0, 3.0, 2.0, 6.0, 4.0])?;
146 /// let (values, indices) = m.topk(2)?;
147 /// assert_eq!(values, vec![6.0, 5.0]);
148 /// assert_eq!(indices, vec![4, 1]); // flat indices
149 /// # Ok(())
150 /// # }
151 /// ```
152 pub fn topk(&self, k: usize) -> Result<(Vec<f32>, Vec<usize>), TruenoError> {
153 if k == 0 {
154 return Ok((vec![], vec![]));
155 }
156
157 let k = k.min(self.data.len());
158 let mut indexed: Vec<(usize, f32)> = self.data.iter().copied().enumerate().collect();
159
160 // Partial sort - only sort k elements
161 indexed.select_nth_unstable_by(k.saturating_sub(1), |a, b| {
162 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
163 });
164
165 indexed.truncate(k);
166 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
167
168 let values: Vec<f32> = indexed.iter().map(|(_, v)| *v).collect();
169 let indices: Vec<usize> = indexed.iter().map(|(i, _)| *i).collect();
170
171 Ok((values, indices))
172 }
173
174 /// Gather elements along axis using indices
175 ///
176 /// For 2D matrix with axis=0: output[i] = self[indices[i], :]
177 /// For 2D matrix with axis=1: output[:, i] = self[:, indices[i]]
178 ///
179 /// # Examples
180 /// ```
181 /// use trueno::matrix::Matrix;
182 /// let m = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
183 /// let gathered = m.gather(&[2, 0], 0).unwrap(); // Select rows 2 and 0
184 /// assert_eq!(gathered.shape(), (2, 2));
185 /// assert_eq!(gathered.get(0, 0), Some(&5.0)); // Row 2
186 /// assert_eq!(gathered.get(1, 0), Some(&1.0)); // Row 0
187 /// ```
188 pub fn gather(&self, indices: &[usize], axis: usize) -> Result<Matrix<f32>, TruenoError> {
189 match axis {
190 0 => self.gather_rows(indices),
191 1 => self.gather_cols(indices),
192 _ => Err(TruenoError::InvalidInput(format!(
193 "Axis {} not supported for 2D matrix (use 0 or 1)",
194 axis
195 ))),
196 }
197 }
198
199 fn gather_rows(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
200 let mut result = Matrix::new(indices.len(), self.cols);
201 for (out_i, &idx) in indices.iter().enumerate() {
202 if idx >= self.rows {
203 return Err(TruenoError::InvalidInput(format!(
204 "Index {} out of bounds for axis 0 with size {}",
205 idx, self.rows
206 )));
207 }
208 result.data[out_i * self.cols..(out_i + 1) * self.cols]
209 .copy_from_slice(&self.data[idx * self.cols..(idx + 1) * self.cols]);
210 }
211 Ok(result)
212 }
213
214 fn gather_cols(&self, indices: &[usize]) -> Result<Matrix<f32>, TruenoError> {
215 let mut result = Matrix::new(self.rows, indices.len());
216 for i in 0..self.rows {
217 for (out_j, &idx) in indices.iter().enumerate() {
218 if idx >= self.cols {
219 return Err(TruenoError::InvalidInput(format!(
220 "Index {} out of bounds for axis 1 with size {}",
221 idx, self.cols
222 )));
223 }
224 result.data[i * indices.len() + out_j] = self.data[i * self.cols + idx];
225 }
226 }
227 Ok(result)
228 }
229
230 /// Pad matrix with a constant value
231 ///
232 /// # Arguments
233 /// * `padding` - ((top, bottom), (left, right)) padding amounts
234 /// * `value` - constant value to pad with (usually 0.0)
235 ///
236 /// # Examples
237 /// ```
238 /// use trueno::matrix::Matrix;
239 /// let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
240 /// let padded = m.pad(((1, 1), (1, 1)), 0.0).unwrap();
241 /// assert_eq!(padded.shape(), (4, 4));
242 /// assert_eq!(padded.get(0, 0), Some(&0.0)); // top-left padding
243 /// assert_eq!(padded.get(1, 1), Some(&1.0)); // original (0,0)
244 /// ```
245 pub fn pad(
246 &self,
247 padding: ((usize, usize), (usize, usize)),
248 value: f32,
249 ) -> Result<Matrix<f32>, TruenoError> {
250 let ((top, bottom), (left, right)) = padding;
251 let new_rows = self.rows + top + bottom;
252 let new_cols = self.cols + left + right;
253
254 let mut result = Matrix::from_vec(new_rows, new_cols, vec![value; new_rows * new_cols])?;
255
256 // Copy original data
257 for i in 0..self.rows {
258 for j in 0..self.cols {
259 result.data[(i + top) * new_cols + (j + left)] = self.data[i * self.cols + j];
260 }
261 }
262
263 Ok(result)
264 }
265}
266
267#[cfg(test)]
268mod tests;