1use crate::variables::{DenseRunVariable, SimpleVariable, SparseRunVariable, merge_simple};
8use bids_core::entities::StringEntities;
9use std::collections::HashMap;
10
11fn source_to_level(source: &str) -> &str {
12 match source {
13 "events" | "physio" | "stim" | "regressors" => "run",
14 "scans" => "session",
15 "sessions" => "subject",
16 "participants" => "dataset",
17 other => other,
18 }
19}
20
21#[derive(Debug, Default, Clone)]
32pub struct VariableCollection {
33 pub variables: HashMap<String, SimpleVariable>,
34 pub level: String,
35 pub entities: StringEntities,
36}
37
38impl VariableCollection {
39 pub fn new(vars: Vec<SimpleVariable>) -> Self {
40 let level = vars
41 .first()
42 .map(|v| source_to_level(&v.source).to_string())
43 .unwrap_or_default();
44
45 let mut by_name: HashMap<String, Vec<&SimpleVariable>> = HashMap::new();
47 for v in &vars {
48 by_name.entry(v.name.clone()).or_default().push(v);
49 }
50 let mut variables = HashMap::new();
51 for (name, var_list) in &by_name {
52 if let Some(m) = merge_simple(var_list) {
53 variables.insert(name.clone(), m);
54 }
55 }
56
57 let entities = index_common_entities(&variables);
58 Self {
59 variables,
60 level,
61 entities,
62 }
63 }
64
65 pub fn get(&self, name: &str) -> Option<&SimpleVariable> {
66 self.variables.get(name)
67 }
68
69 pub fn names(&self) -> Vec<&str> {
70 self.variables
71 .keys()
72 .map(std::string::String::as_str)
73 .collect()
74 }
75
76 pub fn match_variables(&self, pattern: &str, use_regex: bool) -> Vec<&str> {
78 if use_regex {
79 let re =
80 regex::Regex::new(pattern).unwrap_or_else(|_| regex::Regex::new("$^").unwrap());
81 self.variables
82 .keys()
83 .filter(|k| re.is_match(k))
84 .map(std::string::String::as_str)
85 .collect()
86 } else {
87 self.variables
88 .keys()
89 .filter(|k| glob_match(pattern, k))
90 .map(std::string::String::as_str)
91 .collect()
92 }
93 }
94
95 pub fn to_rows(&self) -> Vec<StringEntities> {
97 self.variables
98 .values()
99 .flat_map(super::variables::SimpleVariable::to_rows)
100 .collect()
101 }
102
103 pub fn to_wide(&self) -> (Vec<String>, Vec<Vec<String>>) {
106 let mut entity_cols: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
108 for var in self.variables.values() {
109 for row in &var.index {
110 entity_cols.extend(row.keys().cloned());
111 }
112 }
113 let var_names: Vec<String> = self.variables.keys().cloned().collect();
114 let mut col_names: Vec<String> = entity_cols.into_iter().collect();
115 col_names.extend(var_names.iter().cloned());
116
117 let n_rows = self
119 .variables
120 .values()
121 .next()
122 .map_or(0, super::variables::SimpleVariable::len);
123 let mut rows = Vec::with_capacity(n_rows);
124 for i in 0..n_rows {
125 let mut row = Vec::with_capacity(col_names.len());
126 let first_var = self.variables.values().next();
128 let ent_row = first_var.and_then(|v| v.index.get(i));
129 for col in &col_names {
130 if self.variables.contains_key(col) {
131 let val = self
133 .variables
134 .get(col)
135 .and_then(|v| v.str_values.get(i))
136 .cloned()
137 .unwrap_or_default();
138 row.push(val);
139 } else {
140 let val = ent_row
142 .and_then(|r| r.get(col))
143 .cloned()
144 .unwrap_or_default();
145 row.push(val);
146 }
147 }
148 rows.push(row);
149 }
150 (col_names, rows)
151 }
152
153 pub fn from_rows(rows: &[StringEntities], source: &str) -> Self {
155 let mut by_name: HashMap<String, (Vec<String>, Vec<StringEntities>)> = HashMap::new();
156 for row in rows {
157 let name = row
158 .get("condition")
159 .cloned()
160 .unwrap_or_else(|| "unknown".into());
161 let amp = row.get("amplitude").cloned().unwrap_or_default();
162 let ents: StringEntities = row
163 .iter()
164 .filter(|(k, _)| k.as_str() != "condition" && k.as_str() != "amplitude")
165 .map(|(k, v)| (k.clone(), v.clone()))
166 .collect();
167 let entry = by_name.entry(name).or_default();
168 entry.0.push(amp);
169 entry.1.push(ents);
170 }
171 let vars: Vec<SimpleVariable> = by_name
172 .into_iter()
173 .map(|(name, (values, index))| SimpleVariable::new(&name, source, values, index))
174 .collect();
175 Self::new(vars)
176 }
177}
178
179#[derive(Debug, Default, Clone)]
181pub struct RunVariableCollection {
182 pub sparse: Vec<SparseRunVariable>,
183 pub dense: Vec<DenseRunVariable>,
184 pub sampling_rate: f64,
185}
186
187impl RunVariableCollection {
188 pub fn new(
189 sparse: Vec<SparseRunVariable>,
190 dense: Vec<DenseRunVariable>,
191 sampling_rate: Option<f64>,
192 ) -> Self {
193 Self {
194 sparse,
195 dense,
196 sampling_rate: sampling_rate.unwrap_or(10.0),
197 }
198 }
199
200 pub fn get_sparse(&self, name: &str) -> Option<&SparseRunVariable> {
201 self.sparse.iter().find(|v| v.name == name)
202 }
203 pub fn get_dense(&self, name: &str) -> Option<&DenseRunVariable> {
204 self.dense.iter().find(|v| v.name == name)
205 }
206
207 pub fn sparse_names(&self) -> Vec<&str> {
208 self.sparse.iter().map(|v| v.name.as_str()).collect()
209 }
210 pub fn dense_names(&self) -> Vec<&str> {
211 self.dense.iter().map(|v| v.name.as_str()).collect()
212 }
213 pub fn all_sparse(&self) -> bool {
214 self.dense.is_empty()
215 }
216 pub fn all_dense(&self) -> bool {
217 self.sparse.is_empty()
218 }
219
220 pub fn names(&self) -> Vec<&str> {
221 let mut n: Vec<&str> = self
222 .sparse_names()
223 .into_iter()
224 .chain(self.dense_names())
225 .collect();
226 n.sort();
227 n.dedup();
228 n
229 }
230
231 pub fn to_dense(&mut self, sampling_rate: Option<f64>) {
233 let sr = sampling_rate.unwrap_or(self.sampling_rate);
234 let sparse = std::mem::take(&mut self.sparse);
235 for var in sparse {
236 if var.amplitude.iter().all(|v| v.is_finite()) {
237 self.dense.push(var.to_dense(Some(sr)));
238 }
239 }
240 self.sampling_rate = sr;
241 }
242
243 pub fn resample(&mut self, sampling_rate: f64) {
245 self.dense = self
246 .dense
247 .iter()
248 .map(|v| v.resample(sampling_rate))
249 .collect();
250 self.sampling_rate = sampling_rate;
251 }
252
253 pub fn densify_and_resample(
255 &mut self,
256 sampling_rate: Option<f64>,
257 force_dense: bool,
258 resample_dense: bool,
259 ) {
260 let sr = sampling_rate.unwrap_or(self.sampling_rate);
261 if force_dense {
262 self.to_dense(Some(sr));
263 }
264 if resample_dense {
265 self.resample(sr);
266 }
267 self.sampling_rate = sr;
268 }
269
270 pub fn to_rows_with_options(
272 &self,
273 include_sparse: bool,
274 include_dense: bool,
275 sampling_rate: Option<f64>,
276 ) -> Vec<StringEntities> {
277 let mut all_rows = Vec::new();
278 if include_sparse {
279 for var in &self.sparse {
280 all_rows.extend(var.to_rows());
281 }
282 }
283 if include_dense {
284 let dense_vars: Vec<_> = if let Some(sr) = sampling_rate {
285 self.dense.iter().map(|v| v.resample(sr)).collect()
286 } else {
287 self.dense.clone()
288 };
289 for var in &dense_vars {
290 all_rows.extend(var.to_rows());
291 }
292 }
293 all_rows
294 }
295
296 pub fn resolve_sampling_rate(&self, requested: Option<&str>) -> f64 {
298 match requested {
299 Some("TR") => {
300 let trs: std::collections::HashSet<i64> = self
301 .dense
302 .iter()
303 .flat_map(|v| v.run_info.iter())
304 .map(|r| (r.tr * 1_000_000.0).round() as i64)
305 .collect();
306 if trs.len() == 1 {
307 1.0 / (trs.into_iter().next().unwrap() as f64 / 1_000_000.0)
308 } else {
309 self.sampling_rate
310 }
311 }
312 Some("highest") => self
313 .dense
314 .iter()
315 .map(|v| v.sampling_rate)
316 .fold(self.sampling_rate, f64::max),
317 Some(s) => s.parse().unwrap_or(self.sampling_rate),
318 None => self.sampling_rate,
319 }
320 }
321}
322
323pub fn merge_collections(collections: &[VariableCollection]) -> Option<VariableCollection> {
325 if collections.is_empty() {
326 return None;
327 }
328 let all_vars: Vec<SimpleVariable> = collections
329 .iter()
330 .flat_map(|c| c.variables.values().cloned())
331 .collect();
332 Some(VariableCollection::new(all_vars))
333}
334
335fn index_common_entities(variables: &HashMap<String, SimpleVariable>) -> StringEntities {
338 let all_ents: Vec<&StringEntities> = variables.values().map(|v| &v.entities).collect();
339 let mut common = StringEntities::new();
340 if let Some(first) = all_ents.first() {
341 for (k, v) in *first {
342 if all_ents.iter().all(|e| e.get(k) == Some(v)) {
343 common.insert(k.clone(), v.clone());
344 }
345 }
346 }
347 common
348}
349
350fn glob_match(pattern: &str, text: &str) -> bool {
351 let re_str = format!(
352 "^{}$",
353 pattern
354 .replace('.', r"\.")
355 .replace('*', ".*")
356 .replace('?', ".")
357 );
358 regex::Regex::new(&re_str).is_ok_and(|re| re.is_match(text))
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use std::collections::HashMap;
365
366 #[test]
367 fn test_variable_collection() {
368 let idx = vec![
369 HashMap::from([("subject".into(), "01".into())]),
370 HashMap::from([("subject".into(), "02".into())]),
371 ];
372 let v = SimpleVariable::new("age", "participants", vec!["25".into(), "30".into()], idx);
373 let col = VariableCollection::new(vec![v]);
374 assert_eq!(col.level, "dataset");
375 assert!(col.get("age").is_some());
376 }
377
378 #[test]
379 fn test_match_variables() {
380 let v1 = SimpleVariable::new("age", "participants", vec![], vec![]);
381 let v2 = SimpleVariable::new("sex", "participants", vec![], vec![]);
382 let v3 = SimpleVariable::new("age_group", "participants", vec![], vec![]);
383 let col = VariableCollection::new(vec![v1, v2, v3]);
384 let matches = col.match_variables("age*", false);
385 assert_eq!(matches.len(), 2);
386 }
387
388 #[test]
389 fn test_merge_collections() {
390 let v1 = SimpleVariable::new(
391 "age",
392 "participants",
393 vec!["25".into()],
394 vec![HashMap::from([("subject".into(), "01".into())])],
395 );
396 let v2 = SimpleVariable::new(
397 "age",
398 "participants",
399 vec!["30".into()],
400 vec![HashMap::from([("subject".into(), "02".into())])],
401 );
402 let c1 = VariableCollection::new(vec![v1]);
403 let c2 = VariableCollection::new(vec![v2]);
404 let merged = merge_collections(&[c1, c2]).unwrap();
405 assert_eq!(merged.get("age").unwrap().len(), 2);
406 }
407}