1#[cfg(feature = "blas")]
8extern crate blas_crate;
9#[cfg(feature = "lapack")]
10extern crate lapack_crate;
11
12#[cfg(all(target_os = "macos", any(feature = "blas", feature = "lapack")))]
13extern crate accelerate_src;
14#[cfg(all(not(target_os = "macos"), any(feature = "blas", feature = "lapack")))]
15extern crate openblas_src;
16
17#[cfg(feature = "blas")]
18use blas_crate::{ddot, dgemm};
19#[cfg(feature = "lapack")]
20use lapack_crate::{dgesv, dgetrf, dgetri};
21
22#[cfg(not(feature = "lapack"))]
23use super::decomposition::*;
24use super::Vector;
25use crate::prelude::max;
26
27pub fn arange(start: f64, stop: f64, step: f64) -> Vector {
30 let n = (stop - start) / step;
31 (0..n as usize)
32 .map(|i| start as f64 + i as f64 * step)
33 .collect::<Vector>()
34}
35
36pub fn linspace(start: f64, stop: f64, num: usize) -> Vector {
39 let width = (stop - start) / (num - 1) as f64;
40 (0..num)
41 .map(|i| start + i as f64 * width)
42 .collect::<Vector>()
43}
44
45pub fn is_matrix(m: &[f64], nrows: usize) -> Result<usize, String> {
47 let ncols = m.len() / nrows;
48 if nrows * ncols == m.len() {
49 Ok(ncols)
50 } else {
51 Err("Not a matrix".to_string())
52 }
53}
54
55pub fn is_square(m: &[f64]) -> Result<usize, String> {
57 let n = (m.len() as f32).sqrt();
58 if n % 1. == 0. {
59 Ok(n as usize)
60 } else {
61 Err("Matrix not square".to_string())
62 }
63}
64
65pub fn is_design(m: &[f64], nrows: usize) -> bool {
67 let mut is_design = true;
68 let ncols = is_matrix(m, nrows).unwrap();
69 for i in 0..nrows {
70 if (m[i * ncols] - 1.).abs() > f64::EPSILON {
71 is_design = false;
72 }
73 }
74 is_design
75}
76
77pub fn is_symmetric(m: &[f64]) -> bool {
79 let n = is_square(m).unwrap();
80 for i in 0..n {
81 for j in i..n {
82 if (m[i * n + j] - m[j * n + i]).abs() > f64::EPSILON {
83 return false;
84 }
85 }
86 }
87 true
88}
89
90pub fn is_positive_definite(m: &[f64]) -> bool {
92 if !is_symmetric(m) {
93 return false;
94 }
95 let n = is_square(m).unwrap();
96 for i in 0..n {
97 if m[i * n + i] <= 0. {
98 return false;
99 }
100 }
101 true
102}
103
104pub fn ipiv_parity(ipiv: &[i32]) -> i32 {
107 let mut perm = ipiv.to_owned();
108 let mut par = 0;
109 for i in 0..perm.len() {
110 if perm[i] != i as i32 {
111 let j = perm[i] as usize;
112 perm.swap(i, j);
113 par += 1;
114 }
115 }
116 (-1_i32).pow(par)
117}
118
119pub fn diag(a: &[f64]) -> Vector {
120 let n = is_square(a).unwrap();
121 let mut results = Vector::new(Vec::with_capacity(n));
122 for i in 0..n {
123 results.push(a[i * n + i]);
124 }
125 results
126}
127
128pub fn row_to_col_major(a: &[f64], nrows: usize) -> Vector {
130 let ncols = is_matrix(a, nrows).unwrap();
131 let mut x = Vector::new(a.to_vec());
132 for i in 0..nrows {
133 for j in 0..ncols {
134 x[j * nrows + i] = a[i * ncols + j];
135 }
136 }
137 x
138}
139
140pub fn col_to_row_major(a: &[f64], nrows: usize) -> Vec<f64> {
142 let ncols = is_matrix(a, nrows).unwrap();
143 let mut x = a.to_vec();
144 for i in 0..nrows {
145 for j in 0..ncols {
146 x[i * ncols + j] = a[j * nrows + i];
147 }
148 }
149 x
150}
151
152pub fn transpose(a: &[f64], nrows: usize) -> Vec<f64> {
154 let ncols = is_matrix(a, nrows).unwrap();
155
156 let mut at = Vec::with_capacity(a.len());
157
158 for j in 0..ncols {
159 for i in 0..nrows {
160 at.push(a[i * ncols + j]);
161 }
162 }
163
164 at
165}
166
167pub fn diag_matrix(a: &[f64]) -> Vector {
169 let n = a.len();
170 let mut new = Vector::new(vec![0.; n * n]);
171 for i in 0..n {
172 new[i * n + i] = a[i];
173 }
174 new
175}
176
177pub fn invert_matrix(matrix: &[f64]) -> Vec<f64> {
179 let n = is_square(matrix).unwrap();
180 #[cfg(feature = "lapack")]
181 {
182 let n = n as i32;
183 let mut a = matrix.to_vec();
184 let mut ipiv = vec![0; n as usize];
185 let mut info: i32 = 0;
186 let mut work = vec![0.; 1];
187 unsafe {
188 dgetri(n, &mut a, n, &ipiv, &mut work, -1, &mut info);
189 assert_eq!(info, 0, "dgetri failed");
190 }
191 let lwork = work[0] as usize;
192 work.extend_from_slice(&vec![0.; lwork - 1]);
193 unsafe {
194 dgetrf(n, n, &mut a, n, &mut ipiv, &mut info);
195 assert_eq!(info, 0, "dgetrf failed");
196 }
197 unsafe {
198 dgetri(n, &mut a, n, &ipiv, &mut work, lwork as i32, &mut info);
199 assert_eq!(info, 0, "dgetri failed");
200 }
201 a
202 }
203
204 #[cfg(not(feature = "lapack"))]
205 {
206 let ones = diag_matrix(&vec![1.; n]);
207 solve_sys(matrix, &ones)
208 }
209}
210
211pub fn xtx(x: &[f64], k: usize) -> Vec<f64> {
213 matmul(x, x, k, k, true, false)
214}
215
216pub fn solve_sys(a: &[f64], b: &[f64]) -> Vec<f64> {
218 let n = is_square(a).unwrap();
219 let nsys = is_matrix(b, n).unwrap();
220
221 #[cfg(feature = "lapack")]
222 {
223 let mut a = row_to_col_major(&a, n);
224 let mut b = row_to_col_major(&b, n);
225 let mut ipiv = vec![0; n as usize];
226 let mut info = 0;
227 unsafe {
228 dgesv(
229 n as i32,
230 nsys as i32,
231 &mut a,
232 n as i32,
233 &mut ipiv,
234 &mut b,
235 n as i32,
236 &mut info,
237 );
238 assert_eq!(info, 0, "dgesv failed");
239 }
240 col_to_row_major(&b, n)
241 }
242
243 #[cfg(not(feature = "lapack"))]
244 {
245 let mut solutions = Vec::with_capacity(b.len());
246 let b = row_to_col_major(b, n);
247
248 if is_positive_definite(a) {
249 let l = cholesky(a);
250 for i in 0..nsys {
251 let sol = cholesky_solve(&l, &b[(i * n)..((i + 1) * n)]);
252 assert_eq!(sol.len(), n);
253 solutions.extend_from_slice(&sol);
254 }
255 } else {
256 let (lu, piv) = lu(a);
257 for i in 0..nsys {
258 let sol = lu_solve(&lu, &piv, &b[(i * n)..((i + 1) * n)]);
259 assert_eq!(sol.len(), n);
260 solutions.extend_from_slice(&sol);
261 }
262 }
263
264 col_to_row_major(&solutions, n)
265 }
266}
267
268pub fn solve(a: &[f64], b: &[f64]) -> Vec<f64> {
270 let n = b.len();
271 assert!(a.len() == n * n);
272
273 #[cfg(feature = "lapack")]
274 {
275 let mut lu = row_to_col_major(&a, n);
276 let mut ipiv = vec![0; n as usize];
277 let mut result = b.to_vec();
278 let mut info = 0;
279 unsafe {
280 dgesv(
281 n as i32,
282 1,
283 &mut lu,
284 n as i32,
285 &mut ipiv,
286 &mut result,
287 n as i32,
288 &mut info,
289 );
290 assert_eq!(info, 0, "dgesv failed");
291 }
292 result
293 }
294
295 #[cfg(not(feature = "lapack"))]
296 {
297 if is_positive_definite(a) {
298 let l = cholesky(a);
299 cholesky_solve(&l, b)
300 } else {
301 let (lu, piv) = lu(a);
302 lu_solve(&lu, &piv, b)
303 }
304 }
305}
306
307pub fn matmul_blocked(
310 a: &[f64],
311 b: &[f64],
312 rows_a: usize,
313 rows_b: usize,
314 transpose_a: bool,
315 transpose_b: bool,
316 bsize: usize,
317) -> Vec<f64> {
318 let cols_a = is_matrix(a, rows_a).unwrap();
319 let cols_b = is_matrix(b, rows_b).unwrap();
320
321 let m = if transpose_a { cols_a } else { rows_a };
322 let l = if transpose_a { rows_a } else { cols_a };
323 let n = if transpose_b { rows_b } else { cols_b };
324
325 let mut c = vec![0.; m * n];
326
327 let a = if transpose_a {
328 transpose(a, rows_a)
329 } else {
330 a.to_vec()
331 };
332 let b = if transpose_b {
333 transpose(b, rows_b)
334 } else {
335 b.to_vec()
336 };
337
338 for jj in 0..(n / bsize + 1) {
340 for kk in 0..(l / bsize + 1) {
341 for i in 0..m {
342 for k in (kk * bsize)..std::cmp::min((kk * bsize) + bsize, l) {
343 let temp = a[i * l + k];
344 for j in (jj * bsize)..std::cmp::min((jj * bsize) + bsize, n) {
345 c[i * n + j] += temp * b[k * n + j];
346 }
347 }
348 }
349 }
350 }
351
352 c
353}
354
355pub fn matmul(
357 a: &[f64],
358 b: &[f64],
359 rows_a: usize,
360 rows_b: usize,
361 transpose_a: bool,
362 transpose_b: bool,
363) -> Vec<f64> {
364 let cols_a = is_matrix(a, rows_a).unwrap();
365 let cols_b = is_matrix(b, rows_b).unwrap();
366
367 #[cfg(feature = "blas")]
368 {
369 let (cols_a, rows_a) = (rows_a, cols_a);
376 let (cols_b, rows_b) = (rows_b, cols_b);
377
378 let (transpose_a, transpose_b) = (!transpose_a, !transpose_b);
379
380 let m = if transpose_a { cols_a } else { rows_a };
382 let n = if transpose_b { rows_b } else { cols_b };
383 let k = if transpose_a { rows_a } else { cols_a };
384
385 let trans_a = if transpose_a { b'T' } else { b'N' };
386 let trans_b = if transpose_b { b'T' } else { b'N' };
387
388 let alpha = 1.;
389 let beta = 0.;
390
391 let lda = rows_a;
392 let ldb = rows_b;
393 let ldc = m;
394
395 if transpose_a {
397 assert!(lda >= k, "lda={} must be at least as large as k={}", lda, k);
398 } else {
399 assert!(lda >= m, "lda={} must be at least as large as m={}", lda, m);
400 }
401
402 if transpose_b {
403 assert!(ldb >= n, "ldb={} must be at least as large as n={}", ldb, n);
404 } else {
405 assert!(ldb >= k, "ldb={} must be at least as large as k={}", ldb, k);
406 }
407
408 let mut c = vec![0.; (ldc * n) as usize];
409
410 unsafe {
411 dgemm(
412 trans_a, trans_b, m as i32, n as i32, k as i32, alpha, a, lda as i32, b,
413 ldb as i32, beta, &mut c, ldc as i32,
414 );
415 }
416
417 transpose(&c, n)
419 }
420
421 #[cfg(not(feature = "blas"))]
422 {
423 if transpose_a && transpose_b {
425 return transpose(&matmul(a, b, rows_a, rows_b, false, false), cols_a);
426 }
427
428 let m = if transpose_a { cols_a } else { rows_a };
429 let l = if transpose_a { rows_a } else { cols_a };
430 let n = if transpose_b { rows_b } else { cols_b };
431
432 let mut c = vec![0.; m * n];
433
434 let a = if transpose_a {
435 transpose(a, rows_a)
436 } else {
437 a.to_vec()
438 };
439 let b = if transpose_b {
440 transpose(b, rows_b)
441 } else {
442 b.to_vec()
443 };
444
445 for i in 0..m {
446 for k in 0..l {
447 let temp = a[i * l + k];
448 for j in 0..n {
449 c[i * n + j] += temp * b[k * n + j];
450 }
451 }
452 }
453
454 c
455 }
456}
457
458pub fn design(x: &[f64], rows: usize) -> Vec<f64> {
460 let mut ones = vec![1.; rows];
461 ones.extend_from_slice(x);
462 col_to_row_major(&ones, rows)
463}
464
465pub fn vandermonde(x: &[f64], n: usize) -> Vec<f64> {
468 let mut vm = Vec::with_capacity(x.len() * n);
469
470 for v in x.iter() {
471 for i in 0..n {
472 vm.push(v.powi(i as i32));
473 }
474 }
475
476 vm
477}
478
479pub fn toeplitz(x: &[f64]) -> Vec<f64> {
483 let n = x.len();
484 let mut v = vec![0.; n * n];
485 for i in 0..n as i32 {
486 for j in 0..n as i32 {
487 v[(i * n as i32 + j) as usize] = x[(i - j).abs() as usize];
488 }
489 }
490 v
491}
492
493pub fn sum(x: &[f64]) -> f64 {
495 let n = x.len();
496
497 #[cfg(feature = "blas")]
498 {
499 let y = [1.];
500 unsafe { ddot(n as i32, x, 1, &y, 0) }
501 }
502
503 #[cfg(not(feature = "blas"))]
504 {
505 let chunks = (n - (n % 8)) / 8;
506 let mut s = 0.;
507
508 for i in 0..chunks {
510 let idx = i * 8;
511 assert!(n > idx + 7);
512 s += x[idx]
513 + x[idx + 1]
514 + x[idx + 2]
515 + x[idx + 3]
516 + x[idx + 4]
517 + x[idx + 5]
518 + x[idx + 6]
519 + x[idx + 7];
520 }
521
522 for j in x.iter().take(n).skip(chunks * 8) {
530 s += j;
531 }
532
533 s
534 }
535}
536
537pub fn logsumexp(x: &[f64]) -> f64 {
540 let xmax = max(x);
541 x.iter().map(|v| (v - xmax).exp()).sum::<f64>().ln() + xmax
542 }
545
546pub fn logmeanexp(x: &[f64]) -> f64 {
549 let xmax = max(x);
550 (x.iter().map(|v| (v - xmax).exp()).sum::<f64>() / x.len() as f64).ln() + xmax
551 }
554
555pub fn prod(x: &[f64]) -> f64 {
557 x.iter().product()
558}
559
560pub fn dot(x: &[f64], y: &[f64]) -> f64 {
563 assert_eq!(x.len(), y.len());
564
565 #[cfg(feature = "blas")]
566 {
567 unsafe { ddot(x.len() as i32, x, 1, y, 1) }
568 }
569
570 #[cfg(not(feature = "blas"))]
571 {
572 let n = x.len();
573 let chunks = (n - (n % 8)) / 8;
574 let mut s = 0.;
575
576 for i in 0..chunks {
578 let idx = i * 8;
579 assert!(n > idx + 7);
580
581 s += x[idx] * y[idx]
582 + x[idx + 1] * y[idx + 1]
583 + x[idx + 2] * y[idx + 2]
584 + x[idx + 3] * y[idx + 3]
585 + x[idx + 4] * y[idx + 4]
586 + x[idx + 5] * y[idx + 5]
587 + x[idx + 6] * y[idx + 6]
588 + x[idx + 7] * y[idx + 7];
589 }
590
591 for j in (chunks * 8)..n {
593 s += x[j] * y[j];
594 }
595
596 s
597 }
598}
599
600pub fn norm(x: &[f64]) -> f64 {
602 dot(x, x).sqrt()
603}
604
605pub fn inf_norm(x: &[f64], nrows: usize) -> f64 {
608 let ncols = is_matrix(x, nrows).unwrap();
609 let mut abs_row_sums = Vec::with_capacity(nrows);
610 for i in 0..nrows {
611 let mut s = 0.;
612 for j in 0..ncols {
613 s += x[i * ncols + j].abs();
614 }
615 abs_row_sums.push(s);
616 }
617 max(&abs_row_sums)
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623 use approx_eq::assert_approx_eq;
624
625 #[test]
626 fn test_invert() {
627 let x = [
628 -0.46519316,
629 -3.1042875,
630 -5.01766541,
631 -1.86300107,
632 2.7692825,
633 2.3097699,
634 -12.3854289,
635 -8.70520295,
636 6.02201052,
637 -6.71212792,
638 -1.74683781,
639 -6.08893455,
640 -2.53731118,
641 2.72112893,
642 4.70204472,
643 -1.03387848,
644 ];
645 let inv = invert_matrix(&x);
646
647 let inv_ref = [
648 -0.25572126,
649 0.03156201,
650 0.06146028,
651 -0.16691749,
652 -0.16856104,
653 0.07197315,
654 -0.0498292,
655 -0.00880639,
656 -0.05192178,
657 -0.033113,
658 0.0482877,
659 0.08798427,
660 -0.0522019,
661 -0.03862469,
662 -0.06237155,
663 -0.18061673,
664 ];
665
666 for i in 0..x.len() {
667 assert_approx_eq!(inv[i], inv_ref[i]);
668 }
669 }
670
671 #[test]
672 fn test_matmul() {
673 let x = [
674 7., 2., 6., 5., 5., 5., 3., 9., 2., 2., 3., 9., 7., 9., 7., 8., 2., 7., 4., 5.,
675 ];
676 let nrows = 4;
677 let ncols = is_matrix(&x, nrows).unwrap();
678 assert_eq!(ncols, 5);
679
680 let xtx1 = matmul(&x, &x, 4, 4, true, false);
681 let xtx2 = matmul(&transpose(&x, 4), &x, 5, 4, false, false);
682 let xtx1b = matmul_blocked(&x, &x, 4, 4, true, false, 4);
683 let xtx2b = matmul_blocked(&transpose(&x, 4), &x, 5, 4, false, false, 4);
684
685 let y = vec![5., 5., 4., 6., 8., 5., 6., 4., 3., 6.];
686 let xty1 = matmul(&x, &y, 4, 5, false, false);
687 let xty2 = matmul(&y, &x, 2, 4, false, true);
688 let xty1b = matmul_blocked(&x, &y, 4, 5, false, false, 4);
689 let xty2b = matmul_blocked(&y, &x, 2, 4, false, true, 4);
690
691 assert_eq!(
692 xtx1,
693 vec![
694 147., 72., 164., 104., 106., 72., 98., 116., 105., 89., 164., 116., 215., 139.,
695 132., 104., 105., 139., 126., 112., 106., 89., 132., 112., 103.
696 ]
697 );
698 assert_eq!(xtx1, xtx2);
699 assert_eq!(xtx1, xtx1b);
700 assert_eq!(xtx2, xtx2b);
701 assert_eq!(xtx1b, xtx2b);
702 assert_eq!(xty1, vec![136., 127., 127., 108., 182., 182., 143., 133.]);
703 assert_eq!(xty2, vec![139., 104., 198., 142., 116., 97., 166., 122.]);
704 assert_eq!(xty1, xty1b);
705 assert_eq!(xty2, xty2b);
706 }
707
708 #[test]
709 fn test_solve() {
710 let A = vec![
711 -0.46519316,
712 -3.1042875,
713 -5.01766541,
714 -1.86300107,
715 2.7692825,
716 2.3097699,
717 -12.3854289,
718 -8.70520295,
719 6.02201052,
720 -6.71212792,
721 -1.74683781,
722 -6.08893455,
723 -2.53731118,
724 2.72112893,
725 4.70204472,
726 -1.03387848,
727 ];
728 let b = vec![-4.13075599, -1.28124453, 4.65406058, 3.69106842];
729
730 let x = solve(&A, &b);
731 let x_ref = vec![0.68581948, 0.33965616, 0.8063919, -0.69182874];
732
733 for i in 0..4 {
734 assert_approx_eq!(x[i], x_ref[i]);
735 }
736 }
737
738 #[test]
739 fn test_linspace() {
740 let r1 = vec![2., 2.4, 2.8, 3.2, 3.6, 4.];
741 let e1 = linspace(2., 4., 6);
742 for i in 0..6 {
743 assert_approx_eq!(r1[i], e1[i]);
744 }
745
746 let r2 = vec![
747 -50.,
748 -48.47457627,
749 -46.94915254,
750 -45.42372881,
751 -43.89830508,
752 -42.37288136,
753 -40.84745763,
754 -39.3220339,
755 -37.79661017,
756 -36.27118644,
757 -34.74576271,
758 -33.22033898,
759 -31.69491525,
760 -30.16949153,
761 -28.6440678,
762 -27.11864407,
763 -25.59322034,
764 -24.06779661,
765 -22.54237288,
766 -21.01694915,
767 -19.49152542,
768 -17.96610169,
769 -16.44067797,
770 -14.91525424,
771 -13.38983051,
772 -11.86440678,
773 -10.33898305,
774 -8.81355932,
775 -7.28813559,
776 -5.76271186,
777 -4.23728814,
778 -2.71186441,
779 -1.18644068,
780 0.33898305,
781 1.86440678,
782 3.38983051,
783 4.91525424,
784 6.44067797,
785 7.96610169,
786 9.49152542,
787 11.01694915,
788 12.54237288,
789 14.06779661,
790 15.59322034,
791 17.11864407,
792 18.6440678,
793 20.16949153,
794 21.69491525,
795 23.22033898,
796 24.74576271,
797 26.27118644,
798 27.79661017,
799 29.3220339,
800 30.84745763,
801 32.37288136,
802 33.89830508,
803 35.42372881,
804 36.94915254,
805 38.47457627,
806 40.,
807 ];
808 let e2 = linspace(-50., 40., 60);
809 for i in 0..60 {
810 assert_approx_eq!(r2[i], e2[i]);
811 }
812 }
813
814 #[test]
815 fn test_symmetric() {
816 let x = vec![1., 2., 3., 2., 4., 5., 3., 5., 8.];
817 let y = vec![1., 2., 3., 1., 2., 3., 1., 2., 3.];
818 assert!(is_symmetric(&x));
819 assert_eq!(is_symmetric(&y), false);
820 }
821}