Skip to main content

bids_modeling/
node.rs

1//! Analysis nodes and edges in a BIDS-StatsModels graph.
2//!
3//! Each node represents a statistical analysis step at a specific level of
4//! the BIDS hierarchy. Nodes contain a model specification, variable
5//! transformations, contrast definitions, and grouping criteria. They
6//! produce [`StatsModelsNodeOutput`]s containing design matrices and
7//! contrast information.
8
9use bids_core::entities::StringEntities;
10use bids_variables::collections::VariableCollection;
11use serde::{Deserialize, Serialize};
12
13/// A directed edge between two nodes in the stats model graph.
14///
15/// Edges define data flow from a source node to a destination node,
16/// optionally filtering which contrasts/outputs are passed through
17/// based on entity values.
18#[derive(Debug, Clone)]
19pub struct StatsModelsEdge {
20    pub source: String,
21    pub destination: String,
22    pub filter: StringEntities,
23}
24
25/// Information about a statistical contrast.
26///
27/// Defines a linear combination of model terms to be tested. Contains
28/// the contrast name, the list of conditions involved, their weights,
29/// the statistical test type (t, F), and the BIDS entities identifying
30/// the data this contrast applies to.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ContrastInfo {
33    pub name: String,
34    pub conditions: Vec<String>,
35    pub weights: Vec<f64>,
36    pub test: Option<String>,
37    pub entities: StringEntities,
38}
39
40/// A single analysis node in a BIDS-StatsModel graph.
41///
42/// Each node operates at a specific level (run, session, subject, dataset)
43/// and defines:
44/// - A statistical model specification (GLM or meta-analysis)
45/// - Variable transformations to apply before modeling
46/// - Explicit contrasts and/or dummy contrasts
47/// - Grouping criteria for splitting data into independent analyses
48///
49/// Corresponds to a single entry in the `"Nodes"` array of a BIDS-StatsModels
50/// JSON specification.
51#[derive(Debug, Clone)]
52pub struct StatsModelsNode {
53    pub level: String,
54    pub name: String,
55    pub model: serde_json::Value,
56    pub group_by: Vec<String>,
57    pub transformations: Option<crate::transformations::TransformSpec>,
58    pub contrasts: Vec<serde_json::Value>,
59    pub dummy_contrasts: Option<serde_json::Value>,
60    pub children: Vec<StatsModelsEdge>,
61    pub parents: Vec<StatsModelsEdge>,
62    collections: Vec<VariableCollection>,
63}
64
65impl StatsModelsNode {
66    pub fn new(
67        level: &str,
68        name: &str,
69        model: serde_json::Value,
70        group_by: Vec<String>,
71        transformations: Option<crate::transformations::TransformSpec>,
72        contrasts: Vec<serde_json::Value>,
73        dummy_contrasts: Option<serde_json::Value>,
74    ) -> Self {
75        Self {
76            level: level.to_lowercase(),
77            name: name.into(),
78            model,
79            group_by,
80            transformations,
81            contrasts,
82            dummy_contrasts,
83            children: Vec::new(),
84            parents: Vec::new(),
85            collections: Vec::new(),
86        }
87    }
88
89    pub fn add_child(&mut self, edge: StatsModelsEdge) {
90        self.children.push(edge);
91    }
92    pub fn add_parent(&mut self, edge: StatsModelsEdge) {
93        self.parents.push(edge);
94    }
95
96    pub fn add_collections(&mut self, collections: Vec<VariableCollection>) {
97        self.collections.extend(collections);
98    }
99
100    pub fn get_collections(&self) -> &[VariableCollection] {
101        &self.collections
102    }
103
104    /// Run this node, producing outputs.
105    pub fn run(
106        &self,
107        inputs: &[ContrastInfo],
108        _force_dense: bool,
109        _sampling_rate: &str,
110    ) -> Vec<StatsModelsNodeOutput> {
111        // Group collections and inputs by group_by entities
112        let mut results = Vec::new();
113
114        if self.collections.is_empty() && inputs.is_empty() {
115            return results;
116        }
117
118        // For each collection, apply transformations and build output
119        for collection in &self.collections {
120            let mut coll = collection.clone();
121
122            // Apply transformations
123            if let Some(ref spec) = self.transformations {
124                crate::transformations::apply_transformations(&mut coll, spec);
125            }
126
127            // Extract X variable names from model
128            let x_vars: Vec<String> = self
129                .model
130                .get("x")
131                .or_else(|| self.model.get("X"))
132                .and_then(|v| v.as_array())
133                .map(|arr| {
134                    arr.iter()
135                        .filter_map(|v| {
136                            if v.is_number() {
137                                Some("intercept".into())
138                            } else {
139                                v.as_str().map(String::from)
140                            }
141                        })
142                        .collect()
143                })
144                .unwrap_or_default();
145
146            // Build contrasts
147            let mut contrasts = Vec::new();
148
149            // Dummy contrasts
150            if let Some(ref dc) = self.dummy_contrasts {
151                let test = dc
152                    .get("test")
153                    .or(dc.get("Test"))
154                    .and_then(|v| v.as_str())
155                    .unwrap_or("t")
156                    .to_string();
157                for var_name in &x_vars {
158                    if var_name == "intercept" {
159                        continue;
160                    }
161                    contrasts.push(ContrastInfo {
162                        name: var_name.clone(),
163                        conditions: vec![var_name.clone()],
164                        weights: vec![1.0],
165                        test: Some(test.clone()),
166                        entities: collection.entities.clone(),
167                    });
168                }
169            }
170
171            // Explicit contrasts
172            for con_spec in &self.contrasts {
173                let name = con_spec
174                    .get("name")
175                    .or(con_spec.get("Name"))
176                    .and_then(|v| v.as_str())
177                    .unwrap_or("unnamed");
178                let conditions: Vec<String> = con_spec
179                    .get("condition_list")
180                    .or(con_spec.get("ConditionList"))
181                    .and_then(|v| v.as_array())
182                    .map(|arr| {
183                        arr.iter()
184                            .filter_map(|v| v.as_str().map(String::from))
185                            .collect()
186                    })
187                    .unwrap_or_default();
188                let weights: Vec<f64> = con_spec
189                    .get("weights")
190                    .or(con_spec.get("Weights"))
191                    .and_then(|v| v.as_array())
192                    .map(|arr| arr.iter().filter_map(serde_json::Value::as_f64).collect())
193                    .unwrap_or_default();
194                let test = con_spec
195                    .get("test")
196                    .or(con_spec.get("Test"))
197                    .and_then(|v| v.as_str())
198                    .map(String::from);
199
200                let mut entities = collection.entities.clone();
201                entities.insert("contrast".into(), name.into());
202
203                contrasts.push(ContrastInfo {
204                    name: name.into(),
205                    conditions,
206                    weights,
207                    test,
208                    entities,
209                });
210            }
211
212            // Build design matrix from collection
213            let dm = if !x_vars.is_empty() {
214                let mut cols = Vec::new();
215                let mut col_names = Vec::new();
216                for var_name in &x_vars {
217                    if var_name == "intercept" {
218                        let n = coll
219                            .variables
220                            .values()
221                            .next()
222                            .map_or(0, bids_variables::SimpleVariable::len);
223                        cols.push(vec![1.0; n]);
224                        col_names.push("intercept".into());
225                    } else if let Some(var) = coll.variables.get(var_name) {
226                        cols.push(var.values.clone());
227                        col_names.push(var_name.clone());
228                    }
229                }
230                if !cols.is_empty() {
231                    Some((col_names, cols))
232                } else {
233                    None
234                }
235            } else {
236                None
237            };
238
239            results.push(StatsModelsNodeOutput {
240                node_name: self.name.clone(),
241                entities: collection.entities.clone(),
242                x_variables: x_vars.clone(),
243                contrasts,
244                design_matrix: dm,
245            });
246        }
247
248        results
249    }
250}
251
252/// Output produced by running a stats model node.
253///
254/// Contains the design matrix, contrasts, and entity metadata for a single
255/// group within a node's analysis. Multiple outputs may be produced per node
256/// when data is split by grouping variables.
257#[derive(Debug, Clone)]
258pub struct StatsModelsNodeOutput {
259    pub node_name: String,
260    pub entities: StringEntities,
261    pub x_variables: Vec<String>,
262    pub contrasts: Vec<ContrastInfo>,
263    /// Design matrix: (column_names, data_rows).
264    pub design_matrix: Option<(Vec<String>, Vec<Vec<f64>>)>,
265}
266
267impl StatsModelsNodeOutput {
268    /// Get the design matrix column names (X).
269    pub fn x_columns(&self) -> &[String] {
270        &self.x_variables
271    }
272}
273
274/// Build groups from a list of entity maps, grouping by specified keys.
275/// Returns map from group key to indices.
276pub fn build_groups(
277    entity_maps: &[StringEntities],
278    group_by: &[String],
279) -> std::collections::HashMap<Vec<(String, String)>, Vec<usize>> {
280    let mut groups: std::collections::HashMap<Vec<(String, String)>, Vec<usize>> =
281        std::collections::HashMap::new();
282
283    if group_by.is_empty() {
284        groups.insert(vec![], (0..entity_maps.len()).collect());
285        return groups;
286    }
287
288    // Get unique values for each grouping variable
289    let mut unique_vals: std::collections::HashMap<&str, Vec<String>> =
290        std::collections::HashMap::new();
291    for col in group_by {
292        let vals: std::collections::BTreeSet<String> = entity_maps
293            .iter()
294            .filter_map(|e| e.get(col.as_str()).cloned())
295            .collect();
296        unique_vals.insert(col.as_str(), vals.into_iter().collect());
297    }
298
299    for (i, ents) in entity_maps.iter().enumerate() {
300        let mut base: Vec<(String, String)> = Vec::new();
301        let mut missing: Vec<&str> = Vec::new();
302
303        for col in group_by {
304            if let Some(val) = ents.get(col.as_str()) {
305                base.push((col.clone(), val.clone()));
306            } else {
307                missing.push(col.as_str());
308            }
309        }
310
311        if missing.is_empty() {
312            base.sort();
313            groups.entry(base).or_default().push(i);
314        } else {
315            // Cartesian product of missing values
316            let mut combos = vec![base.clone()];
317            for col in &missing {
318                if let Some(vals) = unique_vals.get(col) {
319                    let mut new_combos = Vec::new();
320                    for combo in &combos {
321                        for val in vals {
322                            let mut c = combo.clone();
323                            c.push((col.to_string(), val.clone()));
324                            new_combos.push(c);
325                        }
326                    }
327                    combos = new_combos;
328                }
329            }
330            for mut combo in combos {
331                combo.sort();
332                groups.entry(combo).or_default().push(i);
333            }
334        }
335    }
336
337    groups
338}