1use bids_core::entities::StringEntities;
10use bids_variables::collections::VariableCollection;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone)]
19pub struct StatsModelsEdge {
20 pub source: String,
21 pub destination: String,
22 pub filter: StringEntities,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ContrastInfo {
33 pub name: String,
34 pub conditions: Vec<String>,
35 pub weights: Vec<f64>,
36 pub test: Option<String>,
37 pub entities: StringEntities,
38}
39
40#[derive(Debug, Clone)]
52pub struct StatsModelsNode {
53 pub level: String,
54 pub name: String,
55 pub model: serde_json::Value,
56 pub group_by: Vec<String>,
57 pub transformations: Option<crate::transformations::TransformSpec>,
58 pub contrasts: Vec<serde_json::Value>,
59 pub dummy_contrasts: Option<serde_json::Value>,
60 pub children: Vec<StatsModelsEdge>,
61 pub parents: Vec<StatsModelsEdge>,
62 collections: Vec<VariableCollection>,
63}
64
65impl StatsModelsNode {
66 pub fn new(
67 level: &str,
68 name: &str,
69 model: serde_json::Value,
70 group_by: Vec<String>,
71 transformations: Option<crate::transformations::TransformSpec>,
72 contrasts: Vec<serde_json::Value>,
73 dummy_contrasts: Option<serde_json::Value>,
74 ) -> Self {
75 Self {
76 level: level.to_lowercase(),
77 name: name.into(),
78 model,
79 group_by,
80 transformations,
81 contrasts,
82 dummy_contrasts,
83 children: Vec::new(),
84 parents: Vec::new(),
85 collections: Vec::new(),
86 }
87 }
88
89 pub fn add_child(&mut self, edge: StatsModelsEdge) {
90 self.children.push(edge);
91 }
92 pub fn add_parent(&mut self, edge: StatsModelsEdge) {
93 self.parents.push(edge);
94 }
95
96 pub fn add_collections(&mut self, collections: Vec<VariableCollection>) {
97 self.collections.extend(collections);
98 }
99
100 pub fn get_collections(&self) -> &[VariableCollection] {
101 &self.collections
102 }
103
104 pub fn run(
106 &self,
107 inputs: &[ContrastInfo],
108 _force_dense: bool,
109 _sampling_rate: &str,
110 ) -> Vec<StatsModelsNodeOutput> {
111 let mut results = Vec::new();
113
114 if self.collections.is_empty() && inputs.is_empty() {
115 return results;
116 }
117
118 for collection in &self.collections {
120 let mut coll = collection.clone();
121
122 if let Some(ref spec) = self.transformations {
124 crate::transformations::apply_transformations(&mut coll, spec);
125 }
126
127 let x_vars: Vec<String> = self
129 .model
130 .get("x")
131 .or_else(|| self.model.get("X"))
132 .and_then(|v| v.as_array())
133 .map(|arr| {
134 arr.iter()
135 .filter_map(|v| {
136 if v.is_number() {
137 Some("intercept".into())
138 } else {
139 v.as_str().map(String::from)
140 }
141 })
142 .collect()
143 })
144 .unwrap_or_default();
145
146 let mut contrasts = Vec::new();
148
149 if let Some(ref dc) = self.dummy_contrasts {
151 let test = dc
152 .get("test")
153 .or(dc.get("Test"))
154 .and_then(|v| v.as_str())
155 .unwrap_or("t")
156 .to_string();
157 for var_name in &x_vars {
158 if var_name == "intercept" {
159 continue;
160 }
161 contrasts.push(ContrastInfo {
162 name: var_name.clone(),
163 conditions: vec![var_name.clone()],
164 weights: vec![1.0],
165 test: Some(test.clone()),
166 entities: collection.entities.clone(),
167 });
168 }
169 }
170
171 for con_spec in &self.contrasts {
173 let name = con_spec
174 .get("name")
175 .or(con_spec.get("Name"))
176 .and_then(|v| v.as_str())
177 .unwrap_or("unnamed");
178 let conditions: Vec<String> = con_spec
179 .get("condition_list")
180 .or(con_spec.get("ConditionList"))
181 .and_then(|v| v.as_array())
182 .map(|arr| {
183 arr.iter()
184 .filter_map(|v| v.as_str().map(String::from))
185 .collect()
186 })
187 .unwrap_or_default();
188 let weights: Vec<f64> = con_spec
189 .get("weights")
190 .or(con_spec.get("Weights"))
191 .and_then(|v| v.as_array())
192 .map(|arr| arr.iter().filter_map(serde_json::Value::as_f64).collect())
193 .unwrap_or_default();
194 let test = con_spec
195 .get("test")
196 .or(con_spec.get("Test"))
197 .and_then(|v| v.as_str())
198 .map(String::from);
199
200 let mut entities = collection.entities.clone();
201 entities.insert("contrast".into(), name.into());
202
203 contrasts.push(ContrastInfo {
204 name: name.into(),
205 conditions,
206 weights,
207 test,
208 entities,
209 });
210 }
211
212 let dm = if !x_vars.is_empty() {
214 let mut cols = Vec::new();
215 let mut col_names = Vec::new();
216 for var_name in &x_vars {
217 if var_name == "intercept" {
218 let n = coll
219 .variables
220 .values()
221 .next()
222 .map_or(0, bids_variables::SimpleVariable::len);
223 cols.push(vec![1.0; n]);
224 col_names.push("intercept".into());
225 } else if let Some(var) = coll.variables.get(var_name) {
226 cols.push(var.values.clone());
227 col_names.push(var_name.clone());
228 }
229 }
230 if !cols.is_empty() {
231 Some((col_names, cols))
232 } else {
233 None
234 }
235 } else {
236 None
237 };
238
239 results.push(StatsModelsNodeOutput {
240 node_name: self.name.clone(),
241 entities: collection.entities.clone(),
242 x_variables: x_vars.clone(),
243 contrasts,
244 design_matrix: dm,
245 });
246 }
247
248 results
249 }
250}
251
252#[derive(Debug, Clone)]
258pub struct StatsModelsNodeOutput {
259 pub node_name: String,
260 pub entities: StringEntities,
261 pub x_variables: Vec<String>,
262 pub contrasts: Vec<ContrastInfo>,
263 pub design_matrix: Option<(Vec<String>, Vec<Vec<f64>>)>,
265}
266
267impl StatsModelsNodeOutput {
268 pub fn x_columns(&self) -> &[String] {
270 &self.x_variables
271 }
272}
273
274pub fn build_groups(
277 entity_maps: &[StringEntities],
278 group_by: &[String],
279) -> std::collections::HashMap<Vec<(String, String)>, Vec<usize>> {
280 let mut groups: std::collections::HashMap<Vec<(String, String)>, Vec<usize>> =
281 std::collections::HashMap::new();
282
283 if group_by.is_empty() {
284 groups.insert(vec![], (0..entity_maps.len()).collect());
285 return groups;
286 }
287
288 let mut unique_vals: std::collections::HashMap<&str, Vec<String>> =
290 std::collections::HashMap::new();
291 for col in group_by {
292 let vals: std::collections::BTreeSet<String> = entity_maps
293 .iter()
294 .filter_map(|e| e.get(col.as_str()).cloned())
295 .collect();
296 unique_vals.insert(col.as_str(), vals.into_iter().collect());
297 }
298
299 for (i, ents) in entity_maps.iter().enumerate() {
300 let mut base: Vec<(String, String)> = Vec::new();
301 let mut missing: Vec<&str> = Vec::new();
302
303 for col in group_by {
304 if let Some(val) = ents.get(col.as_str()) {
305 base.push((col.clone(), val.clone()));
306 } else {
307 missing.push(col.as_str());
308 }
309 }
310
311 if missing.is_empty() {
312 base.sort();
313 groups.entry(base).or_default().push(i);
314 } else {
315 let mut combos = vec![base.clone()];
317 for col in &missing {
318 if let Some(vals) = unique_vals.get(col) {
319 let mut new_combos = Vec::new();
320 for combo in &combos {
321 for val in vals {
322 let mut c = combo.clone();
323 c.push((col.to_string(), val.clone()));
324 new_combos.push(c);
325 }
326 }
327 combos = new_combos;
328 }
329 }
330 for mut combo in combos {
331 combo.sort();
332 groups.entry(combo).or_default().push(i);
333 }
334 }
335 }
336
337 groups
338}