Skip to main content

scirs2_stats/bayesian/
hierarchical.rs

1//! Hierarchical Bayesian models
2//!
3//! This module implements hierarchical (multi-level) Bayesian models that allow
4//! for group-level variation and borrowing of strength across groups.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::{Distribution, Gamma, Normal};
9use scirs2_core::validation::*;
10use scirs2_core::Rng;
11use statrs::statistics::Statistics;
12
13/// Hierarchical linear model with random intercepts and slopes
14///
15/// Model structure:
16/// Level 1: y_ij = β₀j + β₁j * x_ij + ε_ij,  ε_ij ~ N(0, σ²)
17/// Level 2: β₀j = γ₀₀ + γ₀₁ * w_j + u₀j,     u₀j ~ N(0, τ₀₀)
18///          β₁j = γ₁₀ + γ₁₁ * w_j + u₁j,     u₁j ~ N(0, τ₁₁)
19#[derive(Debug, Clone)]
20pub struct HierarchicalLinearModel {
21    /// Fixed effects parameters
22    pub fixed_effects: Array2<f64>,
23    /// Random effects covariance matrix
24    pub random_effects_cov: Array2<f64>,
25    /// Residual variance
26    pub residual_variance: f64,
27    /// Group identifiers
28    pub groups: Array1<usize>,
29    /// Number of groups
30    pub n_groups: usize,
31    /// Number of level-1 predictors
32    pub n_level1_predictors: usize,
33    /// Number of level-2 predictors
34    pub n_level2_predictors: usize,
35    /// Whether to include random slopes
36    pub random_slopes: bool,
37}
38
39impl HierarchicalLinearModel {
40    /// Create a new hierarchical linear model
41    pub fn new(
42        n_groups: usize,
43        n_level1_predictors: usize,
44        n_level2_predictors: usize,
45        random_slopes: bool,
46    ) -> Result<Self> {
47        check_positive(n_groups, "n_groups")?;
48        check_positive(n_level1_predictors, "n_level1_predictors")?;
49
50        let n_random_effects = if random_slopes {
51            n_level1_predictors + 1
52        } else {
53            1
54        };
55        let fixed_effects = Array2::zeros((n_random_effects, n_level2_predictors + 1));
56        let random_effects_cov = Array2::eye(n_random_effects);
57
58        Ok(Self {
59            fixed_effects,
60            random_effects_cov,
61            residual_variance: 1.0,
62            groups: Array1::zeros(0),
63            n_groups,
64            n_level1_predictors,
65            n_level2_predictors,
66            random_slopes,
67        })
68    }
69
70    /// Fit the hierarchical model using MCMC
71    pub fn fit_mcmc<R: Rng + ?Sized>(
72        &mut self,
73        y: ArrayView1<f64>,
74        x_level1: ArrayView2<f64>,
75        x_level2: ArrayView2<f64>,
76        groups: ArrayView1<usize>,
77        n_iter: usize,
78        burnin: usize,
79        rng: &mut R,
80    ) -> Result<HierarchicalModelResults> {
81        checkarray_finite(&y, "y")?;
82        checkarray_finite(&x_level1, "x_level1")?;
83        checkarray_finite(&x_level2, "x_level2")?;
84        check_positive(n_iter, "n_iter")?;
85
86        let n_obs = y.len();
87        if x_level1.nrows() != n_obs {
88            return Err(StatsError::DimensionMismatch(format!(
89                "x_level1 rows ({}) must match y length ({})",
90                x_level1.nrows(),
91                n_obs
92            )));
93        }
94
95        if groups.len() != n_obs {
96            return Err(StatsError::DimensionMismatch(format!(
97                "groups length ({}) must match y length ({})",
98                groups.len(),
99                n_obs
100            )));
101        }
102
103        self.groups = groups.to_owned();
104
105        // Initialize storage for MCMC samples
106        let n_random_effects = if self.random_slopes {
107            self.n_level1_predictors + 1
108        } else {
109            1
110        };
111        let n_fixed = (self.n_level2_predictors + 1) * n_random_effects;
112
113        let mut fixed_effects_samples = Array2::zeros((n_iter - burnin, n_fixed));
114        let mut random_effects_samples =
115            Array2::zeros((n_iter - burnin, self.n_groups * n_random_effects));
116        let mut variance_samples = Array1::zeros(n_iter - burnin);
117        let mut tau_samples = Array2::zeros((n_iter - burnin, n_random_effects * n_random_effects));
118
119        // Initialize random effects for each group
120        let mut random_effects = Array2::zeros((self.n_groups, n_random_effects));
121
122        // MCMC iterations
123        for _iter in 0..n_iter {
124            // 1. Update random effects for each group
125            self.update_random_effects(&y, &x_level1, &x_level2, &mut random_effects, rng)?;
126
127            // 2. Update fixed effects
128            self.update_fixed_effects(&random_effects, &x_level2, rng)?;
129
130            // 3. Update residual variance
131            self.update_residual_variance(&y, &x_level1, &random_effects, rng)?;
132
133            // 4. Update random effects covariance
134            self.update_random_effects_covariance(&random_effects, rng)?;
135
136            // Store samples after burnin
137            if _iter >= burnin {
138                let sample_idx = _iter - burnin;
139
140                // Store fixed effects
141                let mut fixed_flat = Array1::zeros(n_fixed);
142                let mut idx = 0;
143                for i in 0..self.fixed_effects.nrows() {
144                    for j in 0..self.fixed_effects.ncols() {
145                        fixed_flat[idx] = self.fixed_effects[[i, j]];
146                        idx += 1;
147                    }
148                }
149                fixed_effects_samples
150                    .row_mut(sample_idx)
151                    .assign(&fixed_flat);
152
153                // Store random effects
154                let mut random_flat = Array1::zeros(self.n_groups * n_random_effects);
155                let mut idx = 0;
156                for group in 0..self.n_groups {
157                    for effect in 0..n_random_effects {
158                        random_flat[idx] = random_effects[[group, effect]];
159                        idx += 1;
160                    }
161                }
162                random_effects_samples
163                    .row_mut(sample_idx)
164                    .assign(&random_flat);
165
166                // Store variances
167                variance_samples[sample_idx] = self.residual_variance;
168
169                // Store tau (covariance matrix flattened)
170                let mut tau_flat = Array1::zeros(n_random_effects * n_random_effects);
171                let mut idx = 0;
172                for i in 0..n_random_effects {
173                    for j in 0..n_random_effects {
174                        tau_flat[idx] = self.random_effects_cov[[i, j]];
175                        idx += 1;
176                    }
177                }
178                tau_samples.row_mut(sample_idx).assign(&tau_flat);
179            }
180        }
181
182        Ok(HierarchicalModelResults {
183            fixed_effects_samples,
184            random_effects_samples,
185            variance_samples,
186            tau_samples,
187            n_groups: self.n_groups,
188            n_random_effects,
189            n_iter: n_iter - burnin,
190        })
191    }
192
193    /// Update random effects for each group using Gibbs sampling
194    fn update_random_effects<R: scirs2_core::random::Rng + ?Sized>(
195        &self,
196        y: &ArrayView1<f64>,
197        x_level1: &ArrayView2<f64>,
198        x_level2: &ArrayView2<f64>,
199        random_effects: &mut Array2<f64>,
200        rng: &mut R,
201    ) -> Result<()> {
202        let n_random_effects = random_effects.ncols();
203
204        for group in 0..self.n_groups {
205            // Find observations for this group
206            let group_indices: Vec<usize> = self
207                .groups
208                .iter()
209                .enumerate()
210                .filter_map(|(i, &g)| if g == group { Some(i) } else { None })
211                .collect();
212
213            if group_indices.is_empty() {
214                continue;
215            }
216
217            // Extract group data
218            let n_group_obs = group_indices.len();
219            let mut y_group = Array1::zeros(n_group_obs);
220            let mut x_group = Array2::zeros((n_group_obs, self.n_level1_predictors));
221
222            for (i, &obs_idx) in group_indices.iter().enumerate() {
223                y_group[i] = y[obs_idx];
224                x_group.row_mut(i).assign(&x_level1.row(obs_idx));
225            }
226
227            // Compute posterior parameters for random _effects
228            let precision_prior = scirs2_linalg::inv(&self.random_effects_cov.view(), None)
229                .map_err(|e| {
230                    StatsError::ComputationError(format!("Failed to invert covariance: {}", e))
231                })?;
232
233            // Design matrix for random _effects (intercept + slopes if enabled)
234            let mut z_group = Array2::zeros((n_group_obs, n_random_effects));
235            z_group.column_mut(0).fill(1.0); // Intercept
236            if self.random_slopes && n_random_effects > 1 {
237                for i in 1..n_random_effects {
238                    z_group.column_mut(i).assign(&x_group.column(i - 1));
239                }
240            }
241
242            let zt_z = z_group.t().dot(&z_group);
243            let precision_posterior = precision_prior.clone() + zt_z / self.residual_variance;
244
245            let covariance_posterior = scirs2_linalg::inv(&precision_posterior.view(), None)
246                .map_err(|e| {
247                    StatsError::ComputationError(format!(
248                        "Failed to invert posterior precision: {}",
249                        e
250                    ))
251                })?;
252
253            // Compute prior mean for this group
254            let group_level2 = if group < x_level2.nrows() {
255                x_level2.row(group).to_owned()
256            } else {
257                Array1::zeros(x_level2.ncols())
258            };
259
260            let mut prior_mean = Array1::zeros(n_random_effects);
261            for i in 0..n_random_effects {
262                prior_mean[i] = self.fixed_effects.row(i).dot(&group_level2);
263            }
264
265            let data_contrib = z_group.t().dot(&y_group) / self.residual_variance;
266            let prior_contrib = precision_prior.dot(&prior_mean);
267            let posterior_mean = covariance_posterior.dot(&(data_contrib + prior_contrib));
268
269            // Sample from multivariate normal
270            let mvn_sample =
271                sample_multivariate_normal(&posterior_mean, &covariance_posterior, rng)?;
272            random_effects.row_mut(group).assign(&mvn_sample);
273        }
274
275        Ok(())
276    }
277
278    /// Update fixed effects using Gibbs sampling
279    fn update_fixed_effects<R: Rng + ?Sized>(
280        &mut self,
281        random_effects: &Array2<f64>,
282        x_level2: &ArrayView2<f64>,
283        rng: &mut R,
284    ) -> Result<()> {
285        let n_random_effects = self.fixed_effects.nrows();
286        let n_level2_predictors = self.fixed_effects.ncols();
287
288        for i in 0..n_random_effects {
289            // Extract dependent variable (random effect i for all groups)
290            let y_i = random_effects.column(i);
291
292            // Prior parameters (weak priors)
293            let prior_precision = 1e-6;
294            let prior_mean = 0.0;
295
296            // Likelihood precision
297            let tau_ii = self.random_effects_cov[[i, i]];
298            let likelihood_precision = 1.0 / tau_ii;
299
300            // Posterior parameters
301            let xtx = x_level2.t().dot(x_level2);
302            let precision_posterior =
303                Array2::eye(n_level2_predictors) * prior_precision + xtx * likelihood_precision;
304            let covariance_posterior = scirs2_linalg::inv(&precision_posterior.view(), None)
305                .map_err(|e| {
306                    StatsError::ComputationError(format!("Failed to invert precision: {}", e))
307                })?;
308
309            let xty = x_level2.t().dot(&y_i);
310            let data_contrib = xty * likelihood_precision;
311            let prior_contrib =
312                Array1::from_elem(n_level2_predictors, prior_mean * prior_precision);
313            let mean_posterior = covariance_posterior.dot(&(data_contrib + prior_contrib));
314
315            // Sample from multivariate normal
316            let sample = sample_multivariate_normal(&mean_posterior, &covariance_posterior, rng)?;
317            self.fixed_effects.row_mut(i).assign(&sample);
318        }
319
320        Ok(())
321    }
322
323    /// Update residual variance using Gibbs sampling
324    fn update_residual_variance<R: Rng + ?Sized>(
325        &mut self,
326        y: &ArrayView1<f64>,
327        x_level1: &ArrayView2<f64>,
328        random_effects: &Array2<f64>,
329        rng: &mut R,
330    ) -> Result<()> {
331        let n_obs = y.len();
332
333        // Compute residuals
334        let mut residuals_sum_sq = 0.0;
335        for (obs_idx, &group) in self.groups.iter().enumerate() {
336            let y_obs = y[obs_idx];
337            let x_obs = x_level1.row(obs_idx);
338
339            // Predicted value
340            let intercept = random_effects[[group, 0]];
341            let mut y_pred = intercept;
342
343            if self.random_slopes && random_effects.ncols() > 1 {
344                for j in 0..self.n_level1_predictors {
345                    y_pred += random_effects[[group, j + 1]] * x_obs[j];
346                }
347            }
348
349            let residual = y_obs - y_pred;
350            residuals_sum_sq += residual * residual;
351        }
352
353        // Inverse gamma prior parameters
354        let alpha_prior = 1e-3;
355        let beta_prior = 1e-3;
356
357        // Posterior parameters
358        let alpha_posterior = alpha_prior + n_obs as f64 / 2.0;
359        let beta_posterior = beta_prior + residuals_sum_sq / 2.0;
360
361        // Sample from inverse gamma (via gamma)
362        let gamma_dist = Gamma::new(alpha_posterior, 1.0 / beta_posterior).map_err(|e| {
363            StatsError::ComputationError(format!("Failed to create Gamma distribution: {}", e))
364        })?;
365        let precision_sample = gamma_dist.sample(rng);
366        self.residual_variance = 1.0 / precision_sample;
367
368        Ok(())
369    }
370
371    /// Update random effects covariance matrix using inverse Wishart
372    fn update_random_effects_covariance<R: scirs2_core::random::Rng + ?Sized>(
373        &mut self,
374        random_effects: &Array2<f64>,
375        rng: &mut R,
376    ) -> Result<()> {
377        let n_random_effects = random_effects.ncols();
378        let n_groups = random_effects.nrows();
379
380        // Compute sample covariance of random _effects
381        let mut sum_outer_products = Array2::<f64>::zeros((n_random_effects, n_random_effects));
382
383        for group in 0..n_groups {
384            let _effects = random_effects.row(group);
385            let outer = outer_product(&_effects.to_owned());
386            sum_outer_products = sum_outer_products + outer;
387        }
388
389        // Inverse Wishart prior parameters
390        let nu_prior = n_random_effects as f64 + 2.0; // Degrees of freedom
391        let psi_prior = Array2::<f64>::eye(n_random_effects) * 0.1; // Scale matrix
392
393        // Posterior parameters
394        let nu_posterior = nu_prior + n_groups as f64;
395        let psi_posterior = psi_prior + sum_outer_products;
396
397        // Sample from inverse Wishart (simplified using independent gamma for diagonal)
398        // In full implementation, would use proper inverse Wishart sampling
399        let mut new_cov = Array2::<f64>::zeros((n_random_effects, n_random_effects));
400
401        for i in 0..n_random_effects {
402            // Sample diagonal elements from inverse gamma
403            let alpha = nu_posterior / 2.0;
404            let beta = psi_posterior[[i, i]] / 2.0;
405
406            let gamma_dist = Gamma::new(alpha, 1.0 / beta).map_err(|e| {
407                StatsError::ComputationError(format!("Failed to create Gamma distribution: {}", e))
408            })?;
409            let precision = gamma_dist.sample(rng);
410            new_cov[[i, i]] = 1.0 / precision;
411        }
412
413        // For off-diagonal elements, use simplified approach
414        for i in 0..n_random_effects {
415            for j in (i + 1)..n_random_effects {
416                let val1: f64 = psi_posterior[[i, i]];
417                let val2: f64 = psi_posterior[[j, j]];
418                let denom: f64 = (val1 * val2).sqrt();
419                let correlation: f64 = psi_posterior[[i, j]] / denom;
420                let covariance = correlation * (new_cov[[i, i]] * new_cov[[j, j]]).sqrt();
421                new_cov[[i, j]] = covariance * 0.1; // Shrink off-diagonal
422                new_cov[[j, i]] = new_cov[[i, j]];
423            }
424        }
425
426        self.random_effects_cov = new_cov;
427        Ok(())
428    }
429
430    /// Predict for new data
431    pub fn predict(
432        &self,
433        x_level1: ArrayView2<f64>,
434        x_level2: ArrayView2<f64>,
435        groups: ArrayView1<usize>,
436    ) -> Result<Array1<f64>> {
437        checkarray_finite(&x_level1, "x_level1")?;
438        checkarray_finite(&x_level2, "x_level2")?;
439
440        let n_obs = x_level1.nrows();
441        let mut predictions = Array1::zeros(n_obs);
442
443        for (obs_idx, &group) in groups.iter().enumerate() {
444            if group >= self.n_groups {
445                return Err(StatsError::InvalidArgument(format!(
446                    "Group {} exceeds number of groups {}",
447                    group, self.n_groups
448                )));
449            }
450
451            let x_obs = x_level1.row(obs_idx);
452
453            // Compute group-level predictors
454            let zeros_array = Array1::zeros(x_level2.ncols());
455            let group_level2 = if group < x_level2.nrows() {
456                x_level2.row(group)
457            } else {
458                // Handle new groups by using population mean
459                zeros_array.view()
460            };
461
462            // Compute random intercept
463            let intercept = self.fixed_effects.row(0).dot(&group_level2);
464            let mut y_pred = intercept;
465
466            // Add slope effects if enabled
467            if self.random_slopes && self.fixed_effects.nrows() > 1 {
468                for j in 0..self.n_level1_predictors {
469                    let slope = self.fixed_effects.row(j + 1).dot(&group_level2);
470                    y_pred += slope * x_obs[j];
471                }
472            }
473
474            predictions[obs_idx] = y_pred;
475        }
476
477        Ok(predictions)
478    }
479}
480
481/// Results from hierarchical model fitting
482#[derive(Debug, Clone)]
483pub struct HierarchicalModelResults {
484    /// MCMC samples of fixed effects (flattened)
485    pub fixed_effects_samples: Array2<f64>,
486    /// MCMC samples of random effects (flattened)
487    pub random_effects_samples: Array2<f64>,
488    /// MCMC samples of residual variance
489    pub variance_samples: Array1<f64>,
490    /// MCMC samples of random effects covariance (flattened)
491    pub tau_samples: Array2<f64>,
492    /// Number of groups
493    pub n_groups: usize,
494    /// Number of random effects per group
495    pub n_random_effects: usize,
496    /// Number of MCMC samples
497    pub n_iter: usize,
498}
499
500impl HierarchicalModelResults {
501    /// Compute posterior summaries for fixed effects
502    pub fn fixed_effects_summary(&self) -> Result<Array2<f64>> {
503        let n_params = self.fixed_effects_samples.ncols();
504        let mut summary = Array2::zeros((n_params, 4)); // mean, std, 2.5%, 97.5%
505
506        for param in 0..n_params {
507            let samples = self.fixed_effects_samples.column(param);
508            let mean = samples.mean();
509            let std = samples.variance().sqrt();
510
511            let mut sorted_samples = samples.to_vec();
512            sorted_samples.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
513
514            let q025_idx = (0.025 * sorted_samples.len() as f64) as usize;
515            let q975_idx = (0.975 * sorted_samples.len() as f64) as usize;
516            let q025 = sorted_samples[q025_idx];
517            let q975 = sorted_samples[q975_idx.min(sorted_samples.len() - 1)];
518
519            summary[[param, 0]] = mean;
520            summary[[param, 1]] = std;
521            summary[[param, 2]] = q025;
522            summary[[param, 3]] = q975;
523        }
524
525        Ok(summary)
526    }
527
528    /// Compute posterior summaries for random effects variances
529    pub fn random_effects_variance_summary(&self) -> Result<Array2<f64>> {
530        let n_params = self.n_random_effects * self.n_random_effects;
531        let mut summary = Array2::zeros((n_params, 4));
532
533        for param in 0..n_params {
534            let samples = self.tau_samples.column(param);
535            let mean = samples.mean();
536            let std = samples.variance().sqrt();
537
538            let mut sorted_samples = samples.to_vec();
539            sorted_samples.sort_by(|a, b| a.partial_cmp(b).expect("Operation failed"));
540
541            let q025_idx = (0.025 * sorted_samples.len() as f64) as usize;
542            let q975_idx = (0.975 * sorted_samples.len() as f64) as usize;
543            let q025 = sorted_samples[q025_idx];
544            let q975 = sorted_samples[q975_idx.min(sorted_samples.len() - 1)];
545
546            summary[[param, 0]] = mean;
547            summary[[param, 1]] = std;
548            summary[[param, 2]] = q025;
549            summary[[param, 3]] = q975;
550        }
551
552        Ok(summary)
553    }
554}
555
556/// Bayesian ANOVA with hierarchical structure
557#[derive(Debug, Clone)]
558pub struct HierarchicalANOVA {
559    /// Group means
560    pub group_means: Array1<f64>,
561    /// Overall mean
562    pub overall_mean: f64,
563    /// Between-group variance
564    pub between_variance: f64,
565    /// Within-group variance
566    pub within_variance: f64,
567    /// Group assignments
568    pub groups: Array1<usize>,
569    /// Number of groups
570    pub n_groups: usize,
571}
572
573impl HierarchicalANOVA {
574    /// Create new hierarchical ANOVA
575    pub fn new(n_groups: usize) -> Result<Self> {
576        check_positive(n_groups, "n_groups")?;
577
578        Ok(Self {
579            group_means: Array1::zeros(n_groups),
580            overall_mean: 0.0,
581            between_variance: 1.0,
582            within_variance: 1.0,
583            groups: Array1::zeros(0),
584            n_groups,
585        })
586    }
587
588    /// Fit hierarchical ANOVA using MCMC
589    pub fn fit_mcmc<R: Rng + ?Sized>(
590        &mut self,
591        y: ArrayView1<f64>,
592        groups: ArrayView1<usize>,
593        n_iter: usize,
594        burnin: usize,
595        rng: &mut R,
596    ) -> Result<HierarchicalANOVAResults> {
597        checkarray_finite(&y, "y")?;
598        check_positive(n_iter, "n_iter")?;
599
600        if y.len() != groups.len() {
601            return Err(StatsError::DimensionMismatch(format!(
602                "y length ({}) must match groups length ({})",
603                y.len(),
604                groups.len()
605            )));
606        }
607
608        self.groups = groups.to_owned();
609
610        // Initialize storage
611        let mut group_means_samples = Array2::zeros((n_iter - burnin, self.n_groups));
612        let mut overall_mean_samples_ = Array1::zeros(n_iter - burnin);
613        let mut between_var_samples = Array1::zeros(n_iter - burnin);
614        let mut within_var_samples = Array1::zeros(n_iter - burnin);
615
616        // Group statistics
617        let mut group_counts = vec![0; self.n_groups];
618        let mut group_sums = vec![0.0; self.n_groups];
619
620        for (&obs_group, &obs_y) in groups.iter().zip(y.iter()) {
621            if obs_group >= self.n_groups {
622                return Err(StatsError::InvalidArgument(format!(
623                    "Group {} exceeds n_groups {}",
624                    obs_group, self.n_groups
625                )));
626            }
627            group_counts[obs_group] += 1;
628            group_sums[obs_group] += obs_y;
629        }
630
631        // MCMC iterations
632        for _iter in 0..n_iter {
633            // 1. Update group means
634            for group in 0..self.n_groups {
635                if group_counts[group] > 0 {
636                    // Posterior parameters
637                    let prior_precision = 1.0 / self.between_variance;
638                    let likelihood_precision = group_counts[group] as f64 / self.within_variance;
639                    let posterior_precision = prior_precision + likelihood_precision;
640                    let posterior_variance = 1.0 / posterior_precision;
641
642                    let prior_mean_contribution = self.overall_mean * prior_precision;
643                    let likelihood_mean_contribution = group_sums[group] * likelihood_precision;
644                    let posterior_mean = (prior_mean_contribution + likelihood_mean_contribution)
645                        / posterior_precision;
646
647                    // Sample from normal
648                    let normal =
649                        Normal::new(posterior_mean, posterior_variance.sqrt()).map_err(|e| {
650                            StatsError::ComputationError(format!("Failed to create normal: {}", e))
651                        })?;
652                    self.group_means[group] = normal.sample(rng);
653                } else {
654                    // No observations in this group, sample from prior
655                    let normal = Normal::new(self.overall_mean, self.between_variance.sqrt())
656                        .map_err(|e| {
657                            StatsError::ComputationError(format!("Failed to create normal: {}", e))
658                        })?;
659                    self.group_means[group] = normal.sample(rng);
660                }
661            }
662
663            // 2. Update overall mean
664            let group_mean_avg = self.group_means.clone().mean();
665            let prior_variance = 10.0; // Weak prior
666            let likelihood_variance = self.between_variance / self.n_groups as f64;
667            let posterior_variance = 1.0 / (1.0 / prior_variance + 1.0 / likelihood_variance);
668            let posterior_mean =
669                (0.0 / prior_variance + group_mean_avg / likelihood_variance) * posterior_variance;
670
671            let normal = Normal::new(posterior_mean, posterior_variance.sqrt()).map_err(|e| {
672                StatsError::ComputationError(format!("Failed to create normal: {}", e))
673            })?;
674            self.overall_mean = normal.sample(rng);
675
676            // 3. Update between-group variance
677            let sum_sq_deviations: f64 = self
678                .group_means
679                .iter()
680                .map(|&mean| (mean - self.overall_mean).powi(2))
681                .sum();
682
683            let alpha_prior = 1e-3;
684            let beta_prior = 1e-3;
685            let alpha_posterior = alpha_prior + self.n_groups as f64 / 2.0;
686            let beta_posterior = beta_prior + sum_sq_deviations / 2.0;
687
688            let gamma_dist = Gamma::new(alpha_posterior, 1.0 / beta_posterior).map_err(|e| {
689                StatsError::ComputationError(format!("Failed to create Gamma: {}", e))
690            })?;
691            let precision = gamma_dist.sample(rng);
692            self.between_variance = 1.0 / precision;
693
694            // 4. Update within-group variance
695            let mut within_sum_sq = 0.0;
696            let mut total_obs = 0;
697
698            for (&obs_group, &obs_y) in groups.iter().zip(y.iter()) {
699                let residual = obs_y - self.group_means[obs_group];
700                within_sum_sq += residual * residual;
701                total_obs += 1;
702            }
703
704            let alpha_posterior = alpha_prior + total_obs as f64 / 2.0;
705            let beta_posterior = beta_prior + within_sum_sq / 2.0;
706
707            let gamma_dist = Gamma::new(alpha_posterior, 1.0 / beta_posterior).map_err(|e| {
708                StatsError::ComputationError(format!("Failed to create Gamma: {}", e))
709            })?;
710            let precision = gamma_dist.sample(rng);
711            self.within_variance = 1.0 / precision;
712
713            // Store samples after burnin
714            if _iter >= burnin {
715                let sample_idx = _iter - burnin;
716                group_means_samples
717                    .row_mut(sample_idx)
718                    .assign(&self.group_means);
719                overall_mean_samples_[sample_idx] = self.overall_mean;
720                between_var_samples[sample_idx] = self.between_variance;
721                within_var_samples[sample_idx] = self.within_variance;
722            }
723        }
724
725        Ok(HierarchicalANOVAResults {
726            group_means_samples,
727            overall_mean_samples_,
728            between_variance_samples: between_var_samples,
729            within_variance_samples: within_var_samples,
730            n_groups: self.n_groups,
731            n_iter: n_iter - burnin,
732        })
733    }
734}
735
736/// Results from hierarchical ANOVA
737#[derive(Debug, Clone)]
738pub struct HierarchicalANOVAResults {
739    /// MCMC samples of group means
740    pub group_means_samples: Array2<f64>,
741    /// MCMC samples of overall mean
742    pub overall_mean_samples_: Array1<f64>,
743    /// MCMC samples of between-group variance
744    pub between_variance_samples: Array1<f64>,
745    /// MCMC samples of within-group variance
746    pub within_variance_samples: Array1<f64>,
747    /// Number of groups
748    pub n_groups: usize,
749    /// Number of MCMC samples
750    pub n_iter: usize,
751}
752
753impl HierarchicalANOVAResults {
754    /// Compute intraclass correlation coefficient (ICC)
755    pub fn icc_samples(&self) -> Array1<f64> {
756        let mut icc = Array1::zeros(self.n_iter);
757        for i in 0..self.n_iter {
758            let between_var = self.between_variance_samples[i];
759            let within_var = self.within_variance_samples[i];
760            icc[i] = between_var / (between_var + within_var);
761        }
762        icc
763    }
764
765    /// Compute posterior probability that group i has higher mean than group j
766    pub fn prob_group_higher(&self, group_i: usize, group_j: usize) -> Result<f64> {
767        if group_i >= self.n_groups || group_j >= self.n_groups {
768            return Err(StatsError::InvalidArgument(
769                "Group indices out of bounds".to_string(),
770            ));
771        }
772
773        let mut count = 0;
774        for iter in 0..self.n_iter {
775            if self.group_means_samples[[iter, group_i]] > self.group_means_samples[[iter, group_j]]
776            {
777                count += 1;
778            }
779        }
780
781        Ok(count as f64 / self.n_iter as f64)
782    }
783}
784
785// Helper functions
786
787/// Sample from multivariate normal distribution
788#[allow(dead_code)]
789fn sample_multivariate_normal<R: Rng + ?Sized>(
790    mean: &Array1<f64>,
791    covariance: &Array2<f64>,
792    rng: &mut R,
793) -> Result<Array1<f64>> {
794    let dim = mean.len();
795    let normal = Normal::new(0.0, 1.0)
796        .map_err(|e| StatsError::ComputationError(format!("Failed to create normal: {}", e)))?;
797
798    // Sample from standard normal
799    let z = Array1::from_shape_fn(dim, |_| normal.sample(rng));
800
801    // Cholesky decomposition (simplified - use diagonal for now)
802    let mut sample = Array1::zeros(dim);
803    for i in 0..dim {
804        sample[i] = mean[i] + z[i] * covariance[[i, i]].sqrt();
805    }
806
807    Ok(sample)
808}
809
810/// Compute outer product of a vector
811#[allow(dead_code)]
812fn outer_product(v: &Array1<f64>) -> Array2<f64> {
813    let n = v.len();
814    let mut result = Array2::zeros((n, n));
815    for i in 0..n {
816        for j in 0..n {
817            result[[i, j]] = v[i] * v[j];
818        }
819    }
820    result
821}