Skip to main content

bids_variables/
collections.rs

1//! Variable collections for grouping and merging BIDS variables.
2//!
3//! Collections group multiple variables at the same level of the BIDS hierarchy,
4//! providing operations like variable lookup, wildcard matching, filtering,
5//! and merging across runs or subjects.
6
7use crate::variables::{DenseRunVariable, SimpleVariable, SparseRunVariable, merge_simple};
8use bids_core::entities::StringEntities;
9use std::collections::HashMap;
10
11fn source_to_level(source: &str) -> &str {
12    match source {
13        "events" | "physio" | "stim" | "regressors" => "run",
14        "scans" => "session",
15        "sessions" => "subject",
16        "participants" => "dataset",
17        other => other,
18    }
19}
20
21/// A collection of simple BIDS variables at a single level of the hierarchy.
22///
23/// Groups [`SimpleVariable`]s by name, automatically merging variables with
24/// the same name and inferring the hierarchy level from the variable source
25/// (events→run, participants→dataset, etc.).
26///
27/// Supports variable lookup by name, wildcard pattern matching, entity-based
28/// filtering, and subset selection.
29///
30/// Corresponds to PyBIDS' `BIDSVariableCollection` class.
31#[derive(Debug, Default, Clone)]
32pub struct VariableCollection {
33    pub variables: HashMap<String, SimpleVariable>,
34    pub level: String,
35    pub entities: StringEntities,
36}
37
38impl VariableCollection {
39    pub fn new(vars: Vec<SimpleVariable>) -> Self {
40        let level = vars
41            .first()
42            .map(|v| source_to_level(&v.source).to_string())
43            .unwrap_or_default();
44
45        // Merge variables with same name
46        let mut by_name: HashMap<String, Vec<&SimpleVariable>> = HashMap::new();
47        for v in &vars {
48            by_name.entry(v.name.clone()).or_default().push(v);
49        }
50        let mut variables = HashMap::new();
51        for (name, var_list) in &by_name {
52            if let Some(m) = merge_simple(var_list) {
53                variables.insert(name.clone(), m);
54            }
55        }
56
57        let entities = index_common_entities(&variables);
58        Self {
59            variables,
60            level,
61            entities,
62        }
63    }
64
65    pub fn get(&self, name: &str) -> Option<&SimpleVariable> {
66        self.variables.get(name)
67    }
68
69    pub fn names(&self) -> Vec<&str> {
70        self.variables
71            .keys()
72            .map(std::string::String::as_str)
73            .collect()
74    }
75
76    /// Match variable names against a pattern (regex or glob).
77    pub fn match_variables(&self, pattern: &str, use_regex: bool) -> Vec<&str> {
78        if use_regex {
79            let re =
80                regex::Regex::new(pattern).unwrap_or_else(|_| regex::Regex::new("$^").unwrap());
81            self.variables
82                .keys()
83                .filter(|k| re.is_match(k))
84                .map(std::string::String::as_str)
85                .collect()
86        } else {
87            self.variables
88                .keys()
89                .filter(|k| glob_match(pattern, k))
90                .map(std::string::String::as_str)
91                .collect()
92        }
93    }
94
95    /// Convert all variables to tabular rows (long format).
96    pub fn to_rows(&self) -> Vec<StringEntities> {
97        self.variables
98            .values()
99            .flat_map(super::variables::SimpleVariable::to_rows)
100            .collect()
101    }
102
103    /// Convert to wide format: each variable becomes a column.
104    /// Returns `(column_names, rows)` where each row is a `Vec<String>`.
105    pub fn to_wide(&self) -> (Vec<String>, Vec<Vec<String>>) {
106        // Collect all entity columns
107        let mut entity_cols: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
108        for var in self.variables.values() {
109            for row in &var.index {
110                entity_cols.extend(row.keys().cloned());
111            }
112        }
113        let var_names: Vec<String> = self.variables.keys().cloned().collect();
114        let mut col_names: Vec<String> = entity_cols.into_iter().collect();
115        col_names.extend(var_names.iter().cloned());
116
117        // Build rows grouped by entity combination
118        let n_rows = self
119            .variables
120            .values()
121            .next()
122            .map_or(0, super::variables::SimpleVariable::len);
123        let mut rows = Vec::with_capacity(n_rows);
124        for i in 0..n_rows {
125            let mut row = Vec::with_capacity(col_names.len());
126            // Entity columns
127            let first_var = self.variables.values().next();
128            let ent_row = first_var.and_then(|v| v.index.get(i));
129            for col in &col_names {
130                if self.variables.contains_key(col) {
131                    // Variable column
132                    let val = self
133                        .variables
134                        .get(col)
135                        .and_then(|v| v.str_values.get(i))
136                        .cloned()
137                        .unwrap_or_default();
138                    row.push(val);
139                } else {
140                    // Entity column
141                    let val = ent_row
142                        .and_then(|r| r.get(col))
143                        .cloned()
144                        .unwrap_or_default();
145                    row.push(val);
146                }
147            }
148            rows.push(row);
149        }
150        (col_names, rows)
151    }
152
153    /// Create collection from tabular rows.
154    pub fn from_rows(rows: &[StringEntities], source: &str) -> Self {
155        let mut by_name: HashMap<String, (Vec<String>, Vec<StringEntities>)> = HashMap::new();
156        for row in rows {
157            let name = row
158                .get("condition")
159                .cloned()
160                .unwrap_or_else(|| "unknown".into());
161            let amp = row.get("amplitude").cloned().unwrap_or_default();
162            let ents: StringEntities = row
163                .iter()
164                .filter(|(k, _)| k.as_str() != "condition" && k.as_str() != "amplitude")
165                .map(|(k, v)| (k.clone(), v.clone()))
166                .collect();
167            let entry = by_name.entry(name).or_default();
168            entry.0.push(amp);
169            entry.1.push(ents);
170        }
171        let vars: Vec<SimpleVariable> = by_name
172            .into_iter()
173            .map(|(name, (values, index))| SimpleVariable::new(&name, source, values, index))
174            .collect();
175        Self::new(vars)
176    }
177}
178
179/// A collection of run-level BIDS variables.
180#[derive(Debug, Default, Clone)]
181pub struct RunVariableCollection {
182    pub sparse: Vec<SparseRunVariable>,
183    pub dense: Vec<DenseRunVariable>,
184    pub sampling_rate: f64,
185}
186
187impl RunVariableCollection {
188    pub fn new(
189        sparse: Vec<SparseRunVariable>,
190        dense: Vec<DenseRunVariable>,
191        sampling_rate: Option<f64>,
192    ) -> Self {
193        Self {
194            sparse,
195            dense,
196            sampling_rate: sampling_rate.unwrap_or(10.0),
197        }
198    }
199
200    pub fn get_sparse(&self, name: &str) -> Option<&SparseRunVariable> {
201        self.sparse.iter().find(|v| v.name == name)
202    }
203    pub fn get_dense(&self, name: &str) -> Option<&DenseRunVariable> {
204        self.dense.iter().find(|v| v.name == name)
205    }
206
207    pub fn sparse_names(&self) -> Vec<&str> {
208        self.sparse.iter().map(|v| v.name.as_str()).collect()
209    }
210    pub fn dense_names(&self) -> Vec<&str> {
211        self.dense.iter().map(|v| v.name.as_str()).collect()
212    }
213    pub fn all_sparse(&self) -> bool {
214        self.dense.is_empty()
215    }
216    pub fn all_dense(&self) -> bool {
217        self.sparse.is_empty()
218    }
219
220    pub fn names(&self) -> Vec<&str> {
221        let mut n: Vec<&str> = self
222            .sparse_names()
223            .into_iter()
224            .chain(self.dense_names())
225            .collect();
226        n.sort();
227        n.dedup();
228        n
229    }
230
231    /// Convert all sparse variables to dense.
232    pub fn to_dense(&mut self, sampling_rate: Option<f64>) {
233        let sr = sampling_rate.unwrap_or(self.sampling_rate);
234        let sparse = std::mem::take(&mut self.sparse);
235        for var in sparse {
236            if var.amplitude.iter().all(|v| v.is_finite()) {
237                self.dense.push(var.to_dense(Some(sr)));
238            }
239        }
240        self.sampling_rate = sr;
241    }
242
243    /// Resample all dense variables.
244    pub fn resample(&mut self, sampling_rate: f64) {
245        self.dense = self
246            .dense
247            .iter()
248            .map(|v| v.resample(sampling_rate))
249            .collect();
250        self.sampling_rate = sampling_rate;
251    }
252
253    /// Combined densify + resample pipeline.
254    pub fn densify_and_resample(
255        &mut self,
256        sampling_rate: Option<f64>,
257        force_dense: bool,
258        resample_dense: bool,
259    ) {
260        let sr = sampling_rate.unwrap_or(self.sampling_rate);
261        if force_dense {
262            self.to_dense(Some(sr));
263        }
264        if resample_dense {
265            self.resample(sr);
266        }
267        self.sampling_rate = sr;
268    }
269
270    /// Convert to rows with sampling rate option.
271    pub fn to_rows_with_options(
272        &self,
273        include_sparse: bool,
274        include_dense: bool,
275        sampling_rate: Option<f64>,
276    ) -> Vec<StringEntities> {
277        let mut all_rows = Vec::new();
278        if include_sparse {
279            for var in &self.sparse {
280                all_rows.extend(var.to_rows());
281            }
282        }
283        if include_dense {
284            let dense_vars: Vec<_> = if let Some(sr) = sampling_rate {
285                self.dense.iter().map(|v| v.resample(sr)).collect()
286            } else {
287                self.dense.clone()
288            };
289            for var in &dense_vars {
290                all_rows.extend(var.to_rows());
291            }
292        }
293        all_rows
294    }
295
296    /// Resolve sampling rate from string ('TR', 'highest', or numeric).
297    pub fn resolve_sampling_rate(&self, requested: Option<&str>) -> f64 {
298        match requested {
299            Some("TR") => {
300                let trs: std::collections::HashSet<i64> = self
301                    .dense
302                    .iter()
303                    .flat_map(|v| v.run_info.iter())
304                    .map(|r| (r.tr * 1_000_000.0).round() as i64)
305                    .collect();
306                if trs.len() == 1 {
307                    1.0 / (trs.into_iter().next().unwrap() as f64 / 1_000_000.0)
308                } else {
309                    self.sampling_rate
310                }
311            }
312            Some("highest") => self
313                .dense
314                .iter()
315                .map(|v| v.sampling_rate)
316                .fold(self.sampling_rate, f64::max),
317            Some(s) => s.parse().unwrap_or(self.sampling_rate),
318            None => self.sampling_rate,
319        }
320    }
321}
322
323/// Merge multiple variable collections.
324pub fn merge_collections(collections: &[VariableCollection]) -> Option<VariableCollection> {
325    if collections.is_empty() {
326        return None;
327    }
328    let all_vars: Vec<SimpleVariable> = collections
329        .iter()
330        .flat_map(|c| c.variables.values().cloned())
331        .collect();
332    Some(VariableCollection::new(all_vars))
333}
334
335// ─────────── Helpers ───────────
336
337fn index_common_entities(variables: &HashMap<String, SimpleVariable>) -> StringEntities {
338    let all_ents: Vec<&StringEntities> = variables.values().map(|v| &v.entities).collect();
339    let mut common = StringEntities::new();
340    if let Some(first) = all_ents.first() {
341        for (k, v) in *first {
342            if all_ents.iter().all(|e| e.get(k) == Some(v)) {
343                common.insert(k.clone(), v.clone());
344            }
345        }
346    }
347    common
348}
349
350fn glob_match(pattern: &str, text: &str) -> bool {
351    let re_str = format!(
352        "^{}$",
353        pattern
354            .replace('.', r"\.")
355            .replace('*', ".*")
356            .replace('?', ".")
357    );
358    regex::Regex::new(&re_str).is_ok_and(|re| re.is_match(text))
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use std::collections::HashMap;
365
366    #[test]
367    fn test_variable_collection() {
368        let idx = vec![
369            HashMap::from([("subject".into(), "01".into())]),
370            HashMap::from([("subject".into(), "02".into())]),
371        ];
372        let v = SimpleVariable::new("age", "participants", vec!["25".into(), "30".into()], idx);
373        let col = VariableCollection::new(vec![v]);
374        assert_eq!(col.level, "dataset");
375        assert!(col.get("age").is_some());
376    }
377
378    #[test]
379    fn test_match_variables() {
380        let v1 = SimpleVariable::new("age", "participants", vec![], vec![]);
381        let v2 = SimpleVariable::new("sex", "participants", vec![], vec![]);
382        let v3 = SimpleVariable::new("age_group", "participants", vec![], vec![]);
383        let col = VariableCollection::new(vec![v1, v2, v3]);
384        let matches = col.match_variables("age*", false);
385        assert_eq!(matches.len(), 2);
386    }
387
388    #[test]
389    fn test_merge_collections() {
390        let v1 = SimpleVariable::new(
391            "age",
392            "participants",
393            vec!["25".into()],
394            vec![HashMap::from([("subject".into(), "01".into())])],
395        );
396        let v2 = SimpleVariable::new(
397            "age",
398            "participants",
399            vec!["30".into()],
400            vec![HashMap::from([("subject".into(), "02".into())])],
401        );
402        let c1 = VariableCollection::new(vec![v1]);
403        let c2 = VariableCollection::new(vec![v2]);
404        let merged = merge_collections(&[c1, c2]).unwrap();
405        assert_eq!(merged.get("age").unwrap().len(), 2);
406    }
407}