1use bids_core::entities::StringEntities;
8use bids_variables::collections::VariableCollection;
9use bids_variables::variables::SimpleVariable;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
20#[serde(tag = "Name")]
21pub enum Instruction {
22 Rename {
23 input: Vec<String>,
24 output: Vec<String>,
25 },
26 Copy {
27 input: Vec<String>,
28 output: Vec<String>,
29 },
30 Scale {
31 input: Vec<String>,
32 #[serde(default)]
33 demean: bool,
34 #[serde(default)]
35 rescale: bool,
36 #[serde(default)]
37 replace_na: Option<f64>,
38 },
39 Threshold {
40 input: Vec<String>,
41 #[serde(default = "default_threshold")]
42 threshold: f64,
43 #[serde(default)]
44 above: bool,
45 #[serde(default)]
46 binarize: bool,
47 #[serde(default)]
48 signed: bool,
49 },
50 And {
51 input: Vec<String>,
52 output: Option<Vec<String>>,
53 },
54 Or {
55 input: Vec<String>,
56 output: Option<Vec<String>>,
57 },
58 Not {
59 input: Vec<String>,
60 output: Option<Vec<String>>,
61 },
62 Product {
63 input: Vec<String>,
64 output: Option<String>,
65 },
66 Sum {
67 input: Vec<String>,
68 #[serde(default)]
69 weights: Vec<f64>,
70 output: Option<String>,
71 },
72 Power {
73 input: Vec<String>,
74 value: f64,
75 output: Option<Vec<String>>,
76 },
77 Factor {
78 input: Vec<String>,
79 },
80 Filter {
81 input: Vec<String>,
82 query: String,
83 },
84 Replace {
85 input: Vec<String>,
86 replace: HashMap<String, String>,
87 output: Option<Vec<String>>,
88 },
89 Select {
90 input: Vec<String>,
91 },
92 Delete {
93 input: Vec<String>,
94 },
95 Group {
96 input: Vec<String>,
97 output: String,
98 },
99 Resample {
100 input: Vec<String>,
101 sampling_rate: f64,
102 },
103 ToDense {
104 input: Vec<String>,
105 sampling_rate: Option<f64>,
106 },
107 Convolve {
108 input: Vec<String>,
109 #[serde(default = "default_hrf_model")]
110 model: String,
111 },
112}
113
114fn default_threshold() -> f64 {
115 0.0
116}
117fn default_hrf_model() -> String {
118 "spm".into()
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct TransformSpec {
128 pub transformer: String,
129 pub instructions: Vec<serde_json::Value>,
130}
131
132fn dispatch_instruction(collection: &mut VariableCollection, instruction: &serde_json::Value) {
134 let name = instruction
135 .get("Name")
136 .and_then(|v| v.as_str())
137 .unwrap_or("");
138
139 match name {
140 "Rename" => apply_rename(collection, instruction),
141 "Copy" => apply_copy(collection, instruction),
142 "Factor" => apply_factor(collection, instruction),
143 "Select" => apply_select(collection, instruction),
144 "Delete" => apply_delete(collection, instruction),
145 "Replace" => apply_replace(collection, instruction),
146 "Scale" => apply_scale(collection, instruction),
147 "Threshold" => apply_threshold(collection, instruction),
148 "DropNA" => apply_dropna(collection, instruction),
149 "Split" => apply_split(collection, instruction),
150 "Concatenate" => apply_concatenate(collection, instruction),
151 "Orthogonalize" => apply_orthogonalize(collection, instruction),
152 "Lag" => apply_lag(collection, instruction),
153 "Group" | "Resample" | "ToDense" | "Assign" | "Convolve" => {}
155 _ => {}
156 }
157}
158
159pub fn apply_transformations(collection: &mut VariableCollection, spec: &TransformSpec) {
164 for instruction in &spec.instructions {
165 dispatch_instruction(collection, instruction);
166 }
167}
168
169fn get_inputs(instruction: &serde_json::Value) -> Vec<String> {
170 instruction
171 .get("Input")
172 .and_then(|v| v.as_array())
173 .map(|arr| {
174 arr.iter()
175 .filter_map(|v| v.as_str().map(String::from))
176 .collect()
177 })
178 .or_else(|| {
179 instruction
180 .get("Input")
181 .and_then(|v| v.as_str())
182 .map(|s| vec![s.into()])
183 })
184 .unwrap_or_default()
185}
186
187fn get_outputs(instruction: &serde_json::Value) -> Vec<String> {
188 instruction
189 .get("Output")
190 .and_then(|v| v.as_array())
191 .map(|arr| {
192 arr.iter()
193 .filter_map(|v| v.as_str().map(String::from))
194 .collect()
195 })
196 .or_else(|| {
197 instruction
198 .get("Output")
199 .and_then(|v| v.as_str())
200 .map(|s| vec![s.into()])
201 })
202 .unwrap_or_default()
203}
204
205fn apply_rename(collection: &mut VariableCollection, instruction: &serde_json::Value) {
206 let inputs = get_inputs(instruction);
207 let outputs = get_outputs(instruction);
208 for (old, new) in inputs.iter().zip(outputs.iter()) {
209 if let Some(mut var) = collection.variables.remove(old) {
210 var.name = new.clone();
211 collection.variables.insert(new.clone(), var);
212 }
213 }
214}
215
216fn apply_copy(collection: &mut VariableCollection, instruction: &serde_json::Value) {
217 let inputs = get_inputs(instruction);
218 let outputs = get_outputs(instruction);
219 for (src, dst) in inputs.iter().zip(outputs.iter()) {
220 if let Some(var) = collection.variables.get(src) {
221 let mut copy = var.clone();
222 copy.name = dst.clone();
223 collection.variables.insert(dst.clone(), copy);
224 }
225 }
226}
227
228fn apply_factor(collection: &mut VariableCollection, instruction: &serde_json::Value) {
229 let inputs = get_inputs(instruction);
230 let mut new_vars = Vec::new();
231
232 for input_name in &inputs {
233 if let Some(var) = collection.variables.get(input_name) {
234 let str_values = var.str_values.clone();
235 let source = var.source.clone();
236 let index = var.index.clone();
237
238 let mut seen = std::collections::HashSet::new();
239 let unique: Vec<String> = str_values
240 .iter()
241 .filter(|v| !v.is_empty() && seen.insert((*v).clone()))
242 .cloned()
243 .collect();
244
245 for level in &unique {
246 let new_name = format!("{input_name}.{level}");
247 let values: Vec<String> = str_values
248 .iter()
249 .map(|v| if v == level { "1".into() } else { "0".into() })
250 .collect();
251 new_vars.push(SimpleVariable::new(
252 &new_name,
253 &source,
254 values,
255 index.clone(),
256 ));
257 }
258 }
259 }
260
261 for var in new_vars {
262 collection.variables.insert(var.name.clone(), var);
263 }
264}
265
266fn apply_select(collection: &mut VariableCollection, instruction: &serde_json::Value) {
267 let inputs = get_inputs(instruction);
268 let input_set: std::collections::HashSet<String> = inputs.into_iter().collect();
269 collection.variables.retain(|k, _| input_set.contains(k));
270}
271
272fn apply_delete(collection: &mut VariableCollection, instruction: &serde_json::Value) {
273 let inputs = get_inputs(instruction);
274 for name in &inputs {
275 collection.variables.remove(name);
276 }
277}
278
279fn apply_replace(collection: &mut VariableCollection, instruction: &serde_json::Value) {
280 let inputs = get_inputs(instruction);
281 let outputs = get_outputs(instruction);
282 let replace_map: HashMap<String, String> = instruction
283 .get("Replace")
284 .and_then(|v| serde_json::from_value(v.clone()).ok())
285 .unwrap_or_default();
286
287 for (i, input_name) in inputs.iter().enumerate() {
288 if let Some(var) = collection.variables.get(input_name) {
289 let new_values: Vec<String> = var
290 .str_values
291 .iter()
292 .map(|v| replace_map.get(v).cloned().unwrap_or_else(|| v.clone()))
293 .collect();
294 let out_name = outputs.get(i).unwrap_or(input_name);
295 let new_var = SimpleVariable::new(out_name, &var.source, new_values, var.index.clone());
296 collection.variables.insert(out_name.clone(), new_var);
297 }
298 }
299}
300
301fn apply_scale(collection: &mut VariableCollection, instruction: &serde_json::Value) {
302 let inputs = get_inputs(instruction);
303 let demean = instruction
304 .get("Demean")
305 .and_then(serde_json::Value::as_bool)
306 .unwrap_or(false);
307 let rescale = instruction
308 .get("Rescale")
309 .and_then(serde_json::Value::as_bool)
310 .unwrap_or(false);
311
312 for input_name in &inputs {
313 if let Some(var) = collection.variables.get_mut(input_name) {
314 if !var.is_numeric {
315 continue;
316 }
317 let vals = &var.values;
318 let finite: Vec<f64> = vals.iter().copied().filter(|v| v.is_finite()).collect();
319 if finite.is_empty() {
320 continue;
321 }
322
323 let mean = finite.iter().sum::<f64>() / finite.len() as f64;
324 let std = (finite.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
325 / finite.len() as f64)
326 .sqrt();
327
328 for (i, v) in var.values.iter_mut().enumerate() {
329 if !v.is_finite() {
330 continue;
331 }
332 if demean {
333 *v -= mean;
334 }
335 if rescale && std > 1e-15 {
336 *v /= std;
337 }
338 var.str_values[i] = v.to_string();
339 }
340 }
341 }
342}
343
344fn apply_threshold(collection: &mut VariableCollection, instruction: &serde_json::Value) {
345 let inputs = get_inputs(instruction);
346 let threshold = instruction
347 .get("Threshold")
348 .and_then(serde_json::Value::as_f64)
349 .unwrap_or(0.0);
350 let above = instruction
351 .get("Above")
352 .and_then(serde_json::Value::as_bool)
353 .unwrap_or(true);
354 let binarize = instruction
355 .get("Binarize")
356 .and_then(serde_json::Value::as_bool)
357 .unwrap_or(false);
358
359 for input_name in &inputs {
360 if let Some(var) = collection.variables.get_mut(input_name) {
361 if !var.is_numeric {
362 continue;
363 }
364 for (i, v) in var.values.iter_mut().enumerate() {
365 let passes = if above {
366 *v >= threshold
367 } else {
368 *v <= threshold
369 };
370 if binarize {
371 *v = if passes { 1.0 } else { 0.0 };
372 } else if !passes {
373 *v = 0.0;
374 }
375 var.str_values[i] = v.to_string();
376 }
377 }
378 }
379}
380
381fn apply_dropna(collection: &mut VariableCollection, instruction: &serde_json::Value) {
382 let inputs = get_inputs(instruction);
383 for input_name in &inputs {
384 if let Some(var) = collection.variables.get(input_name) {
385 let keep: Vec<usize> = var
386 .str_values
387 .iter()
388 .enumerate()
389 .filter(|(_, v)| !v.is_empty())
390 .map(|(i, _)| i)
391 .collect();
392 let new_values: Vec<String> = keep.iter().map(|&i| var.str_values[i].clone()).collect();
393 let new_index: Vec<StringEntities> = keep
394 .iter()
395 .filter_map(|&i| var.index.get(i).cloned())
396 .collect();
397 let new_var = SimpleVariable::new(&var.name, &var.source, new_values, new_index);
398 collection.variables.insert(input_name.clone(), new_var);
399 }
400 }
401}
402
403fn apply_split(collection: &mut VariableCollection, instruction: &serde_json::Value) {
404 let inputs = get_inputs(instruction);
405 let by = instruction.get("By").and_then(|v| v.as_str()).unwrap_or("");
406 if by.is_empty() {
407 return;
408 }
409
410 let mut new_vars = Vec::new();
411 for input_name in &inputs {
412 if let Some(var) = collection.variables.get(input_name) {
413 let by_var = collection.variables.get(by);
414 if let Some(group_var) = by_var {
415 let mut groups: std::collections::HashMap<String, Vec<usize>> =
416 std::collections::HashMap::new();
417 for (i, val) in group_var.str_values.iter().enumerate() {
418 groups.entry(val.clone()).or_default().push(i);
419 }
420 for (key, indices) in &groups {
421 let name = format!("{input_name}.{key}");
422 let values: Vec<String> = indices
423 .iter()
424 .map(|&i| var.str_values.get(i).cloned().unwrap_or_default())
425 .collect();
426 let index: Vec<StringEntities> = indices
427 .iter()
428 .filter_map(|&i| var.index.get(i).cloned())
429 .collect();
430 new_vars.push(SimpleVariable::new(&name, &var.source, values, index));
431 }
432 }
433 }
434 }
435 for v in new_vars {
436 collection.variables.insert(v.name.clone(), v);
437 }
438}
439
440fn apply_concatenate(collection: &mut VariableCollection, instruction: &serde_json::Value) {
441 let inputs = get_inputs(instruction);
442 let output = instruction
443 .get("Output")
444 .and_then(|v| v.as_str())
445 .unwrap_or("concatenated");
446 let mut all_values = Vec::new();
447 let mut all_index = Vec::new();
448 let mut source = String::new();
449 for input_name in &inputs {
450 if let Some(var) = collection.variables.get(input_name) {
451 if source.is_empty() {
452 source = var.source.clone();
453 }
454 all_values.extend(var.str_values.iter().cloned());
455 all_index.extend(var.index.iter().cloned());
456 }
457 }
458 if !all_values.is_empty() {
459 collection.variables.insert(
460 output.into(),
461 SimpleVariable::new(output, &source, all_values, all_index),
462 );
463 }
464}
465
466fn apply_orthogonalize(collection: &mut VariableCollection, instruction: &serde_json::Value) {
467 let inputs = get_inputs(instruction);
468 let other_names: Vec<String> = instruction
469 .get("Other")
470 .and_then(|v| v.as_array())
471 .map(|arr| {
472 arr.iter()
473 .filter_map(|v| v.as_str().map(String::from))
474 .collect()
475 })
476 .unwrap_or_default();
477
478 for input_name in &inputs {
479 if let Some(var) = collection.variables.get(input_name) {
480 if !var.is_numeric {
481 continue;
482 }
483 let mut x = var.values.clone();
484 for other_name in &other_names {
486 if let Some(other) = collection.variables.get(other_name) {
487 if other.values.len() != x.len() {
488 continue;
489 }
490 let dot_xo: f64 = x.iter().zip(&other.values).map(|(a, b)| a * b).sum();
491 let dot_oo: f64 = other.values.iter().map(|v| v * v).sum();
492 if dot_oo.abs() > 1e-15 {
493 let proj = dot_xo / dot_oo;
494 for (xi, oi) in x.iter_mut().zip(&other.values) {
495 *xi -= proj * oi;
496 }
497 }
498 }
499 }
500 let new_values: Vec<String> = x.iter().map(std::string::ToString::to_string).collect();
501 let new_var =
502 SimpleVariable::new(&var.name, &var.source, new_values, var.index.clone());
503 collection.variables.insert(input_name.clone(), new_var);
504 }
505 }
506}
507
508fn apply_lag(collection: &mut VariableCollection, instruction: &serde_json::Value) {
509 let inputs = get_inputs(instruction);
510 let n_shift = instruction
511 .get("N")
512 .and_then(serde_json::Value::as_i64)
513 .unwrap_or(1);
514 let outputs = get_outputs(instruction);
515
516 for (i, input_name) in inputs.iter().enumerate() {
517 if let Some(var) = collection.variables.get(input_name) {
518 if !var.is_numeric {
519 continue;
520 }
521 let n = var.values.len();
522 let lagged: Vec<f64> = (0..n)
523 .map(|j| {
524 let src = j as i64 - n_shift;
525 if src >= 0 && (src as usize) < n {
526 var.values[src as usize]
527 } else {
528 0.0
529 }
530 })
531 .collect();
532 let new_values: Vec<String> = lagged
533 .iter()
534 .map(std::string::ToString::to_string)
535 .collect();
536 let out_name = outputs.get(i).unwrap_or(input_name);
537 let new_var = SimpleVariable::new(out_name, &var.source, new_values, var.index.clone());
538 collection.variables.insert(out_name.clone(), new_var);
539 }
540 }
541}
542
543pub struct TransformerManager {
545 pub transformer: String,
546 pub keep_history: bool,
547 pub history: Vec<VariableCollection>,
548}
549
550impl TransformerManager {
551 pub fn new(transformer: &str, keep_history: bool) -> Self {
552 Self {
553 transformer: transformer.into(),
554 keep_history,
555 history: Vec::new(),
556 }
557 }
558
559 pub fn transform(
560 &mut self,
561 mut collection: VariableCollection,
562 spec: &TransformSpec,
563 ) -> VariableCollection {
564 for instruction in &spec.instructions {
565 dispatch_instruction(&mut collection, instruction);
566 if self.keep_history {
567 self.history.push(collection.clone());
568 }
569 }
570 collection
571 }
572}
573
574pub fn expand_wildcards(selectors: &[String], pool: &[String]) -> Vec<String> {
576 let mut out = Vec::new();
577 for spec in selectors {
578 if spec.contains('*') || spec.contains('?') || spec.contains('[') {
579 let re_str = format!(
580 "^{}$",
581 spec.replace('.', r"\.")
582 .replace('*', ".*")
583 .replace('?', ".")
584 );
585 if let Ok(re) = regex::Regex::new(&re_str) {
586 for name in pool {
587 if re.is_match(name) {
588 out.push(name.clone());
589 }
590 }
591 }
592 } else {
593 out.push(spec.clone());
594 }
595 }
596 out
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 fn make_collection() -> VariableCollection {
604 use bids_core::entities::StringEntities;
605 let v1 = SimpleVariable::new(
606 "trial_type",
607 "events",
608 vec!["face".into(), "house".into(), "face".into()],
609 vec![StringEntities::new(); 3],
610 );
611 let v2 = SimpleVariable::new(
612 "rt",
613 "events",
614 vec!["0.5".into(), "0.7".into(), "0.6".into()],
615 vec![StringEntities::new(); 3],
616 );
617 VariableCollection::new(vec![v1, v2])
618 }
619
620 #[test]
621 fn test_factor() {
622 let mut col = make_collection();
623 let instr = serde_json::json!({"Name": "Factor", "Input": ["trial_type"]});
624 apply_factor(&mut col, &instr);
625 assert!(col.variables.contains_key("trial_type.face"));
626 assert!(col.variables.contains_key("trial_type.house"));
627 assert_eq!(
628 col.variables["trial_type.face"].str_values,
629 vec!["1", "0", "1"]
630 );
631 }
632
633 #[test]
634 fn test_rename() {
635 let mut col = make_collection();
636 let instr =
637 serde_json::json!({"Name": "Rename", "Input": ["rt"], "Output": ["reaction_time"]});
638 apply_rename(&mut col, &instr);
639 assert!(!col.variables.contains_key("rt"));
640 assert!(col.variables.contains_key("reaction_time"));
641 }
642
643 #[test]
644 fn test_scale() {
645 let mut col = make_collection();
646 let instr =
647 serde_json::json!({"Name": "Scale", "Input": ["rt"], "Demean": true, "Rescale": true});
648 apply_scale(&mut col, &instr);
649 let vals = &col.variables["rt"].values;
650 let mean: f64 = vals.iter().sum::<f64>() / vals.len() as f64;
651 assert!(
652 mean.abs() < 1e-10,
653 "Mean should be ~0 after demean, got {}",
654 mean
655 );
656 }
657
658 #[test]
659 fn test_threshold() {
660 let mut col = make_collection();
661 let instr = serde_json::json!({"Name": "Threshold", "Input": ["rt"], "Threshold": 0.6, "Above": true, "Binarize": true});
662 apply_threshold(&mut col, &instr);
663 assert_eq!(col.variables["rt"].values, vec![0.0, 1.0, 1.0]);
664 }
665}