Skip to main content

bids_modeling/
spec.rs

1//! Statistical model specifications for BIDS-StatsModels.
2//!
3//! Provides data structures for GLM and meta-analysis model specifications,
4//! design matrix construction, VIF computation, and formatted output for
5//! inspection and reporting.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// A single term (predictor column) in a statistical model's design matrix.
11///
12/// Each term has a name (e.g., `"trial_type.face"`, `"intercept"`) and a
13/// vector of numeric values — one per observation.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Term {
16    pub name: String,
17    pub values: Vec<f64>,
18}
19
20/// General Linear Model (GLM) specification.
21///
22/// Contains the design matrix terms and the distributional family/link
23/// function. Can be constructed from raw data rows and a model specification,
24/// and provides methods for extracting the design matrix, computing
25/// dimensions, and adding random effects (placeholder for GLMM).
26#[derive(Debug, Clone)]
27pub struct GlmSpec {
28    pub terms: Vec<Term>,
29    pub family: String,
30    pub link: String,
31}
32
33impl GlmSpec {
34    pub fn new(terms: Vec<Term>) -> Self {
35        Self {
36            terms,
37            family: "gaussian".into(),
38            link: "identity".into(),
39        }
40    }
41
42    /// Build from data rows and a model specification dict.
43    pub fn from_rows(
44        data: &[HashMap<String, f64>],
45        x_names: &[String],
46        model: &serde_json::Value,
47    ) -> Self {
48        let family = model
49            .get("family")
50            .and_then(|v| v.as_str())
51            .unwrap_or("gaussian")
52            .to_string();
53        let link = model
54            .get("link")
55            .and_then(|v| v.as_str())
56            .unwrap_or("identity")
57            .to_string();
58
59        let mut terms = Vec::new();
60        for name in x_names {
61            if name == "intercept" || name == "1" {
62                terms.push(Term {
63                    name: "intercept".into(),
64                    values: vec![1.0; data.len()],
65                });
66            } else {
67                let values: Vec<f64> = data
68                    .iter()
69                    .map(|row| row.get(name).copied().unwrap_or(0.0))
70                    .collect();
71                terms.push(Term {
72                    name: name.clone(),
73                    values,
74                });
75            }
76        }
77
78        Self {
79            terms,
80            family,
81            link,
82        }
83    }
84
85    /// Design matrix as (column_names, column_data).
86    pub fn design_matrix(&self) -> (Vec<String>, Vec<Vec<f64>>) {
87        let names: Vec<String> = self.terms.iter().map(|t| t.name.clone()).collect();
88        let columns: Vec<Vec<f64>> = self.terms.iter().map(|t| t.values.clone()).collect();
89        (names, columns)
90    }
91
92    pub fn n_obs(&self) -> usize {
93        self.terms.first().map_or(0, |t| t.values.len())
94    }
95    pub fn n_predictors(&self) -> usize {
96        self.terms.len()
97    }
98
99    /// Get the design matrix X as column vectors.
100    pub fn x(&self) -> &[Term] {
101        &self.terms
102    }
103
104    /// Add random effects (Z matrix) — placeholder for GLMM.
105    pub fn with_random_effects(self, _z_terms: Vec<Term>) -> Self {
106        // Random effects support is a placeholder
107        self
108    }
109}
110
111/// Meta-analysis specification.
112#[derive(Debug, Clone)]
113pub struct MetaAnalysisSpec {
114    pub terms: Vec<Term>,
115}
116
117impl MetaAnalysisSpec {
118    pub fn new(terms: Vec<Term>) -> Self {
119        Self { terms }
120    }
121
122    pub fn from_rows(data: &[HashMap<String, f64>], x_names: &[String]) -> Self {
123        let terms: Vec<Term> = x_names
124            .iter()
125            .map(|name| {
126                let values: Vec<f64> = data
127                    .iter()
128                    .map(|row| row.get(name).copied().unwrap_or(0.0))
129                    .collect();
130                Term {
131                    name: name.clone(),
132                    values,
133                }
134            })
135            .collect();
136        Self { terms }
137    }
138}
139
140/// Convert dummy-coded columns to a weight vector for a contrast.
141#[must_use]
142pub fn dummies_to_vec(
143    condition_list: &[String],
144    all_columns: &[String],
145    weights: &[f64],
146) -> Vec<f64> {
147    let mut vec = vec![0.0; all_columns.len()];
148    for (cond, &w) in condition_list.iter().zip(weights) {
149        if let Some(idx) = all_columns.iter().position(|c| c == cond) {
150            vec[idx] = w;
151        }
152    }
153    vec
154}
155
156/// Compute Variance Inflation Factor for each column.
157#[must_use]
158pub fn compute_vif(columns: &[Vec<f64>]) -> Vec<f64> {
159    let n_cols = columns.len();
160    if n_cols < 2 {
161        return vec![1.0; n_cols];
162    }
163    let n_rows = columns.first().map_or(0, std::vec::Vec::len);
164    if n_rows < 2 {
165        return vec![1.0; n_cols];
166    }
167
168    // Simple VIF: for each predictor, regress on all others, VIF = 1/(1-R²)
169    (0..n_cols)
170        .map(|i| {
171            let y = &columns[i];
172            let x_others: Vec<&Vec<f64>> = columns
173                .iter()
174                .enumerate()
175                .filter(|(j, _)| *j != i)
176                .map(|(_, c)| c)
177                .collect();
178
179            // Simple: compute R² using correlation-based approximation
180            let y_mean: f64 = y.iter().sum::<f64>() / n_rows as f64;
181            let ss_tot: f64 = y.iter().map(|v| (v - y_mean).powi(2)).sum();
182            if ss_tot < 1e-15 {
183                return 1.0;
184            }
185
186            // Predicted = mean of correlations * other vars (simplified)
187            let mut ss_res = ss_tot;
188            for other in &x_others {
189                let o_mean: f64 = other.iter().sum::<f64>() / n_rows as f64;
190                let cov: f64 = y
191                    .iter()
192                    .zip(other.iter())
193                    .map(|(a, b)| (a - y_mean) * (b - o_mean))
194                    .sum::<f64>()
195                    / n_rows as f64;
196                let o_var: f64 =
197                    other.iter().map(|v| (v - o_mean).powi(2)).sum::<f64>() / n_rows as f64;
198                if o_var > 1e-15 {
199                    let r = cov / (ss_tot / n_rows as f64).sqrt() / o_var.sqrt();
200                    ss_res -= r.powi(2) * ss_tot;
201                }
202            }
203            let r_sq = 1.0 - ss_res / ss_tot;
204            if r_sq >= 1.0 {
205                return f64::INFINITY;
206            }
207            1.0 / (1.0 - r_sq)
208        })
209        .collect()
210}
211
212/// Format a design matrix as an aligned text table.
213#[must_use]
214pub fn format_design_matrix(names: &[String], columns: &[Vec<f64>], max_rows: usize) -> String {
215    let n_rows = columns.first().map_or(0, std::vec::Vec::len);
216    let show = n_rows.min(max_rows);
217    let mut lines = Vec::new();
218
219    // Header
220    let header: String = names
221        .iter()
222        .map(|n| format!("{:>10}", &n[..n.len().min(10)]))
223        .collect::<Vec<_>>()
224        .join(" ");
225    lines.push(header);
226    lines.push("-".repeat(names.len() * 11));
227
228    for i in 0..show {
229        let row: String = columns
230            .iter()
231            .map(|col| {
232                let v = col.get(i).copied().unwrap_or(0.0);
233                if v == v.round() && v.abs() < 1000.0 {
234                    format!("{v:>10.0}")
235                } else {
236                    format!("{v:>10.3}")
237                }
238            })
239            .collect::<Vec<_>>()
240            .join(" ");
241        lines.push(row);
242    }
243
244    if n_rows > max_rows {
245        lines.push(format!("... ({} more rows)", n_rows - max_rows));
246    }
247    lines.join("\n")
248}
249
250/// Format a correlation matrix as text with Unicode intensity blocks.
251#[must_use]
252pub fn format_correlation_matrix(names: &[String], columns: &[Vec<f64>]) -> String {
253    let n = columns.len();
254    let mut corr = vec![vec![0.0f64; n]; n];
255
256    for i in 0..n {
257        for j in 0..n {
258            corr[i][j] = pearson_r(&columns[i], &columns[j]);
259        }
260    }
261
262    let blocks = [' ', '░', '▒', '▓', '█'];
263    let mut lines = Vec::new();
264
265    // Header
266    let header: String = std::iter::once(format!("{:>10}", ""))
267        .chain(names.iter().map(|n| format!("{:>5}", &n[..n.len().min(5)])))
268        .collect::<Vec<_>>()
269        .join("");
270    lines.push(header);
271
272    for i in 0..n {
273        let row: String = std::iter::once(format!("{:>10}", &names[i][..names[i].len().min(10)]))
274            .chain((0..n).map(|j| {
275                let r = corr[i][j].abs();
276                let idx = (r * 4.0).round().min(4.0) as usize;
277                format!("  {:>1}  ", blocks[idx])
278            }))
279            .collect::<Vec<_>>()
280            .join("");
281        lines.push(row);
282    }
283    lines.join("\n")
284}
285
286fn pearson_r(x: &[f64], y: &[f64]) -> f64 {
287    let n = x.len().min(y.len());
288    if n < 2 {
289        return 0.0;
290    }
291    let mx: f64 = x.iter().take(n).sum::<f64>() / n as f64;
292    let my: f64 = y.iter().take(n).sum::<f64>() / n as f64;
293    let mut num = 0.0;
294    let mut dx2 = 0.0;
295    let mut dy2 = 0.0;
296    for i in 0..n {
297        let dx = x[i] - mx;
298        let dy = y[i] - my;
299        num += dx * dy;
300        dx2 += dx * dx;
301        dy2 += dy * dy;
302    }
303    let denom = (dx2 * dy2).sqrt();
304    if denom < 1e-15 { 0.0 } else { num / denom }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_dummies_to_vec() {
313        let cols = vec!["a".into(), "b".into(), "c".into()];
314        let conds = vec!["a".into(), "c".into()];
315        let weights = vec![1.0, -1.0];
316        let result = dummies_to_vec(&conds, &cols, &weights);
317        assert_eq!(result, vec![1.0, 0.0, -1.0]);
318    }
319
320    #[test]
321    fn test_compute_vif() {
322        // Independent columns should have VIF near 1
323        let c1: Vec<f64> = (0..100).map(|i| i as f64).collect();
324        let c2: Vec<f64> = (0..100).map(|i| (i * 7 % 13) as f64).collect();
325        let vifs = compute_vif(&[c1, c2]);
326        assert!(
327            vifs[0] < 5.0,
328            "VIF should be low for independent vars, got {}",
329            vifs[0]
330        );
331    }
332}