1use bids_core::entities::StringEntities;
9
10#[derive(Debug, Clone)]
39pub struct SimpleVariable {
40 pub name: String,
41 pub source: String,
42 pub values: Vec<f64>,
43 pub str_values: Vec<String>,
44 pub index: Vec<StringEntities>,
45 pub entities: StringEntities,
46 pub is_numeric: bool,
47}
48
49impl SimpleVariable {
50 pub fn new(name: &str, source: &str, values: Vec<String>, index: Vec<StringEntities>) -> Self {
51 let numeric_values: Vec<f64> = values
52 .iter()
53 .map(|v| v.parse().unwrap_or(f64::NAN))
54 .collect();
55 let is_numeric = values
56 .iter()
57 .all(|v| v.parse::<f64>().is_ok() || v.is_empty());
58 let entities = extract_common_entities(&index);
59
60 Self {
61 name: name.to_string(),
62 source: source.to_string(),
63 values: numeric_values,
64 str_values: values,
65 index,
66 entities,
67 is_numeric,
68 }
69 }
70
71 pub fn len(&self) -> usize {
72 self.str_values.len()
73 }
74 pub fn is_empty(&self) -> bool {
75 self.str_values.is_empty()
76 }
77
78 pub fn clone_with(&self, data: Option<Vec<String>>, name: Option<&str>) -> Self {
80 let mut cloned = self.clone();
81 if let Some(d) = data {
82 cloned.values = d.iter().map(|v| v.parse().unwrap_or(f64::NAN)).collect();
83 cloned.str_values = d;
84 }
85 if let Some(n) = name {
86 cloned.name = n.to_string();
87 }
88 cloned
89 }
90
91 pub fn filter(&self, filters: &StringEntities) -> Self {
93 let mut values = Vec::new();
94 let mut index = Vec::new();
95
96 for (i, row_ents) in self.index.iter().enumerate() {
97 if filters
98 .iter()
99 .all(|(k, v)| row_ents.get(k).is_none_or(|rv| rv == v))
100 {
101 values.push(self.str_values[i].clone());
102 index.push(row_ents.clone());
103 }
104 }
105
106 Self::new(&self.name, &self.source, values, index)
107 }
108
109 pub fn to_rows(&self) -> Vec<StringEntities> {
111 self.str_values
112 .iter()
113 .enumerate()
114 .map(|(i, val)| {
115 let mut row = self.index.get(i).cloned().unwrap_or_default();
116 row.insert("amplitude".into(), val.clone());
117 row.insert("condition".into(), self.name.clone());
118 row
119 })
120 .collect()
121 }
122}
123
124#[derive(Debug, Clone)]
136pub struct SparseRunVariable {
137 pub name: String,
138 pub source: String,
139 pub onset: Vec<f64>,
140 pub duration: Vec<f64>,
141 pub amplitude: Vec<f64>,
142 pub str_amplitude: Vec<String>,
143 pub index: Vec<StringEntities>,
144 pub entities: StringEntities,
145 pub run_info: Vec<super::node::RunInfo>,
146}
147
148impl SparseRunVariable {
149 pub fn new(
150 name: &str,
151 source: &str,
152 onset: Vec<f64>,
153 duration: Vec<f64>,
154 amplitude: Vec<String>,
155 index: Vec<StringEntities>,
156 run_info: Vec<super::node::RunInfo>,
157 ) -> Self {
158 let numeric_amp: Vec<f64> = amplitude
159 .iter()
160 .map(|v| v.parse().unwrap_or(f64::NAN))
161 .collect();
162 let mut entities = extract_common_entities(&index);
163 if let Some(first_run) = run_info.first() {
165 for (k, v) in &first_run.entities {
166 if run_info.iter().all(|r| r.entities.get(k) == Some(v)) {
167 entities.entry(k.clone()).or_insert_with(|| v.clone());
168 }
169 }
170 }
171
172 Self {
173 name: name.to_string(),
174 source: source.to_string(),
175 onset,
176 duration,
177 amplitude: numeric_amp,
178 str_amplitude: amplitude,
179 index,
180 entities,
181 run_info,
182 }
183 }
184
185 pub fn len(&self) -> usize {
186 self.onset.len()
187 }
188 pub fn is_empty(&self) -> bool {
189 self.onset.is_empty()
190 }
191
192 pub fn get_duration(&self) -> f64 {
194 self.run_info.iter().map(|r| r.duration).sum()
195 }
196
197 pub fn to_dense(&self, sampling_rate: Option<f64>) -> DenseRunVariable {
199 let onsets_ms: Vec<i64> = self
200 .onset
201 .iter()
202 .map(|o| (o * 1000.0).round() as i64)
203 .collect();
204 let durations_ms: Vec<i64> = self
205 .duration
206 .iter()
207 .map(|d| (d * 1000.0).round() as i64)
208 .collect();
209
210 let all_vals: Vec<i64> = onsets_ms
211 .iter()
212 .chain(durations_ms.iter())
213 .copied()
214 .filter(|&v| v > 0)
215 .collect();
216 let gcd_val = all_vals
217 .iter()
218 .copied()
219 .reduce(gcd_pair)
220 .unwrap_or(1)
221 .max(1);
222
223 let bin_sr = 1000.0 / gcd_val as f64;
224 let sr = sampling_rate.map_or(bin_sr, |s| s.max(bin_sr));
225 let total_duration = self.get_duration();
226 let n_samples = (total_duration * sr).ceil() as usize;
227 let mut ts = vec![0.0f64; n_samples];
228
229 let mut run_offset = 0.0;
230 let mut last_onset = -1.0f64;
231 let mut run_i = 0;
232
233 for i in 0..self.onset.len() {
234 if self.onset[i] < last_onset && run_i + 1 < self.run_info.len() {
235 run_offset += self.run_info[run_i].duration;
236 run_i += 1;
237 }
238 let onset_sample = ((run_offset + self.onset[i]) * sr).round() as usize;
239 let dur_samples = (self.duration[i] * sr).round() as usize;
240 let offset_sample = (onset_sample + dur_samples).min(n_samples);
241 for ts_val in ts.iter_mut().take(offset_sample).skip(onset_sample) {
242 *ts_val = self.amplitude[i];
243 }
244 last_onset = self.onset[i];
245 }
246
247 let final_sr = sampling_rate.unwrap_or(sr);
248 if (final_sr - sr).abs() > 0.001 {
249 let new_n = (total_duration * final_sr).ceil() as usize;
250 ts = linear_resample(&ts, new_n);
251 }
252
253 DenseRunVariable::new(
254 &self.name,
255 &self.source,
256 ts,
257 final_sr,
258 self.run_info.clone(),
259 )
260 }
261
262 pub fn filter(&self, filters: &StringEntities) -> Self {
264 let mut onset = Vec::new();
265 let mut duration = Vec::new();
266 let mut amplitude = Vec::new();
267 let mut index = Vec::new();
268
269 for (i, row_ents) in self.index.iter().enumerate() {
270 if filters
271 .iter()
272 .all(|(k, v)| row_ents.get(k).is_none_or(|rv| rv == v))
273 {
274 onset.push(self.onset[i]);
275 duration.push(self.duration[i]);
276 amplitude.push(self.str_amplitude[i].clone());
277 index.push(row_ents.clone());
278 }
279 }
280
281 Self::new(
282 &self.name,
283 &self.source,
284 onset,
285 duration,
286 amplitude,
287 index,
288 self.run_info.clone(),
289 )
290 }
291
292 pub fn to_rows(&self) -> Vec<StringEntities> {
294 (0..self.onset.len())
295 .map(|i| {
296 let mut row = self.index.get(i).cloned().unwrap_or_default();
297 row.insert("onset".into(), self.onset[i].to_string());
298 row.insert("duration".into(), self.duration[i].to_string());
299 row.insert("amplitude".into(), self.str_amplitude[i].clone());
300 row.insert("condition".into(), self.name.clone());
301 row
302 })
303 .collect()
304 }
305}
306
307#[derive(Debug, Clone)]
318pub struct DenseRunVariable {
319 pub name: String,
320 pub source: String,
321 pub values: Vec<f64>,
322 pub sampling_rate: f64,
323 pub run_info: Vec<super::node::RunInfo>,
324 pub entities: StringEntities,
325}
326
327impl DenseRunVariable {
328 pub fn new(
329 name: &str,
330 source: &str,
331 values: Vec<f64>,
332 sampling_rate: f64,
333 run_info: Vec<super::node::RunInfo>,
334 ) -> Self {
335 let mut entities = StringEntities::new();
336 for ri in &run_info {
337 for (k, v) in &ri.entities {
338 entities.entry(k.clone()).or_insert_with(|| v.clone());
339 }
340 }
341 Self {
342 name: name.into(),
343 source: source.into(),
344 values,
345 sampling_rate,
346 run_info,
347 entities,
348 }
349 }
350
351 pub fn len(&self) -> usize {
352 self.values.len()
353 }
354 pub fn is_empty(&self) -> bool {
355 self.values.is_empty()
356 }
357
358 pub fn resample(&self, new_sr: f64) -> Self {
360 if (new_sr - self.sampling_rate).abs() < 0.001 {
361 return self.clone();
362 }
363 let new_n = ((self.values.len() as f64) * new_sr / self.sampling_rate).ceil() as usize;
364 Self {
365 name: self.name.clone(),
366 source: self.source.clone(),
367 values: linear_resample(&self.values, new_n),
368 sampling_rate: new_sr,
369 run_info: self.run_info.clone(),
370 entities: self.entities.clone(),
371 }
372 }
373
374 pub fn resample_to_tr(&self) -> Self {
376 self.run_info
377 .first()
378 .filter(|ri| ri.tr > 0.0)
379 .map(|ri| self.resample(1.0 / ri.tr))
380 .unwrap_or_else(|| self.clone())
381 }
382
383 pub fn to_rows(&self) -> Vec<StringEntities> {
385 let interval = 1.0 / self.sampling_rate;
386 self.values
387 .iter()
388 .enumerate()
389 .map(|(i, val)| {
390 let mut row = self.entities.clone();
391 row.insert("onset".into(), (i as f64 * interval).to_string());
392 row.insert("duration".into(), interval.to_string());
393 row.insert("amplitude".into(), val.to_string());
394 row.insert("condition".into(), self.name.clone());
395 row
396 })
397 .collect()
398 }
399}
400
401impl SparseRunVariable {
402 pub fn select_rows(&self, indices: &[usize]) -> Self {
404 Self::new(
405 &self.name,
406 &self.source,
407 indices
408 .iter()
409 .filter_map(|&i| self.onset.get(i).copied())
410 .collect(),
411 indices
412 .iter()
413 .filter_map(|&i| self.duration.get(i).copied())
414 .collect(),
415 indices
416 .iter()
417 .filter_map(|&i| self.str_amplitude.get(i).cloned())
418 .collect(),
419 indices
420 .iter()
421 .filter_map(|&i| self.index.get(i).cloned())
422 .collect(),
423 self.run_info.clone(),
424 )
425 }
426
427 pub fn split(&self, group_col: &str) -> Vec<Self> {
429 let mut groups: std::collections::HashMap<String, Vec<usize>> =
430 std::collections::HashMap::new();
431 for (i, row) in self.index.iter().enumerate() {
432 let key = row.get(group_col).cloned().unwrap_or_default();
433 groups.entry(key).or_default().push(i);
434 }
435 groups
436 .into_iter()
437 .map(|(key, indices)| {
438 let mut var = self.select_rows(&indices);
439 var.name = format!("{}.{}", self.name, key);
440 var
441 })
442 .collect()
443 }
444}
445
446impl DenseRunVariable {
447 pub fn build_entity_index(&self) -> Vec<(f64, StringEntities)> {
449 let interval = 1.0 / self.sampling_rate;
450 let mut result = Vec::with_capacity(self.values.len());
451 let mut offset = 0.0;
452 let mut run_i = 0;
453 for (i, _) in self.values.iter().enumerate() {
454 let t = i as f64 * interval;
455 while run_i + 1 < self.run_info.len() && t >= offset + self.run_info[run_i].duration {
457 offset += self.run_info[run_i].duration;
458 run_i += 1;
459 }
460 let ents = self
461 .run_info
462 .get(run_i)
463 .map(|ri| ri.entities.clone())
464 .unwrap_or_default();
465 result.push((t, ents));
466 }
467 result
468 }
469}
470
471pub fn get_grouper(index: &[StringEntities], group_by: &[&str]) -> Vec<String> {
473 index
474 .iter()
475 .map(|row| {
476 group_by
477 .iter()
478 .map(|k| row.get(*k).cloned().unwrap_or_default())
479 .collect::<Vec<_>>()
480 .join("@@@")
481 })
482 .collect()
483}
484
485pub fn apply_grouped<F>(values: &[f64], grouper: &[String], func: F) -> Vec<f64>
487where
488 F: Fn(&[f64]) -> Vec<f64>,
489{
490 let mut groups: std::collections::HashMap<&str, Vec<(usize, f64)>> =
491 std::collections::HashMap::new();
492 for (i, (val, key)) in values.iter().zip(grouper).enumerate() {
493 groups.entry(key.as_str()).or_default().push((i, *val));
494 }
495 let mut result = vec![0.0; values.len()];
496 for group in groups.values() {
497 let group_vals: Vec<f64> = group.iter().map(|(_, v)| *v).collect();
498 let transformed = func(&group_vals);
499 for ((idx, _), new_val) in group.iter().zip(transformed) {
500 result[*idx] = new_val;
501 }
502 }
503 result
504}
505
506pub fn merge_simple(variables: &[&SimpleVariable]) -> Option<SimpleVariable> {
510 let first = variables.first()?;
511 let mut all_values = Vec::new();
512 let mut all_index = Vec::new();
513 for v in variables {
514 all_values.extend(v.str_values.iter().cloned());
515 all_index.extend(v.index.iter().cloned());
516 }
517 Some(SimpleVariable::new(
518 &first.name,
519 &first.source,
520 all_values,
521 all_index,
522 ))
523}
524
525pub fn merge_sparse(variables: &[&SparseRunVariable]) -> Option<SparseRunVariable> {
527 let first = variables.first()?;
528 let mut onset = Vec::new();
529 let mut duration = Vec::new();
530 let mut amplitude = Vec::new();
531 let mut index = Vec::new();
532 let mut run_info = Vec::new();
533 for v in variables {
534 onset.extend(&v.onset);
535 duration.extend(&v.duration);
536 amplitude.extend(v.str_amplitude.iter().cloned());
537 index.extend(v.index.iter().cloned());
538 run_info.extend(v.run_info.iter().cloned());
539 }
540 Some(SparseRunVariable::new(
541 &first.name,
542 &first.source,
543 onset,
544 duration,
545 amplitude,
546 index,
547 run_info,
548 ))
549}
550
551fn extract_common_entities(index: &[StringEntities]) -> StringEntities {
554 let mut common = StringEntities::new();
555 if let Some(first) = index.first() {
556 for (k, v) in first {
557 if index.iter().all(|row| row.get(k) == Some(v)) {
558 common.insert(k.clone(), v.clone());
559 }
560 }
561 }
562 common
563}
564
565fn gcd_pair(a: i64, b: i64) -> i64 {
566 let (mut a, mut b) = (a.abs(), b.abs());
567 while b != 0 {
568 let t = b;
569 b = a % b;
570 a = t;
571 }
572 a
573}
574
575fn linear_resample(values: &[f64], new_n: usize) -> Vec<f64> {
576 if new_n == 0 || values.is_empty() {
577 return vec![];
578 }
579 if new_n == values.len() {
580 return values.to_vec();
581 }
582 let old_n = values.len();
583 (0..new_n)
584 .map(|i| {
585 let t = if new_n > 1 {
586 (i as f64) * (old_n as f64 - 1.0) / (new_n as f64 - 1.0)
587 } else {
588 0.0
589 };
590 let lo = t.floor() as usize;
591 let hi = (lo + 1).min(old_n - 1);
592 let frac = t - lo as f64;
593 values[lo] * (1.0 - frac) + values[hi] * frac
594 })
595 .collect()
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::node::RunInfo;
602 use std::collections::HashMap;
603
604 #[test]
605 fn test_sparse_to_dense() {
606 let ri = RunInfo {
607 entities: StringEntities::new(),
608 duration: 10.0,
609 tr: 2.0,
610 image: None,
611 n_vols: 5,
612 };
613 let sparse = SparseRunVariable::new(
614 "trial_type",
615 "events",
616 vec![1.0, 3.0],
617 vec![1.0, 2.0],
618 vec!["1".into(), "1".into()],
619 vec![StringEntities::new(), StringEntities::new()],
620 vec![ri],
621 );
622 let dense = sparse.to_dense(Some(10.0));
623 assert_eq!(dense.sampling_rate, 10.0);
624 assert_eq!(dense.values.len(), 100);
625 assert_eq!(dense.values[10], 1.0);
626 assert_eq!(dense.values[0], 0.0);
627 }
628
629 #[test]
630 fn test_simple_filter() {
631 let idx = vec![
632 HashMap::from([("subject".into(), "01".into())]),
633 HashMap::from([("subject".into(), "02".into())]),
634 ];
635 let var = SimpleVariable::new("age", "participants", vec!["25".into(), "30".into()], idx);
636 let filtered = var.filter(&HashMap::from([("subject".into(), "01".into())]));
637 assert_eq!(filtered.len(), 1);
638 assert_eq!(filtered.str_values[0], "25");
639 }
640
641 #[test]
642 fn test_merge_simple() {
643 let v1 = SimpleVariable::new(
644 "age",
645 "participants",
646 vec!["25".into()],
647 vec![HashMap::from([("subject".into(), "01".into())])],
648 );
649 let v2 = SimpleVariable::new(
650 "age",
651 "participants",
652 vec!["30".into()],
653 vec![HashMap::from([("subject".into(), "02".into())])],
654 );
655 let merged = merge_simple(&[&v1, &v2]).unwrap();
656 assert_eq!(merged.len(), 2);
657 }
658}