Skip to main content

bids_modeling/
graph.rs

1//! BIDS-StatsModels directed acyclic graph.
2//!
3//! [`StatsModelsGraph`] is the top-level type for a BIDS-StatsModels
4//! specification. It contains analysis nodes (run/session/subject/dataset
5//! levels) connected by edges defining how data flows through the pipeline.
6
7use crate::node::{ContrastInfo, StatsModelsEdge, StatsModelsNode, StatsModelsNodeOutput};
8use crate::transformations::TransformSpec;
9use bids_core::error::{BidsError, Result};
10use bids_core::utils::convert_json_keys;
11use std::collections::HashMap;
12
13/// Rooted directed acyclic graph representing a BIDS-StatsModel specification.
14///
15/// A `StatsModelsGraph` is loaded from a JSON file conforming to the
16/// [BIDS-StatsModels](https://bids-standard.github.io/stats-models/)
17/// specification. It contains analysis nodes at different levels of the
18/// BIDS hierarchy (run, session, subject, dataset) connected by edges
19/// that define how contrasts and data flow between levels.
20///
21/// # Lifecycle
22///
23/// 1. **Load** — Parse from JSON file or value via [`from_file()`](Self::from_file)
24///    or [`from_json()`](Self::from_json)
25/// 2. **Validate** — Check structure via [`validate()`](Self::validate)
26/// 3. **Load data** — Populate nodes with variables from a layout via
27///    [`load_collections()`](Self::load_collections)
28/// 4. **Execute** — Run the analysis pipeline via [`run()`](Self::run)
29/// 5. **Export** — Generate DOT graph via [`write_graph()`](Self::write_graph)
30///
31/// Corresponds to PyBIDS' `StatsModelsGraph` class.
32#[derive(Debug)]
33pub struct StatsModelsGraph {
34    pub name: String,
35    pub description: String,
36    pub nodes: Vec<StatsModelsNode>,
37    node_map: HashMap<String, usize>,
38    pub edges: Vec<StatsModelsEdge>,
39    root_idx: usize,
40}
41
42impl StatsModelsGraph {
43    /// Load from a JSON model spec (path or parsed value).
44    pub fn from_json(model_json: &serde_json::Value) -> Result<Self> {
45        let model = convert_json_keys(model_json);
46        Self::from_parsed(&model)
47    }
48
49    /// Load from a file path.
50    pub fn from_file(path: &std::path::Path) -> Result<Self> {
51        let contents = std::fs::read_to_string(path)?;
52        let json: serde_json::Value = serde_json::from_str(&contents)?;
53        Self::from_json(&json)
54    }
55
56    fn from_parsed(model: &serde_json::Value) -> Result<Self> {
57        let name = model
58            .get("name")
59            .and_then(|v| v.as_str())
60            .unwrap_or("unnamed")
61            .into();
62        let description = model
63            .get("description")
64            .and_then(|v| v.as_str())
65            .unwrap_or("")
66            .into();
67
68        // Load nodes
69        let node_specs = model
70            .get("nodes")
71            .and_then(|v| v.as_array())
72            .ok_or_else(|| BidsError::Validation("Model must have 'nodes' array".into()))?;
73
74        let mut nodes = Vec::new();
75        let mut node_map = HashMap::new();
76
77        for spec in node_specs {
78            let level = spec.get("level").and_then(|v| v.as_str()).unwrap_or("run");
79            let node_name = spec
80                .get("name")
81                .and_then(|v| v.as_str())
82                .unwrap_or("unnamed");
83            let model_spec = spec
84                .get("model")
85                .cloned()
86                .unwrap_or(serde_json::Value::Null);
87            let group_by: Vec<String> = spec
88                .get("group_by")
89                .and_then(|v| v.as_array())
90                .map(|arr| {
91                    arr.iter()
92                        .filter_map(|v| v.as_str().map(String::from))
93                        .collect()
94                })
95                .unwrap_or_default();
96
97            let transformations = spec
98                .get("transformations")
99                .and_then(|v| serde_json::from_value::<TransformSpec>(v.clone()).ok());
100
101            let contrasts: Vec<serde_json::Value> = spec
102                .get("contrasts")
103                .and_then(|v| v.as_array())
104                .cloned()
105                .unwrap_or_default();
106            let dummy_contrasts = spec
107                .get("dummy_contrasts")
108                .cloned()
109                .filter(|v| !v.is_null() && *v != serde_json::Value::Bool(false));
110
111            node_map.insert(node_name.to_string(), nodes.len());
112            nodes.push(StatsModelsNode::new(
113                level,
114                node_name,
115                model_spec,
116                group_by,
117                transformations,
118                contrasts,
119                dummy_contrasts,
120            ));
121        }
122
123        // Load edges
124        let mut edges = Vec::new();
125        if let Some(edge_specs) = model.get("edges").and_then(|v| v.as_array()) {
126            for edge_spec in edge_specs {
127                let src = edge_spec
128                    .get("source")
129                    .and_then(|v| v.as_str())
130                    .unwrap_or("");
131                let dst = edge_spec
132                    .get("destination")
133                    .and_then(|v| v.as_str())
134                    .unwrap_or("");
135                let filter: HashMap<String, String> = edge_spec
136                    .get("filter")
137                    .and_then(|v| serde_json::from_value(v.clone()).ok())
138                    .unwrap_or_default();
139                edges.push(StatsModelsEdge {
140                    source: src.into(),
141                    destination: dst.into(),
142                    filter,
143                });
144            }
145        }
146
147        // If no edges, create implicit pipeline order
148        if edges.is_empty() && nodes.len() > 1 {
149            for i in 0..nodes.len() - 1 {
150                edges.push(StatsModelsEdge {
151                    source: nodes[i].name.clone(),
152                    destination: nodes[i + 1].name.clone(),
153                    filter: HashMap::new(),
154                });
155            }
156        }
157
158        // Wire edges to nodes
159        for edge in &edges {
160            if let Some(&src_idx) = node_map.get(&edge.source) {
161                nodes[src_idx].add_child(edge.clone());
162            }
163            if let Some(&dst_idx) = node_map.get(&edge.destination) {
164                nodes[dst_idx].add_parent(edge.clone());
165            }
166        }
167
168        let root_idx = model
169            .get("root")
170            .and_then(|v| v.as_str())
171            .and_then(|name| node_map.get(name).copied())
172            .unwrap_or(0);
173
174        Ok(Self {
175            name,
176            description,
177            nodes,
178            node_map,
179            edges,
180            root_idx,
181        })
182    }
183
184    /// Validate the model structure.
185    pub fn validate(&self) -> Result<()> {
186        // Check unique names
187        let mut names = std::collections::HashSet::new();
188        for node in &self.nodes {
189            if !names.insert(&node.name) {
190                return Err(BidsError::Validation(format!(
191                    "Duplicate node name: '{}'",
192                    node.name
193                )));
194            }
195        }
196        // Check edge references
197        for edge in &self.edges {
198            if !self.node_map.contains_key(&edge.source) {
199                return Err(BidsError::Validation(format!(
200                    "Edge references unknown source: '{}'",
201                    edge.source
202                )));
203            }
204            if !self.node_map.contains_key(&edge.destination) {
205                return Err(BidsError::Validation(format!(
206                    "Edge references unknown destination: '{}'",
207                    edge.destination
208                )));
209            }
210        }
211        Ok(())
212    }
213
214    /// Get a node by name.
215    pub fn get_node(&self, name: &str) -> Option<&StatsModelsNode> {
216        self.node_map.get(name).map(|&i| &self.nodes[i])
217    }
218
219    /// Get the root node.
220    pub fn root_node(&self) -> &StatsModelsNode {
221        &self.nodes[self.root_idx]
222    }
223
224    /// Load collections from a layout into all nodes.
225    pub fn load_collections(&mut self, layout: &bids_layout::BidsLayout) {
226        // Use bids_variables to load, then convert to VariableCollections
227        for node in &mut self.nodes {
228            if let Ok(index) = bids_variables::load_variables(layout, None, Some(&node.level)) {
229                let entities = bids_core::entities::StringEntities::new();
230                let run_colls = index.get_run_collections(&entities);
231                for rc in run_colls {
232                    let vars: Vec<bids_variables::SimpleVariable> = rc
233                        .sparse
234                        .iter()
235                        .map(|v| {
236                            bids_variables::SimpleVariable::new(
237                                &v.name,
238                                &v.source,
239                                v.str_amplitude.clone(),
240                                v.index.clone(),
241                            )
242                        })
243                        .collect();
244                    if !vars.is_empty() {
245                        node.add_collections(vec![bids_variables::VariableCollection::new(vars)]);
246                    }
247                }
248            }
249        }
250    }
251
252    /// Write graph structure as a DOT file (text-based graphviz).
253    pub fn write_graph(&self) -> String {
254        let mut dot = format!("digraph \"{}\" {{\n  node [shape=record];\n", self.name);
255        for node in &self.nodes {
256            dot.push_str(&format!(
257                "  \"{}\" [label=\"{{name: {}|level: {}}}\"];\n",
258                node.name, node.name, node.level
259            ));
260        }
261        for edge in &self.edges {
262            dot.push_str(&format!(
263                "  \"{}\" -> \"{}\";\n",
264                edge.source, edge.destination
265            ));
266        }
267        dot.push_str("}\n");
268        dot
269    }
270
271    /// Render graph to a file using the `dot` command (requires graphviz installed).
272    pub fn render_graph(&self, output_path: &std::path::Path, format: &str) -> std::io::Result<()> {
273        let dot = self.write_graph();
274        let tmp = std::env::temp_dir().join("bids_model.dot");
275        std::fs::write(&tmp, &dot)?;
276        let status = std::process::Command::new("dot")
277            .arg(format!("-T{format}"))
278            .arg("-o")
279            .arg(output_path)
280            .arg(&tmp)
281            .status()?;
282        if !status.success() {
283            return Err(std::io::Error::other("dot command failed"));
284        }
285        Ok(())
286    }
287
288    /// Run the entire graph recursively.
289    pub fn run(&self) -> Vec<StatsModelsNodeOutput> {
290        let mut all_outputs = Vec::new();
291        self.run_node_recursive(self.root_idx, &[], &mut all_outputs);
292        all_outputs
293    }
294
295    fn run_node_recursive(
296        &self,
297        node_idx: usize,
298        inputs: &[ContrastInfo],
299        all_outputs: &mut Vec<StatsModelsNodeOutput>,
300    ) {
301        let node = &self.nodes[node_idx];
302        let outputs = node.run(inputs, true, "TR");
303        let contrasts: Vec<ContrastInfo> =
304            outputs.iter().flat_map(|o| o.contrasts.clone()).collect();
305        all_outputs.extend(outputs);
306
307        for edge in &node.children {
308            if let Some(&dst_idx) = self.node_map.get(&edge.destination) {
309                self.run_node_recursive(dst_idx, &contrasts, all_outputs);
310            }
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_parse_model() {
321        let model = serde_json::json!({
322            "Name": "test_model",
323            "Description": "A test",
324            "BIDSModelVersion": "1.0.0",
325            "Nodes": [
326                {
327                    "Level": "Run",
328                    "Name": "run",
329                    "GroupBy": ["run", "subject"],
330                    "Model": {"Type": "glm", "X": ["trial_type.face"]},
331                    "DummyContrasts": {"Test": "t"}
332                },
333                {
334                    "Level": "Subject",
335                    "Name": "subject",
336                    "GroupBy": ["subject", "contrast"],
337                    "Model": {"Type": "glm", "X": [1]},
338                    "DummyContrasts": {"Test": "t"}
339                }
340            ]
341        });
342
343        let graph = StatsModelsGraph::from_json(&model).unwrap();
344        assert_eq!(graph.name, "test_model");
345        assert_eq!(graph.nodes.len(), 2);
346        assert_eq!(graph.edges.len(), 1);
347        graph.validate().unwrap();
348    }
349}