1use crate::node::{ContrastInfo, StatsModelsEdge, StatsModelsNode, StatsModelsNodeOutput};
8use crate::transformations::TransformSpec;
9use bids_core::error::{BidsError, Result};
10use bids_core::utils::convert_json_keys;
11use std::collections::HashMap;
12
13#[derive(Debug)]
33pub struct StatsModelsGraph {
34 pub name: String,
35 pub description: String,
36 pub nodes: Vec<StatsModelsNode>,
37 node_map: HashMap<String, usize>,
38 pub edges: Vec<StatsModelsEdge>,
39 root_idx: usize,
40}
41
42impl StatsModelsGraph {
43 pub fn from_json(model_json: &serde_json::Value) -> Result<Self> {
45 let model = convert_json_keys(model_json);
46 Self::from_parsed(&model)
47 }
48
49 pub fn from_file(path: &std::path::Path) -> Result<Self> {
51 let contents = std::fs::read_to_string(path)?;
52 let json: serde_json::Value = serde_json::from_str(&contents)?;
53 Self::from_json(&json)
54 }
55
56 fn from_parsed(model: &serde_json::Value) -> Result<Self> {
57 let name = model
58 .get("name")
59 .and_then(|v| v.as_str())
60 .unwrap_or("unnamed")
61 .into();
62 let description = model
63 .get("description")
64 .and_then(|v| v.as_str())
65 .unwrap_or("")
66 .into();
67
68 let node_specs = model
70 .get("nodes")
71 .and_then(|v| v.as_array())
72 .ok_or_else(|| BidsError::Validation("Model must have 'nodes' array".into()))?;
73
74 let mut nodes = Vec::new();
75 let mut node_map = HashMap::new();
76
77 for spec in node_specs {
78 let level = spec.get("level").and_then(|v| v.as_str()).unwrap_or("run");
79 let node_name = spec
80 .get("name")
81 .and_then(|v| v.as_str())
82 .unwrap_or("unnamed");
83 let model_spec = spec
84 .get("model")
85 .cloned()
86 .unwrap_or(serde_json::Value::Null);
87 let group_by: Vec<String> = spec
88 .get("group_by")
89 .and_then(|v| v.as_array())
90 .map(|arr| {
91 arr.iter()
92 .filter_map(|v| v.as_str().map(String::from))
93 .collect()
94 })
95 .unwrap_or_default();
96
97 let transformations = spec
98 .get("transformations")
99 .and_then(|v| serde_json::from_value::<TransformSpec>(v.clone()).ok());
100
101 let contrasts: Vec<serde_json::Value> = spec
102 .get("contrasts")
103 .and_then(|v| v.as_array())
104 .cloned()
105 .unwrap_or_default();
106 let dummy_contrasts = spec
107 .get("dummy_contrasts")
108 .cloned()
109 .filter(|v| !v.is_null() && *v != serde_json::Value::Bool(false));
110
111 node_map.insert(node_name.to_string(), nodes.len());
112 nodes.push(StatsModelsNode::new(
113 level,
114 node_name,
115 model_spec,
116 group_by,
117 transformations,
118 contrasts,
119 dummy_contrasts,
120 ));
121 }
122
123 let mut edges = Vec::new();
125 if let Some(edge_specs) = model.get("edges").and_then(|v| v.as_array()) {
126 for edge_spec in edge_specs {
127 let src = edge_spec
128 .get("source")
129 .and_then(|v| v.as_str())
130 .unwrap_or("");
131 let dst = edge_spec
132 .get("destination")
133 .and_then(|v| v.as_str())
134 .unwrap_or("");
135 let filter: HashMap<String, String> = edge_spec
136 .get("filter")
137 .and_then(|v| serde_json::from_value(v.clone()).ok())
138 .unwrap_or_default();
139 edges.push(StatsModelsEdge {
140 source: src.into(),
141 destination: dst.into(),
142 filter,
143 });
144 }
145 }
146
147 if edges.is_empty() && nodes.len() > 1 {
149 for i in 0..nodes.len() - 1 {
150 edges.push(StatsModelsEdge {
151 source: nodes[i].name.clone(),
152 destination: nodes[i + 1].name.clone(),
153 filter: HashMap::new(),
154 });
155 }
156 }
157
158 for edge in &edges {
160 if let Some(&src_idx) = node_map.get(&edge.source) {
161 nodes[src_idx].add_child(edge.clone());
162 }
163 if let Some(&dst_idx) = node_map.get(&edge.destination) {
164 nodes[dst_idx].add_parent(edge.clone());
165 }
166 }
167
168 let root_idx = model
169 .get("root")
170 .and_then(|v| v.as_str())
171 .and_then(|name| node_map.get(name).copied())
172 .unwrap_or(0);
173
174 Ok(Self {
175 name,
176 description,
177 nodes,
178 node_map,
179 edges,
180 root_idx,
181 })
182 }
183
184 pub fn validate(&self) -> Result<()> {
186 let mut names = std::collections::HashSet::new();
188 for node in &self.nodes {
189 if !names.insert(&node.name) {
190 return Err(BidsError::Validation(format!(
191 "Duplicate node name: '{}'",
192 node.name
193 )));
194 }
195 }
196 for edge in &self.edges {
198 if !self.node_map.contains_key(&edge.source) {
199 return Err(BidsError::Validation(format!(
200 "Edge references unknown source: '{}'",
201 edge.source
202 )));
203 }
204 if !self.node_map.contains_key(&edge.destination) {
205 return Err(BidsError::Validation(format!(
206 "Edge references unknown destination: '{}'",
207 edge.destination
208 )));
209 }
210 }
211 Ok(())
212 }
213
214 pub fn get_node(&self, name: &str) -> Option<&StatsModelsNode> {
216 self.node_map.get(name).map(|&i| &self.nodes[i])
217 }
218
219 pub fn root_node(&self) -> &StatsModelsNode {
221 &self.nodes[self.root_idx]
222 }
223
224 pub fn load_collections(&mut self, layout: &bids_layout::BidsLayout) {
226 for node in &mut self.nodes {
228 if let Ok(index) = bids_variables::load_variables(layout, None, Some(&node.level)) {
229 let entities = bids_core::entities::StringEntities::new();
230 let run_colls = index.get_run_collections(&entities);
231 for rc in run_colls {
232 let vars: Vec<bids_variables::SimpleVariable> = rc
233 .sparse
234 .iter()
235 .map(|v| {
236 bids_variables::SimpleVariable::new(
237 &v.name,
238 &v.source,
239 v.str_amplitude.clone(),
240 v.index.clone(),
241 )
242 })
243 .collect();
244 if !vars.is_empty() {
245 node.add_collections(vec![bids_variables::VariableCollection::new(vars)]);
246 }
247 }
248 }
249 }
250 }
251
252 pub fn write_graph(&self) -> String {
254 let mut dot = format!("digraph \"{}\" {{\n node [shape=record];\n", self.name);
255 for node in &self.nodes {
256 dot.push_str(&format!(
257 " \"{}\" [label=\"{{name: {}|level: {}}}\"];\n",
258 node.name, node.name, node.level
259 ));
260 }
261 for edge in &self.edges {
262 dot.push_str(&format!(
263 " \"{}\" -> \"{}\";\n",
264 edge.source, edge.destination
265 ));
266 }
267 dot.push_str("}\n");
268 dot
269 }
270
271 pub fn render_graph(&self, output_path: &std::path::Path, format: &str) -> std::io::Result<()> {
273 let dot = self.write_graph();
274 let tmp = std::env::temp_dir().join("bids_model.dot");
275 std::fs::write(&tmp, &dot)?;
276 let status = std::process::Command::new("dot")
277 .arg(format!("-T{format}"))
278 .arg("-o")
279 .arg(output_path)
280 .arg(&tmp)
281 .status()?;
282 if !status.success() {
283 return Err(std::io::Error::other("dot command failed"));
284 }
285 Ok(())
286 }
287
288 pub fn run(&self) -> Vec<StatsModelsNodeOutput> {
290 let mut all_outputs = Vec::new();
291 self.run_node_recursive(self.root_idx, &[], &mut all_outputs);
292 all_outputs
293 }
294
295 fn run_node_recursive(
296 &self,
297 node_idx: usize,
298 inputs: &[ContrastInfo],
299 all_outputs: &mut Vec<StatsModelsNodeOutput>,
300 ) {
301 let node = &self.nodes[node_idx];
302 let outputs = node.run(inputs, true, "TR");
303 let contrasts: Vec<ContrastInfo> =
304 outputs.iter().flat_map(|o| o.contrasts.clone()).collect();
305 all_outputs.extend(outputs);
306
307 for edge in &node.children {
308 if let Some(&dst_idx) = self.node_map.get(&edge.destination) {
309 self.run_node_recursive(dst_idx, &contrasts, all_outputs);
310 }
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_parse_model() {
321 let model = serde_json::json!({
322 "Name": "test_model",
323 "Description": "A test",
324 "BIDSModelVersion": "1.0.0",
325 "Nodes": [
326 {
327 "Level": "Run",
328 "Name": "run",
329 "GroupBy": ["run", "subject"],
330 "Model": {"Type": "glm", "X": ["trial_type.face"]},
331 "DummyContrasts": {"Test": "t"}
332 },
333 {
334 "Level": "Subject",
335 "Name": "subject",
336 "GroupBy": ["subject", "contrast"],
337 "Model": {"Type": "glm", "X": [1]},
338 "DummyContrasts": {"Test": "t"}
339 }
340 ]
341 });
342
343 let graph = StatsModelsGraph::from_json(&model).unwrap();
344 assert_eq!(graph.name, "test_model");
345 assert_eq!(graph.nodes.len(), 2);
346 assert_eq!(graph.edges.len(), 1);
347 graph.validate().unwrap();
348 }
349}