Skip to main content

libmat_rs/
lib.rs

1//! Rust bindings for [libmat](https://github.com/alvgaona/libmat), an stb-style single-header
2//! linear algebra library in pure C.
3//!
4//! # Usage
5//!
6//! ```
7//! use libmat_rs::Mat;
8//!
9//! let a = Mat::from_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
10//! let b = Mat::eye(2);
11//! let c = a.mul(&b);
12//! assert!(c.equals(&a));
13//! ```
14//!
15//! # Storage
16//!
17//! Matrices use **column-major** storage (BLAS-compatible). When constructing from a slice,
18//! values are laid out column by column.
19
20mod ffi {
21    extern "C" {
22        pub fn mat_mat(rows: usize, cols: usize) -> *mut Mat;
23        pub fn mat_from(rows: usize, cols: usize, values: *const f32) -> *mut Mat;
24        pub fn mat_reye(dim: usize) -> *mut Mat;
25        pub fn mat_rmul(a: *const Mat, b: *const Mat) -> *mut Mat;
26        pub fn mat_radd(a: *const Mat, b: *const Mat) -> *mut Mat;
27        pub fn mat_at(m: *const Mat, row: usize, col: usize) -> f32;
28        pub fn mat_equals(a: *const Mat, b: *const Mat) -> bool;
29        pub fn mat_print(m: *const Mat);
30        pub fn mat_free_mat(m: *mut Mat);
31        pub fn mat_vec(dim: usize) -> *mut Mat;
32        pub fn mat_eigvals(out: *mut Mat, a: *const Mat);
33        pub fn mat_eigvals_sym(out: *mut Mat, a: *const Mat);
34        pub fn mat_eigen_sym(v: *mut Mat, eigenvalues: *mut Mat, a: *const Mat);
35        pub fn mat_eigen(v: *mut Mat, eigenvalues: *mut Mat, a: *const Mat);
36    }
37
38    #[repr(C)]
39    pub struct Mat {
40        pub rows: usize,
41        pub cols: usize,
42        pub data: *mut f32,
43    }
44}
45
46/// A matrix backed by libmat's C implementation.
47///
48/// Owns its underlying C allocation and frees it on drop.
49/// Uses column-major storage.
50pub struct Mat(*mut ffi::Mat);
51
52impl Mat {
53    /// Creates a zero-initialized matrix with the given dimensions.
54    pub fn new(rows: usize, cols: usize) -> Self {
55        Mat(unsafe { ffi::mat_mat(rows, cols) })
56    }
57
58    /// Creates a matrix from a slice of values in column-major order.
59    pub fn from_slice(rows: usize, cols: usize, data: &[f32]) -> Self {
60        Mat(unsafe { ffi::mat_from(rows, cols, data.as_ptr()) })
61    }
62
63    /// Creates an identity matrix of the given dimension.
64    pub fn eye(dim: usize) -> Self {
65        Mat(unsafe { ffi::mat_reye(dim) })
66    }
67
68    /// Returns the matrix product `self * other`.
69    pub fn mul(&self, other: &Mat) -> Self {
70        Mat(unsafe { ffi::mat_rmul(self.0, other.0) })
71    }
72
73    /// Returns the element-wise sum `self + other`.
74    pub fn add(&self, other: &Mat) -> Self {
75        Mat(unsafe { ffi::mat_radd(self.0, other.0) })
76    }
77
78    /// Returns the element at `(row, col)`.
79    pub fn at(&self, row: usize, col: usize) -> f32 {
80        unsafe { ffi::mat_at(self.0, row, col) }
81    }
82
83    /// Returns `true` if all elements are equal.
84    pub fn equals(&self, other: &Mat) -> bool {
85        unsafe { ffi::mat_equals(self.0, other.0) }
86    }
87
88    /// Prints the matrix to stdout.
89    pub fn print(&self) {
90        unsafe { ffi::mat_print(self.0) }
91    }
92
93    /// Returns the number of rows.
94    pub fn rows(&self) -> usize {
95        unsafe { (*self.0).rows }
96    }
97
98    /// Returns the number of columns.
99    pub fn cols(&self) -> usize {
100        unsafe { (*self.0).cols }
101    }
102
103    /// Computes eigenvalues of a general square matrix.
104    ///
105    /// Uses Hessenberg reduction followed by implicit QR iteration.
106    /// For complex eigenvalues (conjugate pairs), only the real part is stored.
107    pub fn eigvals(&self) -> Mat {
108        let dim = self.rows();
109        let out = Mat(unsafe { ffi::mat_vec(dim) });
110        unsafe { ffi::mat_eigvals(out.0, self.0) };
111        out
112    }
113
114    /// Computes eigenvalues of a symmetric matrix.
115    ///
116    /// Faster than [`eigvals`](Self::eigvals) for symmetric input.
117    /// Uses tridiagonal reduction + implicit QR iteration.
118    pub fn eigvals_sym(&self) -> Mat {
119        let dim = self.rows();
120        let out = Mat(unsafe { ffi::mat_vec(dim) });
121        unsafe { ffi::mat_eigvals_sym(out.0, self.0) };
122        out
123    }
124
125    /// Computes eigendecomposition of a symmetric matrix: `A = V * diag(eigenvalues) * V^T`.
126    ///
127    /// Returns an [`Eigen`] where `eigenvectors` columns are orthogonal eigenvectors
128    /// and `eigenvalues` are sorted in ascending order.
129    pub fn eigen_sym(&self) -> Eigen {
130        let dim = self.rows();
131        let v = Mat(unsafe { ffi::mat_mat(dim, dim) });
132        let eigenvalues = Mat(unsafe { ffi::mat_vec(dim) });
133        unsafe { ffi::mat_eigen_sym(v.0, eigenvalues.0, self.0) };
134        Eigen {
135            eigenvalues,
136            eigenvectors: v,
137        }
138    }
139
140    /// Computes eigendecomposition of a general square matrix.
141    ///
142    /// Returns an [`Eigen`] where `eigenvectors` columns are the eigenvectors
143    /// and `eigenvalues` holds eigenvalues (real parts for complex conjugate pairs).
144    pub fn eigen(&self) -> Eigen {
145        let dim = self.rows();
146        let v = Mat(unsafe { ffi::mat_mat(dim, dim) });
147        let eigenvalues = Mat(unsafe { ffi::mat_vec(dim) });
148        unsafe { ffi::mat_eigen(v.0, eigenvalues.0, self.0) };
149        Eigen {
150            eigenvalues,
151            eigenvectors: v,
152        }
153    }
154}
155
156/// Result of an eigendecomposition.
157///
158/// Contains eigenvalues as a column vector and eigenvectors as columns of a matrix.
159pub struct Eigen {
160    pub eigenvalues: Mat,
161    pub eigenvectors: Mat,
162}
163
164impl Drop for Mat {
165    fn drop(&mut self) {
166        unsafe { ffi::mat_free_mat(self.0) }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_eye() {
176        let m = Mat::eye(3);
177        assert_eq!(m.at(0, 0), 1.0);
178        assert_eq!(m.at(1, 1), 1.0);
179        assert_eq!(m.at(2, 2), 1.0);
180        assert_eq!(m.at(0, 1), 0.0);
181    }
182
183    #[test]
184    fn test_mul() {
185        let a = Mat::from_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
186        let b = Mat::eye(2);
187        let c = a.mul(&b);
188        assert!(c.equals(&a));
189    }
190
191    #[test]
192    fn test_eigen_sym_identity() {
193        let m = Mat::eye(3);
194        let eig = m.eigen_sym();
195        for i in 0..3 {
196            assert!((eig.eigenvalues.at(i, 0) - 1.0).abs() < 1e-5);
197        }
198        assert!(eig.eigenvectors.equals(&Mat::eye(3)));
199    }
200
201    #[test]
202    fn test_eigen_sym_diagonal() {
203        let m = Mat::from_slice(2, 2, &[3.0, 0.0, 0.0, 7.0]);
204        let eig = m.eigen_sym();
205        assert!((eig.eigenvalues.at(0, 0) - 3.0).abs() < 1e-5);
206        assert!((eig.eigenvalues.at(1, 0) - 7.0).abs() < 1e-5);
207    }
208
209    #[test]
210    fn test_eigvals_sym() {
211        let m = Mat::from_slice(2, 2, &[2.0, 1.0, 1.0, 2.0]);
212        let vals = m.eigvals_sym();
213        let mut ev = [vals.at(0, 0), vals.at(1, 0)];
214        ev.sort_by(|a, b| a.partial_cmp(b).unwrap());
215        assert!((ev[0] - 1.0).abs() < 1e-5);
216        assert!((ev[1] - 3.0).abs() < 1e-5);
217    }
218
219    #[test]
220    fn test_eigvals_general() {
221        let m = Mat::eye(2);
222        let vals = m.eigvals();
223        assert!((vals.at(0, 0) - 1.0).abs() < 1e-5);
224        assert!((vals.at(1, 0) - 1.0).abs() < 1e-5);
225    }
226}