Skip to main content

bids_modeling/
transformations.rs

1//! Variable transformations from the `pybids-transforms-v1` specification.
2//!
3//! Implements Rename, Copy, Scale, Threshold, Factor, Filter, Replace, Select,
4//! Delete, And, Or, Not, Product, Sum, Power, and Convolve transformations
5//! that operate on variable collections within a BIDS-StatsModels pipeline.
6
7use bids_core::entities::StringEntities;
8use bids_variables::collections::VariableCollection;
9use bids_variables::variables::SimpleVariable;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13/// A single transformation instruction from the `pybids-transforms-v1` spec.
14///
15/// Each variant corresponds to a named transformation that can be applied
16/// to variables in a [`VariableCollection`]. Transformations modify variable
17/// data in place (rename, scale, threshold) or create new variables (factor,
18/// copy, split).
19#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "Name")]
21pub enum Instruction {
22    Rename {
23        input: Vec<String>,
24        output: Vec<String>,
25    },
26    Copy {
27        input: Vec<String>,
28        output: Vec<String>,
29    },
30    Scale {
31        input: Vec<String>,
32        #[serde(default)]
33        demean: bool,
34        #[serde(default)]
35        rescale: bool,
36        #[serde(default)]
37        replace_na: Option<f64>,
38    },
39    Threshold {
40        input: Vec<String>,
41        #[serde(default = "default_threshold")]
42        threshold: f64,
43        #[serde(default)]
44        above: bool,
45        #[serde(default)]
46        binarize: bool,
47        #[serde(default)]
48        signed: bool,
49    },
50    And {
51        input: Vec<String>,
52        output: Option<Vec<String>>,
53    },
54    Or {
55        input: Vec<String>,
56        output: Option<Vec<String>>,
57    },
58    Not {
59        input: Vec<String>,
60        output: Option<Vec<String>>,
61    },
62    Product {
63        input: Vec<String>,
64        output: Option<String>,
65    },
66    Sum {
67        input: Vec<String>,
68        #[serde(default)]
69        weights: Vec<f64>,
70        output: Option<String>,
71    },
72    Power {
73        input: Vec<String>,
74        value: f64,
75        output: Option<Vec<String>>,
76    },
77    Factor {
78        input: Vec<String>,
79    },
80    Filter {
81        input: Vec<String>,
82        query: String,
83    },
84    Replace {
85        input: Vec<String>,
86        replace: HashMap<String, String>,
87        output: Option<Vec<String>>,
88    },
89    Select {
90        input: Vec<String>,
91    },
92    Delete {
93        input: Vec<String>,
94    },
95    Group {
96        input: Vec<String>,
97        output: String,
98    },
99    Resample {
100        input: Vec<String>,
101        sampling_rate: f64,
102    },
103    ToDense {
104        input: Vec<String>,
105        sampling_rate: Option<f64>,
106    },
107    Convolve {
108        input: Vec<String>,
109        #[serde(default = "default_hrf_model")]
110        model: String,
111    },
112}
113
114fn default_threshold() -> f64 {
115    0.0
116}
117fn default_hrf_model() -> String {
118    "spm".into()
119}
120
121/// A transformation specification containing a transformer name and instructions.
122///
123/// The `transformer` field identifies the transformation engine (typically
124/// `"pybids-transforms-v1"`). The `instructions` field contains a list of
125/// JSON transformation objects that are applied in order.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct TransformSpec {
128    pub transformer: String,
129    pub instructions: Vec<serde_json::Value>,
130}
131
132/// Dispatch a single transformation instruction on a collection.
133fn dispatch_instruction(collection: &mut VariableCollection, instruction: &serde_json::Value) {
134    let name = instruction
135        .get("Name")
136        .and_then(|v| v.as_str())
137        .unwrap_or("");
138
139    match name {
140        "Rename" => apply_rename(collection, instruction),
141        "Copy" => apply_copy(collection, instruction),
142        "Factor" => apply_factor(collection, instruction),
143        "Select" => apply_select(collection, instruction),
144        "Delete" => apply_delete(collection, instruction),
145        "Replace" => apply_replace(collection, instruction),
146        "Scale" => apply_scale(collection, instruction),
147        "Threshold" => apply_threshold(collection, instruction),
148        "DropNA" => apply_dropna(collection, instruction),
149        "Split" => apply_split(collection, instruction),
150        "Concatenate" => apply_concatenate(collection, instruction),
151        "Orthogonalize" => apply_orthogonalize(collection, instruction),
152        "Lag" => apply_lag(collection, instruction),
153        // No-ops: metadata-only or run-level only transforms
154        "Group" | "Resample" | "ToDense" | "Assign" | "Convolve" => {}
155        _ => {}
156    }
157}
158
159/// Apply transformations to a VariableCollection.
160///
161/// This is a simplified version — PyBIDS supports a full transformer
162/// plugin system. We implement the core pybids-transforms-v1 instructions.
163pub fn apply_transformations(collection: &mut VariableCollection, spec: &TransformSpec) {
164    for instruction in &spec.instructions {
165        dispatch_instruction(collection, instruction);
166    }
167}
168
169fn get_inputs(instruction: &serde_json::Value) -> Vec<String> {
170    instruction
171        .get("Input")
172        .and_then(|v| v.as_array())
173        .map(|arr| {
174            arr.iter()
175                .filter_map(|v| v.as_str().map(String::from))
176                .collect()
177        })
178        .or_else(|| {
179            instruction
180                .get("Input")
181                .and_then(|v| v.as_str())
182                .map(|s| vec![s.into()])
183        })
184        .unwrap_or_default()
185}
186
187fn get_outputs(instruction: &serde_json::Value) -> Vec<String> {
188    instruction
189        .get("Output")
190        .and_then(|v| v.as_array())
191        .map(|arr| {
192            arr.iter()
193                .filter_map(|v| v.as_str().map(String::from))
194                .collect()
195        })
196        .or_else(|| {
197            instruction
198                .get("Output")
199                .and_then(|v| v.as_str())
200                .map(|s| vec![s.into()])
201        })
202        .unwrap_or_default()
203}
204
205fn apply_rename(collection: &mut VariableCollection, instruction: &serde_json::Value) {
206    let inputs = get_inputs(instruction);
207    let outputs = get_outputs(instruction);
208    for (old, new) in inputs.iter().zip(outputs.iter()) {
209        if let Some(mut var) = collection.variables.remove(old) {
210            var.name = new.clone();
211            collection.variables.insert(new.clone(), var);
212        }
213    }
214}
215
216fn apply_copy(collection: &mut VariableCollection, instruction: &serde_json::Value) {
217    let inputs = get_inputs(instruction);
218    let outputs = get_outputs(instruction);
219    for (src, dst) in inputs.iter().zip(outputs.iter()) {
220        if let Some(var) = collection.variables.get(src) {
221            let mut copy = var.clone();
222            copy.name = dst.clone();
223            collection.variables.insert(dst.clone(), copy);
224        }
225    }
226}
227
228fn apply_factor(collection: &mut VariableCollection, instruction: &serde_json::Value) {
229    let inputs = get_inputs(instruction);
230    let mut new_vars = Vec::new();
231
232    for input_name in &inputs {
233        if let Some(var) = collection.variables.get(input_name) {
234            let str_values = var.str_values.clone();
235            let source = var.source.clone();
236            let index = var.index.clone();
237
238            let mut seen = std::collections::HashSet::new();
239            let unique: Vec<String> = str_values
240                .iter()
241                .filter(|v| !v.is_empty() && seen.insert((*v).clone()))
242                .cloned()
243                .collect();
244
245            for level in &unique {
246                let new_name = format!("{input_name}.{level}");
247                let values: Vec<String> = str_values
248                    .iter()
249                    .map(|v| if v == level { "1".into() } else { "0".into() })
250                    .collect();
251                new_vars.push(SimpleVariable::new(
252                    &new_name,
253                    &source,
254                    values,
255                    index.clone(),
256                ));
257            }
258        }
259    }
260
261    for var in new_vars {
262        collection.variables.insert(var.name.clone(), var);
263    }
264}
265
266fn apply_select(collection: &mut VariableCollection, instruction: &serde_json::Value) {
267    let inputs = get_inputs(instruction);
268    let input_set: std::collections::HashSet<String> = inputs.into_iter().collect();
269    collection.variables.retain(|k, _| input_set.contains(k));
270}
271
272fn apply_delete(collection: &mut VariableCollection, instruction: &serde_json::Value) {
273    let inputs = get_inputs(instruction);
274    for name in &inputs {
275        collection.variables.remove(name);
276    }
277}
278
279fn apply_replace(collection: &mut VariableCollection, instruction: &serde_json::Value) {
280    let inputs = get_inputs(instruction);
281    let outputs = get_outputs(instruction);
282    let replace_map: HashMap<String, String> = instruction
283        .get("Replace")
284        .and_then(|v| serde_json::from_value(v.clone()).ok())
285        .unwrap_or_default();
286
287    for (i, input_name) in inputs.iter().enumerate() {
288        if let Some(var) = collection.variables.get(input_name) {
289            let new_values: Vec<String> = var
290                .str_values
291                .iter()
292                .map(|v| replace_map.get(v).cloned().unwrap_or_else(|| v.clone()))
293                .collect();
294            let out_name = outputs.get(i).unwrap_or(input_name);
295            let new_var = SimpleVariable::new(out_name, &var.source, new_values, var.index.clone());
296            collection.variables.insert(out_name.clone(), new_var);
297        }
298    }
299}
300
301fn apply_scale(collection: &mut VariableCollection, instruction: &serde_json::Value) {
302    let inputs = get_inputs(instruction);
303    let demean = instruction
304        .get("Demean")
305        .and_then(serde_json::Value::as_bool)
306        .unwrap_or(false);
307    let rescale = instruction
308        .get("Rescale")
309        .and_then(serde_json::Value::as_bool)
310        .unwrap_or(false);
311
312    for input_name in &inputs {
313        if let Some(var) = collection.variables.get_mut(input_name) {
314            if !var.is_numeric {
315                continue;
316            }
317            let vals = &var.values;
318            let finite: Vec<f64> = vals.iter().copied().filter(|v| v.is_finite()).collect();
319            if finite.is_empty() {
320                continue;
321            }
322
323            let mean = finite.iter().sum::<f64>() / finite.len() as f64;
324            let std = (finite.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
325                / finite.len() as f64)
326                .sqrt();
327
328            for (i, v) in var.values.iter_mut().enumerate() {
329                if !v.is_finite() {
330                    continue;
331                }
332                if demean {
333                    *v -= mean;
334                }
335                if rescale && std > 1e-15 {
336                    *v /= std;
337                }
338                var.str_values[i] = v.to_string();
339            }
340        }
341    }
342}
343
344fn apply_threshold(collection: &mut VariableCollection, instruction: &serde_json::Value) {
345    let inputs = get_inputs(instruction);
346    let threshold = instruction
347        .get("Threshold")
348        .and_then(serde_json::Value::as_f64)
349        .unwrap_or(0.0);
350    let above = instruction
351        .get("Above")
352        .and_then(serde_json::Value::as_bool)
353        .unwrap_or(true);
354    let binarize = instruction
355        .get("Binarize")
356        .and_then(serde_json::Value::as_bool)
357        .unwrap_or(false);
358
359    for input_name in &inputs {
360        if let Some(var) = collection.variables.get_mut(input_name) {
361            if !var.is_numeric {
362                continue;
363            }
364            for (i, v) in var.values.iter_mut().enumerate() {
365                let passes = if above {
366                    *v >= threshold
367                } else {
368                    *v <= threshold
369                };
370                if binarize {
371                    *v = if passes { 1.0 } else { 0.0 };
372                } else if !passes {
373                    *v = 0.0;
374                }
375                var.str_values[i] = v.to_string();
376            }
377        }
378    }
379}
380
381fn apply_dropna(collection: &mut VariableCollection, instruction: &serde_json::Value) {
382    let inputs = get_inputs(instruction);
383    for input_name in &inputs {
384        if let Some(var) = collection.variables.get(input_name) {
385            let keep: Vec<usize> = var
386                .str_values
387                .iter()
388                .enumerate()
389                .filter(|(_, v)| !v.is_empty())
390                .map(|(i, _)| i)
391                .collect();
392            let new_values: Vec<String> = keep.iter().map(|&i| var.str_values[i].clone()).collect();
393            let new_index: Vec<StringEntities> = keep
394                .iter()
395                .filter_map(|&i| var.index.get(i).cloned())
396                .collect();
397            let new_var = SimpleVariable::new(&var.name, &var.source, new_values, new_index);
398            collection.variables.insert(input_name.clone(), new_var);
399        }
400    }
401}
402
403fn apply_split(collection: &mut VariableCollection, instruction: &serde_json::Value) {
404    let inputs = get_inputs(instruction);
405    let by = instruction.get("By").and_then(|v| v.as_str()).unwrap_or("");
406    if by.is_empty() {
407        return;
408    }
409
410    let mut new_vars = Vec::new();
411    for input_name in &inputs {
412        if let Some(var) = collection.variables.get(input_name) {
413            let by_var = collection.variables.get(by);
414            if let Some(group_var) = by_var {
415                let mut groups: std::collections::HashMap<String, Vec<usize>> =
416                    std::collections::HashMap::new();
417                for (i, val) in group_var.str_values.iter().enumerate() {
418                    groups.entry(val.clone()).or_default().push(i);
419                }
420                for (key, indices) in &groups {
421                    let name = format!("{input_name}.{key}");
422                    let values: Vec<String> = indices
423                        .iter()
424                        .map(|&i| var.str_values.get(i).cloned().unwrap_or_default())
425                        .collect();
426                    let index: Vec<StringEntities> = indices
427                        .iter()
428                        .filter_map(|&i| var.index.get(i).cloned())
429                        .collect();
430                    new_vars.push(SimpleVariable::new(&name, &var.source, values, index));
431                }
432            }
433        }
434    }
435    for v in new_vars {
436        collection.variables.insert(v.name.clone(), v);
437    }
438}
439
440fn apply_concatenate(collection: &mut VariableCollection, instruction: &serde_json::Value) {
441    let inputs = get_inputs(instruction);
442    let output = instruction
443        .get("Output")
444        .and_then(|v| v.as_str())
445        .unwrap_or("concatenated");
446    let mut all_values = Vec::new();
447    let mut all_index = Vec::new();
448    let mut source = String::new();
449    for input_name in &inputs {
450        if let Some(var) = collection.variables.get(input_name) {
451            if source.is_empty() {
452                source = var.source.clone();
453            }
454            all_values.extend(var.str_values.iter().cloned());
455            all_index.extend(var.index.iter().cloned());
456        }
457    }
458    if !all_values.is_empty() {
459        collection.variables.insert(
460            output.into(),
461            SimpleVariable::new(output, &source, all_values, all_index),
462        );
463    }
464}
465
466fn apply_orthogonalize(collection: &mut VariableCollection, instruction: &serde_json::Value) {
467    let inputs = get_inputs(instruction);
468    let other_names: Vec<String> = instruction
469        .get("Other")
470        .and_then(|v| v.as_array())
471        .map(|arr| {
472            arr.iter()
473                .filter_map(|v| v.as_str().map(String::from))
474                .collect()
475        })
476        .unwrap_or_default();
477
478    for input_name in &inputs {
479        if let Some(var) = collection.variables.get(input_name) {
480            if !var.is_numeric {
481                continue;
482            }
483            let mut x = var.values.clone();
484            // Gram-Schmidt: orthogonalize x against each other variable
485            for other_name in &other_names {
486                if let Some(other) = collection.variables.get(other_name) {
487                    if other.values.len() != x.len() {
488                        continue;
489                    }
490                    let dot_xo: f64 = x.iter().zip(&other.values).map(|(a, b)| a * b).sum();
491                    let dot_oo: f64 = other.values.iter().map(|v| v * v).sum();
492                    if dot_oo.abs() > 1e-15 {
493                        let proj = dot_xo / dot_oo;
494                        for (xi, oi) in x.iter_mut().zip(&other.values) {
495                            *xi -= proj * oi;
496                        }
497                    }
498                }
499            }
500            let new_values: Vec<String> = x.iter().map(std::string::ToString::to_string).collect();
501            let new_var =
502                SimpleVariable::new(&var.name, &var.source, new_values, var.index.clone());
503            collection.variables.insert(input_name.clone(), new_var);
504        }
505    }
506}
507
508fn apply_lag(collection: &mut VariableCollection, instruction: &serde_json::Value) {
509    let inputs = get_inputs(instruction);
510    let n_shift = instruction
511        .get("N")
512        .and_then(serde_json::Value::as_i64)
513        .unwrap_or(1);
514    let outputs = get_outputs(instruction);
515
516    for (i, input_name) in inputs.iter().enumerate() {
517        if let Some(var) = collection.variables.get(input_name) {
518            if !var.is_numeric {
519                continue;
520            }
521            let n = var.values.len();
522            let lagged: Vec<f64> = (0..n)
523                .map(|j| {
524                    let src = j as i64 - n_shift;
525                    if src >= 0 && (src as usize) < n {
526                        var.values[src as usize]
527                    } else {
528                        0.0
529                    }
530                })
531                .collect();
532            let new_values: Vec<String> = lagged
533                .iter()
534                .map(std::string::ToString::to_string)
535                .collect();
536            let out_name = outputs.get(i).unwrap_or(input_name);
537            let new_var = SimpleVariable::new(out_name, &var.source, new_values, var.index.clone());
538            collection.variables.insert(out_name.clone(), new_var);
539        }
540    }
541}
542
543/// A TransformerManager that tracks transform history.
544pub struct TransformerManager {
545    pub transformer: String,
546    pub keep_history: bool,
547    pub history: Vec<VariableCollection>,
548}
549
550impl TransformerManager {
551    pub fn new(transformer: &str, keep_history: bool) -> Self {
552        Self {
553            transformer: transformer.into(),
554            keep_history,
555            history: Vec::new(),
556        }
557    }
558
559    pub fn transform(
560        &mut self,
561        mut collection: VariableCollection,
562        spec: &TransformSpec,
563    ) -> VariableCollection {
564        for instruction in &spec.instructions {
565            dispatch_instruction(&mut collection, instruction);
566            if self.keep_history {
567                self.history.push(collection.clone());
568            }
569        }
570        collection
571    }
572}
573
574/// Expand wildcard patterns in a list of variable names.
575pub fn expand_wildcards(selectors: &[String], pool: &[String]) -> Vec<String> {
576    let mut out = Vec::new();
577    for spec in selectors {
578        if spec.contains('*') || spec.contains('?') || spec.contains('[') {
579            let re_str = format!(
580                "^{}$",
581                spec.replace('.', r"\.")
582                    .replace('*', ".*")
583                    .replace('?', ".")
584            );
585            if let Ok(re) = regex::Regex::new(&re_str) {
586                for name in pool {
587                    if re.is_match(name) {
588                        out.push(name.clone());
589                    }
590                }
591            }
592        } else {
593            out.push(spec.clone());
594        }
595    }
596    out
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    fn make_collection() -> VariableCollection {
604        use bids_core::entities::StringEntities;
605        let v1 = SimpleVariable::new(
606            "trial_type",
607            "events",
608            vec!["face".into(), "house".into(), "face".into()],
609            vec![StringEntities::new(); 3],
610        );
611        let v2 = SimpleVariable::new(
612            "rt",
613            "events",
614            vec!["0.5".into(), "0.7".into(), "0.6".into()],
615            vec![StringEntities::new(); 3],
616        );
617        VariableCollection::new(vec![v1, v2])
618    }
619
620    #[test]
621    fn test_factor() {
622        let mut col = make_collection();
623        let instr = serde_json::json!({"Name": "Factor", "Input": ["trial_type"]});
624        apply_factor(&mut col, &instr);
625        assert!(col.variables.contains_key("trial_type.face"));
626        assert!(col.variables.contains_key("trial_type.house"));
627        assert_eq!(
628            col.variables["trial_type.face"].str_values,
629            vec!["1", "0", "1"]
630        );
631    }
632
633    #[test]
634    fn test_rename() {
635        let mut col = make_collection();
636        let instr =
637            serde_json::json!({"Name": "Rename", "Input": ["rt"], "Output": ["reaction_time"]});
638        apply_rename(&mut col, &instr);
639        assert!(!col.variables.contains_key("rt"));
640        assert!(col.variables.contains_key("reaction_time"));
641    }
642
643    #[test]
644    fn test_scale() {
645        let mut col = make_collection();
646        let instr =
647            serde_json::json!({"Name": "Scale", "Input": ["rt"], "Demean": true, "Rescale": true});
648        apply_scale(&mut col, &instr);
649        let vals = &col.variables["rt"].values;
650        let mean: f64 = vals.iter().sum::<f64>() / vals.len() as f64;
651        assert!(
652            mean.abs() < 1e-10,
653            "Mean should be ~0 after demean, got {}",
654            mean
655        );
656    }
657
658    #[test]
659    fn test_threshold() {
660        let mut col = make_collection();
661        let instr = serde_json::json!({"Name": "Threshold", "Input": ["rt"], "Threshold": 0.6, "Above": true, "Binarize": true});
662        apply_threshold(&mut col, &instr);
663        assert_eq!(col.variables["rt"].values, vec![0.0, 1.0, 1.0]);
664    }
665}