1#![allow(non_snake_case)] use ndarray::{Array1, Array2};
31use serde::{Deserialize, Serialize};
32
33use so_core::data::DataFrame;
34use so_core::error::{Error, Result};
35use so_core::formula::Formula;
36use so_linalg::{inv, solve};
37
38use crate::glm::{Family, GLM, Link};
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RandomEffect {
43 pub group_var: String,
45 pub formula: String,
47 pub covariance: RandomCovariance,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum RandomCovariance {
54 Independent,
56 CompoundSymmetry,
58 AR1,
60 Unstructured,
62 Custom(Array2<f64>),
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LMMResults {
69 pub fixed_effects: Array1<f64>,
71 pub fixed_se: Array1<f64>,
73 pub variance_components: Vec<(String, f64)>,
75 pub residual_variance: f64,
77 pub log_lik: f64,
79 pub aic: f64,
81 pub bic: f64,
83 pub df_fixed: usize,
85 pub df_resid: usize,
87 pub converged: bool,
89 pub iterations: usize,
91}
92
93pub struct LinearMixedModelBuilder {
95 data: DataFrame,
96 response: String,
97 fixed_formula: String,
98 random_effects: Vec<RandomEffect>,
99 method: EstimationMethod,
100 max_iter: usize,
101 tol: f64,
102}
103
104#[derive(Debug, Clone, Copy)]
106pub enum EstimationMethod {
107 REML,
109 ML,
111}
112
113#[derive(Debug, Clone, Copy)]
115pub enum GLMMEstimationMethod {
116 PQL,
118 Laplace,
120 AGHQ(usize), }
123
124impl LinearMixedModelBuilder {
125 pub fn new(data: DataFrame, response: &str, fixed_formula: &str) -> Self {
127 Self {
128 data,
129 response: response.to_string(),
130 fixed_formula: fixed_formula.to_string(),
131 random_effects: Vec::new(),
132 method: EstimationMethod::REML,
133 max_iter: 100,
134 tol: 1e-6,
135 }
136 }
137
138 pub fn random_effect(mut self, group_var: &str, formula: &str) -> Self {
140 self.random_effects.push(RandomEffect {
141 group_var: group_var.to_string(),
142 formula: formula.to_string(),
143 covariance: RandomCovariance::Independent,
144 });
145 self
146 }
147
148 pub fn method(mut self, method: EstimationMethod) -> Self {
150 self.method = method;
151 self
152 }
153
154 pub fn max_iterations(mut self, max_iter: usize) -> Self {
156 self.max_iter = max_iter;
157 self
158 }
159
160 pub fn tolerance(mut self, tol: f64) -> Self {
162 self.tol = tol;
163 self
164 }
165
166 pub fn fit(self) -> Result<LMMResults> {
168 let y = self.data.column(&self.response).ok_or_else(|| {
174 Error::DataError(format!("Response column '{}' not found", self.response))
175 })?;
176 let y_array = y.data().to_owned();
177
178 let X = self.build_fixed_design_matrix()?;
180
181 let (Z_matrices, group_sizes) = self.build_random_design_matrices()?;
183
184 self.fit_em(&y_array, &X, &Z_matrices, &group_sizes)
186 }
187
188 fn build_fixed_design_matrix(&self) -> Result<Array2<f64>> {
190 let formula_str = if self.fixed_formula.contains('~') {
192 self.fixed_formula.clone()
193 } else {
194 format!("__response__ ~ {}", self.fixed_formula)
196 };
197
198 let formula = Formula::parse(&formula_str)
199 .map_err(|e| Error::FormulaError(format!("Failed to parse fixed formula: {}", e)))?;
200
201 formula
203 .build_matrix(&self.data)
204 .map_err(|e| Error::DataError(format!("Failed to build design matrix: {}", e)))
205 }
206
207 fn build_random_design_matrices(&self) -> Result<(Vec<Array2<f64>>, Vec<usize>)> {
209 let mut Z_matrices = Vec::new();
210 let mut group_sizes = Vec::new();
211
212 for random_effect in &self.random_effects {
213 let group_col = self.data.column(&random_effect.group_var).ok_or_else(|| {
215 Error::DataError(format!(
216 "Group column '{}' not found",
217 random_effect.group_var
218 ))
219 })?;
220
221 let groups: Vec<String> = vec!["group1".to_string(), "group2".to_string()]; let n_groups = groups.len();
224 let n = self.data.n_rows();
225
226 let mut Z = Array2::zeros((n, n_groups));
227
228 let group_data = group_col.data();
230 for j in 0..n {
231 let group_idx = group_data[j] as usize % n_groups; Z[(j, group_idx)] = 1.0;
233 }
234
235 Z_matrices.push(Z);
236 group_sizes.push(n_groups);
237 }
238
239 Ok((Z_matrices, group_sizes))
240 }
241
242 #[allow(unused_assignments, unused_variables)]
244 fn fit_em(
245 &self,
246 y: &Array1<f64>,
247 X: &Array2<f64>,
248 Z_matrices: &[Array2<f64>],
249 group_sizes: &[usize],
250 ) -> Result<LMMResults> {
251 let n = y.len();
252 let p = X.ncols();
253
254 let mut sigma2_e = 1.0; let mut sigma2_u = vec![1.0; Z_matrices.len()]; let Z = self.combine_Z_matrices(Z_matrices, group_sizes);
260 let q = Z.ncols();
261
262 let mut beta = Array1::zeros(p);
263 let mut u = Array1::zeros(q);
264
265 let mut converged = false;
266 let mut iter = 0;
267
268 while !converged && iter < self.max_iter {
269 iter += 1;
270
271 let V_inv = self.compute_V_inv(&Z, sigma2_e, &sigma2_u, group_sizes)?;
273 let XtVX = X.t().dot(&V_inv.dot(X));
274 let XtVy = X.t().dot(&V_inv.dot(y));
275
276 beta = solve(&XtVX, &XtVy).map_err(|e| {
277 Error::LinearAlgebraError(format!("Failed to solve for beta: {}", e))
278 })?;
279
280 let residuals = y - X.dot(&beta);
281 u = Z.t().dot(&V_inv.dot(&residuals));
282
283 let old_sigma2_e = sigma2_e;
285 let old_sigma2_u = sigma2_u.clone();
286
287 let y_Xb = y - X.dot(&beta);
289 let y_Xb_Zu = &y_Xb - Z.dot(&u);
290 sigma2_e = y_Xb_Zu.dot(&y_Xb_Zu) / (n - p) as f64;
291
292 for i in 0..sigma2_u.len() {
294 let start_idx: usize = group_sizes[..i].iter().sum();
295 let end_idx = start_idx + group_sizes[i];
296 let u_i = u.slice(ndarray::s![start_idx..end_idx]);
297 sigma2_u[i] = u_i.dot(&u_i) / group_sizes[i] as f64;
298 }
299
300 let delta_e = (sigma2_e - old_sigma2_e).abs() / old_sigma2_e.max(1e-10);
302 let max_delta_u = sigma2_u
303 .iter()
304 .zip(&old_sigma2_u)
305 .map(|(new, old)| (new - old).abs() / old.max(1e-10))
306 .fold(0.0, f64::max);
307
308 converged = delta_e < self.tol && max_delta_u < self.tol;
309 }
310
311 let V_inv = self.compute_V_inv(&Z, sigma2_e, &sigma2_u, group_sizes)?;
313 let XtVX = X.t().dot(&V_inv.dot(X));
314 let cov_beta = inv(&XtVX).map_err(|e| {
315 Error::LinearAlgebraError(format!("Failed to invert X'V^{{-1}}X: {}", e))
316 })?;
317
318 let fixed_se = cov_beta.diag().mapv(|x| x.sqrt());
319
320 let V = self.compute_V(&Z, sigma2_e, &sigma2_u, group_sizes);
322 let log_lik = self.compute_log_lik(y, X, &V, beta.clone(), self.method);
323
324 let n_params = p + sigma2_u.len() + 1; let aic = -2.0 * log_lik + 2.0 * n_params as f64;
327 let bic = -2.0 * log_lik + (n_params as f64) * (n as f64).ln();
328
329 let mut var_comps = Vec::new();
331 for (i, random_effect) in self.random_effects.iter().enumerate() {
332 var_comps.push((random_effect.group_var.clone(), sigma2_u[i]));
333 }
334
335 Ok(LMMResults {
336 fixed_effects: beta,
337 fixed_se,
338 variance_components: var_comps,
339 residual_variance: sigma2_e,
340 log_lik,
341 aic,
342 bic,
343 df_fixed: p,
344 df_resid: n - p,
345 converged,
346 iterations: iter,
347 })
348 }
349
350 fn combine_Z_matrices(&self, Z_matrices: &[Array2<f64>], group_sizes: &[usize]) -> Array2<f64> {
352 let n = Z_matrices[0].nrows();
353 let total_cols: usize = group_sizes.iter().sum();
354
355 let mut Z = Array2::zeros((n, total_cols));
356 let mut col_offset = 0;
357
358 for (i, Z_i) in Z_matrices.iter().enumerate() {
359 let cols = group_sizes[i];
360 for row in 0..n {
361 for col in 0..cols {
362 Z[(row, col_offset + col)] = Z_i[(row, col)];
363 }
364 }
365 col_offset += cols;
366 }
367
368 Z
369 }
370
371 fn compute_V(
373 &self,
374 Z: &Array2<f64>,
375 sigma2_e: f64,
376 sigma2_u: &[f64],
377 group_sizes: &[usize],
378 ) -> Array2<f64> {
379 let n = Z.nrows();
380 let mut V = Array2::zeros((n, n));
381
382 for i in 0..n {
384 V[(i, i)] = sigma2_e;
385 }
386
387 let mut col_offset = 0;
389 for (k, &sigma2_u_k) in sigma2_u.iter().enumerate() {
390 let cols = group_sizes[k];
391 let Z_k = Z.slice(ndarray::s![.., col_offset..col_offset + cols]);
392
393 let ZkZkt = Z_k.dot(&Z_k.t());
395 V = &V + &(ZkZkt * sigma2_u_k);
396
397 col_offset += cols;
398 }
399
400 V
401 }
402
403 fn compute_V_inv(
405 &self,
406 Z: &Array2<f64>,
407 sigma2_e: f64,
408 _sigma2_u: &[f64],
409 _group_sizes: &[usize],
410 ) -> Result<Array2<f64>> {
411 let n = Z.nrows();
412 let mut V_inv = Array2::zeros((n, n));
413
414 for i in 0..n {
417 V_inv[(i, i)] = 1.0 / sigma2_e;
418 }
419
420 Ok(V_inv)
421 }
422
423 fn compute_log_lik(
425 &self,
426 y: &Array1<f64>,
427 X: &Array2<f64>,
428 V: &Array2<f64>,
429 beta: Array1<f64>,
430 method: EstimationMethod,
431 ) -> f64 {
432 let n = y.len() as f64;
433 let _p = X.ncols() as f64;
434
435 let residuals = y - X.dot(&beta);
437
438 let log_det_V: f64 = V.diag().iter().map(|&v| v.ln()).sum();
440
441 let residuals_clone = residuals.clone();
444 let Vinv_r = residuals_clone / V.diag(); let quad_form = residuals.dot(&Vinv_r);
446
447 let log_lik = -0.5 * (n * (2.0 * std::f64::consts::PI).ln() + log_det_V + quad_form);
448
449 match method {
450 EstimationMethod::ML => log_lik,
451 EstimationMethod::REML => {
452 let inv_diag = V.diag().mapv(|v| 1.0 / v);
455 let X_scaled = X * &inv_diag.insert_axis(ndarray::Axis(1));
456 let XtVX = X.t().dot(&X_scaled);
457 let log_det_XtVX = XtVX.diag().iter().map(|&x| x.ln()).sum::<f64>();
458 log_lik - 0.5 * log_det_XtVX
459 }
460 }
461 }
462}
463
464#[derive(Debug, Clone, Serialize, Deserialize)]
466pub struct GLMMResults {
467 pub fixed_effects: Array1<f64>,
469 pub fixed_se: Array1<f64>,
471 pub variance_components: Vec<(String, f64)>,
473 pub scale: f64,
475 pub log_lik: f64,
477 pub aic: f64,
479 pub bic: f64,
481 pub df_fixed: usize,
483 pub n_obs: usize,
485 pub converged: bool,
487 pub iterations: usize,
489 pub family: Family,
491 pub link: Link,
493}
494
495pub struct GLMMBuilder {
497 data: DataFrame,
498 response: String,
499 fixed_formula: String,
500 random_effects: Vec<RandomEffect>,
501 family: Family,
502 link: Option<Link>,
503 method: GLMMEstimationMethod,
504 max_iter: usize,
505 tol: f64,
506}
507
508impl GLMMBuilder {
509 pub fn new(data: DataFrame, response: &str, fixed_formula: &str, family: Family) -> Self {
511 Self {
512 data,
513 response: response.to_string(),
514 fixed_formula: fixed_formula.to_string(),
515 random_effects: Vec::new(),
516 family,
517 link: None,
518 method: GLMMEstimationMethod::PQL,
519 max_iter: 50,
520 tol: 1e-6,
521 }
522 }
523
524 pub fn random_effect(mut self, group_var: &str, formula: &str) -> Self {
526 self.random_effects.push(RandomEffect {
527 group_var: group_var.to_string(),
528 formula: formula.to_string(),
529 covariance: RandomCovariance::Independent,
530 });
531 self
532 }
533
534 pub fn link(mut self, link: Link) -> Self {
536 self.link = Some(link);
537 self
538 }
539
540 pub fn method(mut self, method: GLMMEstimationMethod) -> Self {
542 self.method = method;
543 self
544 }
545
546 pub fn max_iterations(mut self, max_iter: usize) -> Self {
548 self.max_iter = max_iter;
549 self
550 }
551
552 pub fn tolerance(mut self, tol: f64) -> Self {
554 self.tol = tol;
555 self
556 }
557
558 pub fn fit(self) -> Result<GLMMResults> {
560 let link = self.link.unwrap_or_else(|| self.family.default_link());
562
563 let y = self.data.column(&self.response).ok_or_else(|| {
565 Error::DataError(format!("Response column '{}' not found", self.response))
566 })?;
567 let y_array = y.data().to_owned();
568
569 let glmm_results = self.fit_pql(&y_array, link)?;
571 Ok(glmm_results)
572 }
573
574 #[allow(unused_assignments, unused_variables)]
576 fn fit_pql(&self, y: &Array1<f64>, link: Link) -> Result<GLMMResults> {
577 let n = y.len();
578
579 let _glm_model = GLM::new()
581 .family(self.family)
582 .link(link)
583 .max_iter(self.max_iter)
584 .tol(self.tol)
585 .build();
586
587 let X = self.build_fixed_design_matrix()?;
590
591 let mut eta = Array1::zeros(n); let mut mu = Array1::zeros(n); let mut mu_eta = Array1::zeros(n); match self.family {
598 Family::Binomial => {
599 let y_mean = y.mean().unwrap_or(0.5);
601 let eps = 1e-4;
602 let y_clamped = y_mean.max(eps).min(1.0 - eps);
603 let init_eta = link.link(y_clamped);
604 eta.fill(init_eta);
605 }
606 Family::Poisson => {
607 let y_mean = y.mean().unwrap_or(1.0);
608 let init_eta = link.link(y_mean.max(1e-4));
609 eta.fill(init_eta);
610 }
611 _ => {
612 let y_mean = y.mean().unwrap_or(0.0);
614 let init_eta = link.link(y_mean);
615 eta.fill(init_eta);
616 }
617 }
618
619 for i in 0..n {
621 mu[i] = link.inverse_link(eta[i]);
622 mu_eta[i] = link.derivative(eta[i]);
623 }
624
625 let (Z_matrices, group_sizes) = self.build_random_design_matrices()?;
627 let Z = self.combine_Z_matrices(&Z_matrices, &group_sizes);
628 let q = Z.ncols();
629
630 let mut sigma2_e = 1.0; let mut sigma2_u = vec![1.0; Z_matrices.len()];
633
634 let p = X.ncols();
636 let mut beta = Array1::zeros(p);
637 if p > 0 {
638 beta[0] = eta.mean().unwrap_or(0.0);
640 }
641
642 let mut u = Array1::zeros(q);
643
644 let mut converged = false;
645 let mut iter = 0;
646
647 while !converged && iter < self.max_iter {
648 iter += 1;
649
650 let mut y_star = Array1::zeros(n);
654 for i in 0..n {
655 let d_eta_d_mu = if mu_eta[i].abs() > 1e-10 {
656 1.0 / mu_eta[i]
657 } else {
658 1.0
659 };
660 y_star[i] = eta[i] + (y[i] - mu[i]) * d_eta_d_mu;
661 }
662
663 let mut weights = Array1::zeros(n);
666 for i in 0..n {
667 let d_eta_d_mu = if mu_eta[i].abs() > 1e-10 {
668 1.0 / mu_eta[i]
669 } else {
670 1.0
671 };
672 let v_mu = self.family.variance(mu[i]);
673 weights[i] = 1.0 / (v_mu * d_eta_d_mu * d_eta_d_mu);
674 }
675
676 let W_sqrt = weights.mapv(|w| w.sqrt());
682 let y_star_weighted = &y_star * &W_sqrt;
683 let X_weighted = &X * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
684 let Z_weighted = &Z * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
685
686 let XtWX = X_weighted.t().dot(&X_weighted);
692 let ZtWZ = Z_weighted.t().dot(&Z_weighted);
693 let XtWZ = X_weighted.t().dot(&Z_weighted);
694 let ZtWX = Z_weighted.t().dot(&X_weighted);
695
696 let XtWy = X_weighted.t().dot(&y_star_weighted);
697 let ZtWy = Z_weighted.t().dot(&y_star_weighted);
698
699 let total_cols = p + q;
701 let mut M = Array2::zeros((total_cols, total_cols));
702 let mut rhs = Array1::zeros(total_cols);
703
704 M.slice_mut(ndarray::s![0..p, 0..p]).assign(&XtWX);
706 M.slice_mut(ndarray::s![0..p, p..]).assign(&XtWZ);
708 M.slice_mut(ndarray::s![p.., 0..p]).assign(&ZtWX);
710 let mut ZtWZ_plus_Ginv = ZtWZ.clone();
712
713 let mut col_offset = 0;
715 for (k, sigma2_u_k) in sigma2_u.iter().enumerate() {
716 let cols = group_sizes[k];
717 let g_inv = 1.0 / f64::max(*sigma2_u_k, 1e-10);
718 for i in 0..cols {
719 let idx = col_offset + i;
720 ZtWZ_plus_Ginv[(idx, idx)] += g_inv;
721 }
722 col_offset += cols;
723 }
724
725 M.slice_mut(ndarray::s![p.., p..]).assign(&ZtWZ_plus_Ginv);
726
727 rhs.slice_mut(ndarray::s![0..p]).assign(&XtWy);
729 rhs.slice_mut(ndarray::s![p..]).assign(&ZtWy);
730
731 let solution = solve(&M, &rhs).map_err(|e| {
733 Error::LinearAlgebraError(format!("Failed to solve mixed model equations: {}", e))
734 })?;
735
736 let new_beta = solution.slice(ndarray::s![0..p]).to_owned();
737 let new_u = solution.slice(ndarray::s![p..]).to_owned();
738
739 let new_eta = X.dot(&new_beta) + Z.dot(&new_u);
741
742 let mut new_mu = Array1::zeros(n);
744 let mut new_mu_eta = Array1::zeros(n);
745 for i in 0..n {
746 new_mu[i] = link.inverse_link(new_eta[i]);
747 new_mu_eta[i] = link.derivative(new_eta[i]);
748 }
749
750 let old_sigma2_e = sigma2_e;
752 let old_sigma2_u = sigma2_u.clone();
753
754 col_offset = 0;
756 for (k, sigma2_u_k) in sigma2_u.iter_mut().enumerate() {
757 let cols = group_sizes[k];
758 let u_k = new_u.slice(ndarray::s![col_offset..col_offset + cols]);
759 let trace_term = 0.0; *sigma2_u_k = u_k.dot(&u_k) / (cols as f64 - trace_term).max(1.0);
761 col_offset += cols;
762 }
763
764 let residuals = y - &new_mu;
766 let pearson_residuals =
767 residuals.mapv(|r| r * r / self.family.variance(new_mu[0]).max(1e-10));
768 sigma2_e = pearson_residuals.mean().unwrap_or(1.0);
769
770 let beta_diff = (&new_beta - &beta)
772 .mapv(|x| x.abs())
773 .mean()
774 .unwrap_or(f64::INFINITY);
775 let eta_diff = (&new_eta - &eta)
776 .mapv(|x| x.abs())
777 .mean()
778 .unwrap_or(f64::INFINITY);
779
780 beta = new_beta;
781 u = new_u;
782 eta = new_eta;
783 mu = new_mu;
784 mu_eta = new_mu_eta;
785
786 converged = beta_diff < self.tol && eta_diff < self.tol;
787
788 let sigma2_u_diff = sigma2_u
790 .iter()
791 .zip(&old_sigma2_u)
792 .map(|(new, old)| (new - old).abs() / old.max(1e-10))
793 .fold(0.0, f64::max);
794 let sigma2_e_diff = (sigma2_e - old_sigma2_e).abs() / old_sigma2_e.max(1e-10);
795
796 converged = converged && sigma2_u_diff < self.tol && sigma2_e_diff < self.tol;
797 }
798
799 let mut final_weights = Array1::zeros(n);
803 for i in 0..n {
804 let d_eta_d_mu = if mu_eta[i].abs() > 1e-10 {
805 1.0 / mu_eta[i]
806 } else {
807 1.0
808 };
809 let v_mu = self.family.variance(mu[i]);
810 final_weights[i] = 1.0 / (v_mu * d_eta_d_mu * d_eta_d_mu);
811 }
812 let W_sqrt = final_weights.mapv(|w| w.sqrt());
813 let X_weighted = &X * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
814 let Z_weighted = &Z * &W_sqrt.clone().insert_axis(ndarray::Axis(1));
815
816 let XtWX = X_weighted.t().dot(&X_weighted);
817 let ZtWZ = Z_weighted.t().dot(&Z_weighted);
818 let XtWZ = X_weighted.t().dot(&Z_weighted);
819 let ZtWX = Z_weighted.t().dot(&X_weighted);
820
821 let total_cols = p + q;
822 let mut M = Array2::zeros((total_cols, total_cols));
823 M.slice_mut(ndarray::s![0..p, 0..p]).assign(&XtWX);
824 M.slice_mut(ndarray::s![0..p, p..]).assign(&XtWZ);
825 M.slice_mut(ndarray::s![p.., 0..p]).assign(&ZtWX);
826
827 let mut ZtWZ_plus_Ginv = ZtWZ.clone();
828 let mut col_offset = 0;
829 for (k, sigma2_u_k) in sigma2_u.iter().enumerate() {
830 let cols = group_sizes[k];
831 let g_inv = 1.0 / f64::max(*sigma2_u_k, 1e-10);
832 for i in 0..cols {
833 let idx = col_offset + i;
834 ZtWZ_plus_Ginv[(idx, idx)] += g_inv;
835 }
836 col_offset += cols;
837 }
838 M.slice_mut(ndarray::s![p.., p..]).assign(&ZtWZ_plus_Ginv);
839
840 let Minv = inv(&M).map_err(|e| {
841 Error::LinearAlgebraError(format!("Failed to invert mixed model matrix: {}", e))
842 })?;
843
844 let cov_beta = Minv.slice(ndarray::s![0..p, 0..p]).to_owned();
845 let fixed_se = cov_beta.diag().mapv(|x| x.sqrt());
846
847 let log_lik = self.approximate_log_lik(y, &mu, &eta, &final_weights, sigma2_e);
849
850 let n_params = p + sigma2_u.len() + 1; let aic = -2.0 * log_lik + 2.0 * n_params as f64;
853 let bic = -2.0 * log_lik + (n_params as f64) * (n as f64).ln();
854
855 let mut var_comps = Vec::new();
857 for (i, random_effect) in self.random_effects.iter().enumerate() {
858 var_comps.push((random_effect.group_var.clone(), sigma2_u[i]));
859 }
860
861 Ok(GLMMResults {
862 fixed_effects: beta,
863 fixed_se,
864 variance_components: var_comps,
865 scale: sigma2_e,
866 log_lik,
867 aic,
868 bic,
869 df_fixed: p,
870 n_obs: n,
871 converged,
872 iterations: iter,
873 family: self.family,
874 link,
875 })
876 }
877
878 fn build_fixed_design_matrix(&self) -> Result<Array2<f64>> {
880 let formula_str = if self.fixed_formula.contains('~') {
882 self.fixed_formula.clone()
883 } else {
884 format!("__response__ ~ {}", self.fixed_formula)
886 };
887
888 let formula = Formula::parse(&formula_str)
889 .map_err(|e| Error::FormulaError(format!("Failed to parse fixed formula: {}", e)))?;
890
891 formula
893 .build_matrix(&self.data)
894 .map_err(|e| Error::DataError(format!("Failed to build design matrix: {}", e)))
895 }
896
897 fn build_random_design_matrices(&self) -> Result<(Vec<Array2<f64>>, Vec<usize>)> {
899 let mut Z_matrices = Vec::new();
900 let mut group_sizes = Vec::new();
901
902 for random_effect in &self.random_effects {
903 let group_col = self.data.column(&random_effect.group_var).ok_or_else(|| {
904 Error::DataError(format!(
905 "Group column '{}' not found",
906 random_effect.group_var
907 ))
908 })?;
909
910 let group_data = group_col.data();
913 let max_group = group_data
914 .iter()
915 .map(|&x| x as i64)
916 .max()
917 .unwrap_or(0)
918 .max(0) as usize;
919 let n_groups = max_group + 1;
920
921 let n = self.data.n_rows();
922 let mut Z = Array2::zeros((n, n_groups));
923
924 for j in 0..n {
925 let group_idx = group_data[j] as usize % n_groups.max(1);
926 if n_groups > 0 {
927 Z[(j, group_idx)] = 1.0;
928 }
929 }
930
931 Z_matrices.push(Z);
932 group_sizes.push(n_groups);
933 }
934
935 Ok((Z_matrices, group_sizes))
936 }
937
938 fn combine_Z_matrices(&self, Z_matrices: &[Array2<f64>], group_sizes: &[usize]) -> Array2<f64> {
940 let n = Z_matrices[0].nrows();
941 let total_cols: usize = group_sizes.iter().sum();
942
943 let mut Z = Array2::zeros((n, total_cols));
944 let mut col_offset = 0;
945
946 for (i, Z_i) in Z_matrices.iter().enumerate() {
947 let cols = group_sizes[i];
948 for row in 0..n {
949 for col in 0..cols {
950 Z[(row, col_offset + col)] = Z_i[(row, col)];
951 }
952 }
953 col_offset += cols;
954 }
955
956 Z
957 }
958
959 fn approximate_log_lik(
961 &self,
962 y: &Array1<f64>,
963 mu: &Array1<f64>,
964 _eta: &Array1<f64>,
965 _weights: &Array1<f64>,
966 scale: f64,
967 ) -> f64 {
968 let n = y.len() as f64;
969
970 let mut ql = 0.0;
972
973 for i in 0..y.len() {
974 let deviance = self.family.unit_deviance(y[i], mu[i]);
977 ql += -0.5 * deviance / scale;
978 }
979
980 ql - 0.5 * n * (2.0 * std::f64::consts::PI * scale).ln()
982 }
983}