Skip to main content

bids_variables/
node.rs

1//! BIDS hierarchy nodes: run-level and non-run (subject/session/dataset).
2//!
3//! Nodes represent levels of the BIDS hierarchy that hold variables. Run-level
4//! nodes carry sparse and dense time-series variables; non-run nodes carry
5//! simple demographic/session variables.
6
7use bids_core::entities::StringEntities;
8
9use crate::collections::RunVariableCollection;
10use crate::variables::{DenseRunVariable, SimpleVariable, SparseRunVariable};
11
12/// Metadata about a single run in a BIDS dataset.
13///
14/// Contains the run's BIDS entities, temporal parameters (duration, TR),
15/// the path to the associated image file, and the number of volumes.
16/// Used by variable types to track which runs their data belongs to.
17#[derive(Debug, Clone)]
18pub struct RunInfo {
19    pub entities: StringEntities,
20    pub duration: f64,
21    pub tr: f64,
22    pub image: Option<String>,
23    pub n_vols: usize,
24}
25
26/// A non-run node in the BIDS hierarchy (dataset, subject, or session level).
27///
28/// Holds simple variables (e.g., participant demographics, session metadata)
29/// at a specific level of the BIDS hierarchy.
30pub struct Node {
31    pub level: String,
32    pub entities: StringEntities,
33    pub variables: Vec<SimpleVariable>,
34}
35
36impl Node {
37    pub fn new(level: &str, entities: StringEntities) -> Self {
38        Self {
39            level: level.to_lowercase(),
40            entities,
41            variables: Vec::new(),
42        }
43    }
44
45    pub fn add_variable(&mut self, var: SimpleVariable) {
46        self.variables.push(var);
47    }
48}
49
50/// A run-level node with timing information and both sparse and dense variables.
51///
52/// Represents a single functional run with its temporal parameters (duration,
53/// TR, number of volumes) and holds both event-based sparse variables and
54/// continuous dense variables for that run.
55pub struct RunNode {
56    pub level: String,
57    pub entities: StringEntities,
58    pub image_file: Option<String>,
59    pub duration: f64,
60    pub repetition_time: f64,
61    pub n_vols: usize,
62    pub sparse_variables: Vec<SparseRunVariable>,
63    pub dense_variables: Vec<DenseRunVariable>,
64}
65
66impl RunNode {
67    pub fn new(
68        entities: StringEntities,
69        image_file: Option<String>,
70        duration: f64,
71        repetition_time: f64,
72        n_vols: usize,
73    ) -> Self {
74        Self {
75            level: "run".into(),
76            entities,
77            image_file,
78            duration,
79            repetition_time,
80            n_vols,
81            sparse_variables: Vec::new(),
82            dense_variables: Vec::new(),
83        }
84    }
85
86    pub fn get_info(&self) -> RunInfo {
87        RunInfo {
88            entities: self.entities.clone(),
89            duration: self.duration,
90            tr: self.repetition_time,
91            image: self.image_file.clone(),
92            n_vols: self.n_vols,
93        }
94    }
95
96    pub fn add_sparse_variable(&mut self, var: SparseRunVariable) {
97        self.sparse_variables.push(var);
98    }
99
100    pub fn add_dense_variable(&mut self, var: DenseRunVariable) {
101        self.dense_variables.push(var);
102    }
103}
104
105/// Top-level index organizing all variable nodes in a BIDS dataset.
106///
107/// The `NodeIndex` maintains a flat list of nodes (both run-level and
108/// higher-level) and provides methods to find, create, and query nodes
109/// by level and entity values. Nodes are created during variable loading
110/// and can be queried to extract variable collections for statistical
111/// modeling.
112#[derive(Default)]
113pub struct NodeIndex {
114    nodes: Vec<NodeEntry>,
115}
116
117enum NodeEntry {
118    Run(RunNode),
119    Other(Node),
120}
121
122impl NodeIndex {
123    pub fn new() -> Self {
124        Self::default()
125    }
126
127    pub fn create_run_node(
128        &mut self,
129        entities: StringEntities,
130        image_file: Option<String>,
131        duration: f64,
132        tr: f64,
133        n_vols: usize,
134    ) -> usize {
135        self.nodes.push(NodeEntry::Run(RunNode::new(
136            entities, image_file, duration, tr, n_vols,
137        )));
138        self.nodes.len() - 1
139    }
140
141    pub fn create_node(&mut self, level: &str, entities: StringEntities) -> usize {
142        self.nodes
143            .push(NodeEntry::Other(Node::new(level, entities)));
144        self.nodes.len() - 1
145    }
146
147    pub fn get_run_node_mut(&mut self, index: usize) -> Option<&mut RunNode> {
148        match self.nodes.get_mut(index) {
149            Some(NodeEntry::Run(n)) => Some(n),
150            _ => None,
151        }
152    }
153
154    pub fn get_node_mut(&mut self, index: usize) -> Option<&mut Node> {
155        match self.nodes.get_mut(index) {
156            Some(NodeEntry::Other(n)) => Some(n),
157            _ => None,
158        }
159    }
160
161    /// Find nodes matching level and entities, sorted by subject/session/task/run.
162    pub fn find_nodes(&self, level: &str, entities: &StringEntities) -> Vec<usize> {
163        let sort_keys = ["subject", "session", "task", "run"];
164        let mut results: Vec<(usize, Vec<String>)> = self
165            .nodes
166            .iter()
167            .enumerate()
168            .filter(|(_, entry)| {
169                let (node_level, node_ents) = match entry {
170                    NodeEntry::Run(r) => (r.level.as_str(), &r.entities),
171                    NodeEntry::Other(n) => (n.level.as_str(), &n.entities),
172                };
173                node_level == level
174                    && entities
175                        .iter()
176                        .all(|(k, v)| node_ents.get(k).is_none_or(|nv| nv == v))
177            })
178            .map(|(i, entry)| {
179                let ents = match entry {
180                    NodeEntry::Run(r) => &r.entities,
181                    NodeEntry::Other(n) => &n.entities,
182                };
183                let key: Vec<String> = sort_keys
184                    .iter()
185                    .map(|k| ents.get(*k).cloned().unwrap_or_default())
186                    .collect();
187                (i, key)
188            })
189            .collect();
190        results.sort_by(|(_, a), (_, b)| a.cmp(b));
191        results.into_iter().map(|(i, _)| i).collect()
192    }
193
194    /// Find or create a node.
195    pub fn get_or_create_node(&mut self, level: &str, entities: StringEntities) -> usize {
196        let existing = self.find_nodes(level, &entities);
197        if let Some(&idx) = existing.first() {
198            idx
199        } else {
200            self.create_node(level, entities)
201        }
202    }
203
204    /// Find or create a run node.
205    pub fn get_or_create_run_node(
206        &mut self,
207        entities: StringEntities,
208        image_file: Option<String>,
209        duration: f64,
210        tr: f64,
211        n_vols: usize,
212    ) -> usize {
213        let existing = self.find_nodes("run", &entities);
214        if let Some(&idx) = existing.first() {
215            idx
216        } else {
217            self.create_run_node(entities, image_file, duration, tr, n_vols)
218        }
219    }
220
221    /// Collect run-level variables into collections.
222    pub fn get_run_collections(&self, entities: &StringEntities) -> Vec<RunVariableCollection> {
223        let indices = self.find_nodes("run", entities);
224        indices
225            .iter()
226            .filter_map(|&idx| {
227                if let NodeEntry::Run(rn) = &self.nodes[idx] {
228                    if rn.sparse_variables.is_empty() && rn.dense_variables.is_empty() {
229                        return None;
230                    }
231                    Some(RunVariableCollection::new(
232                        rn.sparse_variables.clone(),
233                        rn.dense_variables.clone(),
234                        None,
235                    ))
236                } else {
237                    None
238                }
239            })
240            .collect()
241    }
242}