Skip to main content

trueno/matrix/ops/linear/
mod.rs

1//! Linear algebra operations for Matrix
2//!
3//! This module provides linear operations:
4//! - `transpose()` - Matrix transpose
5//! - `matvec()` - Matrix-vector multiplication
6//! - `vecmat()` - Vector-matrix multiplication
7
8use crate::{Backend, TruenoError, Vector};
9
10#[cfg(feature = "tracing")]
11use tracing::instrument;
12
13/// Backend dispatch macro for dot product - centralizes platform-specific SIMD dispatch
14macro_rules! dispatch_dot {
15    ($backend:expr, $a:expr, $b:expr) => {{
16        #[cfg(target_arch = "x86_64")]
17        use crate::backends::{avx2::Avx2Backend, sse2::Sse2Backend};
18        use crate::backends::{scalar::ScalarBackend, VectorBackend};
19        // SAFETY: CPU features verified at runtime before backend selection
20        unsafe {
21            match $backend {
22                Backend::Scalar => ScalarBackend::dot($a, $b),
23                #[cfg(target_arch = "x86_64")]
24                Backend::SSE2 | Backend::AVX => Sse2Backend::dot($a, $b),
25                #[cfg(target_arch = "x86_64")]
26                Backend::AVX2 | Backend::AVX512 => Avx2Backend::dot($a, $b),
27                #[cfg(not(target_arch = "x86_64"))]
28                Backend::SSE2 | Backend::AVX | Backend::AVX2 | Backend::AVX512 => {
29                    ScalarBackend::dot($a, $b)
30                }
31                #[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
32                Backend::NEON => {
33                    use crate::backends::neon::NeonBackend;
34                    NeonBackend::dot($a, $b)
35                }
36                #[cfg(not(any(target_arch = "aarch64", target_arch = "arm")))]
37                Backend::NEON => ScalarBackend::dot($a, $b),
38                #[cfg(target_arch = "wasm32")]
39                Backend::WasmSIMD => {
40                    use crate::backends::wasm::WasmBackend;
41                    WasmBackend::dot($a, $b)
42                }
43                #[cfg(not(target_arch = "wasm32"))]
44                Backend::WasmSIMD => ScalarBackend::dot($a, $b),
45                Backend::GPU | Backend::Auto => ScalarBackend::dot($a, $b),
46            }
47        }
48    }};
49}
50
51use super::super::Matrix;
52
53impl Matrix<f32> {
54    /// Transpose this matrix (swap rows and columns)
55    ///
56    /// Returns a new matrix with dimensions swapped: `self.rows → result.cols`,
57    /// `self.cols → result.rows`.
58    ///
59    /// # Performance
60    ///
61    /// Uses cache-optimized block-wise transpose with 32x32 blocks.
62    /// Sequential writes for output ensure good cache behavior.
63    ///
64    /// # Example
65    ///
66    /// ```
67    /// use trueno::Matrix;
68    ///
69    /// let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
70    /// let t = m.transpose();
71    ///
72    /// // [[1, 2, 3],     [[1, 4],
73    /// //  [4, 5, 6]]  →   [2, 5],
74    /// //                  [3, 6]]
75    /// assert_eq!(t.rows(), 3);
76    /// assert_eq!(t.cols(), 2);
77    /// assert_eq!(t.get(0, 0), Some(&1.0));
78    /// assert_eq!(t.get(0, 1), Some(&4.0));
79    /// assert_eq!(t.get(1, 0), Some(&2.0));
80    /// ```
81    // KAIZEN-040: Delegate to crate::blis::transpose which has AVX2 8×8
82    // in-register micro-kernel with 64×64 L1-resident tiling and prefetch.
83    // Previous implementation used scalar 32×32 blocks.
84    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(dims = %format!("{}x{}", self.rows, self.cols))))]
85    pub fn transpose(&self) -> Matrix<f32> {
86        // Uninit allocation: transpose writes every element (plus remainder edges).
87        // Skipping the zero-fill saves ~300µs at 2048×2048 (16MB).
88        let n = self.cols * self.rows;
89        let mut data: Vec<f32> = Vec::with_capacity(n);
90        // SAFETY: transpose() writes every element of result.data:
91        //   - 8×8 AVX2 tiles cover rows/8 × cols/8 blocks
92        //   - Scalar remainder writes cover the edge rows/cols
93        unsafe {
94            data.set_len(n);
95        }
96        let mut result = Matrix { rows: self.cols, cols: self.rows, data, backend: self.backend };
97
98        // BLIS transpose handles AVX2 dispatch, remainder edges, and shape-adaptive
99        // loop ordering internally. Dimensions are correct by construction so
100        // the only possible error (size mismatch) cannot occur.
101        if let Err(e) =
102            crate::blis::transpose::transpose(self.rows, self.cols, &self.data, &mut result.data)
103        {
104            // Unreachable: result is allocated as cols×rows which matches rows×cols elements.
105            // If somehow triggered, fall back to scalar element-wise transpose.
106            debug_assert!(false, "BLIS transpose dimension mismatch: {e}");
107            for i in 0..self.rows {
108                for j in 0..self.cols {
109                    result.data[j * self.rows + i] = self.data[i * self.cols + j];
110                }
111            }
112        }
113
114        result
115    }
116
117    /// Matrix-vector multiplication (column vector): A × v
118    ///
119    /// Multiplies this matrix by a column vector, computing `A × v` where the result
120    /// is a column vector with length equal to the number of rows in `A`.
121    ///
122    /// # Mathematical Definition
123    ///
124    /// For an m×n matrix A and an n-dimensional vector v:
125    /// ```text
126    /// result[i] = Σ(j=0 to n-1) A[i,j] × v[j]
127    /// ```
128    ///
129    /// # Arguments
130    ///
131    /// * `v` - Column vector with length equal to `self.cols()`
132    ///
133    /// # Returns
134    ///
135    /// A new vector with length `self.rows()`
136    ///
137    /// # Errors
138    ///
139    /// Returns `InvalidInput` if `v.len() != self.cols()`
140    ///
141    /// # Example
142    ///
143    /// ```
144    /// use trueno::{Matrix, Vector};
145    ///
146    /// let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
147    /// let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
148    /// let result = m.matvec(&v).unwrap();
149    ///
150    /// // [[1, 2, 3]   [1]   [1×1 + 2×2 + 3×3]   [14]
151    /// //  [4, 5, 6]] × [2] = [4×1 + 5×2 + 6×3] = [32]
152    /// //               [3]
153    /// assert_eq!(result.as_slice(), &[14.0, 32.0]);
154    /// ```
155    pub fn matvec(&self, v: &Vector<f32>) -> Result<Vector<f32>, TruenoError> {
156        if v.len() != self.cols {
157            return Err(TruenoError::InvalidInput(format!(
158                "Vector length {} does not match matrix columns {} for matrix-vector multiplication",
159                v.len(),
160                self.cols
161            )));
162        }
163
164        let v_slice = v.as_slice();
165
166        // Uninit allocation: every element is SET (not accumulated) by
167        // `*result = dispatch_dot!(...)` or parallel `*out = dispatch_dot!(...)`.
168        let n = self.rows;
169        let mut result_data: Vec<f32> = Vec::with_capacity(n);
170        // SAFETY: Both serial and parallel paths write every element via
171        // `*out = dispatch_dot!(...)` (SET, not accumulate). No reads before writes.
172        unsafe {
173            result_data.set_len(n);
174        }
175
176        // Parallel execution for large matrices (≥2048 rows)
177        // CGP-DBUF: lowered from 4096 to 2048. Previous regression at 2048 was
178        // from thread::scope (~40µs). Rayon par_chunks_mut is ~3µs overhead.
179        // 2048×2048 matvec: ~180µs compute → 3µs is 1.7% acceptable.
180        #[cfg(feature = "parallel")]
181        {
182            const PARALLEL_THRESHOLD: usize = 2048;
183
184            if self.rows >= PARALLEL_THRESHOLD {
185                use rayon::prelude::*;
186
187                // Chunk rows into slices per thread (amortizes task overhead).
188                // Previous per-row parallelism spawned rows-many tasks; chunked
189                // spawns num_threads tasks, each processing rows/num_threads rows.
190                let num_threads = rayon::current_num_threads().min(8);
191                let rows_per = (self.rows + num_threads - 1) / num_threads;
192                let cols = self.cols;
193                let data = &self.data;
194
195                result_data.par_chunks_mut(rows_per).enumerate().for_each(|(tid, out_chunk)| {
196                    let row_start = tid * rows_per;
197                    for (i, out) in out_chunk.iter_mut().enumerate() {
198                        let r = row_start + i;
199                        let row = &data[r * cols..(r + 1) * cols];
200                        *out = dispatch_dot!(self.backend, row, v_slice);
201                    }
202                });
203
204                // Move result_data — avoids redundant from_slice copy.
205                return Ok(Vector::from_vec(result_data));
206            }
207        }
208
209        // SIMD-optimized execution: each row-vector product is a dot product
210        for (i, result) in result_data.iter_mut().enumerate() {
211            let row_start = i * self.cols;
212            let row = &self.data[row_start..(row_start + self.cols)];
213
214            // Use SIMD dot product for each row
215            *result = dispatch_dot!(self.backend, row, v_slice);
216        }
217
218        // Move result_data — avoids redundant from_slice copy.
219        Ok(Vector::from_vec(result_data))
220    }
221
222    /// Vector-matrix multiplication (row vector): v^T × A
223    ///
224    /// Multiplies a row vector by this matrix, computing `v^T × A` where the result
225    /// is a row vector with length equal to the number of columns in `A`.
226    ///
227    /// # Mathematical Definition
228    ///
229    /// For an m-dimensional vector v and an m×n matrix A:
230    /// ```text
231    /// result[j] = Σ(i=0 to m-1) v[i] × A[i,j]
232    /// ```
233    ///
234    /// # Arguments
235    ///
236    /// * `v` - Row vector with length equal to `m.rows()`
237    /// * `m` - Matrix to multiply
238    ///
239    /// # Returns
240    ///
241    /// A new vector with length `m.cols()`
242    ///
243    /// # Errors
244    ///
245    /// Returns `InvalidInput` if `v.len() != m.rows()`
246    ///
247    /// # Example
248    ///
249    /// ```
250    /// use trueno::{Matrix, Vector};
251    ///
252    /// let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
253    /// let v = Vector::from_slice(&[1.0, 2.0]);
254    /// let result = Matrix::vecmat(&v, &m).unwrap();
255    ///
256    /// // [1, 2] × [[1, 2, 3]  = [1×1 + 2×4, 1×2 + 2×5, 1×3 + 2×6]
257    /// //           [4, 5, 6]]
258    /// //         = [9, 12, 15]
259    /// assert_eq!(result.as_slice(), &[9.0, 12.0, 15.0]);
260    /// ```
261    // KAIZEN-041: Uses crate::blis::gemv with AVX2 VFMADD,
262    // 4-way K-unrolling and N-tiled accumulators.
263    pub fn vecmat(v: &Vector<f32>, m: &Matrix<f32>) -> Result<Vector<f32>, TruenoError> {
264        if v.len() != m.rows {
265            return Err(TruenoError::InvalidInput(format!(
266                "Vector length {} does not match matrix rows {} for vector-matrix multiplication",
267                v.len(),
268                m.rows
269            )));
270        }
271
272        let mut result_data = vec![0.0f32; m.cols];
273
274        // Parallelize along K dimension for large matrices (DRAM-bound → multi-channel).
275        // Threshold: K * N >= 4M (e.g., 2048×2048). Below this, thread overhead dominates.
276        #[cfg(feature = "parallel")]
277        {
278            const PARALLEL_THRESHOLD: usize = 4_000_000;
279            if m.rows * m.cols >= PARALLEL_THRESHOLD {
280                use rayon::prelude::*;
281                let n = m.cols;
282                let k = m.rows;
283                let num_threads = rayon::current_num_threads().min(8); // cap at 8 for DRAM BW
284                let k_per = (k + num_threads - 1) / num_threads;
285
286                // Each thread computes partial c for its slice of K rows
287                let partials: Vec<Vec<f32>> = (0..num_threads)
288                    .into_par_iter()
289                    .map(|t| {
290                        let k_start = t * k_per;
291                        let k_end = (k_start + k_per).min(k);
292                        if k_start >= k_end {
293                            return vec![0.0f32; n];
294                        }
295                        let mut local = vec![0.0f32; n];
296                        let v_slice = &v.as_slice()[k_start..k_end];
297                        let b_slice = &m.data[k_start * n..k_end * n];
298                        crate::blis::gemv::gemv(k_end - k_start, n, v_slice, b_slice, &mut local);
299                        local
300                    })
301                    .collect();
302
303                // Reduce partials
304                for p in &partials {
305                    for (i, &v) in p.iter().enumerate() {
306                        result_data[i] += v;
307                    }
308                }
309                return Ok(Vector::from_vec(result_data));
310            }
311        }
312
313        crate::blis::gemv::gemv(m.rows, m.cols, v.as_slice(), &m.data, &mut result_data);
314        Ok(Vector::from_vec(result_data))
315    }
316}
317
318#[cfg(test)]
319mod tests;