1use anyhow::{anyhow, Result};
24use faer::Mat as FaerMat;
25use nalgebra::{DMatrix, DVector, SymmetricEigen};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum Method {
30 ML,
32 #[default]
34 REML,
35}
36
37#[derive(Debug, Clone)]
39pub struct MixedSolveOptions {
40 pub method: Method,
42 pub bounds: (f64, f64),
44 pub se: bool,
46 pub return_hinv: bool,
48}
49
50impl Default for MixedSolveOptions {
51 fn default() -> Self {
52 Self {
53 method: Method::REML,
54 bounds: (1e-9, 1e9),
55 se: false,
56 return_hinv: false,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
67pub struct MixedSolveResult {
68 pub vu: f64,
70 pub ve: f64,
72 pub beta: DVector<f64>,
74 pub beta_se: Option<DVector<f64>>,
76 pub u: DVector<f64>,
78 pub u_se: Option<DVector<f64>>,
80 pub ll: f64,
82 pub hinv: Option<DMatrix<f64>>,
84}
85
86pub fn mixed_solve(
127 y: &[f64],
128 z: Option<&DMatrix<f64>>,
129 k: Option<&DMatrix<f64>>,
130 x: Option<&DMatrix<f64>>,
131 options: Option<MixedSolveOptions>,
132) -> Result<MixedSolveResult> {
133 let opts = options.unwrap_or_default();
134 let pi = std::f64::consts::PI;
135
136 let n_full = y.len();
138
139 let not_na: Vec<usize> = (0..n_full).filter(|&i| y[i].is_finite()).collect();
141
142 if not_na.is_empty() {
143 return Err(anyhow!("All y values are NA"));
144 }
145
146 let x_full: DMatrix<f64> = match x {
149 Some(x_mat) => x_mat.clone(),
150 None => DMatrix::from_element(n_full, 1, 1.0),
151 };
152 let p = x_full.ncols();
153
154 let z_full: DMatrix<f64> = match z {
157 Some(z_mat) => z_mat.clone(),
158 None => DMatrix::identity(n_full, n_full),
159 };
160 let m = z_full.ncols();
161
162 if z_full.nrows() != n_full {
165 return Err(anyhow!(
166 "nrow(Z) = {} != n = {}",
167 z_full.nrows(),
168 n_full
169 ));
170 }
171 if x_full.nrows() != n_full {
173 return Err(anyhow!(
174 "nrow(X) = {} != n = {}",
175 x_full.nrows(),
176 n_full
177 ));
178 }
179
180 if let Some(k_mat) = k {
182 if k_mat.nrows() != m || k_mat.ncols() != m {
185 return Err(anyhow!(
186 "K must be {} x {}, got {} x {}",
187 m,
188 m,
189 k_mat.nrows(),
190 k_mat.ncols()
191 ));
192 }
193 }
194
195 let n = not_na.len();
201 let y_vec: DVector<f64> = DVector::from_iterator(n, not_na.iter().map(|&i| y[i]));
202 let z_mat = DMatrix::from_fn(n, m, |i, j| z_full[(not_na[i], j)]);
203 let x_mat = DMatrix::from_fn(n, p, |i, j| x_full[(not_na[i], j)]);
204
205 let xtx = x_mat.transpose() * &x_mat;
207
208 let xtx_inv = xtx
212 .clone()
213 .try_inverse()
214 .ok_or_else(|| anyhow!("X not full rank"))?;
215
216 let x_xtxinv = &x_mat * &xtx_inv;
219 let s_mat = DMatrix::identity(n, n) - &x_xtxinv * x_mat.transpose();
220
221 let use_eigen = n <= m + p;
224
225 let phi: Vec<f64>;
227 let theta: Vec<f64>;
228 let u_mat: DMatrix<f64>;
229 let q_mat: DMatrix<f64>;
230
231 if use_eigen {
232 let offset = (n as f64).sqrt();
238
239 let hb: DMatrix<f64> = if let Some(k_mat) = k {
241 let zk = &z_mat * k_mat;
243 let zkzt = &zk * z_mat.transpose();
244 let mut hb = zkzt;
245 for i in 0..n {
246 hb[(i, i)] += offset;
247 }
248 hb
249 } else {
250 let zzt = &z_mat * z_mat.transpose();
252 let mut hb = zzt;
253 for i in 0..n {
254 hb[(i, i)] += offset;
255 }
256 hb
257 };
258
259 let hb_eig = SymmetricEigen::new(hb.clone());
261
262 let mut hb_indices: Vec<usize> = (0..n).collect();
268 hb_indices.sort_by(|&a, &b| {
269 hb_eig.eigenvalues[b]
270 .partial_cmp(&hb_eig.eigenvalues[a])
271 .unwrap_or(std::cmp::Ordering::Equal)
272 });
273
274 phi = hb_indices
276 .iter()
277 .map(|&i| hb_eig.eigenvalues[i] - offset)
278 .collect();
279
280 let min_phi = phi.iter().cloned().fold(f64::INFINITY, f64::min);
282 if min_phi < -1e-6 {
283 return Err(anyhow!("K not positive semi-definite (min phi = {})", min_phi));
284 }
285
286 u_mat = DMatrix::from_fn(n, n, |i, j| hb_eig.eigenvectors[(i, hb_indices[j])]);
288
289 let shbs = &s_mat * &hb * &s_mat;
291
292 let shbs_eig = SymmetricEigen::new(shbs);
294
295 let mut shbs_indices: Vec<usize> = (0..n).collect();
297 shbs_indices.sort_by(|&a, &b| {
298 shbs_eig.eigenvalues[b]
299 .partial_cmp(&shbs_eig.eigenvalues[a])
300 .unwrap_or(std::cmp::Ordering::Equal)
301 });
302
303 let n_theta = n - p;
306 theta = shbs_indices
307 .iter()
308 .take(n_theta) .map(|&i| shbs_eig.eigenvalues[i] - offset)
310 .collect();
311
312 q_mat = DMatrix::from_fn(n, n_theta, |i, j| {
315 shbs_eig.eigenvectors[(i, shbs_indices[j])]
316 });
317 } else {
318 let zbt: DMatrix<f64> = if let Some(k_mat) = k {
324 let mut k_jittered = k_mat.clone();
327 for i in 0..m {
328 k_jittered[(i, i)] += 1e-6;
329 }
330 let chol = k_jittered
331 .cholesky()
332 .ok_or_else(|| anyhow!("K not positive semi-definite"))?;
333 let b_t = chol.l().transpose(); &z_mat * &b_t.transpose() } else {
341 z_mat.clone()
343 };
344
345 let zbt_faer = nalgebra_to_faer(&zbt);
347 let svd_zbt = zbt_faer.svd();
348 let u_full = faer_to_nalgebra(&svd_zbt.u());
349 let d_vals = svd_zbt.s_diagonal();
350
351 u_mat = u_full;
353
354 phi = (0..n)
356 .map(|i| {
357 if i < d_vals.nrows() {
358 let d = d_vals.read(i);
359 d * d
360 } else {
361 0.0
362 }
363 })
364 .collect();
365
366 let szbt = &s_mat * &zbt;
368
369 let szbt_faer = nalgebra_to_faer(&szbt);
374 let svd_szbt = szbt_faer.thin_svd();
375 let u_szbt = faer_to_nalgebra(&svd_szbt.u());
376 let d_szbt = svd_szbt.s_diagonal();
377
378 let n_u_cols = u_szbt.ncols();
380 let mut combined = DMatrix::zeros(n, p + n_u_cols);
381 for i in 0..n {
382 for j in 0..p {
383 combined[(i, j)] = x_mat[(i, j)];
384 }
385 for j in 0..n_u_cols {
386 combined[(i, p + j)] = u_szbt[(i, j)];
387 }
388 }
389
390 let combined_faer = nalgebra_to_faer(&combined);
391 let qr = combined_faer.qr();
392
393 let q_full = faer_to_nalgebra(&qr.compute_q().as_ref());
395 let q_complement = q_full.columns(p, n - p).into_owned();
396 q_mat = q_complement;
397
398 let r_faer = qr.compute_r();
400 let r_full = faer_mat_to_nalgebra(&r_faer);
401
402 let r22_size = m.min(r_full.nrows().saturating_sub(p)).min(r_full.ncols().saturating_sub(p));
405
406 let theta_result: Result<Vec<f64>, ()> = if r22_size > 0 && d_szbt.nrows() > 0 {
407 let mut r22_sq = DMatrix::zeros(r22_size, r22_size);
409 for i in 0..r22_size {
410 for j in 0..r22_size {
411 let val = r_full[(p + i, p + j)];
412 r22_sq[(i, j)] = val * val;
413 }
414 }
415
416 let t_r22_sq = r22_sq.transpose();
418
419 let d_sq_len = r22_size.min(d_szbt.nrows());
421 let d_sq: Vec<f64> = (0..d_sq_len)
422 .map(|i| {
423 let d = d_szbt.read(i);
424 d * d
425 })
426 .collect();
427 let d_sq_vec = DVector::from_row_slice(&d_sq);
428
429 match t_r22_sq.clone().try_inverse() {
431 Some(inv) => {
432 let ans = inv * &d_sq_vec;
433 let n_theta = n - p;
434 Ok((0..n_theta)
435 .map(|i| {
436 if i < ans.len() {
437 ans[i]
438 } else {
439 0.0
440 }
441 })
442 .collect())
443 }
444 None => Err(()),
445 }
446 } else {
447 Err(())
448 };
449
450 theta = match theta_result {
451 Ok(t) => t,
452 Err(_) => {
453 vec![0.0; n - p]
457 }
458 };
459 }
460
461 let omega = q_mat.transpose() * &y_vec;
463
464 let omega_sq: Vec<f64> = omega.iter().map(|v| v * v).collect();
466
467 let (lambda_opt, obj_val, df): (f64, f64, usize);
469
470 if opts.method == Method::ML {
471 let f_ml = |lambda: f64| -> f64 {
475 if lambda <= 0.0 {
476 return f64::INFINITY;
477 }
478 let sum_ratio: f64 = omega_sq
479 .iter()
480 .zip(theta.iter())
481 .map(|(o, t)| o / (t + lambda))
482 .sum();
483 if sum_ratio <= 0.0 {
484 return f64::INFINITY;
485 }
486 let sum_log_phi: f64 = phi.iter().map(|p| (p + lambda).ln()).sum();
487 (n as f64) * sum_ratio.ln() + sum_log_phi
488 };
489
490 let (opt_lambda, opt_obj) = golden_section_minimize(f_ml, opts.bounds.0, opts.bounds.1);
491 lambda_opt = opt_lambda;
492 obj_val = opt_obj;
493 df = n;
494 } else {
495 let n_p = n - p;
499 let f_reml = |lambda: f64| -> f64 {
500 if lambda <= 0.0 {
501 return f64::INFINITY;
502 }
503 let sum_ratio: f64 = omega_sq
504 .iter()
505 .zip(theta.iter())
506 .map(|(o, t)| o / (t + lambda))
507 .sum();
508 if sum_ratio <= 0.0 {
509 return f64::INFINITY;
510 }
511 let sum_log_theta: f64 = theta.iter().map(|t| (t + lambda).ln()).sum();
512 (n_p as f64) * sum_ratio.ln() + sum_log_theta
513 };
514
515 let (opt_lambda, opt_obj) = golden_section_minimize(f_reml, opts.bounds.0, opts.bounds.1);
516 lambda_opt = opt_lambda;
517 obj_val = opt_obj;
518 df = n - p;
519 }
520
521 let vu_opt: f64 = omega_sq
523 .iter()
524 .zip(theta.iter())
525 .map(|(o, t)| o / (t + lambda_opt))
526 .sum::<f64>()
527 / (df as f64);
528
529 let ve_opt = lambda_opt * vu_opt;
531
532 let mut hinv = DMatrix::zeros(n, n);
535 for i in 0..n {
536 for j in 0..n {
537 let mut sum = 0.0;
538 for kk in 0..n {
539 sum += u_mat[(i, kk)] * u_mat[(j, kk)] / (phi[kk] + lambda_opt);
540 }
541 hinv[(i, j)] = sum;
542 }
543 }
544
545 let hinv_x = &hinv * &x_mat;
547 let w = x_mat.transpose() * &hinv_x;
548
549 let w_inv = w
551 .clone()
552 .try_inverse()
553 .ok_or_else(|| anyhow!("W not invertible"))?;
554 let hinv_y = &hinv * &y_vec;
555 let beta = &w_inv * (x_mat.transpose() * &hinv_y);
556
557 let kzt: DMatrix<f64> = if let Some(k_mat) = k {
560 k_mat * z_mat.transpose()
562 } else {
563 z_mat.transpose()
565 };
566
567 let kzt_hinv = &kzt * &hinv;
569
570 let resid = &y_vec - &x_mat * β
572 let u_blup = &kzt_hinv * &resid;
573
574 let ll = -0.5 * (obj_val + (df as f64) + (df as f64) * (2.0 * pi / (df as f64)).ln());
576
577 let (beta_se, u_se) = if opts.se {
579 let winv = w_inv.clone();
581
582 let beta_se_vec: DVector<f64> =
584 DVector::from_fn(p, |i, _| (vu_opt * winv[(i, i)]).sqrt());
585
586 let ww = &kzt_hinv * kzt.transpose();
588
589 let www = &kzt_hinv * &x_mat;
591
592 let u_se_vec: DVector<f64> = if k.is_none() {
594 let www_winv = &www * &winv;
596 let www_term = &www_winv * www.transpose();
597 DVector::from_fn(m, |i, _| {
598 let val = vu_opt * (1.0 - ww[(i, i)] + www_term[(i, i)]);
599 if val > 0.0 {
600 val.sqrt()
601 } else {
602 0.0
603 }
604 })
605 } else {
606 let k_mat = k.unwrap();
608 let www_winv = &www * &winv;
609 let www_term = &www_winv * www.transpose();
610 DVector::from_fn(m, |i, _| {
611 let val = vu_opt * (k_mat[(i, i)] - ww[(i, i)] + www_term[(i, i)]);
612 if val > 0.0 {
613 val.sqrt()
614 } else {
615 0.0
616 }
617 })
618 };
619
620 (Some(beta_se_vec), Some(u_se_vec))
621 } else {
622 (None, None)
623 };
624
625 let hinv_return = if opts.return_hinv { Some(hinv) } else { None };
627
628 Ok(MixedSolveResult {
629 vu: vu_opt,
630 ve: ve_opt,
631 beta,
632 beta_se,
633 u: u_blup,
634 u_se,
635 ll,
636 hinv: hinv_return,
637 })
638}
639
640fn golden_section_minimize<F>(f: F, mut a: f64, mut b: f64) -> (f64, f64)
642where
643 F: Fn(f64) -> f64,
644{
645 let gr = 0.5 * (1.0 + 5f64.sqrt()); let tol = 1e-8;
647 let max_iter = 100;
648
649 let mut c = b - (b - a) / gr;
650 let mut d = a + (b - a) / gr;
651 let mut fc = f(c);
652 let mut fd = f(d);
653
654 for _ in 0..max_iter {
655 if (b - a).abs() < tol {
656 break;
657 }
658 if fc < fd {
659 b = d;
660 d = c;
661 fd = fc;
662 c = b - (b - a) / gr;
663 fc = f(c);
664 } else {
665 a = c;
666 c = d;
667 fc = fd;
668 d = a + (b - a) / gr;
669 fd = f(d);
670 }
671 }
672
673 let x_min = if fc < fd { c } else { d };
674 let f_min = if fc < fd { fc } else { fd };
675 (x_min, f_min)
676}
677
678fn nalgebra_to_faer(m: &DMatrix<f64>) -> FaerMat<f64> {
683 let nrows = m.nrows();
684 let ncols = m.ncols();
685 FaerMat::from_fn(nrows, ncols, |i, j| m[(i, j)])
686}
687
688fn faer_to_nalgebra(m: &faer::MatRef<f64>) -> DMatrix<f64> {
689 let nrows = m.nrows();
690 let ncols = m.ncols();
691 DMatrix::from_fn(nrows, ncols, |i, j| m.read(i, j))
692}
693
694fn faer_mat_to_nalgebra(m: &FaerMat<f64>) -> DMatrix<f64> {
695 let nrows = m.nrows();
696 let ncols = m.ncols();
697 DMatrix::from_fn(nrows, ncols, |i, j| m.read(i, j))
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use approx::assert_relative_eq;
704
705 #[test]
706 fn test_mixed_solve_simple_intercept() {
707 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
709 let result = mixed_solve(&y, None, None, None, None).unwrap();
710
711 assert_relative_eq!(result.beta[0], 3.0, epsilon = 0.5);
713 assert!(result.vu >= 0.0);
714 assert!(result.ve >= 0.0);
715 }
716
717 #[test]
718 fn test_mixed_solve_with_na() {
719 let y = vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0];
721 let result = mixed_solve(&y, None, None, None, None).unwrap();
722
723 assert_relative_eq!(result.beta[0], 3.0, epsilon = 0.5);
725 }
726
727 #[test]
728 fn test_mixed_solve_with_se() {
729 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
730 let opts = MixedSolveOptions {
731 se: true,
732 ..Default::default()
733 };
734 let result = mixed_solve(&y, None, None, None, Some(opts)).unwrap();
735
736 assert!(result.beta_se.is_some());
737 assert!(result.u_se.is_some());
738 assert!(result.beta_se.unwrap()[0] > 0.0);
739 }
740
741 #[test]
742 fn test_mixed_solve_ml_vs_reml() {
743 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
744
745 let opts_reml = MixedSolveOptions {
746 method: Method::REML,
747 ..Default::default()
748 };
749 let result_reml = mixed_solve(&y, None, None, None, Some(opts_reml)).unwrap();
750
751 let opts_ml = MixedSolveOptions {
752 method: Method::ML,
753 ..Default::default()
754 };
755 let result_ml = mixed_solve(&y, None, None, None, Some(opts_ml)).unwrap();
756
757 assert_relative_eq!(result_reml.beta[0], result_ml.beta[0], epsilon = 0.5);
759 assert!(result_reml.vu >= 0.0);
762 assert!(result_ml.vu >= 0.0);
763 }
764
765 #[test]
766 fn test_mixed_solve_with_fixed_effects() {
767 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
768 let x = DMatrix::from_row_slice(6, 2, &[
770 1.0, 0.0,
771 1.0, 1.0,
772 1.0, 2.0,
773 1.0, 3.0,
774 1.0, 4.0,
775 1.0, 5.0,
776 ]);
777
778 let result = mixed_solve(&y, None, None, Some(&x), None).unwrap();
779
780 assert_eq!(result.beta.len(), 2);
781 }
783
784 #[test]
785 fn test_mixed_solve_return_hinv() {
786 let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
787 let opts = MixedSolveOptions {
788 return_hinv: true,
789 ..Default::default()
790 };
791 let result = mixed_solve(&y, None, None, None, Some(opts)).unwrap();
792
793 assert!(result.hinv.is_some());
794 let hinv = result.hinv.unwrap();
795 assert_eq!(hinv.nrows(), 5);
796 assert_eq!(hinv.ncols(), 5);
797 }
798}