Skip to main content

edf_rs/headers/
annotation_list.rs

1use std::str::FromStr;
2
3use crate::error::edf_error::EDFError;
4
5// In case of multiple annotation signals, only the first one is required to have TALTKs and it is the only one used as a ref. Others could have them too, but they would simply be ignored / counted as empty free text
6// If annotation starts in e.g. DR 12 and has a duration until DR 16, it will only show as an annotation in DR 12 and not in any of DR 13, DR 14, etc.
7
8#[derive(Debug, Default, Clone, PartialEq)]
9pub struct AnnotationList {
10    pub onset: f64, // relative to file_start time
11    pub duration: f64,
12    pub annotations: Vec<String>,
13}
14
15impl AnnotationList {
16    pub fn new(onset: f64, duration: f64, annotations: Vec<String>) -> Result<Self, EDFError> {
17        if !annotations.iter().all(is_valid_string) {
18            return Err(EDFError::IllegalCharacters);
19        }
20
21        Ok(Self {
22            onset,
23            duration,
24            annotations,
25        })
26    }
27
28    pub fn new_time_keeping(onset: f64) -> Self {
29        Self {
30            onset,
31            duration: 0.0,
32            annotations: vec![String::new()],
33        }
34    }
35
36    pub fn new_time_keeping_reasoned(onset: f64, reason: String) -> Self {
37        Self {
38            onset,
39            duration: 0.0,
40            annotations: vec![String::new(), reason],
41        }
42    }
43
44    pub fn add_annotation(&mut self, annotation: String) -> Result<(), EDFError> {
45        self.insert_annotation(self.annotations.len(), annotation)
46    }
47
48    pub fn insert_annotation(&mut self, index: usize, annotation: String) -> Result<(), EDFError> {
49        if !is_valid_string(&annotation) {
50            return Err(EDFError::IllegalCharacters);
51        }
52
53        self.annotations.insert(index, annotation);
54
55        Ok(())
56    }
57
58    pub fn remove_annotation(&mut self, index: usize) {
59        self.annotations.remove(index);
60    }
61
62    pub fn get_annotations(&self) -> &Vec<String> {
63        &self.annotations
64    }
65
66    pub fn deserialize(data: &[u8]) -> Result<Self, EDFError> {
67        // Trim padding NUL bytes from end and remove the last byte HEX 14 which follows the last annotation value.
68        // Therefore splitting at byte HEX 14 later returns the correct amount of annotations.
69        if !data.ends_with(&[b'\x14', b'\x00']) {
70            return Err(EDFError::InvalidHeaderTAL);
71        }
72        let data = &data[..data.len() - 2];
73
74        // Split the TAL header (ASCII) and annotations (UTF-8)
75        let header: String = data
76            .into_iter()
77            .take_while(|c| **c != b'\x14')
78            .map(|c| *c as char)
79            .collect();
80
81        // Get header values (separated by byte HEX 15)
82        let header_parts = header.split('\x15').collect::<Vec<_>>();
83        if header_parts.is_empty() {
84            return Err(EDFError::InvalidHeaderTAL);
85        }
86
87        // Parse onset and duration from header
88        let onset = f64::from_str(header_parts[0]).map_err(|_| EDFError::InvalidHeaderTAL)?;
89        let duration = header_parts
90            .iter()
91            .nth(1)
92            .map(|d| f64::from_str(*d))
93            .transpose()
94            .map_err(|_| EDFError::InvalidHeaderTAL)?
95            .unwrap_or(0.0);
96
97        // Parse annotations (skip header bytes)
98        let data = &data[header.len() + 1..];
99        let annotations = data
100            .split(|c| *c == b'\x14')
101            .map(|a| String::from_utf8_lossy(a).to_string())
102            .collect::<Vec<_>>();
103
104        Ok(AnnotationList {
105            onset,
106            duration,
107            annotations,
108        })
109    }
110
111    pub fn serialize(&self) -> String {
112        if self.annotations.is_empty() {
113            return String::new();
114        }
115
116        let onset_sign = if self.onset >= 0.0 { "+" } else { "-" };
117        let onset = format!("{}{}", onset_sign, self.onset);
118        let header = if self.duration <= 0.0 {
119            format!("{}\x14", onset)
120        } else {
121            format!("{}\x15{}\x14", onset, self.duration)
122        };
123
124        let annotations = self.annotations.join("\x14");
125
126        format!("{}{}\x14\x00", header, annotations)
127    }
128
129    pub fn is_time_keeping(&self) -> bool {
130        self.annotations
131            .first()
132            .map(String::is_empty)
133            .unwrap_or(false)
134    }
135
136    pub fn time_keeping_reason(&self) -> Option<String> {
137        if !self.is_time_keeping() {
138            return None;
139        }
140
141        self.annotations.iter().nth(1).cloned()
142    }
143}
144
145fn is_valid_string(s: &String) -> bool {
146    s.chars()
147        .all(|c| !matches!(c, '\0'..='\x1f') || c == '\t' || c == '\n' || c == '\r')
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn deserialize() {
156        let tal = AnnotationList::deserialize(b"+30\x1520\x14\x14\x00").unwrap();
157        assert_eq!(tal.onset, 30.0);
158        assert_eq!(tal.duration, 20.0);
159        assert!(tal.is_time_keeping());
160        assert_eq!(tal.annotations.len(), 1);
161
162        let tal = AnnotationList::deserialize(b"+30\x14\x14\x00").unwrap();
163        assert_eq!(tal.onset, 30.0);
164        assert_eq!(tal.duration, 0.0);
165        assert!(tal.is_time_keeping());
166        assert_eq!(tal.annotations.len(), 1);
167
168        let tal = AnnotationList::deserialize(b"+30\x14\x14");
169        assert!(tal.is_err());
170
171        let tal =
172            AnnotationList::deserialize(b"-0.489\x158.123\x14\x14Some reason\x14\x00").unwrap();
173        assert_eq!(tal.onset, -0.489);
174        assert_eq!(tal.duration, 8.123);
175        assert!(tal.is_time_keeping());
176        assert_eq!(tal.annotations.len(), 2);
177        assert_eq!(tal.annotations[1], "Some reason".to_string());
178
179        let tal = AnnotationList::deserialize(b"+0\x14Free text\x14\x00").unwrap();
180        assert_eq!(tal.onset, 0.0);
181        assert_eq!(tal.duration, 0.0);
182        assert!(!tal.is_time_keeping());
183        assert_eq!(tal.annotations.len(), 1);
184        assert_eq!(tal.annotations[0], "Free text".to_string());
185
186        let tal = AnnotationList::deserialize(b"+30\x1520\x14\x14\x00").unwrap();
187        assert_eq!(tal.annotations.len(), 1);
188    }
189}