Skip to main content

piper_plus/
timing.rs

1//! Phoneme timing extraction from ONNX model duration output.
2//!
3//! VITS models optionally output a `durations` tensor [1, phoneme_length]
4//! containing the number of frames (hop_length-sized) each phoneme occupies.
5//! This module converts frame counts to millisecond timestamps.
6
7use serde::Serialize;
8
9use crate::error::PiperError;
10
11/// Default hop length for VITS models
12pub const DEFAULT_HOP_LENGTH: usize = 256;
13
14/// Timing information for a single phoneme
15#[derive(Debug, Clone, Serialize)]
16pub struct PhonemeTimingInfo {
17    pub phoneme: String,
18    pub start_ms: f64,
19    pub end_ms: f64,
20    pub duration_ms: f64,
21}
22
23/// Complete timing result for a synthesized utterance
24#[derive(Debug, Clone, Serialize)]
25pub struct TimingResult {
26    pub phonemes: Vec<PhonemeTimingInfo>,
27    pub total_duration_ms: f64,
28    pub sample_rate: u32,
29}
30
31impl TimingResult {
32    /// Serialize to JSON string (pretty-printed)
33    pub fn to_json(&self) -> Result<String, PiperError> {
34        serde_json::to_string_pretty(self).map_err(PiperError::from)
35    }
36
37    /// Serialize to JSON string (compact, one line per phoneme)
38    pub fn to_json_compact(&self) -> Result<String, PiperError> {
39        serde_json::to_string(self).map_err(PiperError::from)
40    }
41
42    /// Serialize to TSV string (tab-separated: start_ms, end_ms, duration_ms, phoneme)
43    pub fn to_tsv(&self) -> String {
44        let mut buf = String::from("start_ms\tend_ms\tduration_ms\tphoneme\n");
45        for p in &self.phonemes {
46            buf.push_str(&format!(
47                "{:.3}\t{:.3}\t{:.3}\t{}\n",
48                p.start_ms, p.end_ms, p.duration_ms, p.phoneme
49            ));
50        }
51        buf
52    }
53
54    /// Serialize to SRT-like subtitle format
55    pub fn to_srt(&self) -> String {
56        let mut buf = String::new();
57        for (i, p) in self.phonemes.iter().enumerate() {
58            let idx = i + 1;
59            let start = format_srt_timestamp(p.start_ms);
60            let end = format_srt_timestamp(p.end_ms);
61            buf.push_str(&format!("{idx}\n{start} --> {end}\n{}\n\n", p.phoneme));
62        }
63        buf
64    }
65}
66
67/// Format milliseconds as SRT timestamp: HH:MM:SS,mmm
68fn format_srt_timestamp(ms: f64) -> String {
69    let total_ms = ms.round() as u64;
70    let millis = total_ms % 1000;
71    let total_secs = total_ms / 1000;
72    let secs = total_secs % 60;
73    let total_mins = total_secs / 60;
74    let mins = total_mins % 60;
75    let hours = total_mins / 60;
76    format!("{hours:02}:{mins:02}:{secs:02},{millis:03}")
77}
78
79/// Convert duration tensor output to timing information.
80///
81/// # Arguments
82/// * `durations` - Duration values from ONNX output tensor [phoneme_length]
83/// * `phoneme_tokens` - Corresponding phoneme token strings
84/// * `sample_rate` - Audio sample rate (e.g., 22050)
85/// * `hop_length` - STFT hop length (typically 256 for VITS)
86///
87/// # Returns
88/// TimingResult with start/end timestamps for each phoneme
89pub fn durations_to_timing(
90    durations: &[f32],
91    phoneme_tokens: &[String],
92    sample_rate: u32,
93    hop_length: usize,
94) -> Result<TimingResult, PiperError> {
95    if durations.len() != phoneme_tokens.len() {
96        return Err(PiperError::Inference(format!(
97            "durations length ({}) != phoneme_tokens length ({})",
98            durations.len(),
99            phoneme_tokens.len()
100        )));
101    }
102
103    if sample_rate == 0 {
104        return Err(PiperError::Inference("sample_rate must be > 0".to_string()));
105    }
106
107    if hop_length == 0 {
108        return Err(PiperError::Inference("hop_length must be > 0".to_string()));
109    }
110
111    // Time in seconds for one frame
112    let frame_time_s = hop_length as f64 / sample_rate as f64;
113    let frame_time_ms = frame_time_s * 1000.0;
114
115    let mut phonemes = Vec::with_capacity(durations.len());
116    let mut cursor_ms: f64 = 0.0;
117
118    for (dur, token) in durations.iter().zip(phoneme_tokens.iter()) {
119        let dur_frames = (*dur).max(0.0) as f64;
120        let duration_ms = dur_frames * frame_time_ms;
121        let start_ms = cursor_ms;
122        let end_ms = cursor_ms + duration_ms;
123
124        phonemes.push(PhonemeTimingInfo {
125            phoneme: token.clone(),
126            start_ms,
127            end_ms,
128            duration_ms,
129        });
130
131        cursor_ms = end_ms;
132    }
133
134    Ok(TimingResult {
135        total_duration_ms: cursor_ms,
136        phonemes,
137        sample_rate,
138    })
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    // ---------------------------------------------------------------
146    // Helper
147    // ---------------------------------------------------------------
148
149    fn tokens(names: &[&str]) -> Vec<String> {
150        names.iter().map(|s| s.to_string()).collect()
151    }
152
153    // ---------------------------------------------------------------
154    // 1. Basic duration conversion (known values)
155    // ---------------------------------------------------------------
156
157    #[test]
158    fn test_basic_conversion_22050() {
159        // sample_rate=22050, hop=256 => frame_time = 256/22050 s ~ 11.6099 ms
160        let durations = vec![10.0, 20.0, 5.0];
161        let toks = tokens(&["a", "b", "c"]);
162        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
163
164        let frame_ms = 256.0 / 22050.0 * 1000.0;
165
166        assert_eq!(result.phonemes.len(), 3);
167        assert_eq!(result.sample_rate, 22050);
168
169        // phoneme "a": start=0, dur=10*frame_ms
170        assert!((result.phonemes[0].start_ms - 0.0).abs() < 1e-6);
171        assert!((result.phonemes[0].duration_ms - 10.0 * frame_ms).abs() < 1e-6);
172        assert!((result.phonemes[0].end_ms - 10.0 * frame_ms).abs() < 1e-6);
173
174        // phoneme "b": start=10*frame_ms, dur=20*frame_ms
175        assert!((result.phonemes[1].start_ms - 10.0 * frame_ms).abs() < 1e-6);
176        assert!((result.phonemes[1].duration_ms - 20.0 * frame_ms).abs() < 1e-6);
177
178        // phoneme "c": start=30*frame_ms, dur=5*frame_ms
179        assert!((result.phonemes[2].start_ms - 30.0 * frame_ms).abs() < 1e-6);
180
181        // total = 35 * frame_ms
182        assert!((result.total_duration_ms - 35.0 * frame_ms).abs() < 1e-6);
183    }
184
185    // ---------------------------------------------------------------
186    // 2. Empty durations
187    // ---------------------------------------------------------------
188
189    #[test]
190    fn test_empty_durations() {
191        let result = durations_to_timing(&[], &[], 22050, 256).unwrap();
192        assert!(result.phonemes.is_empty());
193        assert!((result.total_duration_ms - 0.0).abs() < 1e-6);
194    }
195
196    // ---------------------------------------------------------------
197    // 3. Single phoneme
198    // ---------------------------------------------------------------
199
200    #[test]
201    fn test_single_phoneme() {
202        let durations = vec![8.0];
203        let toks = tokens(&["k"]);
204        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
205
206        assert_eq!(result.phonemes.len(), 1);
207        assert!((result.phonemes[0].start_ms - 0.0).abs() < 1e-6);
208
209        let frame_ms = 256.0 / 22050.0 * 1000.0;
210        assert!((result.phonemes[0].duration_ms - 8.0 * frame_ms).abs() < 1e-6);
211        assert!((result.total_duration_ms - 8.0 * frame_ms).abs() < 1e-6);
212    }
213
214    // ---------------------------------------------------------------
215    // 4. Zero durations
216    // ---------------------------------------------------------------
217
218    #[test]
219    fn test_zero_durations() {
220        let durations = vec![0.0, 10.0, 0.0];
221        let toks = tokens(&["^", "a", "_"]);
222        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
223
224        assert_eq!(result.phonemes.len(), 3);
225
226        // First phoneme has zero duration
227        assert!((result.phonemes[0].duration_ms - 0.0).abs() < 1e-6);
228        assert!((result.phonemes[0].start_ms - result.phonemes[0].end_ms).abs() < 1e-6);
229
230        // Second phoneme starts at 0 too
231        assert!((result.phonemes[1].start_ms - 0.0).abs() < 1e-6);
232
233        // Third phoneme starts at 10*frame_ms, zero duration
234        let frame_ms = 256.0 / 22050.0 * 1000.0;
235        assert!((result.phonemes[2].start_ms - 10.0 * frame_ms).abs() < 1e-6);
236        assert!((result.phonemes[2].duration_ms - 0.0).abs() < 1e-6);
237    }
238
239    // ---------------------------------------------------------------
240    // 5. Mismatched lengths error
241    // ---------------------------------------------------------------
242
243    #[test]
244    fn test_mismatched_lengths() {
245        let durations = vec![1.0, 2.0, 3.0];
246        let toks = tokens(&["a", "b"]);
247        let err = durations_to_timing(&durations, &toks, 22050, 256).unwrap_err();
248        let msg = err.to_string();
249        assert!(msg.contains("3"));
250        assert!(msg.contains("2"));
251    }
252
253    // ---------------------------------------------------------------
254    // 6. JSON pretty-print serialization roundtrip
255    // ---------------------------------------------------------------
256
257    #[test]
258    fn test_json_roundtrip() {
259        let durations = vec![5.0, 15.0];
260        let toks = tokens(&["h", "i"]);
261        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
262
263        let json = result.to_json().unwrap();
264
265        // Deserialize back to verify structure
266        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
267        assert!(parsed.is_object());
268        assert!(parsed["phonemes"].is_array());
269        assert_eq!(parsed["phonemes"].as_array().unwrap().len(), 2);
270        assert_eq!(parsed["sample_rate"].as_u64().unwrap(), 22050);
271
272        let first = &parsed["phonemes"][0];
273        assert_eq!(first["phoneme"].as_str().unwrap(), "h");
274        assert!((first["start_ms"].as_f64().unwrap() - 0.0).abs() < 1e-6);
275    }
276
277    // ---------------------------------------------------------------
278    // 7. JSON compact serialization
279    // ---------------------------------------------------------------
280
281    #[test]
282    fn test_json_compact() {
283        let durations = vec![3.0];
284        let toks = tokens(&["x"]);
285        let result = durations_to_timing(&durations, &toks, 16000, 256).unwrap();
286
287        let json_compact = result.to_json_compact().unwrap();
288        // Compact should not contain newlines (single-line JSON)
289        assert!(!json_compact.contains('\n'));
290        assert!(json_compact.contains("\"phoneme\":\"x\""));
291    }
292
293    // ---------------------------------------------------------------
294    // 8. TSV format correctness
295    // ---------------------------------------------------------------
296
297    #[test]
298    fn test_tsv_format() {
299        let durations = vec![10.0, 20.0];
300        let toks = tokens(&["p", "q"]);
301        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
302
303        let tsv = result.to_tsv();
304        let lines: Vec<&str> = tsv.lines().collect();
305
306        // Header line
307        assert_eq!(lines[0], "start_ms\tend_ms\tduration_ms\tphoneme");
308
309        // Data rows
310        assert_eq!(lines.len(), 3); // header + 2 phonemes
311
312        // First row should start at 0.000
313        assert!(lines[1].starts_with("0.000\t"));
314        assert!(lines[1].ends_with("\tp"));
315
316        // Second row phoneme should be "q"
317        assert!(lines[2].ends_with("\tq"));
318    }
319
320    // ---------------------------------------------------------------
321    // 9. TSV empty input
322    // ---------------------------------------------------------------
323
324    #[test]
325    fn test_tsv_empty() {
326        let result = durations_to_timing(&[], &[], 22050, 256).unwrap();
327        let tsv = result.to_tsv();
328        let lines: Vec<&str> = tsv.lines().collect();
329        assert_eq!(lines.len(), 1); // header only
330    }
331
332    // ---------------------------------------------------------------
333    // 10. SRT format correctness
334    // ---------------------------------------------------------------
335
336    #[test]
337    fn test_srt_format() {
338        // Use easy numbers: sample_rate=1000, hop_length=1 => 1 frame = 1 ms
339        let durations = vec![500.0, 1500.0, 3000.0];
340        let toks = tokens(&["a", "bb", "c"]);
341        let result = durations_to_timing(&durations, &toks, 1000, 1).unwrap();
342
343        let srt = result.to_srt();
344        let blocks: Vec<&str> = srt.split("\n\n").filter(|b| !b.is_empty()).collect();
345        assert_eq!(blocks.len(), 3);
346
347        // First block: index 1, 00:00:00,000 --> 00:00:00,500, phoneme "a"
348        let lines0: Vec<&str> = blocks[0].lines().collect();
349        assert_eq!(lines0[0], "1");
350        assert_eq!(lines0[1], "00:00:00,000 --> 00:00:00,500");
351        assert_eq!(lines0[2], "a");
352
353        // Second block: index 2, 00:00:00,500 --> 00:00:02,000, phoneme "bb"
354        let lines1: Vec<&str> = blocks[1].lines().collect();
355        assert_eq!(lines1[0], "2");
356        assert_eq!(lines1[1], "00:00:00,500 --> 00:00:02,000");
357        assert_eq!(lines1[2], "bb");
358
359        // Third block: 00:00:02,000 --> 00:00:05,000
360        let lines2: Vec<&str> = blocks[2].lines().collect();
361        assert_eq!(lines2[0], "3");
362        assert_eq!(lines2[1], "00:00:02,000 --> 00:00:05,000");
363        assert_eq!(lines2[2], "c");
364    }
365
366    // ---------------------------------------------------------------
367    // 11. SRT timestamp with hours/minutes
368    // ---------------------------------------------------------------
369
370    #[test]
371    fn test_srt_large_timestamps() {
372        // 90 minutes + 5 seconds + 123 ms = 5,405,123 ms
373        // Use sample_rate=1000, hop=1 so frames = ms directly
374        let dur_ms = 5_405_123.0_f32;
375        let durations = vec![dur_ms];
376        let toks = tokens(&["long"]);
377        let result = durations_to_timing(&durations, &toks, 1000, 1).unwrap();
378
379        let srt = result.to_srt();
380        assert!(srt.contains("00:00:00,000 --> 01:30:05,123"));
381    }
382
383    // ---------------------------------------------------------------
384    // 12. Sample rate 16000
385    // ---------------------------------------------------------------
386
387    #[test]
388    fn test_sample_rate_16000() {
389        let durations = vec![16.0];
390        let toks = tokens(&["z"]);
391        let result = durations_to_timing(&durations, &toks, 16000, 256).unwrap();
392
393        // frame_ms = 256/16000*1000 = 16.0 ms
394        // duration_ms = 16 frames * 16.0 ms = 256.0 ms
395        let expected_ms = 16.0 * (256.0 / 16000.0 * 1000.0);
396        assert!((result.phonemes[0].duration_ms - expected_ms).abs() < 1e-6);
397        assert!((result.total_duration_ms - expected_ms).abs() < 1e-6);
398    }
399
400    // ---------------------------------------------------------------
401    // 13. Sample rate 44100
402    // ---------------------------------------------------------------
403
404    #[test]
405    fn test_sample_rate_44100() {
406        let durations = vec![100.0];
407        let toks = tokens(&["w"]);
408        let result = durations_to_timing(&durations, &toks, 44100, 256).unwrap();
409
410        let frame_ms = 256.0 / 44100.0 * 1000.0;
411        let expected_ms = 100.0 * frame_ms;
412        assert!((result.phonemes[0].duration_ms - expected_ms).abs() < 1e-6);
413        assert_eq!(result.sample_rate, 44100);
414    }
415
416    // ---------------------------------------------------------------
417    // 14. Large duration values
418    // ---------------------------------------------------------------
419
420    #[test]
421    fn test_large_duration_values() {
422        let durations = vec![100_000.0, 200_000.0];
423        let toks = tokens(&["aa", "bb"]);
424        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
425
426        let frame_ms = 256.0 / 22050.0 * 1000.0;
427        let expected_total = 300_000.0 * frame_ms;
428        assert!((result.total_duration_ms - expected_total).abs() < 1e-3);
429
430        // Second phoneme starts after the first
431        assert!((result.phonemes[1].start_ms - 100_000.0 * frame_ms).abs() < 1e-3);
432    }
433
434    // ---------------------------------------------------------------
435    // 15. Floating point precision -- cumulative sum stays accurate
436    // ---------------------------------------------------------------
437
438    #[test]
439    fn test_floating_point_precision() {
440        // Many small durations to test accumulation
441        let n = 1000;
442        let durations: Vec<f32> = vec![1.0; n];
443        let toks: Vec<String> = (0..n).map(|i| format!("p{i}")).collect();
444        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
445
446        let frame_ms = 256.0 / 22050.0 * 1000.0;
447        let expected_total = n as f64 * frame_ms;
448
449        // Total should be very close despite 1000 additions
450        assert!(
451            (result.total_duration_ms - expected_total).abs() < 0.01,
452            "total={} expected={}",
453            result.total_duration_ms,
454            expected_total
455        );
456
457        // Last phoneme end should equal total
458        let last = result.phonemes.last().unwrap();
459        assert!((last.end_ms - result.total_duration_ms).abs() < 1e-9);
460    }
461
462    // ---------------------------------------------------------------
463    // 16. Negative duration values are clamped to zero
464    // ---------------------------------------------------------------
465
466    #[test]
467    fn test_negative_durations_clamped() {
468        let durations = vec![-5.0, 10.0, -1.0];
469        let toks = tokens(&["a", "b", "c"]);
470        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
471
472        // Negative durations should be treated as 0
473        assert!((result.phonemes[0].duration_ms - 0.0).abs() < 1e-6);
474        assert!((result.phonemes[2].duration_ms - 0.0).abs() < 1e-6);
475
476        // Only phoneme "b" contributes to total
477        let frame_ms = 256.0 / 22050.0 * 1000.0;
478        assert!((result.total_duration_ms - 10.0 * frame_ms).abs() < 1e-6);
479    }
480
481    // ---------------------------------------------------------------
482    // 17. Zero sample_rate error
483    // ---------------------------------------------------------------
484
485    #[test]
486    fn test_zero_sample_rate_error() {
487        let durations = vec![1.0];
488        let toks = tokens(&["a"]);
489        let err = durations_to_timing(&durations, &toks, 0, 256).unwrap_err();
490        assert!(err.to_string().contains("sample_rate"));
491    }
492
493    // ---------------------------------------------------------------
494    // 18. Zero hop_length error
495    // ---------------------------------------------------------------
496
497    #[test]
498    fn test_zero_hop_length_error() {
499        let durations = vec![1.0];
500        let toks = tokens(&["a"]);
501        let err = durations_to_timing(&durations, &toks, 22050, 0).unwrap_err();
502        assert!(err.to_string().contains("hop_length"));
503    }
504
505    // ---------------------------------------------------------------
506    // 19. DEFAULT_HOP_LENGTH constant value
507    // ---------------------------------------------------------------
508
509    #[test]
510    fn test_default_hop_length() {
511        assert_eq!(DEFAULT_HOP_LENGTH, 256);
512    }
513
514    // ---------------------------------------------------------------
515    // 20. Phoneme ordering preserved
516    // ---------------------------------------------------------------
517
518    #[test]
519    fn test_phoneme_ordering_preserved() {
520        let durations = vec![1.0, 2.0, 3.0, 4.0, 5.0];
521        let toks = tokens(&["^", "k", "o", "N", "_"]);
522        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
523
524        let names: Vec<&str> = result.phonemes.iter().map(|p| p.phoneme.as_str()).collect();
525        assert_eq!(names, vec!["^", "k", "o", "N", "_"]);
526
527        // Each start equals previous end
528        for i in 1..result.phonemes.len() {
529            assert!(
530                (result.phonemes[i].start_ms - result.phonemes[i - 1].end_ms).abs() < 1e-9,
531                "gap between phoneme {} and {}",
532                i - 1,
533                i
534            );
535        }
536    }
537
538    // ---------------------------------------------------------------
539    // 21. TSV field values match JSON values
540    // ---------------------------------------------------------------
541
542    #[test]
543    fn test_tsv_and_json_consistency() {
544        let durations = vec![7.0, 13.0];
545        let toks = tokens(&["s", "t"]);
546        let result = durations_to_timing(&durations, &toks, 22050, 256).unwrap();
547
548        let json_str = result.to_json().unwrap();
549        let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
550
551        let tsv = result.to_tsv();
552        let data_lines: Vec<&str> = tsv.lines().skip(1).collect();
553
554        for (i, line) in data_lines.iter().enumerate() {
555            let fields: Vec<&str> = line.split('\t').collect();
556            assert_eq!(fields.len(), 4);
557
558            let tsv_start: f64 = fields[0].parse().unwrap();
559            let tsv_end: f64 = fields[1].parse().unwrap();
560            let tsv_dur: f64 = fields[2].parse().unwrap();
561            let tsv_phoneme = fields[3];
562
563            let json_ph = &parsed["phonemes"][i];
564            let json_start = json_ph["start_ms"].as_f64().unwrap();
565            let json_end = json_ph["end_ms"].as_f64().unwrap();
566            let json_phoneme = json_ph["phoneme"].as_str().unwrap();
567
568            assert!((tsv_start - json_start).abs() < 0.01);
569            assert!((tsv_end - json_end).abs() < 0.01);
570            assert!(tsv_dur > 0.0 || (tsv_dur - 0.0).abs() < 1e-6);
571            assert_eq!(tsv_phoneme, json_phoneme);
572        }
573    }
574
575    // ---------------------------------------------------------------
576    // 22. Phoneme name containing tab in TSV output
577    // ---------------------------------------------------------------
578
579    #[test]
580    fn test_tsv_phoneme_with_tab() {
581        // A phoneme token that contains a literal tab character.
582        // The current TSV writer does not escape it, so the tab will
583        // appear as an extra column, producing 5 fields instead of 4.
584        let durations = vec![5.0];
585        let toks = vec!["a\tb".to_string()];
586        let result = durations_to_timing(&durations, &toks, 1000, 1).unwrap();
587
588        let tsv = result.to_tsv();
589        let data_line = tsv.lines().nth(1).expect("expected a data line");
590        let fields: Vec<&str> = data_line.split('\t').collect();
591
592        // Tab inside the phoneme name splits the field, yielding 5 columns.
593        assert_eq!(
594            fields.len(),
595            5,
596            "tab inside phoneme name produces an extra TSV column"
597        );
598    }
599
600    // ---------------------------------------------------------------
601    // 23. Phoneme name containing newline in SRT output
602    // ---------------------------------------------------------------
603
604    #[test]
605    fn test_srt_phoneme_with_newline() {
606        // A phoneme token with an embedded newline will split the
607        // subtitle text across two visual lines inside the SRT entry.
608        // The entry count should still be 1 since entries are delimited
609        // by blank lines ("\n\n").
610        let durations = vec![10.0];
611        let toks = vec!["line1\nline2".to_string()];
612        let result = durations_to_timing(&durations, &toks, 1000, 1).unwrap();
613
614        let srt = result.to_srt();
615
616        // The block delimiter is "\n\n".  Because the phoneme itself
617        // contains "\n", we verify that the index "1" and the arrow
618        // marker are still present and structurally correct.
619        assert!(srt.contains("1\n"));
620        assert!(srt.contains(" --> "));
621        assert!(srt.contains("line1\nline2"));
622    }
623
624    // ---------------------------------------------------------------
625    // 24. Duration with NaN — clamped to 0 by f32::max(0.0)
626    // ---------------------------------------------------------------
627
628    #[test]
629    fn test_nan_duration() {
630        let durations = vec![f32::NAN, 10.0];
631        let toks = tokens(&["nan_ph", "ok"]);
632        let result = durations_to_timing(&durations, &toks, 1000, 1).unwrap();
633
634        // Rust's f32::max returns the non-NaN argument when one operand
635        // is NaN, so NAN.max(0.0) == 0.0. The NaN is effectively clamped.
636        assert!(
637            (result.phonemes[0].duration_ms - 0.0).abs() < 1e-9,
638            "NaN duration is clamped to 0 by f32::max"
639        );
640        assert!(
641            (result.phonemes[0].start_ms - result.phonemes[0].end_ms).abs() < 1e-9,
642            "start == end for zero-duration phoneme"
643        );
644
645        // The second phoneme should still have a valid duration.
646        assert!(
647            (result.phonemes[1].duration_ms - 10.0).abs() < 1e-6,
648            "non-NaN phoneme keeps its value"
649        );
650
651        // Total should only reflect the valid phoneme
652        assert!(
653            (result.total_duration_ms - 10.0).abs() < 1e-6,
654            "total reflects only the non-NaN phoneme"
655        );
656    }
657
658    // ---------------------------------------------------------------
659    // 25. Duration with Infinity — propagates as infinite ms
660    // ---------------------------------------------------------------
661
662    #[test]
663    fn test_infinity_duration() {
664        let durations = vec![f32::INFINITY];
665        let toks = tokens(&["inf_ph"]);
666        let result = durations_to_timing(&durations, &toks, 1000, 1).unwrap();
667
668        assert!(
669            result.phonemes[0].duration_ms.is_infinite(),
670            "Infinity duration propagates"
671        );
672        assert!(
673            result.total_duration_ms.is_infinite(),
674            "total also becomes infinite"
675        );
676    }
677
678    // ---------------------------------------------------------------
679    // 26. Unicode / IPA phoneme names in all formats
680    // ---------------------------------------------------------------
681
682    #[test]
683    fn test_unicode_phoneme_names() {
684        let ipa_tokens = vec![
685            "\u{0251}\u{02D0}".to_string(), // ɑː
686            "\u{0283}".to_string(),         // ʃ
687            "\u{014B}".to_string(),         // ŋ
688        ];
689        let durations = vec![5.0, 3.0, 7.0];
690        let result = durations_to_timing(&durations, &ipa_tokens, 1000, 1).unwrap();
691
692        // Verify phoneme names are preserved
693        assert_eq!(result.phonemes[0].phoneme, "\u{0251}\u{02D0}");
694        assert_eq!(result.phonemes[1].phoneme, "\u{0283}");
695        assert_eq!(result.phonemes[2].phoneme, "\u{014B}");
696
697        // JSON roundtrip preserves Unicode
698        let json = result.to_json().unwrap();
699        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
700        assert_eq!(
701            parsed["phonemes"][0]["phoneme"].as_str().unwrap(),
702            "\u{0251}\u{02D0}"
703        );
704
705        // TSV contains the Unicode characters
706        let tsv = result.to_tsv();
707        assert!(tsv.contains("\u{0251}\u{02D0}"));
708        assert!(tsv.contains("\u{0283}"));
709        assert!(tsv.contains("\u{014B}"));
710
711        // SRT contains the Unicode characters
712        let srt = result.to_srt();
713        assert!(srt.contains("\u{0251}\u{02D0}"));
714        assert!(srt.contains("\u{0283}"));
715        assert!(srt.contains("\u{014B}"));
716    }
717
718    // ---------------------------------------------------------------
719    // 27. Very small durations preserve precision
720    // ---------------------------------------------------------------
721
722    #[test]
723    fn test_very_small_durations_precision() {
724        // 0.001 frames at sample_rate=1000, hop=1 => 0.001 ms per frame
725        let durations = vec![0.001_f32];
726        let toks = tokens(&["tiny"]);
727        let result = durations_to_timing(&durations, &toks, 1000, 1).unwrap();
728
729        // frame_time_ms = 1.0 ms; duration = 0.001 * 1.0 = 0.001 ms
730        let expected = 0.001_f64;
731        assert!(
732            (result.phonemes[0].duration_ms - expected).abs() < 1e-9,
733            "very small duration: got {} expected {}",
734            result.phonemes[0].duration_ms,
735            expected
736        );
737
738        // TSV should render with sub-millisecond precision via {:.3}
739        let tsv = result.to_tsv();
740        let data_line = tsv.lines().nth(1).unwrap();
741        // The duration field (3rd column) should be "0.001"
742        let fields: Vec<&str> = data_line.split('\t').collect();
743        assert_eq!(fields[2], "0.001");
744    }
745
746    // ---------------------------------------------------------------
747    // 28. TimingResult direct construction and field access
748    // ---------------------------------------------------------------
749
750    #[test]
751    fn test_timing_result_direct_construction() {
752        let timing = TimingResult {
753            phonemes: vec![
754                PhonemeTimingInfo {
755                    phoneme: "hello".to_string(),
756                    start_ms: 0.0,
757                    end_ms: 100.5,
758                    duration_ms: 100.5,
759                },
760                PhonemeTimingInfo {
761                    phoneme: "world".to_string(),
762                    start_ms: 100.5,
763                    end_ms: 250.0,
764                    duration_ms: 149.5,
765                },
766            ],
767            total_duration_ms: 250.0,
768            sample_rate: 48000,
769        };
770
771        // Field access
772        assert_eq!(timing.phonemes.len(), 2);
773        assert_eq!(timing.phonemes[0].phoneme, "hello");
774        assert_eq!(timing.phonemes[1].phoneme, "world");
775        assert!((timing.phonemes[0].start_ms - 0.0).abs() < 1e-9);
776        assert!((timing.phonemes[0].end_ms - 100.5).abs() < 1e-9);
777        assert!((timing.phonemes[0].duration_ms - 100.5).abs() < 1e-9);
778        assert!((timing.phonemes[1].start_ms - 100.5).abs() < 1e-9);
779        assert!((timing.phonemes[1].end_ms - 250.0).abs() < 1e-9);
780        assert!((timing.phonemes[1].duration_ms - 149.5).abs() < 1e-9);
781        assert!((timing.total_duration_ms - 250.0).abs() < 1e-9);
782        assert_eq!(timing.sample_rate, 48000);
783
784        // Clone trait works
785        let cloned = timing.clone();
786        assert_eq!(cloned.phonemes.len(), timing.phonemes.len());
787        assert_eq!(cloned.sample_rate, timing.sample_rate);
788
789        // Serialization works on directly constructed structs
790        let json = timing.to_json().unwrap();
791        assert!(json.contains("\"hello\""));
792        assert!(json.contains("\"world\""));
793        assert!(json.contains("48000"));
794    }
795
796    // ---------------------------------------------------------------
797    // 29. JSON serializes non-finite f64 as null (serde_json >=1.0.128)
798    // ---------------------------------------------------------------
799
800    #[test]
801    fn test_json_nonfinite_serialized_as_null() {
802        // serde_json >= 1.0.128 serializes NaN / Infinity as JSON null
803        // rather than returning an error.  Verify this behaviour so that
804        // callers know what to expect when a TimingResult contains
805        // non-finite values (e.g. from an Infinity input duration).
806        let timing = TimingResult {
807            phonemes: vec![PhonemeTimingInfo {
808                phoneme: "inf".to_string(),
809                start_ms: 0.0,
810                end_ms: f64::INFINITY,
811                duration_ms: f64::INFINITY,
812            }],
813            total_duration_ms: f64::INFINITY,
814            sample_rate: 22050,
815        };
816
817        // to_json should succeed (not error)
818        let json = timing.to_json().expect("to_json should succeed");
819        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
820
821        // Non-finite values become null in JSON
822        assert!(
823            parsed["total_duration_ms"].is_null(),
824            "Infinity total_duration_ms serialized as null"
825        );
826        assert!(
827            parsed["phonemes"][0]["end_ms"].is_null(),
828            "Infinity end_ms serialized as null"
829        );
830        assert!(
831            parsed["phonemes"][0]["duration_ms"].is_null(),
832            "Infinity duration_ms serialized as null"
833        );
834
835        // Finite values remain as numbers
836        assert!(
837            parsed["phonemes"][0]["start_ms"].is_number(),
838            "finite start_ms remains a number"
839        );
840
841        // Compact format also succeeds
842        let compact = timing.to_json_compact().expect("compact should succeed");
843        assert!(
844            compact.contains("null"),
845            "compact JSON contains null for Infinity"
846        );
847
848        // PiperError::from(serde_json::Error) conversion is exercised
849        // by to_json / to_json_compact internally via map_err.
850        // Verify the error path is reachable with truly invalid JSON input.
851        let bad_json = "{ not valid json }";
852        let serde_err: Result<serde_json::Value, _> = serde_json::from_str(bad_json);
853        let piper_err: PiperError = serde_err.unwrap_err().into();
854        let msg = piper_err.to_string();
855        assert!(
856            msg.contains("JSON"),
857            "PiperError from serde_json mentions JSON: {}",
858            msg
859        );
860    }
861}