1mod ffi {
21 extern "C" {
22 pub fn mat_mat(rows: usize, cols: usize) -> *mut Mat;
23 pub fn mat_from(rows: usize, cols: usize, values: *const f32) -> *mut Mat;
24 pub fn mat_reye(dim: usize) -> *mut Mat;
25 pub fn mat_rmul(a: *const Mat, b: *const Mat) -> *mut Mat;
26 pub fn mat_radd(a: *const Mat, b: *const Mat) -> *mut Mat;
27 pub fn mat_at(m: *const Mat, row: usize, col: usize) -> f32;
28 pub fn mat_equals(a: *const Mat, b: *const Mat) -> bool;
29 pub fn mat_print(m: *const Mat);
30 pub fn mat_free_mat(m: *mut Mat);
31 pub fn mat_vec(dim: usize) -> *mut Mat;
32 pub fn mat_eigvals(out: *mut Mat, a: *const Mat);
33 pub fn mat_eigvals_sym(out: *mut Mat, a: *const Mat);
34 pub fn mat_eigen_sym(v: *mut Mat, eigenvalues: *mut Mat, a: *const Mat);
35 pub fn mat_eigen(v: *mut Mat, eigenvalues: *mut Mat, a: *const Mat);
36 }
37
38 #[repr(C)]
39 pub struct Mat {
40 pub rows: usize,
41 pub cols: usize,
42 pub data: *mut f32,
43 }
44}
45
46pub struct Mat(*mut ffi::Mat);
51
52impl Mat {
53 pub fn new(rows: usize, cols: usize) -> Self {
55 Mat(unsafe { ffi::mat_mat(rows, cols) })
56 }
57
58 pub fn from_slice(rows: usize, cols: usize, data: &[f32]) -> Self {
60 Mat(unsafe { ffi::mat_from(rows, cols, data.as_ptr()) })
61 }
62
63 pub fn eye(dim: usize) -> Self {
65 Mat(unsafe { ffi::mat_reye(dim) })
66 }
67
68 pub fn mul(&self, other: &Mat) -> Self {
70 Mat(unsafe { ffi::mat_rmul(self.0, other.0) })
71 }
72
73 pub fn add(&self, other: &Mat) -> Self {
75 Mat(unsafe { ffi::mat_radd(self.0, other.0) })
76 }
77
78 pub fn at(&self, row: usize, col: usize) -> f32 {
80 unsafe { ffi::mat_at(self.0, row, col) }
81 }
82
83 pub fn equals(&self, other: &Mat) -> bool {
85 unsafe { ffi::mat_equals(self.0, other.0) }
86 }
87
88 pub fn print(&self) {
90 unsafe { ffi::mat_print(self.0) }
91 }
92
93 pub fn rows(&self) -> usize {
95 unsafe { (*self.0).rows }
96 }
97
98 pub fn cols(&self) -> usize {
100 unsafe { (*self.0).cols }
101 }
102
103 pub fn eigvals(&self) -> Mat {
108 let dim = self.rows();
109 let out = Mat(unsafe { ffi::mat_vec(dim) });
110 unsafe { ffi::mat_eigvals(out.0, self.0) };
111 out
112 }
113
114 pub fn eigvals_sym(&self) -> Mat {
119 let dim = self.rows();
120 let out = Mat(unsafe { ffi::mat_vec(dim) });
121 unsafe { ffi::mat_eigvals_sym(out.0, self.0) };
122 out
123 }
124
125 pub fn eigen_sym(&self) -> Eigen {
130 let dim = self.rows();
131 let v = Mat(unsafe { ffi::mat_mat(dim, dim) });
132 let eigenvalues = Mat(unsafe { ffi::mat_vec(dim) });
133 unsafe { ffi::mat_eigen_sym(v.0, eigenvalues.0, self.0) };
134 Eigen {
135 eigenvalues,
136 eigenvectors: v,
137 }
138 }
139
140 pub fn eigen(&self) -> Eigen {
145 let dim = self.rows();
146 let v = Mat(unsafe { ffi::mat_mat(dim, dim) });
147 let eigenvalues = Mat(unsafe { ffi::mat_vec(dim) });
148 unsafe { ffi::mat_eigen(v.0, eigenvalues.0, self.0) };
149 Eigen {
150 eigenvalues,
151 eigenvectors: v,
152 }
153 }
154}
155
156pub struct Eigen {
160 pub eigenvalues: Mat,
161 pub eigenvectors: Mat,
162}
163
164impl Drop for Mat {
165 fn drop(&mut self) {
166 unsafe { ffi::mat_free_mat(self.0) }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn test_eye() {
176 let m = Mat::eye(3);
177 assert_eq!(m.at(0, 0), 1.0);
178 assert_eq!(m.at(1, 1), 1.0);
179 assert_eq!(m.at(2, 2), 1.0);
180 assert_eq!(m.at(0, 1), 0.0);
181 }
182
183 #[test]
184 fn test_mul() {
185 let a = Mat::from_slice(2, 2, &[1.0, 2.0, 3.0, 4.0]);
186 let b = Mat::eye(2);
187 let c = a.mul(&b);
188 assert!(c.equals(&a));
189 }
190
191 #[test]
192 fn test_eigen_sym_identity() {
193 let m = Mat::eye(3);
194 let eig = m.eigen_sym();
195 for i in 0..3 {
196 assert!((eig.eigenvalues.at(i, 0) - 1.0).abs() < 1e-5);
197 }
198 assert!(eig.eigenvectors.equals(&Mat::eye(3)));
199 }
200
201 #[test]
202 fn test_eigen_sym_diagonal() {
203 let m = Mat::from_slice(2, 2, &[3.0, 0.0, 0.0, 7.0]);
204 let eig = m.eigen_sym();
205 assert!((eig.eigenvalues.at(0, 0) - 3.0).abs() < 1e-5);
206 assert!((eig.eigenvalues.at(1, 0) - 7.0).abs() < 1e-5);
207 }
208
209 #[test]
210 fn test_eigvals_sym() {
211 let m = Mat::from_slice(2, 2, &[2.0, 1.0, 1.0, 2.0]);
212 let vals = m.eigvals_sym();
213 let mut ev = [vals.at(0, 0), vals.at(1, 0)];
214 ev.sort_by(|a, b| a.partial_cmp(b).unwrap());
215 assert!((ev[0] - 1.0).abs() < 1e-5);
216 assert!((ev[1] - 3.0).abs() < 1e-5);
217 }
218
219 #[test]
220 fn test_eigvals_general() {
221 let m = Mat::eye(2);
222 let vals = m.eigvals();
223 assert!((vals.at(0, 0) - 1.0).abs() < 1e-5);
224 assert!((vals.at(1, 0) - 1.0).abs() < 1e-5);
225 }
226}