1use indexmap::IndexMap;
14use regex::Regex;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::OnceLock;
18
19use crate::padded_int::PaddedInt;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct Entity {
37 pub name: String,
38 pub pattern: String,
39 #[serde(default)]
40 pub mandatory: bool,
41 #[serde(default)]
42 pub directory: Option<String>,
43 #[serde(default = "default_dtype")]
44 pub dtype: String,
45
46 #[serde(skip)]
48 compiled_regex: OnceLock<Option<Regex>>,
49}
50
51fn default_dtype() -> String {
52 "str".to_string()
53}
54
55impl std::fmt::Display for Entity {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(f, "Entity('{}', dtype={})", self.name, self.dtype)?;
58 if self.mandatory {
59 write!(f, " [mandatory]")?;
60 }
61 Ok(())
62 }
63}
64
65impl Entity {
66 pub fn new(name: &str, pattern: &str) -> Self {
72 let lock = OnceLock::new();
73 let compiled = Regex::new(pattern).ok();
74 #[cfg(debug_assertions)]
75 if compiled.is_none() {
76 log::warn!("invalid regex pattern for entity '{name}': {pattern}");
77 }
78 let _ = lock.set(compiled);
79 Self {
80 name: name.to_string(),
81 pattern: pattern.to_string(),
82 mandatory: false,
83 directory: None,
84 dtype: "str".to_string(),
85 compiled_regex: lock,
86 }
87 }
88
89 pub fn try_new(name: &str, pattern: &str) -> Result<Self, regex::Error> {
95 let compiled = Regex::new(pattern)?;
96 let lock = OnceLock::new();
97 let _ = lock.set(Some(compiled));
98 Ok(Self {
99 name: name.to_string(),
100 pattern: pattern.to_string(),
101 mandatory: false,
102 directory: None,
103 dtype: "str".to_string(),
104 compiled_regex: lock,
105 })
106 }
107
108 #[must_use]
110 pub fn with_dtype(mut self, dtype: &str) -> Self {
111 self.dtype = dtype.to_string();
112 self
113 }
114
115 #[must_use]
117 pub fn with_directory(mut self, directory: &str) -> Self {
118 self.directory = Some(directory.to_string());
119 self
120 }
121
122 #[must_use]
124 pub fn with_mandatory(mut self, mandatory: bool) -> Self {
125 self.mandatory = mandatory;
126 self
127 }
128
129 pub fn regex(&self) -> Option<&Regex> {
133 self.compiled_regex
134 .get_or_init(|| Regex::new(&self.pattern).ok())
135 .as_ref()
136 }
137
138 pub fn match_path(&self, path: &str) -> Option<EntityValue> {
143 let regex = self.regex()?;
144 let caps = regex.captures(path)?;
145 let val_str = caps.get(1)?.as_str();
146 Some(self.coerce_value(val_str))
147 }
148
149 pub fn coerce_value(&self, val: &str) -> EntityValue {
151 match self.dtype.as_str() {
152 "int" => EntityValue::Int(PaddedInt::new(val)),
153 "float" => EntityValue::Float(val.parse().unwrap_or(0.0)),
154 "bool" => EntityValue::Bool(val.parse().unwrap_or(false)),
155 _ => EntityValue::Str(val.to_string()),
156 }
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
174#[serde(untagged)]
175pub enum EntityValue {
176 Str(String),
177 Int(PaddedInt),
178 Float(f64),
179 Bool(bool),
180 Json(serde_json::Value),
181}
182
183impl From<&str> for EntityValue {
184 fn from(s: &str) -> Self {
185 EntityValue::Str(s.to_string())
186 }
187}
188
189impl From<String> for EntityValue {
190 fn from(s: String) -> Self {
191 EntityValue::Str(s)
192 }
193}
194
195impl From<i32> for EntityValue {
196 fn from(v: i32) -> Self {
197 EntityValue::Int(PaddedInt::from(v))
198 }
199}
200
201impl From<i64> for EntityValue {
202 fn from(v: i64) -> Self {
203 EntityValue::Int(PaddedInt::from(v))
204 }
205}
206
207impl From<f64> for EntityValue {
208 fn from(v: f64) -> Self {
209 EntityValue::Float(v)
210 }
211}
212
213impl From<bool> for EntityValue {
214 fn from(v: bool) -> Self {
215 EntityValue::Bool(v)
216 }
217}
218
219impl EntityValue {
220 #[must_use]
225 pub fn as_str_lossy(&self) -> std::borrow::Cow<'_, str> {
226 match self {
227 EntityValue::Str(s) => std::borrow::Cow::Borrowed(s),
228 EntityValue::Int(i) => std::borrow::Cow::Owned(i.to_string()),
229 EntityValue::Float(f) => std::borrow::Cow::Owned(f.to_string()),
230 EntityValue::Bool(b) => std::borrow::Cow::Owned(b.to_string()),
231 EntityValue::Json(v) => std::borrow::Cow::Owned(v.to_string()),
232 }
233 }
234
235 #[must_use]
239 pub fn as_i64(&self) -> Option<i64> {
240 match self {
241 EntityValue::Int(p) => Some(p.value()),
242 EntityValue::Float(f) => Some(*f as i64),
243 EntityValue::Str(s) => s.parse().ok(),
244 _ => None,
245 }
246 }
247
248 #[must_use]
250 pub fn as_f64(&self) -> Option<f64> {
251 match self {
252 EntityValue::Float(f) => Some(*f),
253 EntityValue::Int(p) => Some(p.value() as f64),
254 EntityValue::Str(s) => s.parse().ok(),
255 _ => None,
256 }
257 }
258
259 #[must_use]
261 pub fn as_bool(&self) -> Option<bool> {
262 match self {
263 EntityValue::Bool(b) => Some(*b),
264 EntityValue::Str(s) => s.parse().ok(),
265 _ => None,
266 }
267 }
268
269 #[must_use]
271 pub fn is_str(&self) -> bool {
272 matches!(self, EntityValue::Str(_))
273 }
274
275 #[must_use]
277 pub fn is_int(&self) -> bool {
278 matches!(self, EntityValue::Int(_))
279 }
280}
281
282impl std::fmt::Display for EntityValue {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 write!(f, "{}", self.as_str_lossy())
285 }
286}
287
288impl PartialEq for EntityValue {
289 fn eq(&self, other: &Self) -> bool {
290 *self.as_str_lossy() == *other.as_str_lossy()
295 }
296}
297
298impl Eq for EntityValue {}
299
300impl std::hash::Hash for EntityValue {
301 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
302 self.as_str_lossy().hash(state);
303 }
304}
305
306pub type Entities = IndexMap<String, EntityValue>;
312
313pub type StringEntities = HashMap<String, String>;
315
316pub const ENTITY_ORDER: &[&str] = &[
318 "subject",
319 "session",
320 "sample",
321 "task",
322 "tracksys",
323 "acquisition",
324 "ceagent",
325 "staining",
326 "tracer",
327 "reconstruction",
328 "direction",
329 "run",
330 "modality",
331 "echo",
332 "flip",
333 "inversion",
334 "mtransfer",
335 "part",
336 "processing",
337 "hemisphere",
338 "space",
339 "split",
340 "recording",
341 "chunk",
342 "atlas",
343 "resolution",
344 "density",
345 "label",
346 "description",
347 "suffix",
348 "extension",
349 "datatype",
350];
351
352#[must_use]
357pub fn parse_file_entities(path: &str, entities: &[Entity]) -> Entities {
358 let mut result = Entities::new();
359 for entity in entities.iter() {
360 if let Some(val) = entity.match_path(path) {
361 result.insert(entity.name.clone(), val);
362 }
363 }
364 result
365}
366
367#[must_use]
369pub fn sort_entities(entities: &Entities) -> Vec<(String, EntityValue)> {
370 let mut pairs: Vec<_> = entities
371 .iter()
372 .map(|(k, v)| (k.clone(), v.clone()))
373 .collect();
374
375 pairs.sort_by_key(|(k, _)| {
376 ENTITY_ORDER
377 .iter()
378 .position(|&e| e == k.as_str())
379 .unwrap_or(ENTITY_ORDER.len())
380 });
381
382 pairs
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_entity_matching() {
391 let ent = Entity::new("subject", r"[/\\]+sub-([a-zA-Z0-9]+)");
392 let val = ent.match_path("/sub-01/anat/sub-01_T1w.nii.gz");
393 assert!(val.is_some());
394 assert_eq!(val.unwrap().as_str_lossy(), "01");
395 }
396
397 #[test]
398 fn test_int_entity() {
399 let ent = Entity::new("run", r"[_/\\]+run-(\d+)").with_dtype("int");
400 let val = ent.match_path("sub-01_task-rest_run-02_bold.nii.gz");
401 assert!(val.is_some());
402 match val.unwrap() {
403 EntityValue::Int(p) => {
404 assert_eq!(p.value(), 2);
405 assert_eq!(p.to_string(), "02");
406 }
407 _ => panic!("Expected Int"),
408 }
409 }
410
411 #[test]
412 fn test_parse_file_entities() {
413 let entities = vec![
414 Entity::new("subject", r"[/\\]+sub-([a-zA-Z0-9]+)"),
415 Entity::new("session", r"[_/\\]+ses-([a-zA-Z0-9]+)"),
416 Entity::new("task", r"[_/\\]+task-([a-zA-Z0-9]+)"),
417 Entity::new("suffix", r"[_/\\]([a-zA-Z0-9]+)\.[^/\\]+$"),
418 ];
419 let result = parse_file_entities(
420 "/sub-01/ses-02/eeg/sub-01_ses-02_task-rest_eeg.edf",
421 &entities,
422 );
423 assert_eq!(result.get("subject").unwrap().as_str_lossy(), "01");
424 assert_eq!(result.get("session").unwrap().as_str_lossy(), "02");
425 assert_eq!(result.get("task").unwrap().as_str_lossy(), "rest");
426 }
427}