1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Term {
16 pub name: String,
17 pub values: Vec<f64>,
18}
19
20#[derive(Debug, Clone)]
27pub struct GlmSpec {
28 pub terms: Vec<Term>,
29 pub family: String,
30 pub link: String,
31}
32
33impl GlmSpec {
34 pub fn new(terms: Vec<Term>) -> Self {
35 Self {
36 terms,
37 family: "gaussian".into(),
38 link: "identity".into(),
39 }
40 }
41
42 pub fn from_rows(
44 data: &[HashMap<String, f64>],
45 x_names: &[String],
46 model: &serde_json::Value,
47 ) -> Self {
48 let family = model
49 .get("family")
50 .and_then(|v| v.as_str())
51 .unwrap_or("gaussian")
52 .to_string();
53 let link = model
54 .get("link")
55 .and_then(|v| v.as_str())
56 .unwrap_or("identity")
57 .to_string();
58
59 let mut terms = Vec::new();
60 for name in x_names {
61 if name == "intercept" || name == "1" {
62 terms.push(Term {
63 name: "intercept".into(),
64 values: vec![1.0; data.len()],
65 });
66 } else {
67 let values: Vec<f64> = data
68 .iter()
69 .map(|row| row.get(name).copied().unwrap_or(0.0))
70 .collect();
71 terms.push(Term {
72 name: name.clone(),
73 values,
74 });
75 }
76 }
77
78 Self {
79 terms,
80 family,
81 link,
82 }
83 }
84
85 pub fn design_matrix(&self) -> (Vec<String>, Vec<Vec<f64>>) {
87 let names: Vec<String> = self.terms.iter().map(|t| t.name.clone()).collect();
88 let columns: Vec<Vec<f64>> = self.terms.iter().map(|t| t.values.clone()).collect();
89 (names, columns)
90 }
91
92 pub fn n_obs(&self) -> usize {
93 self.terms.first().map_or(0, |t| t.values.len())
94 }
95 pub fn n_predictors(&self) -> usize {
96 self.terms.len()
97 }
98
99 pub fn x(&self) -> &[Term] {
101 &self.terms
102 }
103
104 pub fn with_random_effects(self, _z_terms: Vec<Term>) -> Self {
106 self
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct MetaAnalysisSpec {
114 pub terms: Vec<Term>,
115}
116
117impl MetaAnalysisSpec {
118 pub fn new(terms: Vec<Term>) -> Self {
119 Self { terms }
120 }
121
122 pub fn from_rows(data: &[HashMap<String, f64>], x_names: &[String]) -> Self {
123 let terms: Vec<Term> = x_names
124 .iter()
125 .map(|name| {
126 let values: Vec<f64> = data
127 .iter()
128 .map(|row| row.get(name).copied().unwrap_or(0.0))
129 .collect();
130 Term {
131 name: name.clone(),
132 values,
133 }
134 })
135 .collect();
136 Self { terms }
137 }
138}
139
140#[must_use]
142pub fn dummies_to_vec(
143 condition_list: &[String],
144 all_columns: &[String],
145 weights: &[f64],
146) -> Vec<f64> {
147 let mut vec = vec![0.0; all_columns.len()];
148 for (cond, &w) in condition_list.iter().zip(weights) {
149 if let Some(idx) = all_columns.iter().position(|c| c == cond) {
150 vec[idx] = w;
151 }
152 }
153 vec
154}
155
156#[must_use]
158pub fn compute_vif(columns: &[Vec<f64>]) -> Vec<f64> {
159 let n_cols = columns.len();
160 if n_cols < 2 {
161 return vec![1.0; n_cols];
162 }
163 let n_rows = columns.first().map_or(0, std::vec::Vec::len);
164 if n_rows < 2 {
165 return vec![1.0; n_cols];
166 }
167
168 (0..n_cols)
170 .map(|i| {
171 let y = &columns[i];
172 let x_others: Vec<&Vec<f64>> = columns
173 .iter()
174 .enumerate()
175 .filter(|(j, _)| *j != i)
176 .map(|(_, c)| c)
177 .collect();
178
179 let y_mean: f64 = y.iter().sum::<f64>() / n_rows as f64;
181 let ss_tot: f64 = y.iter().map(|v| (v - y_mean).powi(2)).sum();
182 if ss_tot < 1e-15 {
183 return 1.0;
184 }
185
186 let mut ss_res = ss_tot;
188 for other in &x_others {
189 let o_mean: f64 = other.iter().sum::<f64>() / n_rows as f64;
190 let cov: f64 = y
191 .iter()
192 .zip(other.iter())
193 .map(|(a, b)| (a - y_mean) * (b - o_mean))
194 .sum::<f64>()
195 / n_rows as f64;
196 let o_var: f64 =
197 other.iter().map(|v| (v - o_mean).powi(2)).sum::<f64>() / n_rows as f64;
198 if o_var > 1e-15 {
199 let r = cov / (ss_tot / n_rows as f64).sqrt() / o_var.sqrt();
200 ss_res -= r.powi(2) * ss_tot;
201 }
202 }
203 let r_sq = 1.0 - ss_res / ss_tot;
204 if r_sq >= 1.0 {
205 return f64::INFINITY;
206 }
207 1.0 / (1.0 - r_sq)
208 })
209 .collect()
210}
211
212#[must_use]
214pub fn format_design_matrix(names: &[String], columns: &[Vec<f64>], max_rows: usize) -> String {
215 let n_rows = columns.first().map_or(0, std::vec::Vec::len);
216 let show = n_rows.min(max_rows);
217 let mut lines = Vec::new();
218
219 let header: String = names
221 .iter()
222 .map(|n| format!("{:>10}", &n[..n.len().min(10)]))
223 .collect::<Vec<_>>()
224 .join(" ");
225 lines.push(header);
226 lines.push("-".repeat(names.len() * 11));
227
228 for i in 0..show {
229 let row: String = columns
230 .iter()
231 .map(|col| {
232 let v = col.get(i).copied().unwrap_or(0.0);
233 if v == v.round() && v.abs() < 1000.0 {
234 format!("{v:>10.0}")
235 } else {
236 format!("{v:>10.3}")
237 }
238 })
239 .collect::<Vec<_>>()
240 .join(" ");
241 lines.push(row);
242 }
243
244 if n_rows > max_rows {
245 lines.push(format!("... ({} more rows)", n_rows - max_rows));
246 }
247 lines.join("\n")
248}
249
250#[must_use]
252pub fn format_correlation_matrix(names: &[String], columns: &[Vec<f64>]) -> String {
253 let n = columns.len();
254 let mut corr = vec![vec![0.0f64; n]; n];
255
256 for i in 0..n {
257 for j in 0..n {
258 corr[i][j] = pearson_r(&columns[i], &columns[j]);
259 }
260 }
261
262 let blocks = [' ', '░', '▒', '▓', '█'];
263 let mut lines = Vec::new();
264
265 let header: String = std::iter::once(format!("{:>10}", ""))
267 .chain(names.iter().map(|n| format!("{:>5}", &n[..n.len().min(5)])))
268 .collect::<Vec<_>>()
269 .join("");
270 lines.push(header);
271
272 for i in 0..n {
273 let row: String = std::iter::once(format!("{:>10}", &names[i][..names[i].len().min(10)]))
274 .chain((0..n).map(|j| {
275 let r = corr[i][j].abs();
276 let idx = (r * 4.0).round().min(4.0) as usize;
277 format!(" {:>1} ", blocks[idx])
278 }))
279 .collect::<Vec<_>>()
280 .join("");
281 lines.push(row);
282 }
283 lines.join("\n")
284}
285
286fn pearson_r(x: &[f64], y: &[f64]) -> f64 {
287 let n = x.len().min(y.len());
288 if n < 2 {
289 return 0.0;
290 }
291 let mx: f64 = x.iter().take(n).sum::<f64>() / n as f64;
292 let my: f64 = y.iter().take(n).sum::<f64>() / n as f64;
293 let mut num = 0.0;
294 let mut dx2 = 0.0;
295 let mut dy2 = 0.0;
296 for i in 0..n {
297 let dx = x[i] - mx;
298 let dy = y[i] - my;
299 num += dx * dy;
300 dx2 += dx * dx;
301 dy2 += dy * dy;
302 }
303 let denom = (dx2 * dy2).sqrt();
304 if denom < 1e-15 { 0.0 } else { num / denom }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_dummies_to_vec() {
313 let cols = vec!["a".into(), "b".into(), "c".into()];
314 let conds = vec!["a".into(), "c".into()];
315 let weights = vec![1.0, -1.0];
316 let result = dummies_to_vec(&conds, &cols, &weights);
317 assert_eq!(result, vec![1.0, 0.0, -1.0]);
318 }
319
320 #[test]
321 fn test_compute_vif() {
322 let c1: Vec<f64> = (0..100).map(|i| i as f64).collect();
324 let c2: Vec<f64> = (0..100).map(|i| (i * 7 % 13) as f64).collect();
325 let vifs = compute_vif(&[c1, c2]);
326 assert!(
327 vifs[0] < 5.0,
328 "VIF should be low for independent vars, got {}",
329 vifs[0]
330 );
331 }
332}