1use crate::types::{TextGrid, TextGridError, Tier, TierType, Interval, Point};
44use std::fs::File;
45use std::io::{Read, Write, BufReader, BufWriter};
46use std::path::Path;
47
48pub fn read_binary<P: AsRef<Path>>(path: P) -> Result<TextGrid, TextGridError> {
66 let file = File::open(path)?;
67 let mut reader = BufReader::new(file);
68 let mut buffer = Vec::new();
69 reader.read_to_end(&mut buffer)?;
70
71 let mut cursor = 0;
72 if &buffer[cursor..cursor + 12] != b"ooBinaryFile" {
73 return Err(TextGridError::Format("Not a Praat binary TextGrid".into()));
74 }
75 cursor += 12;
76
77 let obj_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
78 cursor += 2;
79 if &buffer[cursor..cursor + obj_len] != b"TextGrid" {
80 return Err(TextGridError::Format("Invalid object class".into()));
81 }
82 cursor += obj_len;
83
84 let xmin = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
85 cursor += 8;
86 let xmax = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
87 cursor += 8;
88 let size = u32::from_le_bytes(buffer[cursor..cursor + 4].try_into().unwrap()) as usize;
89 cursor += 4;
90
91 let mut tiers = Vec::with_capacity(size);
92 for _ in 0..size {
93 let class_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
94 cursor += 2;
95 let class = String::from_utf8(buffer[cursor..cursor + class_len].to_vec())?;
96 let tier_type = if class == "IntervalTier" {
97 TierType::IntervalTier
98 } else if class == "TextTier" {
99 TierType::PointTier
100 } else {
101 return Err(TextGridError::Format("Unknown tier type".into()));
102 };
103 cursor += class_len;
104
105 let name_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
106 cursor += 2;
107 let name = String::from_utf8(buffer[cursor..cursor + name_len].to_vec())?;
108 cursor += name_len;
109
110 let tier_xmin = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
111 cursor += 8;
112 let tier_xmax = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
113 cursor += 8;
114
115 let count = u32::from_le_bytes(buffer[cursor..cursor + 4].try_into().unwrap()) as usize;
116 cursor += 4;
117
118 let mut intervals = Vec::new();
119 let mut points = Vec::new();
120 match tier_type {
121 TierType::IntervalTier => {
122 for _ in 0..count {
123 let xmin = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
124 cursor += 8;
125 let xmax = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
126 cursor += 8;
127 let text_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
128 cursor += 2;
129 let text = String::from_utf8(buffer[cursor..cursor + text_len].to_vec())?;
130 cursor += text_len;
131 intervals.push(Interval { xmin, xmax, text });
132 }
133 }
134 TierType::PointTier => {
135 for _ in 0..count {
136 let time = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
137 cursor += 8;
138 let mark_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
139 cursor += 2;
140 let mark = String::from_utf8(buffer[cursor..cursor + mark_len].to_vec())?;
141 cursor += mark_len;
142 points.push(Point { time, mark });
143 }
144 }
145 }
146
147 tiers.push(Tier { name, tier_type, xmin: tier_xmin, xmax: tier_xmax, intervals, points });
148 }
149
150 Ok(TextGrid::new(xmin, xmax)?.with_tiers(tiers))
151}
152
153pub fn write_binary<P: AsRef<Path>>(textgrid: &TextGrid, path: P) -> Result<(), TextGridError> {
171 let file = File::create(path)?;
172 let mut writer = BufWriter::new(file);
173
174 writer.write_all(b"ooBinaryFile")?;
175 let class = b"TextGrid";
176 writer.write_all(&(class.len() as u16).to_le_bytes())?;
177 writer.write_all(class)?;
178 writer.write_all(&textgrid.xmin.to_le_bytes())?;
179 writer.write_all(&textgrid.xmax.to_le_bytes())?;
180 writer.write_all(&(textgrid.tiers.len() as u32).to_le_bytes())?;
181
182 for tier in &textgrid.tiers {
183 let class = match tier.tier_type {
184 TierType::IntervalTier => b"IntervalTier" as &[u8],
185 TierType::PointTier => b"TextTier" as &[u8],
186 };
187 writer.write_all(&(class.len() as u16).to_le_bytes())?;
188 writer.write_all(class)?;
189
190 let name_bytes = tier.name.as_bytes();
191 writer.write_all(&(name_bytes.len() as u16).to_le_bytes())?;
192 writer.write_all(name_bytes)?;
193
194 writer.write_all(&tier.xmin.to_le_bytes())?;
195 writer.write_all(&tier.xmax.to_le_bytes())?;
196
197 match tier.tier_type {
198 TierType::IntervalTier => {
199 writer.write_all(&(tier.intervals.len() as u32).to_le_bytes())?;
200 for interval in &tier.intervals {
201 writer.write_all(&interval.xmin.to_le_bytes())?;
202 writer.write_all(&interval.xmax.to_le_bytes())?;
203 let text_bytes = interval.text.as_bytes();
204 writer.write_all(&(text_bytes.len() as u16).to_le_bytes())?;
205 writer.write_all(text_bytes)?;
206 }
207 }
208 TierType::PointTier => {
209 writer.write_all(&(tier.points.len() as u32).to_le_bytes())?;
210 for point in &tier.points {
211 writer.write_all(&point.time.to_le_bytes())?;
212 let mark_bytes = point.mark.as_bytes();
213 writer.write_all(&(mark_bytes.len() as u16).to_le_bytes())?;
214 writer.write_all(mark_bytes)?;
215 }
216 }
217 }
218 }
219 writer.flush()?;
220 Ok(())
221}