rrblup_rs/
mixed_solve.rs

1//! Rust implementation of R/rrBLUP::mixed.solve
2//!
3//! This module provides a direct translation of the mixed.solve function from
4//! the R/rrBLUP package for solving mixed linear models using spectral decomposition.
5//!
6//! # Example
7//!
8//! ```
9//! use rrblup_rs::mixed_solve::{mixed_solve, MixedSolveOptions, Method};
10//!
11//! // Simple mixed model with intercept only
12//! let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
13//! let result = mixed_solve(&y, None, None, None, None).unwrap();
14//! assert!(result.vu >= 0.0);
15//! assert!(result.ve >= 0.0);
16//! ```
17//!
18//! # Reference
19//!
20//! R/rrBLUP package by Jeffrey Endelman:
21//! <https://cran.r-project.org/package=rrBLUP>
22
23use anyhow::{anyhow, Result};
24use faer::Mat as FaerMat;
25use nalgebra::{DMatrix, DVector, SymmetricEigen};
26
27/// Method for variance component estimation
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum Method {
30    /// Maximum Likelihood
31    ML,
32    /// Restricted Maximum Likelihood (default)
33    #[default]
34    REML,
35}
36
37/// Options for mixed.solve function
38#[derive(Debug, Clone)]
39pub struct MixedSolveOptions {
40    /// Method for variance component estimation (default: REML)
41    pub method: Method,
42    /// Bounds for lambda optimization (default: [1e-9, 1e9])
43    pub bounds: (f64, f64),
44    /// Whether to compute standard errors (default: false)
45    pub se: bool,
46    /// Whether to return H inverse matrix (default: false)
47    pub return_hinv: bool,
48}
49
50impl Default for MixedSolveOptions {
51    fn default() -> Self {
52        Self {
53            method: Method::REML,
54            bounds: (1e-9, 1e9),
55            se: false,
56            return_hinv: false,
57        }
58    }
59}
60
61/// Result from mixed.solve function
62///
63/// Corresponds to the list returned by R's mixed.solve:
64/// list(Vu, Ve, beta, u, LL) or with SE: list(Vu, Ve, beta, beta.SE, u, u.SE, LL)
65/// and optionally Hinv if return.Hinv=TRUE
66#[derive(Debug, Clone)]
67pub struct MixedSolveResult {
68    /// Variance of random effects (Vu)
69    pub vu: f64,
70    /// Residual variance (Ve)
71    pub ve: f64,
72    /// Fixed effects coefficients (beta)
73    pub beta: DVector<f64>,
74    /// Standard errors of beta (only if SE=TRUE)
75    pub beta_se: Option<DVector<f64>>,
76    /// Random effects BLUPs (u)
77    pub u: DVector<f64>,
78    /// Standard errors of u (only if SE=TRUE)
79    pub u_se: Option<DVector<f64>>,
80    /// Log-likelihood (LL)
81    pub ll: f64,
82    /// Inverse of H matrix (only if return.Hinv=TRUE)
83    pub hinv: Option<DMatrix<f64>>,
84}
85
86/// Solve mixed model y = Xβ + Zu + e using spectral decomposition
87///
88/// This is a Rust implementation of R/rrBLUP::mixed.solve().
89///
90/// # Arguments
91///
92/// * `y` - Response vector (n x 1), may contain NaN for missing values
93/// * `z` - Random effects design matrix (n x m), defaults to identity if None
94/// * `k` - Covariance matrix for random effects (m x m), defaults to identity if None
95/// * `x` - Fixed effects design matrix (n x p), defaults to intercept if None
96/// * `options` - Optional settings (method, bounds, SE, return.Hinv)
97///
98/// # Returns
99///
100/// [`MixedSolveResult`] with variance components, BLUPs, and optionally SEs.
101///
102/// # Errors
103///
104/// Returns an error if:
105/// - All `y` values are `NaN`
106/// - Matrix dimensions are incompatible (`nrow(Z) != n` or `nrow(X) != n`)
107/// - `X` is not full rank
108/// - `K` is not positive semi-definite
109///
110/// # Example
111///
112/// ```
113/// use rrblup_rs::mixed_solve::{mixed_solve, MixedSolveOptions};
114///
115/// let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
116/// let result = mixed_solve(&y, None, None, None, None).unwrap();
117///
118/// println!("Genetic variance (Vu): {}", result.vu);
119/// println!("Residual variance (Ve): {}", result.ve);
120/// println!("Fixed effects (beta): {:?}", result.beta);
121/// ```
122///
123/// # Notes
124///
125/// The parameter order matches R: mixed.solve(y, Z=NULL, K=NULL, X=NULL, ...)
126pub fn mixed_solve(
127    y: &[f64],
128    z: Option<&DMatrix<f64>>,
129    k: Option<&DMatrix<f64>>,
130    x: Option<&DMatrix<f64>>,
131    options: Option<MixedSolveOptions>,
132) -> Result<MixedSolveResult> {
133    let opts = options.unwrap_or_default();
134    let pi = std::f64::consts::PI;
135
136    // R: n <- length(y)
137    let n_full = y.len();
138
139    // R: not.NA <- which(!is.na(y))
140    let not_na: Vec<usize> = (0..n_full).filter(|&i| y[i].is_finite()).collect();
141
142    if not_na.is_empty() {
143        return Err(anyhow!("All y values are NA"));
144    }
145
146    // Setup X matrix
147    // if (is.null(X)) { p <- 1; X <- matrix(rep(1,n),n,1) }
148    let x_full: DMatrix<f64> = match x {
149        Some(x_mat) => x_mat.clone(),
150        None => DMatrix::from_element(n_full, 1, 1.0),
151    };
152    let p = x_full.ncols();
153
154    // Setup Z matrix
155    // if (is.null(Z)) { Z <- diag(n) }
156    let z_full: DMatrix<f64> = match z {
157        Some(z_mat) => z_mat.clone(),
158        None => DMatrix::identity(n_full, n_full),
159    };
160    let m = z_full.ncols();
161
162    // Dimension checks
163    // stopifnot(nrow(Z) == n)
164    if z_full.nrows() != n_full {
165        return Err(anyhow!(
166            "nrow(Z) = {} != n = {}",
167            z_full.nrows(),
168            n_full
169        ));
170    }
171    // stopifnot(nrow(X) == n)
172    if x_full.nrows() != n_full {
173        return Err(anyhow!(
174            "nrow(X) = {} != n = {}",
175            x_full.nrows(),
176            n_full
177        ));
178    }
179
180    // Check K dimensions if provided
181    if let Some(k_mat) = k {
182        // stopifnot(nrow(K) == m)
183        // stopifnot(ncol(K) == m)
184        if k_mat.nrows() != m || k_mat.ncols() != m {
185            return Err(anyhow!(
186                "K must be {} x {}, got {} x {}",
187                m,
188                m,
189                k_mat.nrows(),
190                k_mat.ncols()
191            ));
192        }
193    }
194
195    // Subset to non-NA observations
196    // R: Z <- as.matrix(Z[not.NA,])
197    // R: X <- as.matrix(X[not.NA,])
198    // R: n <- length(not.NA)
199    // R: y <- matrix(y[not.NA],n,1)
200    let n = not_na.len();
201    let y_vec: DVector<f64> = DVector::from_iterator(n, not_na.iter().map(|&i| y[i]));
202    let z_mat = DMatrix::from_fn(n, m, |i, j| z_full[(not_na[i], j)]);
203    let x_mat = DMatrix::from_fn(n, p, |i, j| x_full[(not_na[i], j)]);
204
205    // R: XtX <- crossprod(X, X)
206    let xtx = x_mat.transpose() * &x_mat;
207
208    // R: rank.X <- qr(XtX)$rank
209    // if (rank.X < p) {stop("X not full rank")}
210    // R: XtXinv <- solve(XtX)
211    let xtx_inv = xtx
212        .clone()
213        .try_inverse()
214        .ok_or_else(|| anyhow!("X not full rank"))?;
215
216    // R: S <- diag(n) - tcrossprod(X%*%XtXinv,X)
217    // S = I - X @ XtXinv @ X'
218    let x_xtxinv = &x_mat * &xtx_inv;
219    let s_mat = DMatrix::identity(n, n) - &x_xtxinv * x_mat.transpose();
220
221    // Determine spectral method
222    // if (n <= m + p) { spectral.method <- "eigen" } else { spectral.method <- "cholesky" }
223    let use_eigen = n <= m + p;
224
225    // Variables that will be set by either branch
226    let phi: Vec<f64>;
227    let theta: Vec<f64>;
228    let u_mat: DMatrix<f64>;
229    let q_mat: DMatrix<f64>;
230
231    if use_eigen {
232        // ============================================================
233        // EIGEN BRANCH: spectral.method == "eigen"
234        // ============================================================
235
236        // R: offset <- sqrt(n)
237        let offset = (n as f64).sqrt();
238
239        // Hb computation
240        let hb: DMatrix<f64> = if let Some(k_mat) = k {
241            // R: Hb <- tcrossprod(Z%*%K,Z) + offset*diag(n)
242            let zk = &z_mat * k_mat;
243            let zkzt = &zk * z_mat.transpose();
244            let mut hb = zkzt;
245            for i in 0..n {
246                hb[(i, i)] += offset;
247            }
248            hb
249        } else {
250            // R: Hb <- tcrossprod(Z,Z) + offset*diag(n)
251            let zzt = &z_mat * z_mat.transpose();
252            let mut hb = zzt;
253            for i in 0..n {
254                hb[(i, i)] += offset;
255            }
256            hb
257        };
258
259        // R: Hb.system <- eigen(Hb, symmetric = TRUE)
260        let hb_eig = SymmetricEigen::new(hb.clone());
261
262        // NOTE: R's eigen() returns eigenvalues in DESCENDING order.
263        // nalgebra's SymmetricEigen does NOT guarantee sorted eigenvalues!
264        // We must explicitly sort them.
265
266        // Create sorted indices for Hb eigenvalues (descending order)
267        let mut hb_indices: Vec<usize> = (0..n).collect();
268        hb_indices.sort_by(|&a, &b| {
269            hb_eig.eigenvalues[b]
270                .partial_cmp(&hb_eig.eigenvalues[a])
271                .unwrap_or(std::cmp::Ordering::Equal)
272        });
273
274        // R: phi <- Hb.system$values - offset (in descending order)
275        phi = hb_indices
276            .iter()
277            .map(|&i| hb_eig.eigenvalues[i] - offset)
278            .collect();
279
280        // if (min(phi) < -1e-6) {stop("K not positive semi-definite.")}
281        let min_phi = phi.iter().cloned().fold(f64::INFINITY, f64::min);
282        if min_phi < -1e-6 {
283            return Err(anyhow!("K not positive semi-definite (min phi = {})", min_phi));
284        }
285
286        // R: U <- Hb.system$vectors (columns in descending eigenvalue order)
287        u_mat = DMatrix::from_fn(n, n, |i, j| hb_eig.eigenvectors[(i, hb_indices[j])]);
288
289        // R: SHbS <- S %*% Hb %*% S
290        let shbs = &s_mat * &hb * &s_mat;
291
292        // R: SHbS.system <- eigen(SHbS, symmetric = TRUE)
293        let shbs_eig = SymmetricEigen::new(shbs);
294
295        // Create sorted indices for SHbS eigenvalues (descending order)
296        let mut shbs_indices: Vec<usize> = (0..n).collect();
297        shbs_indices.sort_by(|&a, &b| {
298            shbs_eig.eigenvalues[b]
299                .partial_cmp(&shbs_eig.eigenvalues[a])
300                .unwrap_or(std::cmp::Ordering::Equal)
301        });
302
303        // R: theta <- SHbS.system$values[1:(n - p)] - offset
304        // R takes first n-p eigenvalues (largest, in descending order)
305        let n_theta = n - p;
306        theta = shbs_indices
307            .iter()
308            .take(n_theta) // Take the largest n_theta eigenvalues
309            .map(|&i| shbs_eig.eigenvalues[i] - offset)
310            .collect();
311
312        // R: Q <- SHbS.system$vectors[, 1:(n - p)]
313        // Take columns corresponding to the largest n_theta eigenvalues
314        q_mat = DMatrix::from_fn(n, n_theta, |i, j| {
315            shbs_eig.eigenvectors[(i, shbs_indices[j])]
316        });
317    } else {
318        // ============================================================
319        // CHOLESKY BRANCH: spectral.method == "cholesky"
320        // ============================================================
321
322        // B = chol(K) if K provided (with jitter on diagonal)
323        let zbt: DMatrix<f64> = if let Some(k_mat) = k {
324            // diag(K) <- diag(K) + 1e-6
325            // R: B <- try(chol(K),silent=TRUE)
326            let mut k_jittered = k_mat.clone();
327            for i in 0..m {
328                k_jittered[(i, i)] += 1e-6;
329            }
330            let chol = k_jittered
331                .cholesky()
332                .ok_or_else(|| anyhow!("K not positive semi-definite"))?;
333            // R: ZBt <- tcrossprod(Z,B)
334            // In R, chol() returns upper triangular, so B = chol(K)
335            // tcrossprod(Z, B) = Z @ B'
336            // nalgebra cholesky().l() returns lower triangular L where K = L @ L'
337            // So we need Z @ L' = Z @ chol(K)' in R terms
338            let b_t = chol.l().transpose(); // This is the upper triangular
339            &z_mat * &b_t.transpose() // tcrossprod(Z, B) = Z @ B' where B is upper triangular
340        } else {
341            // R: ZBt <- Z
342            z_mat.clone()
343        };
344
345        // R: svd.ZBt <- svd(ZBt,nu=n)
346        let zbt_faer = nalgebra_to_faer(&zbt);
347        let svd_zbt = zbt_faer.svd();
348        let u_full = faer_to_nalgebra(&svd_zbt.u());
349        let d_vals = svd_zbt.s_diagonal();
350
351        // R: U <- svd.ZBt$u
352        u_mat = u_full;
353
354        // R: phi <- c(svd.ZBt$d^2,rep(0,n-m))
355        phi = (0..n)
356            .map(|i| {
357                if i < d_vals.nrows() {
358                    let d = d_vals.read(i);
359                    d * d
360                } else {
361                    0.0
362                }
363            })
364            .collect();
365
366        // R: SZBt <- S %*% ZBt
367        let szbt = &s_mat * &zbt;
368
369        // R: svd.SZBt <- try(svd(SZBt),silent=TRUE)
370        // if (inherits(svd.SZBt,what="try-error")) {
371        //   svd.SZBt <- svd(SZBt+matrix(1e-10,nrow=nrow(SZBt),ncol=ncol(SZBt)))
372        // }
373        let szbt_faer = nalgebra_to_faer(&szbt);
374        let svd_szbt = szbt_faer.thin_svd();
375        let u_szbt = faer_to_nalgebra(&svd_szbt.u());
376        let d_szbt = svd_szbt.s_diagonal();
377
378        // R: QR <- qr(cbind(X,svd.SZBt$u))
379        let n_u_cols = u_szbt.ncols();
380        let mut combined = DMatrix::zeros(n, p + n_u_cols);
381        for i in 0..n {
382            for j in 0..p {
383                combined[(i, j)] = x_mat[(i, j)];
384            }
385            for j in 0..n_u_cols {
386                combined[(i, p + j)] = u_szbt[(i, j)];
387            }
388        }
389
390        let combined_faer = nalgebra_to_faer(&combined);
391        let qr = combined_faer.qr();
392
393        // R: Q <- qr.Q(QR,complete=TRUE)[,(p+1):n]
394        let q_full = faer_to_nalgebra(&qr.compute_q().as_ref());
395        let q_complement = q_full.columns(p, n - p).into_owned();
396        q_mat = q_complement;
397
398        // R: R <- qr.R(QR)[p+1:m,p+1:m]
399        let r_faer = qr.compute_r();
400        let r_full = faer_mat_to_nalgebra(&r_faer);
401
402        // R: ans <- try(solve(t(R^2), svd.SZBt$d^2),silent=TRUE)
403        // R: theta <- c(ans,rep(0, n - p - m))
404        let r22_size = m.min(r_full.nrows().saturating_sub(p)).min(r_full.ncols().saturating_sub(p));
405
406        let theta_result: Result<Vec<f64>, ()> = if r22_size > 0 && d_szbt.nrows() > 0 {
407            // Extract R22 = R[p+1:m, p+1:m] (1-indexed in R, 0-indexed here)
408            let mut r22_sq = DMatrix::zeros(r22_size, r22_size);
409            for i in 0..r22_size {
410                for j in 0..r22_size {
411                    let val = r_full[(p + i, p + j)];
412                    r22_sq[(i, j)] = val * val;
413                }
414            }
415
416            // t(R^2)
417            let t_r22_sq = r22_sq.transpose();
418
419            // svd.SZBt$d^2
420            let d_sq_len = r22_size.min(d_szbt.nrows());
421            let d_sq: Vec<f64> = (0..d_sq_len)
422                .map(|i| {
423                    let d = d_szbt.read(i);
424                    d * d
425                })
426                .collect();
427            let d_sq_vec = DVector::from_row_slice(&d_sq);
428
429            // solve(t(R^2), d^2)
430            match t_r22_sq.clone().try_inverse() {
431                Some(inv) => {
432                    let ans = inv * &d_sq_vec;
433                    let n_theta = n - p;
434                    Ok((0..n_theta)
435                        .map(|i| {
436                            if i < ans.len() {
437                                ans[i]
438                            } else {
439                                0.0
440                            }
441                        })
442                        .collect())
443                }
444                None => Err(()),
445            }
446        } else {
447            Err(())
448        };
449
450        theta = match theta_result {
451            Ok(t) => t,
452            Err(_) => {
453                // Fallback: this would trigger spectral.method <- "eigen" in R
454                // For simplicity, we use zeros which may not be fully correct
455                // but follows the structure
456                vec![0.0; n - p]
457            }
458        };
459    }
460
461    // R: omega <- crossprod(Q, y)
462    let omega = q_mat.transpose() * &y_vec;
463
464    // R: omega.sq <- omega^2
465    let omega_sq: Vec<f64> = omega.iter().map(|v| v * v).collect();
466
467    // Optimization
468    let (lambda_opt, obj_val, df): (f64, f64, usize);
469
470    if opts.method == Method::ML {
471        // R: f.ML <- function(lambda, n, theta, omega.sq, phi) {
472        //   n * log(sum(omega.sq/(theta + lambda))) + sum(log(phi + lambda))
473        // }
474        let f_ml = |lambda: f64| -> f64 {
475            if lambda <= 0.0 {
476                return f64::INFINITY;
477            }
478            let sum_ratio: f64 = omega_sq
479                .iter()
480                .zip(theta.iter())
481                .map(|(o, t)| o / (t + lambda))
482                .sum();
483            if sum_ratio <= 0.0 {
484                return f64::INFINITY;
485            }
486            let sum_log_phi: f64 = phi.iter().map(|p| (p + lambda).ln()).sum();
487            (n as f64) * sum_ratio.ln() + sum_log_phi
488        };
489
490        let (opt_lambda, opt_obj) = golden_section_minimize(f_ml, opts.bounds.0, opts.bounds.1);
491        lambda_opt = opt_lambda;
492        obj_val = opt_obj;
493        df = n;
494    } else {
495        // R: f.REML <- function(lambda, n.p, theta, omega.sq) {
496        //   n.p * log(sum(omega.sq/(theta + lambda))) + sum(log(theta + lambda))
497        // }
498        let n_p = n - p;
499        let f_reml = |lambda: f64| -> f64 {
500            if lambda <= 0.0 {
501                return f64::INFINITY;
502            }
503            let sum_ratio: f64 = omega_sq
504                .iter()
505                .zip(theta.iter())
506                .map(|(o, t)| o / (t + lambda))
507                .sum();
508            if sum_ratio <= 0.0 {
509                return f64::INFINITY;
510            }
511            let sum_log_theta: f64 = theta.iter().map(|t| (t + lambda).ln()).sum();
512            (n_p as f64) * sum_ratio.ln() + sum_log_theta
513        };
514
515        let (opt_lambda, opt_obj) = golden_section_minimize(f_reml, opts.bounds.0, opts.bounds.1);
516        lambda_opt = opt_lambda;
517        obj_val = opt_obj;
518        df = n - p;
519    }
520
521    // R: Vu.opt <- sum(omega.sq/(theta + lambda.opt))/df
522    let vu_opt: f64 = omega_sq
523        .iter()
524        .zip(theta.iter())
525        .map(|(o, t)| o / (t + lambda_opt))
526        .sum::<f64>()
527        / (df as f64);
528
529    // R: Ve.opt <- lambda.opt * Vu.opt
530    let ve_opt = lambda_opt * vu_opt;
531
532    // R: Hinv <- U %*% (t(U)/(phi+lambda.opt))
533    // Hinv[i,j] = sum_k U[i,k] * U[j,k] / (phi[k] + lambda)
534    let mut hinv = DMatrix::zeros(n, n);
535    for i in 0..n {
536        for j in 0..n {
537            let mut sum = 0.0;
538            for kk in 0..n {
539                sum += u_mat[(i, kk)] * u_mat[(j, kk)] / (phi[kk] + lambda_opt);
540            }
541            hinv[(i, j)] = sum;
542        }
543    }
544
545    // R: W <- crossprod(X,Hinv%*%X)
546    let hinv_x = &hinv * &x_mat;
547    let w = x_mat.transpose() * &hinv_x;
548
549    // R: beta <- array(solve(W,crossprod(X,Hinv%*%y)))
550    let w_inv = w
551        .clone()
552        .try_inverse()
553        .ok_or_else(|| anyhow!("W not invertible"))?;
554    let hinv_y = &hinv * &y_vec;
555    let beta = &w_inv * (x_mat.transpose() * &hinv_y);
556
557    // KZt computation
558    // if (is.null(K)) { KZt <- t(Z) } else { KZt <- tcrossprod(K,Z) }
559    let kzt: DMatrix<f64> = if let Some(k_mat) = k {
560        // R: KZt <- tcrossprod(K,Z) = K @ Z'
561        k_mat * z_mat.transpose()
562    } else {
563        // R: KZt <- t(Z)
564        z_mat.transpose()
565    };
566
567    // R: KZt.Hinv <- KZt %*% Hinv
568    let kzt_hinv = &kzt * &hinv;
569
570    // R: u <- array(KZt.Hinv %*% (y - X%*%beta))
571    let resid = &y_vec - &x_mat * &beta;
572    let u_blup = &kzt_hinv * &resid;
573
574    // LL = -0.5 * (soln$objective + df + df * log(2 * pi/df))
575    let ll = -0.5 * (obj_val + (df as f64) + (df as f64) * (2.0 * pi / (df as f64)).ln());
576
577    // Standard errors (if SE=TRUE)
578    let (beta_se, u_se) = if opts.se {
579        // R: Winv <- solve(W)
580        let winv = w_inv.clone();
581
582        // R: beta.SE <- array(sqrt(Vu.opt*diag(Winv)))
583        let beta_se_vec: DVector<f64> =
584            DVector::from_fn(p, |i, _| (vu_opt * winv[(i, i)]).sqrt());
585
586        // R: WW <- tcrossprod(KZt.Hinv,KZt)
587        let ww = &kzt_hinv * kzt.transpose();
588
589        // R: WWW <- KZt.Hinv%*%X
590        let www = &kzt_hinv * &x_mat;
591
592        // u.SE computation
593        let u_se_vec: DVector<f64> = if k.is_none() {
594            // R: u.SE <- array(sqrt(Vu.opt * (rep(1,m) - diag(WW) + diag(tcrossprod(WWW%*%Winv,WWW)))))
595            let www_winv = &www * &winv;
596            let www_term = &www_winv * www.transpose();
597            DVector::from_fn(m, |i, _| {
598                let val = vu_opt * (1.0 - ww[(i, i)] + www_term[(i, i)]);
599                if val > 0.0 {
600                    val.sqrt()
601                } else {
602                    0.0
603                }
604            })
605        } else {
606            // R: u.SE <- array(sqrt(Vu.opt * (diag(K) - diag(WW) + diag(tcrossprod(WWW%*%Winv,WWW)))))
607            let k_mat = k.unwrap();
608            let www_winv = &www * &winv;
609            let www_term = &www_winv * www.transpose();
610            DVector::from_fn(m, |i, _| {
611                let val = vu_opt * (k_mat[(i, i)] - ww[(i, i)] + www_term[(i, i)]);
612                if val > 0.0 {
613                    val.sqrt()
614                } else {
615                    0.0
616                }
617            })
618        };
619
620        (Some(beta_se_vec), Some(u_se_vec))
621    } else {
622        (None, None)
623    };
624
625    // Return Hinv only if requested
626    let hinv_return = if opts.return_hinv { Some(hinv) } else { None };
627
628    Ok(MixedSolveResult {
629        vu: vu_opt,
630        ve: ve_opt,
631        beta,
632        beta_se,
633        u: u_blup,
634        u_se,
635        ll,
636        hinv: hinv_return,
637    })
638}
639
640/// Golden section search for minimization (equivalent to R's optimize)
641fn golden_section_minimize<F>(f: F, mut a: f64, mut b: f64) -> (f64, f64)
642where
643    F: Fn(f64) -> f64,
644{
645    let gr = 0.5 * (1.0 + 5f64.sqrt()); // Golden ratio
646    let tol = 1e-8;
647    let max_iter = 100;
648
649    let mut c = b - (b - a) / gr;
650    let mut d = a + (b - a) / gr;
651    let mut fc = f(c);
652    let mut fd = f(d);
653
654    for _ in 0..max_iter {
655        if (b - a).abs() < tol {
656            break;
657        }
658        if fc < fd {
659            b = d;
660            d = c;
661            fd = fc;
662            c = b - (b - a) / gr;
663            fc = f(c);
664        } else {
665            a = c;
666            c = d;
667            fc = fd;
668            d = a + (b - a) / gr;
669            fd = f(d);
670        }
671    }
672
673    let x_min = if fc < fd { c } else { d };
674    let f_min = if fc < fd { fc } else { fd };
675    (x_min, f_min)
676}
677
678// ============================================================
679// Helper functions for matrix conversions between nalgebra and faer
680// ============================================================
681
682fn nalgebra_to_faer(m: &DMatrix<f64>) -> FaerMat<f64> {
683    let nrows = m.nrows();
684    let ncols = m.ncols();
685    FaerMat::from_fn(nrows, ncols, |i, j| m[(i, j)])
686}
687
688fn faer_to_nalgebra(m: &faer::MatRef<f64>) -> DMatrix<f64> {
689    let nrows = m.nrows();
690    let ncols = m.ncols();
691    DMatrix::from_fn(nrows, ncols, |i, j| m.read(i, j))
692}
693
694fn faer_mat_to_nalgebra(m: &FaerMat<f64>) -> DMatrix<f64> {
695    let nrows = m.nrows();
696    let ncols = m.ncols();
697    DMatrix::from_fn(nrows, ncols, |i, j| m.read(i, j))
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use approx::assert_relative_eq;
704
705    #[test]
706    fn test_mixed_solve_simple_intercept() {
707        // Simple test: y = 1, 2, 3, 4, 5 with intercept only
708        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
709        let result = mixed_solve(&y, None, None, None, None).unwrap();
710
711        // Mean should be approximately 3.0
712        assert_relative_eq!(result.beta[0], 3.0, epsilon = 0.5);
713        assert!(result.vu >= 0.0);
714        assert!(result.ve >= 0.0);
715    }
716
717    #[test]
718    fn test_mixed_solve_with_na() {
719        // Test with NA (NaN) values
720        let y = vec![1.0, f64::NAN, 3.0, f64::NAN, 5.0];
721        let result = mixed_solve(&y, None, None, None, None).unwrap();
722
723        // Mean of 1, 3, 5 is 3.0
724        assert_relative_eq!(result.beta[0], 3.0, epsilon = 0.5);
725    }
726
727    #[test]
728    fn test_mixed_solve_with_se() {
729        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
730        let opts = MixedSolveOptions {
731            se: true,
732            ..Default::default()
733        };
734        let result = mixed_solve(&y, None, None, None, Some(opts)).unwrap();
735
736        assert!(result.beta_se.is_some());
737        assert!(result.u_se.is_some());
738        assert!(result.beta_se.unwrap()[0] > 0.0);
739    }
740
741    #[test]
742    fn test_mixed_solve_ml_vs_reml() {
743        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
744
745        let opts_reml = MixedSolveOptions {
746            method: Method::REML,
747            ..Default::default()
748        };
749        let result_reml = mixed_solve(&y, None, None, None, Some(opts_reml)).unwrap();
750
751        let opts_ml = MixedSolveOptions {
752            method: Method::ML,
753            ..Default::default()
754        };
755        let result_ml = mixed_solve(&y, None, None, None, Some(opts_ml)).unwrap();
756
757        // Both should give similar beta estimates
758        assert_relative_eq!(result_reml.beta[0], result_ml.beta[0], epsilon = 0.5);
759        // REML typically gives larger variance estimates than ML
760        // (not always true for very small samples, so we just check they're both valid)
761        assert!(result_reml.vu >= 0.0);
762        assert!(result_ml.vu >= 0.0);
763    }
764
765    #[test]
766    fn test_mixed_solve_with_fixed_effects() {
767        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
768        // X with intercept and one covariate
769        let x = DMatrix::from_row_slice(6, 2, &[
770            1.0, 0.0,
771            1.0, 1.0,
772            1.0, 2.0,
773            1.0, 3.0,
774            1.0, 4.0,
775            1.0, 5.0,
776        ]);
777
778        let result = mixed_solve(&y, None, None, Some(&x), None).unwrap();
779
780        assert_eq!(result.beta.len(), 2);
781        // With this design, beta should capture the linear trend
782    }
783
784    #[test]
785    fn test_mixed_solve_return_hinv() {
786        let y = vec![1.0, 2.0, 3.0, 4.0, 5.0];
787        let opts = MixedSolveOptions {
788            return_hinv: true,
789            ..Default::default()
790        };
791        let result = mixed_solve(&y, None, None, None, Some(opts)).unwrap();
792
793        assert!(result.hinv.is_some());
794        let hinv = result.hinv.unwrap();
795        assert_eq!(hinv.nrows(), 5);
796        assert_eq!(hinv.ncols(), 5);
797    }
798}