trueno/matrix/ops/linear/mod.rs
1//! Linear algebra operations for Matrix
2//!
3//! This module provides linear operations:
4//! - `transpose()` - Matrix transpose
5//! - `matvec()` - Matrix-vector multiplication
6//! - `vecmat()` - Vector-matrix multiplication
7
8use crate::{Backend, TruenoError, Vector};
9
10#[cfg(feature = "tracing")]
11use tracing::instrument;
12
13/// Backend dispatch macro for dot product - centralizes platform-specific SIMD dispatch
14macro_rules! dispatch_dot {
15 ($backend:expr, $a:expr, $b:expr) => {{
16 #[cfg(target_arch = "x86_64")]
17 use crate::backends::{avx2::Avx2Backend, sse2::Sse2Backend};
18 use crate::backends::{scalar::ScalarBackend, VectorBackend};
19 // SAFETY: CPU features verified at runtime before backend selection
20 unsafe {
21 match $backend {
22 Backend::Scalar => ScalarBackend::dot($a, $b),
23 #[cfg(target_arch = "x86_64")]
24 Backend::SSE2 | Backend::AVX => Sse2Backend::dot($a, $b),
25 #[cfg(target_arch = "x86_64")]
26 Backend::AVX2 | Backend::AVX512 => Avx2Backend::dot($a, $b),
27 #[cfg(not(target_arch = "x86_64"))]
28 Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
29 ScalarBackend::dot($a, $b)
30 }
31 #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
32 Backend::NEON => {
33 use crate::backends::neon::NeonBackend;
34 NeonBackend::dot($a, $b)
35 }
36 #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
37 Backend::NEON => ScalarBackend::dot($a, $b),
38 #[cfg(target_arch = "wasm32")]
39 Backend::WasmSIMD => {
40 use crate::backends::wasm::WasmBackend;
41 WasmBackend::dot($a, $b)
42 }
43 #[cfg(not(target_arch = "wasm32"))]
44 Backend::WasmSIMD => ScalarBackend::dot($a, $b),
45 Backend::GPU | Backend::Auto => ScalarBackend::dot($a, $b),
46 }
47 }
48 }};
49}
50
51use super::super::Matrix;
52
53impl Matrix<f32> {
54 /// Transpose this matrix (swap rows and columns)
55 ///
56 /// Returns a new matrix with dimensions swapped: `self.rows → result.cols`,
57 /// `self.cols → result.rows`.
58 ///
59 /// # Performance
60 ///
61 /// Uses cache-optimized block-wise transpose with 32x32 blocks.
62 /// Sequential writes for output ensure good cache behavior.
63 ///
64 /// # Example
65 ///
66 /// ```
67 /// use trueno::Matrix;
68 ///
69 /// let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
70 /// let t = m.transpose();
71 ///
72 /// // [[1, 2, 3], [[1, 4],
73 /// // [4, 5, 6]] → [2, 5],
74 /// // [3, 6]]
75 /// assert_eq!(t.rows(), 3);
76 /// assert_eq!(t.cols(), 2);
77 /// assert_eq!(t.get(0, 0), Some(&1.0));
78 /// assert_eq!(t.get(0, 1), Some(&4.0));
79 /// assert_eq!(t.get(1, 0), Some(&2.0));
80 /// ```
81 // KAIZEN-040: Delegate to crate::blis::transpose which has AVX2 8×8
82 // in-register micro-kernel with 64×64 L1-resident tiling and prefetch.
83 // Previous implementation used scalar 32×32 blocks.
84 #[cfg_attr(feature = "tracing", instrument(skip(self), fields(dims = %format!("{}x{}", self.rows, self.cols))))]
85 pub fn transpose(&self) -> Matrix<f32> {
86 // Uninit allocation: transpose writes every element (plus remainder edges).
87 // Skipping the zero-fill saves ~300µs at 2048×2048 (16MB).
88 let n = self.cols * self.rows;
89 let mut data: Vec<f32> = Vec::with_capacity(n);
90 // SAFETY: transpose() writes every element of result.data:
91 // - 8×8 AVX2 tiles cover rows/8 × cols/8 blocks
92 // - Scalar remainder writes cover the edge rows/cols
93 unsafe {
94 data.set_len(n);
95 }
96 let mut result = Matrix { rows: self.cols, cols: self.rows, data, backend: self.backend };
97
98 // BLIS transpose handles AVX2 dispatch, remainder edges, and shape-adaptive
99 // loop ordering internally. Dimensions are correct by construction so
100 // the only possible error (size mismatch) cannot occur.
101 if let Err(e) =
102 crate::blis::transpose::transpose(self.rows, self.cols, &self.data, &mut result.data)
103 {
104 // Unreachable: result is allocated as cols×rows which matches rows×cols elements.
105 // If somehow triggered, fall back to scalar element-wise transpose.
106 debug_assert!(false, "BLIS transpose dimension mismatch: {e}");
107 for i in 0..self.rows {
108 for j in 0..self.cols {
109 result.data[j * self.rows + i] = self.data[i * self.cols + j];
110 }
111 }
112 }
113
114 result
115 }
116
117 /// Matrix-vector multiplication (column vector): A × v
118 ///
119 /// Multiplies this matrix by a column vector, computing `A × v` where the result
120 /// is a column vector with length equal to the number of rows in `A`.
121 ///
122 /// # Mathematical Definition
123 ///
124 /// For an m×n matrix A and an n-dimensional vector v:
125 /// ```text
126 /// result[i] = Σ(j=0 to n-1) A[i,j] × v[j]
127 /// ```
128 ///
129 /// # Arguments
130 ///
131 /// * `v` - Column vector with length equal to `self.cols()`
132 ///
133 /// # Returns
134 ///
135 /// A new vector with length `self.rows()`
136 ///
137 /// # Errors
138 ///
139 /// Returns `InvalidInput` if `v.len() != self.cols()`
140 ///
141 /// # Example
142 ///
143 /// ```
144 /// use trueno::{Matrix, Vector};
145 ///
146 /// let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
147 /// let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
148 /// let result = m.matvec(&v).unwrap();
149 ///
150 /// // [[1, 2, 3] [1] [1×1 + 2×2 + 3×3] [14]
151 /// // [4, 5, 6]] × [2] = [4×1 + 5×2 + 6×3] = [32]
152 /// // [3]
153 /// assert_eq!(result.as_slice(), &[14.0, 32.0]);
154 /// ```
155 pub fn matvec(&self, v: &Vector<f32>) -> Result<Vector<f32>, TruenoError> {
156 if v.len() != self.cols {
157 return Err(TruenoError::InvalidInput(format!(
158 "Vector length {} does not match matrix columns {} for matrix-vector multiplication",
159 v.len(),
160 self.cols
161 )));
162 }
163
164 let v_slice = v.as_slice();
165
166 // Uninit allocation: every element is SET (not accumulated) by
167 // `*result = dispatch_dot!(...)` or parallel `*out = dispatch_dot!(...)`.
168 let n = self.rows;
169 let mut result_data: Vec<f32> = Vec::with_capacity(n);
170 // SAFETY: Both serial and parallel paths write every element via
171 // `*out = dispatch_dot!(...)` (SET, not accumulate). No reads before writes.
172 unsafe {
173 result_data.set_len(n);
174 }
175
176 // Parallel execution for large matrices (≥2048 rows)
177 // CGP-DBUF: lowered from 4096 to 2048. Previous regression at 2048 was
178 // from thread::scope (~40µs). Rayon par_chunks_mut is ~3µs overhead.
179 // 2048×2048 matvec: ~180µs compute → 3µs is 1.7% acceptable.
180 #[cfg(feature = "parallel")]
181 {
182 const PARALLEL_THRESHOLD: usize = 2048;
183
184 if self.rows >= PARALLEL_THRESHOLD {
185 use rayon::prelude::*;
186
187 // Chunk rows into slices per thread (amortizes task overhead).
188 // Previous per-row parallelism spawned rows-many tasks; chunked
189 // spawns num_threads tasks, each processing rows/num_threads rows.
190 let num_threads = rayon::current_num_threads().min(8);
191 let rows_per = (self.rows + num_threads - 1) / num_threads;
192 let cols = self.cols;
193 let data = &self.data;
194
195 result_data.par_chunks_mut(rows_per).enumerate().for_each(|(tid, out_chunk)| {
196 let row_start = tid * rows_per;
197 for (i, out) in out_chunk.iter_mut().enumerate() {
198 let r = row_start + i;
199 let row = &data[r * cols..(r + 1) * cols];
200 *out = dispatch_dot!(self.backend, row, v_slice);
201 }
202 });
203
204 // Move result_data — avoids redundant from_slice copy.
205 return Ok(Vector::from_vec(result_data));
206 }
207 }
208
209 // SIMD-optimized execution: each row-vector product is a dot product
210 for (i, result) in result_data.iter_mut().enumerate() {
211 let row_start = i * self.cols;
212 let row = &self.data[row_start..(row_start + self.cols)];
213
214 // Use SIMD dot product for each row
215 *result = dispatch_dot!(self.backend, row, v_slice);
216 }
217
218 // Move result_data — avoids redundant from_slice copy.
219 Ok(Vector::from_vec(result_data))
220 }
221
222 /// Vector-matrix multiplication (row vector): v^T × A
223 ///
224 /// Multiplies a row vector by this matrix, computing `v^T × A` where the result
225 /// is a row vector with length equal to the number of columns in `A`.
226 ///
227 /// # Mathematical Definition
228 ///
229 /// For an m-dimensional vector v and an m×n matrix A:
230 /// ```text
231 /// result[j] = Σ(i=0 to m-1) v[i] × A[i,j]
232 /// ```
233 ///
234 /// # Arguments
235 ///
236 /// * `v` - Row vector with length equal to `m.rows()`
237 /// * `m` - Matrix to multiply
238 ///
239 /// # Returns
240 ///
241 /// A new vector with length `m.cols()`
242 ///
243 /// # Errors
244 ///
245 /// Returns `InvalidInput` if `v.len() != m.rows()`
246 ///
247 /// # Example
248 ///
249 /// ```
250 /// use trueno::{Matrix, Vector};
251 ///
252 /// let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
253 /// let v = Vector::from_slice(&[1.0, 2.0]);
254 /// let result = Matrix::vecmat(&v, &m).unwrap();
255 ///
256 /// // [1, 2] × [[1, 2, 3] = [1×1 + 2×4, 1×2 + 2×5, 1×3 + 2×6]
257 /// // [4, 5, 6]]
258 /// // = [9, 12, 15]
259 /// assert_eq!(result.as_slice(), &[9.0, 12.0, 15.0]);
260 /// ```
261 // KAIZEN-041: Uses crate::blis::gemv with AVX2 VFMADD,
262 // 4-way K-unrolling and N-tiled accumulators.
263 pub fn vecmat(v: &Vector<f32>, m: &Matrix<f32>) -> Result<Vector<f32>, TruenoError> {
264 if v.len() != m.rows {
265 return Err(TruenoError::InvalidInput(format!(
266 "Vector length {} does not match matrix rows {} for vector-matrix multiplication",
267 v.len(),
268 m.rows
269 )));
270 }
271
272 let mut result_data = vec![0.0f32; m.cols];
273
274 // Parallelize along K dimension for large matrices (DRAM-bound → multi-channel).
275 // Threshold: K * N >= 4M (e.g., 2048×2048). Below this, thread overhead dominates.
276 #[cfg(feature = "parallel")]
277 {
278 const PARALLEL_THRESHOLD: usize = 4_000_000;
279 if m.rows * m.cols >= PARALLEL_THRESHOLD {
280 use rayon::prelude::*;
281 let n = m.cols;
282 let k = m.rows;
283 let num_threads = rayon::current_num_threads().min(8); // cap at 8 for DRAM BW
284 let k_per = (k + num_threads - 1) / num_threads;
285
286 // Each thread computes partial c for its slice of K rows
287 let partials: Vec<Vec<f32>> = (0..num_threads)
288 .into_par_iter()
289 .map(|t| {
290 let k_start = t * k_per;
291 let k_end = (k_start + k_per).min(k);
292 if k_start >= k_end {
293 return vec![0.0f32; n];
294 }
295 let mut local = vec![0.0f32; n];
296 let v_slice = &v.as_slice()[k_start..k_end];
297 let b_slice = &m.data[k_start * n..k_end * n];
298 crate::blis::gemv::gemv(k_end - k_start, n, v_slice, b_slice, &mut local);
299 local
300 })
301 .collect();
302
303 // Reduce partials
304 for p in &partials {
305 for (i, &v) in p.iter().enumerate() {
306 result_data[i] += v;
307 }
308 }
309 return Ok(Vector::from_vec(result_data));
310 }
311 }
312
313 crate::blis::gemv::gemv(m.rows, m.cols, v.as_slice(), &m.data, &mut result_data);
314 Ok(Vector::from_vec(result_data))
315 }
316}
317
318#[cfg(test)]
319mod tests;