Skip to main content

bids_variables/
variables.rs

1//! BIDS variable types: simple, sparse-run, and dense-run.
2//!
3//! Implements the three variable types used in BIDS statistical modeling:
4//! [`SimpleVariable`] (participant/session-level), [`SparseRunVariable`]
5//! (event-level with onset/duration), and [`DenseRunVariable`] (continuous
6//! regressors sampled at TR).
7
8use bids_core::entities::StringEntities;
9
10/// A simple variable with no timing information.
11///
12/// Represents a column from a BIDS tabular file such as `participants.tsv`,
13/// `sessions.tsv`, or `scans.tsv`. Each variable has a name, source file
14/// identifier, string values, and an entity index that maps each row to
15/// its BIDS entities (e.g., which subject each value belongs to).
16///
17/// Values are stored as both strings and parsed floats (NaN for non-numeric).
18/// The `is_numeric` flag indicates whether all values are parseable as f64.
19///
20/// Corresponds to PyBIDS' `SimpleVariable` class.
21///
22/// # Example
23///
24/// ```
25/// use bids_variables::SimpleVariable;
26/// use std::collections::HashMap;
27///
28/// let index = vec![
29///     HashMap::from([("subject".into(), "01".into())]),
30///     HashMap::from([("subject".into(), "02".into())]),
31/// ];
32/// let var = SimpleVariable::new("age", "participants",
33///     vec!["25".into(), "30".into()], index);
34///
35/// assert_eq!(var.len(), 2);
36/// assert!(var.is_numeric);
37/// ```
38#[derive(Debug, Clone)]
39pub struct SimpleVariable {
40    pub name: String,
41    pub source: String,
42    pub values: Vec<f64>,
43    pub str_values: Vec<String>,
44    pub index: Vec<StringEntities>,
45    pub entities: StringEntities,
46    pub is_numeric: bool,
47}
48
49impl SimpleVariable {
50    pub fn new(name: &str, source: &str, values: Vec<String>, index: Vec<StringEntities>) -> Self {
51        let numeric_values: Vec<f64> = values
52            .iter()
53            .map(|v| v.parse().unwrap_or(f64::NAN))
54            .collect();
55        let is_numeric = values
56            .iter()
57            .all(|v| v.parse::<f64>().is_ok() || v.is_empty());
58        let entities = extract_common_entities(&index);
59
60        Self {
61            name: name.to_string(),
62            source: source.to_string(),
63            values: numeric_values,
64            str_values: values,
65            index,
66            entities,
67            is_numeric,
68        }
69    }
70
71    pub fn len(&self) -> usize {
72        self.str_values.len()
73    }
74    pub fn is_empty(&self) -> bool {
75        self.str_values.is_empty()
76    }
77
78    /// Clone with optional data/name replacement.
79    pub fn clone_with(&self, data: Option<Vec<String>>, name: Option<&str>) -> Self {
80        let mut cloned = self.clone();
81        if let Some(d) = data {
82            cloned.values = d.iter().map(|v| v.parse().unwrap_or(f64::NAN)).collect();
83            cloned.str_values = d;
84        }
85        if let Some(n) = name {
86            cloned.name = n.to_string();
87        }
88        cloned
89    }
90
91    /// Filter rows matching given entity criteria.
92    pub fn filter(&self, filters: &StringEntities) -> Self {
93        let mut values = Vec::new();
94        let mut index = Vec::new();
95
96        for (i, row_ents) in self.index.iter().enumerate() {
97            if filters
98                .iter()
99                .all(|(k, v)| row_ents.get(k).is_none_or(|rv| rv == v))
100            {
101                values.push(self.str_values[i].clone());
102                index.push(row_ents.clone());
103            }
104        }
105
106        Self::new(&self.name, &self.source, values, index)
107    }
108
109    /// Convert to tabular rows.
110    pub fn to_rows(&self) -> Vec<StringEntities> {
111        self.str_values
112            .iter()
113            .enumerate()
114            .map(|(i, val)| {
115                let mut row = self.index.get(i).cloned().unwrap_or_default();
116                row.insert("amplitude".into(), val.clone());
117                row.insert("condition".into(), self.name.clone());
118                row
119            })
120            .collect()
121    }
122}
123
124/// A sparse run variable representing events with onset, duration, and amplitude.
125///
126/// Loaded from `_events.tsv` files, each event has a time point (onset),
127/// a duration, and an amplitude value. Multiple runs can be represented in
128/// a single variable via the `run_info` vector.
129///
130/// Can be converted to a [`DenseRunVariable`] via [`to_dense()`](Self::to_dense)
131/// for convolution with hemodynamic response functions or other time-domain
132/// operations.
133///
134/// Corresponds to PyBIDS' `SparseRunVariable` class.
135#[derive(Debug, Clone)]
136pub struct SparseRunVariable {
137    pub name: String,
138    pub source: String,
139    pub onset: Vec<f64>,
140    pub duration: Vec<f64>,
141    pub amplitude: Vec<f64>,
142    pub str_amplitude: Vec<String>,
143    pub index: Vec<StringEntities>,
144    pub entities: StringEntities,
145    pub run_info: Vec<super::node::RunInfo>,
146}
147
148impl SparseRunVariable {
149    pub fn new(
150        name: &str,
151        source: &str,
152        onset: Vec<f64>,
153        duration: Vec<f64>,
154        amplitude: Vec<String>,
155        index: Vec<StringEntities>,
156        run_info: Vec<super::node::RunInfo>,
157    ) -> Self {
158        let numeric_amp: Vec<f64> = amplitude
159            .iter()
160            .map(|v| v.parse().unwrap_or(f64::NAN))
161            .collect();
162        let mut entities = extract_common_entities(&index);
163        // Also include common entities from run_info
164        if let Some(first_run) = run_info.first() {
165            for (k, v) in &first_run.entities {
166                if run_info.iter().all(|r| r.entities.get(k) == Some(v)) {
167                    entities.entry(k.clone()).or_insert_with(|| v.clone());
168                }
169            }
170        }
171
172        Self {
173            name: name.to_string(),
174            source: source.to_string(),
175            onset,
176            duration,
177            amplitude: numeric_amp,
178            str_amplitude: amplitude,
179            index,
180            entities,
181            run_info,
182        }
183    }
184
185    pub fn len(&self) -> usize {
186        self.onset.len()
187    }
188    pub fn is_empty(&self) -> bool {
189        self.onset.is_empty()
190    }
191
192    /// Total duration of all runs.
193    pub fn get_duration(&self) -> f64 {
194        self.run_info.iter().map(|r| r.duration).sum()
195    }
196
197    /// Convert sparse to dense representation using GCD-based bin size.
198    pub fn to_dense(&self, sampling_rate: Option<f64>) -> DenseRunVariable {
199        let onsets_ms: Vec<i64> = self
200            .onset
201            .iter()
202            .map(|o| (o * 1000.0).round() as i64)
203            .collect();
204        let durations_ms: Vec<i64> = self
205            .duration
206            .iter()
207            .map(|d| (d * 1000.0).round() as i64)
208            .collect();
209
210        let all_vals: Vec<i64> = onsets_ms
211            .iter()
212            .chain(durations_ms.iter())
213            .copied()
214            .filter(|&v| v > 0)
215            .collect();
216        let gcd_val = all_vals
217            .iter()
218            .copied()
219            .reduce(gcd_pair)
220            .unwrap_or(1)
221            .max(1);
222
223        let bin_sr = 1000.0 / gcd_val as f64;
224        let sr = sampling_rate.map_or(bin_sr, |s| s.max(bin_sr));
225        let total_duration = self.get_duration();
226        let n_samples = (total_duration * sr).ceil() as usize;
227        let mut ts = vec![0.0f64; n_samples];
228
229        let mut run_offset = 0.0;
230        let mut last_onset = -1.0f64;
231        let mut run_i = 0;
232
233        for i in 0..self.onset.len() {
234            if self.onset[i] < last_onset && run_i + 1 < self.run_info.len() {
235                run_offset += self.run_info[run_i].duration;
236                run_i += 1;
237            }
238            let onset_sample = ((run_offset + self.onset[i]) * sr).round() as usize;
239            let dur_samples = (self.duration[i] * sr).round() as usize;
240            let offset_sample = (onset_sample + dur_samples).min(n_samples);
241            for ts_val in ts.iter_mut().take(offset_sample).skip(onset_sample) {
242                *ts_val = self.amplitude[i];
243            }
244            last_onset = self.onset[i];
245        }
246
247        let final_sr = sampling_rate.unwrap_or(sr);
248        if (final_sr - sr).abs() > 0.001 {
249            let new_n = (total_duration * final_sr).ceil() as usize;
250            ts = linear_resample(&ts, new_n);
251        }
252
253        DenseRunVariable::new(
254            &self.name,
255            &self.source,
256            ts,
257            final_sr,
258            self.run_info.clone(),
259        )
260    }
261
262    /// Filter events by entity criteria.
263    pub fn filter(&self, filters: &StringEntities) -> Self {
264        let mut onset = Vec::new();
265        let mut duration = Vec::new();
266        let mut amplitude = Vec::new();
267        let mut index = Vec::new();
268
269        for (i, row_ents) in self.index.iter().enumerate() {
270            if filters
271                .iter()
272                .all(|(k, v)| row_ents.get(k).is_none_or(|rv| rv == v))
273            {
274                onset.push(self.onset[i]);
275                duration.push(self.duration[i]);
276                amplitude.push(self.str_amplitude[i].clone());
277                index.push(row_ents.clone());
278            }
279        }
280
281        Self::new(
282            &self.name,
283            &self.source,
284            onset,
285            duration,
286            amplitude,
287            index,
288            self.run_info.clone(),
289        )
290    }
291
292    /// Convert to tabular rows.
293    pub fn to_rows(&self) -> Vec<StringEntities> {
294        (0..self.onset.len())
295            .map(|i| {
296                let mut row = self.index.get(i).cloned().unwrap_or_default();
297                row.insert("onset".into(), self.onset[i].to_string());
298                row.insert("duration".into(), self.duration[i].to_string());
299                row.insert("amplitude".into(), self.str_amplitude[i].clone());
300                row.insert("condition".into(), self.name.clone());
301                row
302            })
303            .collect()
304    }
305}
306
307/// A dense run variable with uniformly-sampled time series data.
308///
309/// Represents continuous signals such as physiological recordings (`_physio.tsv.gz`),
310/// stimulus waveforms (`_stim.tsv.gz`), or confound regressors (`_regressors.tsv`).
311/// Data is stored as a vector of f64 values at a fixed sampling rate.
312///
313/// Supports resampling to different rates via [`resample()`](Self::resample) and
314/// TR-based downsampling via [`resample_to_tr()`](Self::resample_to_tr).
315///
316/// Corresponds to PyBIDS' `DenseRunVariable` class.
317#[derive(Debug, Clone)]
318pub struct DenseRunVariable {
319    pub name: String,
320    pub source: String,
321    pub values: Vec<f64>,
322    pub sampling_rate: f64,
323    pub run_info: Vec<super::node::RunInfo>,
324    pub entities: StringEntities,
325}
326
327impl DenseRunVariable {
328    pub fn new(
329        name: &str,
330        source: &str,
331        values: Vec<f64>,
332        sampling_rate: f64,
333        run_info: Vec<super::node::RunInfo>,
334    ) -> Self {
335        let mut entities = StringEntities::new();
336        for ri in &run_info {
337            for (k, v) in &ri.entities {
338                entities.entry(k.clone()).or_insert_with(|| v.clone());
339            }
340        }
341        Self {
342            name: name.into(),
343            source: source.into(),
344            values,
345            sampling_rate,
346            run_info,
347            entities,
348        }
349    }
350
351    pub fn len(&self) -> usize {
352        self.values.len()
353    }
354    pub fn is_empty(&self) -> bool {
355        self.values.is_empty()
356    }
357
358    /// Resample to a different sampling rate.
359    pub fn resample(&self, new_sr: f64) -> Self {
360        if (new_sr - self.sampling_rate).abs() < 0.001 {
361            return self.clone();
362        }
363        let new_n = ((self.values.len() as f64) * new_sr / self.sampling_rate).ceil() as usize;
364        Self {
365            name: self.name.clone(),
366            source: self.source.clone(),
367            values: linear_resample(&self.values, new_n),
368            sampling_rate: new_sr,
369            run_info: self.run_info.clone(),
370            entities: self.entities.clone(),
371        }
372    }
373
374    /// Resample to TR-based sampling rate.
375    pub fn resample_to_tr(&self) -> Self {
376        self.run_info
377            .first()
378            .filter(|ri| ri.tr > 0.0)
379            .map(|ri| self.resample(1.0 / ri.tr))
380            .unwrap_or_else(|| self.clone())
381    }
382
383    /// Convert to tabular rows.
384    pub fn to_rows(&self) -> Vec<StringEntities> {
385        let interval = 1.0 / self.sampling_rate;
386        self.values
387            .iter()
388            .enumerate()
389            .map(|(i, val)| {
390                let mut row = self.entities.clone();
391                row.insert("onset".into(), (i as f64 * interval).to_string());
392                row.insert("duration".into(), interval.to_string());
393                row.insert("amplitude".into(), val.to_string());
394                row.insert("condition".into(), self.name.clone());
395                row
396            })
397            .collect()
398    }
399}
400
401impl SparseRunVariable {
402    /// Select specific row indices.
403    pub fn select_rows(&self, indices: &[usize]) -> Self {
404        Self::new(
405            &self.name,
406            &self.source,
407            indices
408                .iter()
409                .filter_map(|&i| self.onset.get(i).copied())
410                .collect(),
411            indices
412                .iter()
413                .filter_map(|&i| self.duration.get(i).copied())
414                .collect(),
415            indices
416                .iter()
417                .filter_map(|&i| self.str_amplitude.get(i).cloned())
418                .collect(),
419            indices
420                .iter()
421                .filter_map(|&i| self.index.get(i).cloned())
422                .collect(),
423            self.run_info.clone(),
424        )
425    }
426
427    /// Split into multiple variables based on a grouper.
428    pub fn split(&self, group_col: &str) -> Vec<Self> {
429        let mut groups: std::collections::HashMap<String, Vec<usize>> =
430            std::collections::HashMap::new();
431        for (i, row) in self.index.iter().enumerate() {
432            let key = row.get(group_col).cloned().unwrap_or_default();
433            groups.entry(key).or_default().push(i);
434        }
435        groups
436            .into_iter()
437            .map(|(key, indices)| {
438                let mut var = self.select_rows(&indices);
439                var.name = format!("{}.{}", self.name, key);
440                var
441            })
442            .collect()
443    }
444}
445
446impl DenseRunVariable {
447    /// Build entity index with timestamps for each sample.
448    pub fn build_entity_index(&self) -> Vec<(f64, StringEntities)> {
449        let interval = 1.0 / self.sampling_rate;
450        let mut result = Vec::with_capacity(self.values.len());
451        let mut offset = 0.0;
452        let mut run_i = 0;
453        for (i, _) in self.values.iter().enumerate() {
454            let t = i as f64 * interval;
455            // Advance run if we've passed the current run's duration
456            while run_i + 1 < self.run_info.len() && t >= offset + self.run_info[run_i].duration {
457                offset += self.run_info[run_i].duration;
458                run_i += 1;
459            }
460            let ents = self
461                .run_info
462                .get(run_i)
463                .map(|ri| ri.entities.clone())
464                .unwrap_or_default();
465            result.push((t, ents));
466        }
467        result
468    }
469}
470
471/// Get a grouper key for groupby operations.
472pub fn get_grouper(index: &[StringEntities], group_by: &[&str]) -> Vec<String> {
473    index
474        .iter()
475        .map(|row| {
476            group_by
477                .iter()
478                .map(|k| row.get(*k).cloned().unwrap_or_default())
479                .collect::<Vec<_>>()
480                .join("@@@")
481        })
482        .collect()
483}
484
485/// Apply a function to groups defined by a grouper.
486pub fn apply_grouped<F>(values: &[f64], grouper: &[String], func: F) -> Vec<f64>
487where
488    F: Fn(&[f64]) -> Vec<f64>,
489{
490    let mut groups: std::collections::HashMap<&str, Vec<(usize, f64)>> =
491        std::collections::HashMap::new();
492    for (i, (val, key)) in values.iter().zip(grouper).enumerate() {
493        groups.entry(key.as_str()).or_default().push((i, *val));
494    }
495    let mut result = vec![0.0; values.len()];
496    for group in groups.values() {
497        let group_vals: Vec<f64> = group.iter().map(|(_, v)| *v).collect();
498        let transformed = func(&group_vals);
499        for ((idx, _), new_val) in group.iter().zip(transformed) {
500            result[*idx] = new_val;
501        }
502    }
503    result
504}
505
506// ──────────────────────── Merge functions ────────────────────────
507
508/// Merge a list of simple variables with the same name.
509pub fn merge_simple(variables: &[&SimpleVariable]) -> Option<SimpleVariable> {
510    let first = variables.first()?;
511    let mut all_values = Vec::new();
512    let mut all_index = Vec::new();
513    for v in variables {
514        all_values.extend(v.str_values.iter().cloned());
515        all_index.extend(v.index.iter().cloned());
516    }
517    Some(SimpleVariable::new(
518        &first.name,
519        &first.source,
520        all_values,
521        all_index,
522    ))
523}
524
525/// Merge sparse run variables.
526pub fn merge_sparse(variables: &[&SparseRunVariable]) -> Option<SparseRunVariable> {
527    let first = variables.first()?;
528    let mut onset = Vec::new();
529    let mut duration = Vec::new();
530    let mut amplitude = Vec::new();
531    let mut index = Vec::new();
532    let mut run_info = Vec::new();
533    for v in variables {
534        onset.extend(&v.onset);
535        duration.extend(&v.duration);
536        amplitude.extend(v.str_amplitude.iter().cloned());
537        index.extend(v.index.iter().cloned());
538        run_info.extend(v.run_info.iter().cloned());
539    }
540    Some(SparseRunVariable::new(
541        &first.name,
542        &first.source,
543        onset,
544        duration,
545        amplitude,
546        index,
547        run_info,
548    ))
549}
550
551// ──────────────────────── Helpers ────────────────────────
552
553fn extract_common_entities(index: &[StringEntities]) -> StringEntities {
554    let mut common = StringEntities::new();
555    if let Some(first) = index.first() {
556        for (k, v) in first {
557            if index.iter().all(|row| row.get(k) == Some(v)) {
558                common.insert(k.clone(), v.clone());
559            }
560        }
561    }
562    common
563}
564
565fn gcd_pair(a: i64, b: i64) -> i64 {
566    let (mut a, mut b) = (a.abs(), b.abs());
567    while b != 0 {
568        let t = b;
569        b = a % b;
570        a = t;
571    }
572    a
573}
574
575fn linear_resample(values: &[f64], new_n: usize) -> Vec<f64> {
576    if new_n == 0 || values.is_empty() {
577        return vec![];
578    }
579    if new_n == values.len() {
580        return values.to_vec();
581    }
582    let old_n = values.len();
583    (0..new_n)
584        .map(|i| {
585            let t = if new_n > 1 {
586                (i as f64) * (old_n as f64 - 1.0) / (new_n as f64 - 1.0)
587            } else {
588                0.0
589            };
590            let lo = t.floor() as usize;
591            let hi = (lo + 1).min(old_n - 1);
592            let frac = t - lo as f64;
593            values[lo] * (1.0 - frac) + values[hi] * frac
594        })
595        .collect()
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use crate::node::RunInfo;
602    use std::collections::HashMap;
603
604    #[test]
605    fn test_sparse_to_dense() {
606        let ri = RunInfo {
607            entities: StringEntities::new(),
608            duration: 10.0,
609            tr: 2.0,
610            image: None,
611            n_vols: 5,
612        };
613        let sparse = SparseRunVariable::new(
614            "trial_type",
615            "events",
616            vec![1.0, 3.0],
617            vec![1.0, 2.0],
618            vec!["1".into(), "1".into()],
619            vec![StringEntities::new(), StringEntities::new()],
620            vec![ri],
621        );
622        let dense = sparse.to_dense(Some(10.0));
623        assert_eq!(dense.sampling_rate, 10.0);
624        assert_eq!(dense.values.len(), 100);
625        assert_eq!(dense.values[10], 1.0);
626        assert_eq!(dense.values[0], 0.0);
627    }
628
629    #[test]
630    fn test_simple_filter() {
631        let idx = vec![
632            HashMap::from([("subject".into(), "01".into())]),
633            HashMap::from([("subject".into(), "02".into())]),
634        ];
635        let var = SimpleVariable::new("age", "participants", vec!["25".into(), "30".into()], idx);
636        let filtered = var.filter(&HashMap::from([("subject".into(), "01".into())]));
637        assert_eq!(filtered.len(), 1);
638        assert_eq!(filtered.str_values[0], "25");
639    }
640
641    #[test]
642    fn test_merge_simple() {
643        let v1 = SimpleVariable::new(
644            "age",
645            "participants",
646            vec!["25".into()],
647            vec![HashMap::from([("subject".into(), "01".into())])],
648        );
649        let v2 = SimpleVariable::new(
650            "age",
651            "participants",
652            vec!["30".into()],
653            vec![HashMap::from([("subject".into(), "02".into())])],
654        );
655        let merged = merge_simple(&[&v1, &v2]).unwrap();
656        assert_eq!(merged.len(), 2);
657    }
658}