Skip to main content

math_linear/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod backend;
4pub mod surface;
5
6use tensor_data::{F32Tensor, F32TensorView, TensorShape};
7use vector_analysis_core::{cosine_similarity, dot, DenseVector};
8use video_analysis_core::{DetectError, Result};
9
10fn invalid_argument(message: impl Into<String>) -> DetectError {
11    DetectError::InvalidArgument(message.into())
12}
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15/// Checked dense matrix dimensions.
16pub struct MatrixShape {
17    /// Number of matrix rows.
18    pub rows: usize,
19    /// Number of matrix columns.
20    pub cols: usize,
21}
22
23impl MatrixShape {
24    /// Creates a shape with non-zero rows and columns.
25    pub fn new(rows: usize, cols: usize) -> Result<Self> {
26        let shape = Self { rows, cols };
27        shape.validate()?;
28        Ok(shape)
29    }
30
31    /// Verifies non-zero dimensions and element-count overflow safety.
32    pub fn validate(self) -> Result<()> {
33        if self.rows == 0 || self.cols == 0 {
34            return Err(invalid_argument(
35                "matrix rows and cols must be greater than zero",
36            ));
37        }
38        let _ = self.element_count()?;
39        Ok(())
40    }
41
42    /// Multiplies rows by columns and fails on `usize` overflow.
43    pub fn element_count(self) -> Result<usize> {
44        self.rows
45            .checked_mul(self.cols)
46            .ok_or_else(|| invalid_argument("matrix element count overflowed usize"))
47    }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51/// Memory interpretation for a matrix or matrix view.
52pub enum MatrixLayout {
53    /// Contiguous rows, where adjacent values advance across columns.
54    RowMajor,
55    /// Contiguous columns, used by transpose views without copying.
56    ColumnMajor,
57}
58
59#[derive(Debug, Clone, PartialEq)]
60/// Owned finite `f32` matrix stored in row-major order.
61pub struct F32Matrix {
62    shape: MatrixShape,
63    layout: MatrixLayout,
64    values: Vec<f32>,
65}
66
67impl F32Matrix {
68    /// Creates a row-major matrix after shape and finite-value validation.
69    pub fn new(shape: MatrixShape, values: Vec<f32>) -> Result<Self> {
70        let matrix = Self {
71            shape,
72            layout: MatrixLayout::RowMajor,
73            values,
74        };
75        matrix.validate()?;
76        Ok(matrix)
77    }
78
79    /// Creates a row-major matrix filled with zeros.
80    pub fn zeros(rows: usize, cols: usize) -> Result<Self> {
81        let shape = MatrixShape::new(rows, cols)?;
82        Self::new(shape, vec![0.0; shape.element_count()?])
83    }
84
85    /// Creates a row-major square identity matrix.
86    pub fn identity(size: usize) -> Result<Self> {
87        let shape = MatrixShape::new(size, size)?;
88        let mut values = vec![0.0; shape.element_count()?];
89        for index in 0..size {
90            values[index * size + index] = 1.0;
91        }
92        Self::new(shape, values)
93    }
94
95    /// Builds a matrix from compile-time-sized row arrays.
96    pub fn from_rows<const R: usize, const C: usize>(rows: [[f32; C]; R]) -> Result<Self> {
97        let mut values = Vec::with_capacity(R * C);
98        for row in rows {
99            values.extend(row);
100        }
101        Self::new(MatrixShape::new(R, C)?, values)
102    }
103
104    /// Returns the checked row and column dimensions.
105    pub fn shape(&self) -> MatrixShape {
106        self.shape
107    }
108
109    /// Returns the matrix storage layout.
110    pub fn layout(&self) -> MatrixLayout {
111        self.layout
112    }
113
114    /// Borrows the contiguous matrix values.
115    pub fn values(&self) -> &[f32] {
116        &self.values
117    }
118
119    /// Consumes the matrix and returns its contiguous values.
120    pub fn into_values(self) -> Vec<f32> {
121        self.values
122    }
123
124    /// Borrows this matrix as a view without copying values.
125    pub fn as_view(&self) -> F32MatrixView<'_> {
126        F32MatrixView {
127            shape: self.shape,
128            layout: self.layout,
129            values: &self.values,
130        }
131    }
132
133    /// Borrows one row, respecting the current matrix layout.
134    pub fn row(&self, index: usize) -> Result<RowView<'_>> {
135        self.as_view().row(index)
136    }
137
138    /// Borrows one column, respecting the current matrix layout.
139    pub fn column(&self, index: usize) -> Result<ColumnView<'_>> {
140        self.as_view().column(index)
141    }
142
143    /// Multiplies this matrix by `right`.
144    pub fn matmul(&self, right: &F32MatrixView<'_>) -> Result<Self> {
145        self.as_view().matmul(right)
146    }
147
148    /// Multiplies this matrix by a finite dense vector.
149    pub fn matvec(&self, vector: &[f32]) -> Result<DenseVector> {
150        self.as_view().matvec(vector)
151    }
152
153    /// Creates a transposed view without copying values.
154    pub fn transpose_view(&self) -> F32MatrixView<'_> {
155        self.as_view().transpose()
156    }
157
158    /// Returns a row-major owned transpose of this matrix.
159    pub fn transpose_owned(&self) -> Result<Self> {
160        self.as_view().transpose_owned()
161    }
162
163    /// Returns a row-major matrix whose rows have unit L2 norm.
164    pub fn l2_normalize_rows(&self) -> Result<Self> {
165        self.as_view().l2_normalize_rows()
166    }
167
168    /// Returns a row-major matrix whose columns have unit L2 norm.
169    pub fn l2_normalize_columns(&self) -> Result<Self> {
170        self.as_view().l2_normalize_columns()
171    }
172
173    /// Computes all pairwise row cosine similarities against `right`.
174    pub fn pairwise_row_cosine(&self, right: &F32MatrixView<'_>) -> Result<Self> {
175        self.as_view().pairwise_row_cosine(right)
176    }
177
178    /// Decomposes this square matrix using LU factorization with partial pivoting.
179    pub fn lu_decompose(&self) -> Result<LuDecomposition> {
180        self.as_view().lu_decompose()
181    }
182
183    /// Computes this square matrix determinant through LU decomposition.
184    pub fn determinant(&self) -> Result<f32> {
185        self.as_view().determinant()
186    }
187
188    /// Solves `A x = b` for a finite vector `b`.
189    pub fn solve_vector(&self, b: &[f32]) -> Result<Vec<f32>> {
190        self.as_view().solve_vector(b)
191    }
192
193    /// Solves `A X = B` for matrix `B`.
194    pub fn solve_matrix(&self, b: &F32MatrixView<'_>) -> Result<Self> {
195        self.as_view().solve_matrix(b)
196    }
197
198    /// Computes this square matrix inverse.
199    pub fn inverse(&self) -> Result<Self> {
200        self.as_view().inverse()
201    }
202
203    /// Verifies shape/value count agreement and rejects non-finite values.
204    pub fn validate(&self) -> Result<()> {
205        self.shape.validate()?;
206        if self.values.len() != self.shape.element_count()? {
207            return Err(invalid_argument(format!(
208                "matrix shape expects {} values but matrix has {}",
209                self.shape.element_count()?,
210                self.values.len()
211            )));
212        }
213        if self.values.iter().any(|value| !value.is_finite()) {
214            return Err(invalid_argument("matrix values must be finite"));
215        }
216        Ok(())
217    }
218}
219
220#[derive(Debug, Clone, Copy, PartialEq)]
221/// Borrowed finite `f32` matrix values with shape and layout metadata.
222pub struct F32MatrixView<'a> {
223    shape: MatrixShape,
224    layout: MatrixLayout,
225    values: &'a [f32],
226}
227
228impl<'a> F32MatrixView<'a> {
229    /// Creates a row-major matrix view after validating shape and values.
230    pub fn new(shape: MatrixShape, values: &'a [f32]) -> Result<Self> {
231        let view = Self {
232            shape,
233            layout: MatrixLayout::RowMajor,
234            values,
235        };
236        view.validate()?;
237        Ok(view)
238    }
239
240    /// Returns the checked row and column dimensions.
241    pub fn shape(&self) -> MatrixShape {
242        self.shape
243    }
244
245    /// Returns how the borrowed value slice is interpreted.
246    pub fn layout(&self) -> MatrixLayout {
247        self.layout
248    }
249
250    /// Borrows the underlying contiguous values.
251    pub fn values(&self) -> &'a [f32] {
252        self.values
253    }
254
255    /// Creates a transposed view by swapping dimensions and layout metadata.
256    pub fn transpose(self) -> Self {
257        Self {
258            shape: MatrixShape {
259                rows: self.shape.cols,
260                cols: self.shape.rows,
261            },
262            layout: match self.layout {
263                MatrixLayout::RowMajor => MatrixLayout::ColumnMajor,
264                MatrixLayout::ColumnMajor => MatrixLayout::RowMajor,
265            },
266            values: self.values,
267        }
268    }
269
270    /// Returns a row-major owned transpose of this view.
271    pub fn transpose_owned(&self) -> Result<F32Matrix> {
272        self.transpose().into_owned()
273    }
274
275    /// Borrows one logical row from this view.
276    pub fn row(self, index: usize) -> Result<RowView<'a>> {
277        if index >= self.shape.rows {
278            return Err(invalid_argument(format!(
279                "row index {index} is out of bounds for {} rows",
280                self.shape.rows
281            )));
282        }
283        let (offset, stride) = match self.layout {
284            MatrixLayout::RowMajor => (index * self.shape.cols, 1),
285            MatrixLayout::ColumnMajor => (index, self.shape.rows),
286        };
287        Ok(RowView {
288            values: self.values,
289            len: self.shape.cols,
290            offset,
291            stride,
292        })
293    }
294
295    /// Borrows one logical column from this view.
296    pub fn column(self, index: usize) -> Result<ColumnView<'a>> {
297        if index >= self.shape.cols {
298            return Err(invalid_argument(format!(
299                "column index {index} is out of bounds for {} cols",
300                self.shape.cols
301            )));
302        }
303        let (offset, stride) = match self.layout {
304            MatrixLayout::RowMajor => (index, self.shape.cols),
305            MatrixLayout::ColumnMajor => (index * self.shape.rows, 1),
306        };
307        Ok(ColumnView {
308            values: self.values,
309            len: self.shape.rows,
310            offset,
311            stride,
312        })
313    }
314
315    /// Reads one value by logical row and column.
316    pub fn get(self, row: usize, col: usize) -> Result<f32> {
317        if row >= self.shape.rows || col >= self.shape.cols {
318            return Err(invalid_argument("matrix indices are out of bounds"));
319        }
320        let index = match self.layout {
321            MatrixLayout::RowMajor => row * self.shape.cols + col,
322            MatrixLayout::ColumnMajor => col * self.shape.rows + row,
323        };
324        Ok(self.values[index])
325    }
326
327    /// Returns whether this view is square.
328    pub fn is_square(&self) -> bool {
329        self.shape.rows == self.shape.cols
330    }
331
332    /// Adds two matrices with equal shape and returns a row-major matrix.
333    pub fn add(&self, right: &F32MatrixView<'_>) -> Result<F32Matrix> {
334        self.elementwise_binary(right, |left, right| left + right)
335    }
336
337    /// Subtracts two matrices with equal shape and returns a row-major matrix.
338    pub fn sub(&self, right: &F32MatrixView<'_>) -> Result<F32Matrix> {
339        self.elementwise_binary(right, |left, right| left - right)
340    }
341
342    /// Scales every matrix value by a finite factor.
343    pub fn scale(&self, factor: f32) -> Result<F32Matrix> {
344        self.validate()?;
345        if !factor.is_finite() {
346            return Err(invalid_argument("matrix scale factor must be finite"));
347        }
348        let mut values = Vec::with_capacity(self.shape.element_count()?);
349        for row in 0..self.shape.rows {
350            for col in 0..self.shape.cols {
351                values.push(self.get(row, col)? * factor);
352            }
353        }
354        F32Matrix::new(self.shape, values)
355    }
356
357    /// Computes the Frobenius norm.
358    pub fn frobenius_norm(&self) -> Result<f32> {
359        self.validate()?;
360        let sum_squares = self.values.iter().map(|value| value * value).sum::<f32>();
361        if !sum_squares.is_finite() {
362            return Err(invalid_argument(
363                "matrix Frobenius norm produced a non-finite value",
364            ));
365        }
366        Ok(sum_squares.sqrt())
367    }
368
369    /// Computes the trace of a square matrix.
370    pub fn trace(&self) -> Result<f32> {
371        self.validate()?;
372        self.require_square("matrix trace")?;
373        let mut trace = 0.0;
374        for index in 0..self.shape.rows {
375            trace += self.get(index, index)?;
376        }
377        if !trace.is_finite() {
378            return Err(invalid_argument("matrix trace produced a non-finite value"));
379        }
380        Ok(trace)
381    }
382
383    /// Multiplies this view by another matrix view.
384    pub fn matmul(self, right: &F32MatrixView<'_>) -> Result<F32Matrix> {
385        self.validate()?;
386        right.validate()?;
387        if self.shape.cols != right.shape.rows {
388            return Err(invalid_argument("matrix multiply shapes are incompatible"));
389        }
390        let shape = MatrixShape::new(self.shape.rows, right.shape.cols)?;
391        let mut values = vec![0.0; shape.element_count()?];
392        for row in 0..self.shape.rows {
393            for col in 0..right.shape.cols {
394                let mut acc = 0.0;
395                for inner in 0..self.shape.cols {
396                    acc += self.get(row, inner)? * right.get(inner, col)?;
397                }
398                values[row * shape.cols + col] = acc;
399            }
400        }
401        F32Matrix::new(shape, values)
402    }
403
404    /// Multiplies this view by a finite dense vector.
405    pub fn matvec(self, vector: &[f32]) -> Result<DenseVector> {
406        self.validate()?;
407        if vector.len() != self.shape.cols {
408            return Err(invalid_argument(
409                "matrix/vector dimensions are incompatible",
410            ));
411        }
412        if vector.iter().any(|value| !value.is_finite()) {
413            return Err(invalid_argument("matrix/vector values must be finite"));
414        }
415        let mut output = vec![0.0; self.shape.rows];
416        for (row, output_value) in output.iter_mut().enumerate() {
417            *output_value = dot(self.row(row)?.as_slice().as_slice(), vector)?;
418        }
419        DenseVector::new(output)
420    }
421
422    /// Sums each logical row.
423    pub fn row_sums(self) -> Result<Vec<f32>> {
424        (0..self.shape.rows)
425            .map(|index| Ok(self.row(index)?.iter().sum()))
426            .collect()
427    }
428
429    /// Sums each logical column.
430    pub fn column_sums(self) -> Result<Vec<f32>> {
431        (0..self.shape.cols)
432            .map(|index| Ok(self.column(index)?.iter().sum()))
433            .collect()
434    }
435
436    /// Averages each logical row.
437    pub fn row_means(&self) -> Result<Vec<f32>> {
438        self.validate()?;
439        (0..self.shape.rows)
440            .map(|index| Ok(self.row(index)?.iter().sum::<f32>() / self.shape.cols as f32))
441            .collect()
442    }
443
444    /// Averages each logical column.
445    pub fn column_means(&self) -> Result<Vec<f32>> {
446        self.validate()?;
447        (0..self.shape.cols)
448            .map(|index| Ok(self.column(index)?.iter().sum::<f32>() / self.shape.rows as f32))
449            .collect()
450    }
451
452    /// Returns a row-major matrix whose rows have unit L2 norm.
453    pub fn l2_normalize_rows(self) -> Result<F32Matrix> {
454        let mut values = Vec::with_capacity(self.values.len());
455        for row in 0..self.shape.rows {
456            let row_view = self.row(row)?;
457            let vector = DenseVector::new(row_view.as_slice())?.l2_normalized()?;
458            values.extend(vector.into_values());
459        }
460        F32Matrix::new(self.shape, values)
461    }
462
463    /// Returns a row-major matrix whose columns have unit L2 norm.
464    pub fn l2_normalize_columns(self) -> Result<F32Matrix> {
465        let mut values = vec![0.0; self.shape.element_count()?];
466        let normalized = (0..self.shape.cols)
467            .map(|col| DenseVector::new(self.column(col)?.as_slice())?.l2_normalized())
468            .collect::<Result<Vec<_>>>()?;
469        for row in 0..self.shape.rows {
470            for col in 0..self.shape.cols {
471                values[row * self.shape.cols + col] = normalized[col].as_slice()[row];
472            }
473        }
474        F32Matrix::new(self.shape, values)
475    }
476
477    /// Computes the dot product for every pair of rows in two matrices.
478    pub fn pairwise_row_dot(self, right: &F32MatrixView<'_>) -> Result<F32Matrix> {
479        if self.shape.cols != right.shape.cols {
480            return Err(invalid_argument(
481                "row pairwise dot requires equal column counts",
482            ));
483        }
484        let shape = MatrixShape::new(self.shape.rows, right.shape.rows)?;
485        let mut values = vec![0.0; shape.element_count()?];
486        for row in 0..self.shape.rows {
487            let left = self.row(row)?;
488            for other_row in 0..right.shape.rows {
489                let right_row = right.row(other_row)?;
490                values[row * shape.cols + other_row] =
491                    dot(left.as_slice().as_slice(), right_row.as_slice().as_slice())?;
492            }
493        }
494        F32Matrix::new(shape, values)
495    }
496
497    /// Computes cosine similarity for every pair of rows in two matrices.
498    pub fn pairwise_row_cosine(self, right: &F32MatrixView<'_>) -> Result<F32Matrix> {
499        if self.shape.cols != right.shape.cols {
500            return Err(invalid_argument(
501                "row pairwise cosine requires equal column counts",
502            ));
503        }
504        let shape = MatrixShape::new(self.shape.rows, right.shape.rows)?;
505        let mut values = vec![0.0; shape.element_count()?];
506        for row in 0..self.shape.rows {
507            let left = self.row(row)?;
508            for other_row in 0..right.shape.rows {
509                let right_row = right.row(other_row)?;
510                values[row * shape.cols + other_row] =
511                    cosine_similarity(left.as_slice().as_slice(), right_row.as_slice().as_slice())?;
512            }
513        }
514        F32Matrix::new(shape, values)
515    }
516
517    /// Decomposes this square matrix using LU factorization with partial pivoting.
518    pub fn lu_decompose(&self) -> Result<LuDecomposition> {
519        self.validate()?;
520        self.require_square("LU decomposition")?;
521        backend::pure::lu_decompose(*self)
522    }
523
524    /// Computes this square matrix determinant through LU decomposition.
525    pub fn determinant(&self) -> Result<f32> {
526        self.lu_decompose()?.determinant()
527    }
528
529    /// Solves `A x = b` for a finite vector `b`.
530    pub fn solve_vector(&self, b: &[f32]) -> Result<Vec<f32>> {
531        self.lu_decompose()?.solve_vector(b)
532    }
533
534    /// Solves `A X = B` for matrix `B`.
535    pub fn solve_matrix(&self, b: &F32MatrixView<'_>) -> Result<F32Matrix> {
536        self.lu_decompose()?.solve_matrix(b)
537    }
538
539    /// Computes this square matrix inverse.
540    pub fn inverse(&self) -> Result<F32Matrix> {
541        let identity = F32Matrix::identity(self.shape.rows)?;
542        self.solve_matrix(&identity.as_view())
543    }
544
545    /// Verifies shape/value count agreement and rejects non-finite values.
546    pub fn validate(self) -> Result<()> {
547        self.shape.validate()?;
548        if self.values.len() != self.shape.element_count()? {
549            return Err(invalid_argument(format!(
550                "matrix shape expects {} values but matrix view has {}",
551                self.shape.element_count()?,
552                self.values.len()
553            )));
554        }
555        if self.values.iter().any(|value| !value.is_finite()) {
556            return Err(invalid_argument("matrix values must be finite"));
557        }
558        Ok(())
559    }
560
561    /// Copies this view into an owned row-major matrix.
562    pub fn into_owned(self) -> Result<F32Matrix> {
563        let mut values = Vec::with_capacity(self.shape.element_count()?);
564        for row in 0..self.shape.rows {
565            for col in 0..self.shape.cols {
566                values.push(self.get(row, col)?);
567            }
568        }
569        F32Matrix::new(self.shape, values)
570    }
571
572    fn elementwise_binary(
573        &self,
574        right: &F32MatrixView<'_>,
575        op: impl Fn(f32, f32) -> f32,
576    ) -> Result<F32Matrix> {
577        self.validate()?;
578        right.validate()?;
579        if self.shape != right.shape {
580            return Err(invalid_argument("matrix shapes are incompatible"));
581        }
582        let mut values = Vec::with_capacity(self.shape.element_count()?);
583        for row in 0..self.shape.rows {
584            for col in 0..self.shape.cols {
585                values.push(op(self.get(row, col)?, right.get(row, col)?));
586            }
587        }
588        F32Matrix::new(self.shape, values)
589    }
590
591    fn require_square(&self, operation: &str) -> Result<()> {
592        if !self.is_square() {
593            return Err(invalid_argument(format!(
594                "{operation} requires a square matrix"
595            )));
596        }
597        Ok(())
598    }
599}
600
601#[derive(Debug, Clone, Copy, PartialEq, Eq)]
602/// Matrix triangle selector for decomposed matrices.
603pub enum MatrixTriangle {
604    /// Unit diagonal lower triangular factor.
605    Lower,
606    /// Upper triangular factor.
607    Upper,
608}
609
610#[derive(Debug, Clone, PartialEq)]
611/// LU decomposition with partial pivoting for a finite square `f32` matrix.
612pub struct LuDecomposition {
613    shape: MatrixShape,
614    lu: Vec<f32>,
615    pivots: Vec<usize>,
616    swap_count: usize,
617}
618
619impl LuDecomposition {
620    pub(crate) fn new(
621        shape: MatrixShape,
622        lu: Vec<f32>,
623        pivots: Vec<usize>,
624        swap_count: usize,
625    ) -> Self {
626        Self {
627            shape,
628            lu,
629            pivots,
630            swap_count,
631        }
632    }
633
634    /// Returns the square decomposition shape.
635    pub fn shape(&self) -> MatrixShape {
636        self.shape
637    }
638
639    /// Returns the row permutation after partial pivoting.
640    pub fn pivots(&self) -> &[usize] {
641        &self.pivots
642    }
643
644    /// Returns the number of pivot row swaps performed.
645    pub fn swap_count(&self) -> usize {
646        self.swap_count
647    }
648
649    /// Computes the determinant from the upper-triangular diagonal and swap parity.
650    pub fn determinant(&self) -> Result<f32> {
651        self.validate()?;
652        let size = self.shape.rows;
653        let mut determinant = if self.swap_count.is_multiple_of(2) {
654            1.0
655        } else {
656            -1.0
657        };
658        for index in 0..size {
659            determinant *= self.lu[index * size + index];
660        }
661        if !determinant.is_finite() {
662            return Err(invalid_argument(
663                "matrix determinant produced a non-finite value",
664            ));
665        }
666        Ok(determinant)
667    }
668
669    /// Solves `A x = b` for a finite vector `b`.
670    pub fn solve_vector(&self, b: &[f32]) -> Result<Vec<f32>> {
671        self.validate()?;
672        let size = self.shape.rows;
673        if b.len() != size {
674            return Err(invalid_argument(
675                "linear solve vector length is incompatible",
676            ));
677        }
678        if b.iter().any(|value| !value.is_finite()) {
679            return Err(invalid_argument("linear solve values must be finite"));
680        }
681
682        let mut y = vec![0.0; size];
683        for row in 0..size {
684            let mut sum = b[self.pivots[row]];
685            for (col, value) in y.iter().enumerate().take(row) {
686                sum -= self.lu[row * size + col] * value;
687            }
688            y[row] = sum;
689        }
690
691        let tolerance = backend::pure::pivot_tolerance(&self.lu);
692        let mut x = vec![0.0; size];
693        for row in (0..size).rev() {
694            let mut sum = y[row];
695            for (col, value) in x.iter().enumerate().take(size).skip(row + 1) {
696                sum -= self.lu[row * size + col] * value;
697            }
698            let pivot = self.lu[row * size + row];
699            if pivot.abs() <= tolerance {
700                return Err(invalid_argument("matrix is singular or near-singular"));
701            }
702            x[row] = sum / pivot;
703        }
704        if x.iter().any(|value| !value.is_finite()) {
705            return Err(invalid_argument("linear solve produced non-finite values"));
706        }
707        Ok(x)
708    }
709
710    /// Solves `A X = B` for a finite matrix `B`.
711    pub fn solve_matrix(&self, b: &F32MatrixView<'_>) -> Result<F32Matrix> {
712        self.validate()?;
713        b.validate()?;
714        if b.shape.rows != self.shape.rows {
715            return Err(invalid_argument(
716                "linear solve matrix rows are incompatible",
717            ));
718        }
719        let shape = MatrixShape::new(self.shape.rows, b.shape.cols)?;
720        let mut values = vec![0.0; shape.element_count()?];
721        for col in 0..b.shape.cols {
722            let rhs = b.column(col)?.as_slice();
723            let solution = self.solve_vector(&rhs)?;
724            for row in 0..shape.rows {
725                values[row * shape.cols + col] = solution[row];
726            }
727        }
728        F32Matrix::new(shape, values)
729    }
730
731    /// Extracts the unit diagonal lower triangular factor.
732    pub fn lower_matrix(&self) -> Result<F32Matrix> {
733        self.triangle_matrix(MatrixTriangle::Lower)
734    }
735
736    /// Extracts the upper triangular factor.
737    pub fn upper_matrix(&self) -> Result<F32Matrix> {
738        self.triangle_matrix(MatrixTriangle::Upper)
739    }
740
741    fn triangle_matrix(&self, triangle: MatrixTriangle) -> Result<F32Matrix> {
742        self.validate()?;
743        let size = self.shape.rows;
744        let mut values = vec![0.0; self.shape.element_count()?];
745        for row in 0..size {
746            for col in 0..size {
747                values[row * size + col] = match triangle {
748                    MatrixTriangle::Lower if row > col => self.lu[row * size + col],
749                    MatrixTriangle::Lower if row == col => 1.0,
750                    MatrixTriangle::Upper if row <= col => self.lu[row * size + col],
751                    _ => 0.0,
752                };
753            }
754        }
755        F32Matrix::new(self.shape, values)
756    }
757
758    fn validate(&self) -> Result<()> {
759        self.shape.validate()?;
760        if self.shape.rows != self.shape.cols {
761            return Err(invalid_argument(
762                "LU decomposition requires a square matrix",
763            ));
764        }
765        if self.lu.len() != self.shape.element_count()? {
766            return Err(invalid_argument(
767                "LU decomposition values do not match shape",
768            ));
769        }
770        if self.pivots.len() != self.shape.rows {
771            return Err(invalid_argument(
772                "LU decomposition pivots do not match shape",
773            ));
774        }
775        if self.lu.iter().any(|value| !value.is_finite()) {
776            return Err(invalid_argument("LU decomposition values must be finite"));
777        }
778        Ok(())
779    }
780}
781
782#[derive(Debug, Clone, Copy, PartialEq)]
783/// Strided borrowed view over one logical matrix row.
784pub struct RowView<'a> {
785    values: &'a [f32],
786    len: usize,
787    offset: usize,
788    stride: usize,
789}
790
791impl<'a> RowView<'a> {
792    /// Returns the number of values in the row.
793    pub fn len(&self) -> usize {
794        self.len
795    }
796
797    /// Returns whether the row has no values.
798    pub fn is_empty(&self) -> bool {
799        self.len == 0
800    }
801
802    /// Iterates over row values in logical column order.
803    pub fn iter(&self) -> impl Iterator<Item = f32> + '_ {
804        (0..self.len).map(|index| self.values[self.offset + index * self.stride])
805    }
806
807    /// Collects the possibly strided row into a contiguous vector.
808    pub fn as_slice(&self) -> Vec<f32> {
809        self.iter().collect()
810    }
811
812    /// Copies the row into a validated dense vector.
813    pub fn to_dense_vector(&self) -> Result<DenseVector> {
814        DenseVector::new(self.as_slice())
815    }
816}
817
818#[derive(Debug, Clone, Copy, PartialEq)]
819/// Strided borrowed view over one logical matrix column.
820pub struct ColumnView<'a> {
821    values: &'a [f32],
822    len: usize,
823    offset: usize,
824    stride: usize,
825}
826
827impl<'a> ColumnView<'a> {
828    /// Returns the number of values in the column.
829    pub fn len(&self) -> usize {
830        self.len
831    }
832
833    /// Returns whether the column has no values.
834    pub fn is_empty(&self) -> bool {
835        self.len == 0
836    }
837
838    /// Iterates over column values in logical row order.
839    pub fn iter(&self) -> impl Iterator<Item = f32> + '_ {
840        (0..self.len).map(|index| self.values[self.offset + index * self.stride])
841    }
842
843    /// Collects the possibly strided column into a contiguous vector.
844    pub fn as_slice(&self) -> Vec<f32> {
845        self.iter().collect()
846    }
847
848    /// Copies the column into a validated dense vector.
849    pub fn to_dense_vector(&self) -> Result<DenseVector> {
850        DenseVector::new(self.as_slice())
851    }
852}
853
854#[derive(Debug, Clone, PartialEq)]
855/// Finite 2D convolution kernel stored in row-major order.
856pub struct Kernel2d {
857    width: usize,
858    height: usize,
859    values: Vec<f32>,
860}
861
862impl Kernel2d {
863    /// Creates a kernel with non-zero dimensions and matching finite values.
864    pub fn new(width: usize, height: usize, values: impl Into<Vec<f32>>) -> Result<Self> {
865        let kernel = Self {
866            width,
867            height,
868            values: values.into(),
869        };
870        kernel.validate()?;
871        Ok(kernel)
872    }
873
874    /// Returns the number of columns in the kernel.
875    pub fn width(&self) -> usize {
876        self.width
877    }
878
879    /// Returns the number of rows in the kernel.
880    pub fn height(&self) -> usize {
881        self.height
882    }
883
884    /// Borrows row-major kernel values.
885    pub fn values(&self) -> &[f32] {
886        &self.values
887    }
888
889    /// Verifies dimensions, value count, and finite values.
890    pub fn validate(&self) -> Result<()> {
891        if self.width == 0 || self.height == 0 {
892            return Err(invalid_argument(
893                "kernel width and height must be greater than zero",
894            ));
895        }
896        if self.values.len()
897            != self
898                .width
899                .checked_mul(self.height)
900                .ok_or_else(|| invalid_argument("kernel element count overflowed usize"))?
901        {
902            return Err(invalid_argument("kernel dimensions do not match values"));
903        }
904        if self.values.iter().any(|value| !value.is_finite()) {
905            return Err(invalid_argument("kernel values must be finite"));
906        }
907        Ok(())
908    }
909
910    /// Creates a 3x3 identity kernel with a `1.0` center coefficient.
911    pub fn identity_3x3() -> Self {
912        Self {
913            width: 3,
914            height: 3,
915            values: vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
916        }
917    }
918
919    /// Creates a standard 3x3 sharpen kernel.
920    pub fn sharpen_3x3() -> Self {
921        Self {
922            width: 3,
923            height: 3,
924            values: vec![0.0, -1.0, 0.0, -1.0, 5.0, -1.0, 0.0, -1.0, 0.0],
925        }
926    }
927
928    /// Creates a standard 3x3 edge-detection kernel.
929    pub fn edge_3x3() -> Self {
930        Self {
931            width: 3,
932            height: 3,
933            values: vec![-1.0, -1.0, -1.0, -1.0, 8.0, -1.0, -1.0, -1.0, -1.0],
934        }
935    }
936
937    /// Creates an unnormalized 3x3 box blur kernel.
938    pub fn blur_3x3() -> Self {
939        Self {
940            width: 3,
941            height: 3,
942            values: vec![1.0; 9],
943        }
944    }
945
946    /// Copies a 3x3 kernel into a fixed-size row-major array.
947    pub fn as_array_3x3(&self) -> Result<[f32; 9]> {
948        if self.width != 3 || self.height != 3 {
949            return Err(invalid_argument("kernel is not 3x3"));
950        }
951        Ok(self
952            .values
953            .clone()
954            .try_into()
955            .expect("kernel length is validated"))
956    }
957}
958
959impl From<[f32; 9]> for Kernel2d {
960    fn from(value: [f32; 9]) -> Self {
961        Self {
962            width: 3,
963            height: 3,
964            values: value.to_vec(),
965        }
966    }
967}
968
969#[derive(Debug, Clone, PartialEq)]
970/// Finite 1D convolution kernel.
971pub struct Kernel1d {
972    values: Vec<f32>,
973}
974
975impl Kernel1d {
976    /// Creates a non-empty kernel and rejects non-finite values.
977    pub fn new(values: impl Into<Vec<f32>>) -> Result<Self> {
978        let kernel = Self {
979            values: values.into(),
980        };
981        kernel.validate()?;
982        Ok(kernel)
983    }
984
985    /// Borrows kernel coefficients in storage order.
986    pub fn values(&self) -> &[f32] {
987        &self.values
988    }
989
990    /// Verifies that the kernel is non-empty and finite.
991    pub fn validate(&self) -> Result<()> {
992        if self.values.is_empty() {
993            return Err(invalid_argument("1D kernel must not be empty"));
994        }
995        if self.values.iter().any(|value| !value.is_finite()) {
996            return Err(invalid_argument("1D kernel values must be finite"));
997        }
998        Ok(())
999    }
1000}
1001
1002impl TryFrom<&F32Tensor> for F32Matrix {
1003    type Error = DetectError;
1004
1005    fn try_from(value: &F32Tensor) -> Result<Self> {
1006        if value.shape().rank() != 2 {
1007            return Err(invalid_argument(
1008                "tensor-to-matrix conversion requires rank 2",
1009            ));
1010        }
1011        let dims = value.shape().dimensions();
1012        Self::new(MatrixShape::new(dims[0], dims[1])?, value.values().to_vec())
1013    }
1014}
1015
1016impl TryFrom<F32Tensor> for F32Matrix {
1017    type Error = DetectError;
1018
1019    fn try_from(value: F32Tensor) -> Result<Self> {
1020        if value.shape().rank() != 2 {
1021            return Err(invalid_argument(
1022                "tensor-to-matrix conversion requires rank 2",
1023            ));
1024        }
1025        let dims = value.shape().dimensions().to_vec();
1026        Self::new(MatrixShape::new(dims[0], dims[1])?, value.into_values())
1027    }
1028}
1029
1030impl<'a> TryFrom<F32TensorView<'a>> for F32MatrixView<'a> {
1031    type Error = DetectError;
1032
1033    fn try_from(value: F32TensorView<'a>) -> Result<Self> {
1034        if value.shape().rank() != 2 {
1035            return Err(invalid_argument(
1036                "tensor view to matrix view conversion requires rank 2",
1037            ));
1038        }
1039        let dims = value.shape().dimensions();
1040        Self::new(MatrixShape::new(dims[0], dims[1])?, value.values())
1041    }
1042}
1043
1044impl TryFrom<&F32Matrix> for F32Tensor {
1045    type Error = DetectError;
1046
1047    fn try_from(value: &F32Matrix) -> Result<Self> {
1048        F32Tensor::new(
1049            TensorShape::new([value.shape.rows, value.shape.cols])?,
1050            value.values.clone(),
1051        )
1052    }
1053}
1054
1055impl TryFrom<RowView<'_>> for DenseVector {
1056    type Error = DetectError;
1057
1058    fn try_from(value: RowView<'_>) -> Result<Self> {
1059        DenseVector::new(value.as_slice())
1060    }
1061}
1062
1063impl TryFrom<ColumnView<'_>> for DenseVector {
1064    type Error = DetectError;
1065
1066    fn try_from(value: ColumnView<'_>) -> Result<Self> {
1067        DenseVector::new(value.as_slice())
1068    }
1069}
1070
1071impl TryFrom<&DenseVector> for F32Matrix {
1072    type Error = DetectError;
1073
1074    fn try_from(value: &DenseVector) -> Result<Self> {
1075        F32Matrix::new(
1076            MatrixShape::new(1, value.dimensions())?,
1077            value.as_slice().to_vec(),
1078        )
1079    }
1080}
1081
1082#[cfg(test)]
1083mod tests {
1084    use super::*;
1085
1086    fn assert_close(left: f32, right: f32) {
1087        assert!((left - right).abs() < 1.0e-4, "expected {left} ≈ {right}");
1088    }
1089
1090    #[test]
1091    fn validates_shapes_and_stride_backed_views() {
1092        assert!(MatrixShape::new(0, 2).is_err());
1093        let matrix = F32Matrix::from_rows([[1.0, 2.0], [3.0, 4.0]]).unwrap();
1094        assert_eq!(
1095            matrix.transpose_view().row(0).unwrap().as_slice(),
1096            vec![1.0, 3.0]
1097        );
1098    }
1099
1100    #[test]
1101    fn matrix_multiply_matches_expected_values() {
1102        let left = F32Matrix::from_rows([[1.0, 2.0], [3.0, 4.0]]).unwrap();
1103        let right = F32Matrix::from_rows([[2.0, 0.0], [1.0, 2.0]]).unwrap();
1104        let product = left.matmul(&right.as_view()).unwrap();
1105        assert_eq!(product.values(), &[4.0, 4.0, 10.0, 8.0]);
1106    }
1107
1108    #[test]
1109    fn pairwise_row_cosine_and_kernel_helpers_work() {
1110        let matrix = F32Matrix::from_rows([[1.0, 0.0], [0.0, 1.0]]).unwrap();
1111        let cosine = matrix.pairwise_row_cosine(&matrix.as_view()).unwrap();
1112        assert_eq!(cosine.values(), &[1.0, 0.0, 0.0, 1.0]);
1113        assert_eq!(Kernel2d::sharpen_3x3().as_array_3x3().unwrap()[4], 5.0);
1114    }
1115
1116    #[test]
1117    fn tensor_and_vector_bridges_round_trip() {
1118        let tensor = F32Tensor::from_dims([2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1119        let matrix = F32Matrix::try_from(&tensor).unwrap();
1120        assert_eq!(
1121            matrix.row(1).unwrap().to_dense_vector().unwrap().as_slice(),
1122            &[3.0, 4.0]
1123        );
1124        let tensor_round_trip = F32Tensor::try_from(&matrix).unwrap();
1125        assert_eq!(tensor_round_trip.values(), tensor.values());
1126    }
1127
1128    #[test]
1129    fn identity_matrix_has_unit_diagonal() {
1130        let matrix = F32Matrix::identity(3).unwrap();
1131
1132        assert_eq!(
1133            matrix.values(),
1134            &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
1135        );
1136    }
1137
1138    #[test]
1139    fn transpose_owned_round_trip_restores_original() {
1140        let matrix = F32Matrix::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).unwrap();
1141        let transposed = matrix.as_view().transpose_owned().unwrap();
1142        assert_eq!(transposed.shape(), MatrixShape { rows: 3, cols: 2 });
1143        assert_eq!(transposed.values(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1144
1145        let round_trip = transposed.as_view().transpose_owned().unwrap();
1146        assert_eq!(round_trip, matrix);
1147    }
1148
1149    #[test]
1150    fn add_sub_and_scale_produce_expected_values() {
1151        let left = F32Matrix::from_rows([[1.0, 2.0], [3.0, 4.0]]).unwrap();
1152        let right = F32Matrix::from_rows([[5.0, 6.0], [7.0, 8.0]]).unwrap();
1153
1154        let added = left.as_view().add(&right.as_view()).unwrap();
1155        assert_eq!(added.values(), &[6.0, 8.0, 10.0, 12.0]);
1156
1157        let subtracted = right.as_view().sub(&left.as_view()).unwrap();
1158        assert_eq!(subtracted.values(), &[4.0, 4.0, 4.0, 4.0]);
1159
1160        let scaled = left.as_view().scale(0.5).unwrap();
1161        assert_eq!(scaled.values(), &[0.5, 1.0, 1.5, 2.0]);
1162        assert!(left.as_view().scale(f32::NAN).is_err());
1163    }
1164
1165    #[test]
1166    fn frobenius_norm_and_means_are_correct() {
1167        let matrix = F32Matrix::from_rows([[1.0, 2.0], [3.0, 4.0]]).unwrap();
1168
1169        assert_close(matrix.as_view().frobenius_norm().unwrap(), 30.0_f32.sqrt());
1170        assert_eq!(matrix.as_view().row_means().unwrap(), vec![1.5, 3.5]);
1171        assert_eq!(matrix.as_view().column_means().unwrap(), vec![2.0, 3.0]);
1172    }
1173
1174    #[test]
1175    fn trace_requires_square_matrix() {
1176        let matrix = F32Matrix::from_rows([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]).unwrap();
1177
1178        assert!(matrix.as_view().trace().is_err());
1179    }
1180
1181    #[test]
1182    fn lu_determinant_works_for_small_matrices() {
1183        let matrix_2x2 = F32Matrix::from_rows([[2.0, 1.0], [1.0, 3.0]]).unwrap();
1184        assert_close(matrix_2x2.determinant().unwrap(), 5.0);
1185
1186        let matrix_3x3 =
1187            F32Matrix::from_rows([[6.0, 1.0, 1.0], [4.0, -2.0, 5.0], [2.0, 8.0, 7.0]]).unwrap();
1188        assert_close(matrix_3x3.determinant().unwrap(), -306.0);
1189    }
1190
1191    #[test]
1192    fn solving_vector_returns_expected_values() {
1193        let matrix = F32Matrix::from_rows([[2.0, 1.0], [1.0, 3.0]]).unwrap();
1194        let solution = matrix.solve_vector(&[1.0, 2.0]).unwrap();
1195
1196        assert_close(solution[0], 0.2);
1197        assert_close(solution[1], 0.6);
1198    }
1199
1200    #[test]
1201    fn matrix_inverse_multiplies_to_identity() {
1202        let matrix = F32Matrix::from_rows([[4.0, 7.0], [2.0, 6.0]]).unwrap();
1203        let inverse = matrix.inverse().unwrap();
1204        let product = matrix.matmul(&inverse.as_view()).unwrap();
1205
1206        assert_close(product.as_view().get(0, 0).unwrap(), 1.0);
1207        assert_close(product.as_view().get(0, 1).unwrap(), 0.0);
1208        assert_close(product.as_view().get(1, 0).unwrap(), 0.0);
1209        assert_close(product.as_view().get(1, 1).unwrap(), 1.0);
1210    }
1211
1212    #[test]
1213    fn singular_matrix_returns_error() {
1214        let matrix = F32Matrix::from_rows([[1.0, 2.0], [2.0, 4.0]]).unwrap();
1215
1216        assert!(matrix.lu_decompose().is_err());
1217        assert!(matrix.inverse().is_err());
1218    }
1219
1220    #[test]
1221    fn pivoting_handles_zero_initial_pivot() {
1222        let matrix = F32Matrix::from_rows([[0.0, 2.0], [1.0, 3.0]]).unwrap();
1223        let decomposition = matrix.lu_decompose().unwrap();
1224        let solution = decomposition.solve_vector(&[4.0, 5.0]).unwrap();
1225
1226        assert_eq!(decomposition.swap_count(), 1);
1227        assert_close(decomposition.determinant().unwrap(), -2.0);
1228        assert_close(solution[0], -1.0);
1229        assert_close(solution[1], 2.0);
1230    }
1231}