textgrid/
binary.rs

1//! Binary format support for Praat `.TextGrid` files, matching Praat's specification.
2//!
3//! This module provides functionality to read and write Praat TextGrid files in their binary format.
4//! It follows Praat's binary specification for efficient storage and retrieval of TextGrid data,
5//! supporting both IntervalTiers and PointTiers (TextTiers in Praat terminology).
6//!
7//! ## Binary Format Overview
8//! - Uses little-endian byte order.
9//! - Starts with a `"ooBinaryFile"` header, followed by the object class `"TextGrid"`.
10//! - Stores time values as 64-bit floats (`f64`), lengths as 16-bit or 32-bit integers, and text as UTF-8 strings with length prefixes.
11//!
12//! ## Usage
13//! ```rust
14//! use textgrid::{TextGrid, Tier, TierType, Interval, write_binary, read_binary};
15//!
16//! fn main() -> Result<(), textgrid::TextGridError> {
17//!     // Create a simple TextGrid
18//!     let mut tg = TextGrid::new(0.0, 10.0)?;
19//!     let tier = Tier {
20//!         name: "words".to_string(),
21//!         tier_type: TierType::IntervalTier,
22//!         xmin: 0.0,
23//!         xmax: 10.0,
24//!         intervals: vec![Interval {
25//!             xmin: 1.0,
26//!             xmax: 2.0,
27//!             text: "hello".to_string(),
28//!         }],
29//!         points: vec![],
30//!     };
31//!     tg.add_tier(tier)?;
32//!
33//!     // Write to a binary file
34//!     write_binary(&tg, "output.TextGrid")?;
35//!
36//!     // Read it back
37//!     let read_tg = read_binary("output.TextGrid")?;
38//!     assert_eq!(read_tg.tiers[0].intervals[0].text, "hello");
39//!     Ok(())
40//! }
41//! ```
42
43use crate::types::{TextGrid, TextGridError, Tier, TierType, Interval, Point};
44use std::fs::File;
45use std::io::{Read, Write, BufReader, BufWriter};
46use std::path::Path;
47
48/// Reads a Praat `.TextGrid` file from the binary format.
49///
50/// # Arguments
51/// * `path` - Path to the binary `.TextGrid` file, implementing `AsRef<Path>`.
52///
53/// # Returns
54/// Returns a `Result` containing the parsed `TextGrid` or a `TextGridError`.
55///
56/// # Errors
57/// - `TextGridError::IO` if the file cannot be opened or read.
58/// - `TextGridError::Format` if the file does not match the Praat binary format (e.g., wrong header, invalid class, or malformed data).
59///
60/// # Examples
61/// ```rust
62/// let tg = textgrid::read_binary("test.TextGrid").unwrap();
63/// assert_eq!(tg.tiers.len(), 1); // Assuming test.TextGrid has one tier
64/// ```
65pub fn read_binary<P: AsRef<Path>>(path: P) -> Result<TextGrid, TextGridError> {
66    let file = File::open(path)?;
67    let mut reader = BufReader::new(file);
68    let mut buffer = Vec::new();
69    reader.read_to_end(&mut buffer)?;
70
71    let mut cursor = 0;
72    if &buffer[cursor..cursor + 12] != b"ooBinaryFile" {
73        return Err(TextGridError::Format("Not a Praat binary TextGrid".into()));
74    }
75    cursor += 12;
76
77    let obj_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
78    cursor += 2;
79    if &buffer[cursor..cursor + obj_len] != b"TextGrid" {
80        return Err(TextGridError::Format("Invalid object class".into()));
81    }
82    cursor += obj_len;
83
84    let xmin = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
85    cursor += 8;
86    let xmax = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
87    cursor += 8;
88    let size = u32::from_le_bytes(buffer[cursor..cursor + 4].try_into().unwrap()) as usize;
89    cursor += 4;
90
91    let mut tiers = Vec::with_capacity(size);
92    for _ in 0..size {
93        let class_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
94        cursor += 2;
95        let class = String::from_utf8(buffer[cursor..cursor + class_len].to_vec())?;
96        let tier_type = if class == "IntervalTier" {
97            TierType::IntervalTier
98        } else if class == "TextTier" {
99            TierType::PointTier
100        } else {
101            return Err(TextGridError::Format("Unknown tier type".into()));
102        };
103        cursor += class_len;
104
105        let name_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
106        cursor += 2;
107        let name = String::from_utf8(buffer[cursor..cursor + name_len].to_vec())?;
108        cursor += name_len;
109
110        let tier_xmin = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
111        cursor += 8;
112        let tier_xmax = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
113        cursor += 8;
114
115        let count = u32::from_le_bytes(buffer[cursor..cursor + 4].try_into().unwrap()) as usize;
116        cursor += 4;
117
118        let mut intervals = Vec::new();
119        let mut points = Vec::new();
120        match tier_type {
121            TierType::IntervalTier => {
122                for _ in 0..count {
123                    let xmin = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
124                    cursor += 8;
125                    let xmax = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
126                    cursor += 8;
127                    let text_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
128                    cursor += 2;
129                    let text = String::from_utf8(buffer[cursor..cursor + text_len].to_vec())?;
130                    cursor += text_len;
131                    intervals.push(Interval { xmin, xmax, text });
132                }
133            }
134            TierType::PointTier => {
135                for _ in 0..count {
136                    let time = f64::from_le_bytes(buffer[cursor..cursor + 8].try_into().unwrap());
137                    cursor += 8;
138                    let mark_len = u16::from_le_bytes(buffer[cursor..cursor + 2].try_into().unwrap()) as usize;
139                    cursor += 2;
140                    let mark = String::from_utf8(buffer[cursor..cursor + mark_len].to_vec())?;
141                    cursor += mark_len;
142                    points.push(Point { time, mark });
143                }
144            }
145        }
146
147        tiers.push(Tier { name, tier_type, xmin: tier_xmin, xmax: tier_xmax, intervals, points });
148    }
149
150    Ok(TextGrid::new(xmin, xmax)?.with_tiers(tiers))
151}
152
153/// Writes a `TextGrid` to a Praat `.TextGrid` file in binary format.
154///
155/// # Arguments
156/// * `textgrid` - The `TextGrid` to write.
157/// * `path` - Path to the output file, implementing `AsRef<Path>`.
158///
159/// # Returns
160/// Returns a `Result` indicating success (`Ok(())`) or a `TextGridError`.
161///
162/// # Errors
163/// - `TextGridError::IO` if the file cannot be created or written to.
164///
165/// # Examples
166/// ```rust
167/// let tg = TextGrid::new(0.0, 5.0).unwrap(); // Assume tiers are added
168/// textgrid::write_binary(&tg, "test.TextGrid").unwrap();
169/// ```
170pub fn write_binary<P: AsRef<Path>>(textgrid: &TextGrid, path: P) -> Result<(), TextGridError> {
171    let file = File::create(path)?;
172    let mut writer = BufWriter::new(file);
173
174    writer.write_all(b"ooBinaryFile")?;
175    let class = b"TextGrid";
176    writer.write_all(&(class.len() as u16).to_le_bytes())?;
177    writer.write_all(class)?;
178    writer.write_all(&textgrid.xmin.to_le_bytes())?;
179    writer.write_all(&textgrid.xmax.to_le_bytes())?;
180    writer.write_all(&(textgrid.tiers.len() as u32).to_le_bytes())?;
181
182    for tier in &textgrid.tiers {
183        let class = match tier.tier_type {
184            TierType::IntervalTier => b"IntervalTier" as &[u8],
185            TierType::PointTier => b"TextTier" as &[u8],
186        };
187        writer.write_all(&(class.len() as u16).to_le_bytes())?;
188        writer.write_all(class)?;
189
190        let name_bytes = tier.name.as_bytes();
191        writer.write_all(&(name_bytes.len() as u16).to_le_bytes())?;
192        writer.write_all(name_bytes)?;
193
194        writer.write_all(&tier.xmin.to_le_bytes())?;
195        writer.write_all(&tier.xmax.to_le_bytes())?;
196
197        match tier.tier_type {
198            TierType::IntervalTier => {
199                writer.write_all(&(tier.intervals.len() as u32).to_le_bytes())?;
200                for interval in &tier.intervals {
201                    writer.write_all(&interval.xmin.to_le_bytes())?;
202                    writer.write_all(&interval.xmax.to_le_bytes())?;
203                    let text_bytes = interval.text.as_bytes();
204                    writer.write_all(&(text_bytes.len() as u16).to_le_bytes())?;
205                    writer.write_all(text_bytes)?;
206                }
207            }
208            TierType::PointTier => {
209                writer.write_all(&(tier.points.len() as u32).to_le_bytes())?;
210                for point in &tier.points {
211                    writer.write_all(&point.time.to_le_bytes())?;
212                    let mark_bytes = point.mark.as_bytes();
213                    writer.write_all(&(mark_bytes.len() as u16).to_le_bytes())?;
214                    writer.write_all(mark_bytes)?;
215                }
216            }
217        }
218    }
219    writer.flush()?;
220    Ok(())
221}