1use std::io::{Seek, Write};
6use std::path::Path;
7
8use crate::error::PiperError;
9
10pub trait AudioSink {
19 fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError>;
21
22 fn finalize(&mut self) -> Result<(), PiperError>;
24}
25
26#[derive(Debug, Clone)]
32pub struct StreamingResult {
33 pub total_audio_seconds: f64,
35 pub total_infer_seconds: f64,
37 pub chunk_count: usize,
39}
40
41pub struct BufferSink {
47 samples: Vec<i16>,
48 sample_rate: Option<u32>,
49}
50
51impl BufferSink {
52 pub fn new() -> Self {
54 Self {
55 samples: Vec::new(),
56 sample_rate: None,
57 }
58 }
59
60 pub fn get_samples(&self) -> &[i16] {
62 &self.samples
63 }
64
65 pub fn sample_rate(&self) -> Option<u32> {
67 self.sample_rate
68 }
69}
70
71impl Default for BufferSink {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl AudioSink for BufferSink {
78 fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
79 self.sample_rate = Some(sample_rate);
80 self.samples.extend_from_slice(samples);
81 Ok(())
82 }
83
84 fn finalize(&mut self) -> Result<(), PiperError> {
85 Ok(())
86 }
87}
88
89pub struct WavFileSink {
99 file: std::fs::File,
100 sample_rate: u32,
101 total_samples: usize,
102 header_written: bool,
103}
104
105impl WavFileSink {
106 pub fn new(path: &Path) -> Result<Self, PiperError> {
111 let file = std::fs::File::create(path)?;
112 Ok(Self {
113 file,
114 sample_rate: 0,
115 total_samples: 0,
116 header_written: false,
117 })
118 }
119
120 fn write_header(&mut self, sample_rate: u32) -> Result<(), PiperError> {
122 let placeholder_data_size: u32 = 0;
123 let placeholder_file_size: u32 = 36; self.file.write_all(b"RIFF")?;
127 self.file.write_all(&placeholder_file_size.to_le_bytes())?;
128 self.file.write_all(b"WAVE")?;
129
130 self.file.write_all(b"fmt ")?;
132 self.file.write_all(&16u32.to_le_bytes())?; self.file.write_all(&1u16.to_le_bytes())?; self.file.write_all(&1u16.to_le_bytes())?; self.file.write_all(&sample_rate.to_le_bytes())?;
136 self.file.write_all(&(sample_rate * 2).to_le_bytes())?; self.file.write_all(&2u16.to_le_bytes())?; self.file.write_all(&16u16.to_le_bytes())?; self.file.write_all(b"data")?;
142 self.file.write_all(&placeholder_data_size.to_le_bytes())?;
143
144 self.sample_rate = sample_rate;
145 self.header_written = true;
146 Ok(())
147 }
148
149 fn update_sizes(&mut self) -> Result<(), PiperError> {
151 let data_size_u64 = (self.total_samples as u64) * 2;
152 if data_size_u64 > u32::MAX as u64 {
153 return Err(PiperError::Streaming(
154 "WAV file exceeds 4GB limit".to_string(),
155 ));
156 }
157 let data_size = data_size_u64 as u32;
158 let file_size = data_size + 36;
159
160 self.file.seek(std::io::SeekFrom::Start(4))?;
162 self.file.write_all(&file_size.to_le_bytes())?;
163
164 self.file.seek(std::io::SeekFrom::Start(40))?;
166 self.file.write_all(&data_size.to_le_bytes())?;
167
168 self.file.flush()?;
170 Ok(())
171 }
172}
173
174impl Drop for WavFileSink {
175 fn drop(&mut self) {
176 let _ = self.finalize();
179 }
180}
181
182impl AudioSink for WavFileSink {
183 fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
184 if !self.header_written {
185 self.write_header(sample_rate)?;
186 }
187
188 if self.sample_rate != sample_rate {
190 return Err(PiperError::Streaming(format!(
191 "sample rate mismatch: expected {}, got {}",
192 self.sample_rate, sample_rate
193 )));
194 }
195
196 let mut buf = Vec::with_capacity(samples.len() * 2);
198 for &sample in samples {
199 buf.extend_from_slice(&sample.to_le_bytes());
200 }
201 self.file.write_all(&buf)?;
202 self.total_samples += samples.len();
203 Ok(())
204 }
205
206 fn finalize(&mut self) -> Result<(), PiperError> {
207 if self.header_written {
208 self.update_sizes()?;
209 }
210 Ok(())
211 }
212}
213
214pub fn crossfade(prev_tail: &[i16], next_head: &[i16], overlap_samples: usize) -> Vec<i16> {
228 let actual_overlap = overlap_samples.min(prev_tail.len()).min(next_head.len());
229
230 if actual_overlap == 0 {
231 return Vec::new();
232 }
233
234 let mut blended = Vec::with_capacity(actual_overlap);
235 for i in 0..actual_overlap {
236 let alpha = if actual_overlap <= 1 {
238 1.0
239 } else {
240 (i as f64) / ((actual_overlap - 1) as f64)
241 };
242 let prev_sample = prev_tail[prev_tail.len() - actual_overlap + i] as f64;
243 let next_sample = next_head[i] as f64;
244 let mixed = prev_sample * (1.0 - alpha) + next_sample * alpha;
245 blended.push(mixed.clamp(-32768.0, 32767.0) as i16);
246 }
247 blended
248}
249
250pub fn split_sentences(text: &str) -> Vec<String> {
263 if text.is_empty() {
264 return Vec::new();
265 }
266
267 let mut sentences = Vec::new();
268 let mut current = String::new();
269
270 let mut chars = text.chars().peekable();
271
272 while let Some(ch) = chars.next() {
273 current.push(ch);
274
275 if is_sentence_terminator(ch) {
277 while let Some(&next_ch) = chars.peek() {
280 if is_closing_punctuation(next_ch) {
281 current.push(chars.next().unwrap());
282 } else {
283 break;
284 }
285 }
286
287 let trimmed = current.trim().to_string();
289 if !trimmed.is_empty() {
290 sentences.push(trimmed);
291 }
292 current.clear();
293
294 while let Some(&next_ch) = chars.peek() {
296 if next_ch.is_whitespace() {
297 chars.next();
298 } else {
299 break;
300 }
301 }
302 }
303 }
304
305 let trimmed = current.trim().to_string();
307 if !trimmed.is_empty() {
308 sentences.push(trimmed);
309 }
310
311 sentences
312}
313
314fn is_sentence_terminator(ch: char) -> bool {
316 matches!(
317 ch,
318 '.' | '!' | '?' | '\u{3002}' | '\u{FF01}' | '\u{FF1F}' )
322}
323
324fn is_closing_punctuation(ch: char) -> bool {
327 matches!(
328 ch,
329 ')' | ']'
330 | '}'
331 | '"'
332 | '\''
333 | '\u{300D}' | '\u{300F}' | '\u{FF09}' | '\u{FF3D}' | '\u{3011}' | '\u{FF63}' )
340}
341
342#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
355 fn test_buffer_sink_collects_samples() {
356 let mut sink = BufferSink::new();
357 sink.write_chunk(&[1, 2, 3], 22050).unwrap();
358 sink.write_chunk(&[4, 5], 22050).unwrap();
359 sink.finalize().unwrap();
360 assert_eq!(sink.get_samples(), &[1, 2, 3, 4, 5]);
361 }
362
363 #[test]
364 fn test_buffer_sink_empty() {
365 let mut sink = BufferSink::new();
366 sink.finalize().unwrap();
367 assert!(sink.get_samples().is_empty());
368 assert_eq!(sink.sample_rate(), None);
369 }
370
371 #[test]
372 fn test_buffer_sink_sample_rate() {
373 let mut sink = BufferSink::new();
374 assert_eq!(sink.sample_rate(), None);
375 sink.write_chunk(&[100], 44100).unwrap();
376 assert_eq!(sink.sample_rate(), Some(44100));
377 }
378
379 #[test]
380 fn test_buffer_sink_default() {
381 let sink = BufferSink::default();
382 assert!(sink.get_samples().is_empty());
383 }
384
385 #[cfg(feature = "onnx")]
390 #[test]
391 fn test_wav_file_sink_writes_valid_wav() {
392 let dir = tempfile::tempdir().unwrap();
393 let wav_path = dir.path().join("test.wav");
394
395 {
396 let mut sink = WavFileSink::new(&wav_path).unwrap();
397 let samples: Vec<i16> = (0..100).collect();
398 sink.write_chunk(&samples, 22050).unwrap();
399 sink.finalize().unwrap();
400 }
401
402 let reader = hound::WavReader::open(&wav_path).unwrap();
404 let spec = reader.spec();
405 assert_eq!(spec.channels, 1);
406 assert_eq!(spec.sample_rate, 22050);
407 assert_eq!(spec.bits_per_sample, 16);
408 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
409 let expected: Vec<i16> = (0..100).collect();
410 assert_eq!(read_samples, expected);
411 }
412
413 #[cfg(feature = "onnx")]
414 #[test]
415 fn test_wav_file_sink_multiple_chunks() {
416 let dir = tempfile::tempdir().unwrap();
417 let wav_path = dir.path().join("multi.wav");
418
419 {
420 let mut sink = WavFileSink::new(&wav_path).unwrap();
421 sink.write_chunk(&[10, 20, 30], 16000).unwrap();
422 sink.write_chunk(&[40, 50], 16000).unwrap();
423 sink.write_chunk(&[60], 16000).unwrap();
424 sink.finalize().unwrap();
425 }
426
427 let reader = hound::WavReader::open(&wav_path).unwrap();
428 assert_eq!(reader.spec().sample_rate, 16000);
429 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
430 assert_eq!(read_samples, vec![10, 20, 30, 40, 50, 60]);
431 }
432
433 #[test]
434 fn test_wav_file_sink_finalize_without_write() {
435 let dir = tempfile::tempdir().unwrap();
436 let wav_path = dir.path().join("empty.wav");
437
438 let mut sink = WavFileSink::new(&wav_path).unwrap();
439 sink.finalize().unwrap();
441 }
442
443 #[test]
448 fn test_crossfade_basic() {
449 let prev = vec![1000i16; 10];
451 let next = vec![0i16; 10];
452 let result = crossfade(&prev, &next, 4);
453 assert_eq!(result.len(), 4);
454 assert_eq!(result[0], 1000);
456 assert_eq!(result[3], 0);
458 }
459
460 #[test]
461 fn test_crossfade_equal_blend() {
462 let prev = vec![100i16; 4];
463 let next = vec![200i16; 4];
464 let result = crossfade(&prev, &next, 4);
465 assert_eq!(result.len(), 4);
466 assert_eq!(result[0], 100);
468 assert_eq!(result[2], 166);
470 }
471
472 #[test]
473 fn test_crossfade_zero_overlap() {
474 let prev = vec![100i16; 5];
475 let next = vec![200i16; 5];
476 let result = crossfade(&prev, &next, 0);
477 assert!(result.is_empty());
478 }
479
480 #[test]
481 fn test_crossfade_overlap_exceeds_prev() {
482 let prev = vec![500i16; 3];
483 let next = vec![0i16; 10];
484 let result = crossfade(&prev, &next, 100);
485 assert_eq!(result.len(), 3);
487 }
488
489 #[test]
490 fn test_crossfade_overlap_exceeds_next() {
491 let prev = vec![500i16; 10];
492 let next = vec![0i16; 2];
493 let result = crossfade(&prev, &next, 100);
494 assert_eq!(result.len(), 2);
496 }
497
498 #[test]
499 fn test_crossfade_empty_slices() {
500 let result = crossfade(&[], &[], 10);
501 assert!(result.is_empty());
502 }
503
504 #[test]
505 fn test_crossfade_one_sample() {
506 let prev = vec![1000i16];
507 let next = vec![0i16];
508 let result = crossfade(&prev, &next, 1);
509 assert_eq!(result.len(), 1);
510 assert_eq!(result[0], 0);
512 }
513
514 #[test]
519 fn test_split_sentences_japanese() {
520 let text = "こんにちは。今日は良い天気ですね。明日も晴れるでしょう。";
521 let result = split_sentences(text);
522 assert_eq!(result.len(), 3);
523 assert_eq!(result[0], "こんにちは。");
524 assert_eq!(result[1], "今日は良い天気ですね。");
525 assert_eq!(result[2], "明日も晴れるでしょう。");
526 }
527
528 #[test]
529 fn test_split_sentences_english() {
530 let text = "Hello world. How are you? I am fine!";
531 let result = split_sentences(text);
532 assert_eq!(result.len(), 3);
533 assert_eq!(result[0], "Hello world.");
534 assert_eq!(result[1], "How are you?");
535 assert_eq!(result[2], "I am fine!");
536 }
537
538 #[test]
539 fn test_split_sentences_mixed_punctuation() {
540 let text = "日本語のテスト。English test! 混合テスト?";
541 let result = split_sentences(text);
542 assert_eq!(result.len(), 3);
543 assert_eq!(result[0], "日本語のテスト。");
544 assert_eq!(result[1], "English test!");
545 assert_eq!(result[2], "混合テスト?");
546 }
547
548 #[test]
549 fn test_split_sentences_fullwidth_punctuation() {
550 let text = "すごい!本当ですか?はい。";
551 let result = split_sentences(text);
552 assert_eq!(result.len(), 3);
553 assert_eq!(result[0], "すごい!");
554 assert_eq!(result[1], "本当ですか?");
555 assert_eq!(result[2], "はい。");
556 }
557
558 #[test]
559 fn test_split_sentences_empty() {
560 let result = split_sentences("");
561 assert!(result.is_empty());
562 }
563
564 #[test]
565 fn test_split_sentences_no_terminator() {
566 let text = "This has no ending punctuation";
567 let result = split_sentences(text);
568 assert_eq!(result.len(), 1);
569 assert_eq!(result[0], "This has no ending punctuation");
570 }
571
572 #[test]
573 fn test_split_sentences_whitespace_only() {
574 let result = split_sentences(" ");
575 assert!(result.is_empty());
576 }
577
578 #[test]
579 fn test_split_sentences_with_closing_brackets() {
580 let text = "「こんにちは。」次の文。";
581 let result = split_sentences(text);
582 assert_eq!(result.len(), 2);
583 assert_eq!(result[0], "「こんにちは。」");
584 assert_eq!(result[1], "次の文。");
585 }
586
587 #[test]
588 fn test_split_sentences_single_sentence() {
589 let text = "一つだけ。";
590 let result = split_sentences(text);
591 assert_eq!(result.len(), 1);
592 assert_eq!(result[0], "一つだけ。");
593 }
594
595 #[test]
600 fn test_streaming_result_construction() {
601 let result = StreamingResult {
602 total_audio_seconds: 5.0,
603 total_infer_seconds: 1.5,
604 chunk_count: 3,
605 };
606 assert!((result.total_audio_seconds - 5.0).abs() < 1e-9);
607 assert!((result.total_infer_seconds - 1.5).abs() < 1e-9);
608 assert_eq!(result.chunk_count, 3);
609 }
610
611 #[test]
612 fn test_streaming_result_clone() {
613 let result = StreamingResult {
614 total_audio_seconds: 2.0,
615 total_infer_seconds: 0.8,
616 chunk_count: 1,
617 };
618 let cloned = result.clone();
619 assert_eq!(cloned.chunk_count, result.chunk_count);
620 assert!((cloned.total_audio_seconds - result.total_audio_seconds).abs() < 1e-9);
621 }
622
623 #[test]
624 fn test_streaming_result_debug() {
625 let result = StreamingResult {
626 total_audio_seconds: 3.14,
627 total_infer_seconds: 1.0,
628 chunk_count: 2,
629 };
630 let debug = format!("{:?}", result);
631 assert!(debug.contains("total_audio_seconds"));
632 assert!(debug.contains("chunk_count"));
633 }
634
635 #[test]
640 fn test_audio_sink_object_safety() {
641 fn accept_sink(sink: &mut dyn AudioSink) -> Result<(), PiperError> {
643 sink.write_chunk(&[1, 2, 3], 22050)?;
644 sink.finalize()
645 }
646 let mut buffer = BufferSink::new();
647 accept_sink(&mut buffer).unwrap();
648 assert_eq!(buffer.get_samples(), &[1, 2, 3]);
649 }
650
651 #[cfg(feature = "onnx")]
656 #[test]
657 fn test_wav_file_sink_drop_finalizes() {
658 let dir = tempfile::tempdir().unwrap();
660 let wav_path = dir.path().join("drop_test.wav");
661
662 {
663 let mut sink = WavFileSink::new(&wav_path).unwrap();
664 let samples: Vec<i16> = vec![100, 200, 300, -100, -200];
665 sink.write_chunk(&samples, 22050).unwrap();
666 }
668
669 let reader = hound::WavReader::open(&wav_path).unwrap();
671 let spec = reader.spec();
672 assert_eq!(spec.channels, 1);
673 assert_eq!(spec.sample_rate, 22050);
674 assert_eq!(spec.bits_per_sample, 16);
675 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
676 assert_eq!(read_samples, vec![100, 200, 300, -100, -200]);
677 }
678
679 #[test]
680 fn test_wav_file_sink_sample_rate_mismatch_rejected() {
681 let dir = tempfile::tempdir().unwrap();
683 let wav_path = dir.path().join("rate_mismatch.wav");
684
685 let mut sink = WavFileSink::new(&wav_path).unwrap();
686 sink.write_chunk(&[10, 20], 16000).unwrap();
687 let err = sink.write_chunk(&[30, 40], 44100).unwrap_err();
688 let msg = err.to_string();
689 assert!(
690 msg.contains("sample rate mismatch"),
691 "expected sample rate mismatch error, got: {}",
692 msg
693 );
694 }
695
696 #[cfg(feature = "onnx")]
697 #[test]
698 fn test_wav_file_sink_same_sample_rate_ok() {
699 let dir = tempfile::tempdir().unwrap();
701 let wav_path = dir.path().join("same_rate.wav");
702
703 {
704 let mut sink = WavFileSink::new(&wav_path).unwrap();
705 sink.write_chunk(&[10, 20], 16000).unwrap();
706 sink.write_chunk(&[30, 40], 16000).unwrap();
707 sink.finalize().unwrap();
708 }
709
710 let reader = hound::WavReader::open(&wav_path).unwrap();
711 assert_eq!(reader.spec().sample_rate, 16000);
712 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
713 assert_eq!(read_samples, vec![10, 20, 30, 40]);
714 }
715
716 #[test]
717 fn test_wav_file_sink_overflow_rejected() {
718 let dir = tempfile::tempdir().unwrap();
722 let wav_path = dir.path().join("overflow.wav");
723
724 let mut sink = WavFileSink::new(&wav_path).unwrap();
725 sink.write_chunk(&[1], 22050).unwrap();
726 sink.total_samples = (u32::MAX as usize) / 2 + 2;
728 let err = sink.finalize().unwrap_err();
729 let msg = err.to_string();
730 assert!(
731 msg.contains("4GB"),
732 "expected 4GB limit error, got: {}",
733 msg
734 );
735 }
736
737 #[test]
742 fn test_crossfade_negative_samples() {
743 let prev = vec![-10000i16, -5000];
745 let next = vec![5000i16, 10000];
746 let result = crossfade(&prev, &next, 2);
747 assert_eq!(result.len(), 2);
748 assert_eq!(result[0], -10000);
750 assert_eq!(result[1], 10000);
752 }
753
754 #[test]
755 fn test_crossfade_max_i16_values() {
756 let prev = vec![i16::MAX, i16::MAX];
759 let next = vec![i16::MIN, i16::MIN];
760 let result = crossfade(&prev, &next, 2);
761 assert_eq!(result.len(), 2);
762 assert_eq!(result[0], i16::MAX);
764 assert_eq!(result[1], i16::MIN);
766 }
767
768 #[test]
773 fn test_split_sentences_consecutive_terminators() {
774 let result = split_sentences("Really?! Yes.");
779 assert_eq!(result.len(), 3);
780 assert_eq!(result[0], "Really?");
781 assert_eq!(result[1], "!");
782 assert_eq!(result[2], "Yes.");
783 }
784
785 #[test]
786 fn test_split_sentences_single_char_sentence() {
787 let result = split_sentences("A. B.");
789 assert_eq!(result.len(), 2);
790 assert_eq!(result[0], "A.");
791 assert_eq!(result[1], "B.");
792 }
793
794 #[test]
795 fn test_split_sentences_newline_separator() {
796 let result = split_sentences("Hello.\nWorld.");
798 assert_eq!(result.len(), 2);
799 assert_eq!(result[0], "Hello.");
800 assert_eq!(result[1], "World.");
801 }
802
803 #[test]
808 fn test_buffer_sink_large_chunks() {
809 let mut sink = BufferSink::new();
811 let chunk: Vec<i16> = (0..10_000).map(|i| (i % 1000) as i16).collect();
812 for _ in 0..100 {
813 sink.write_chunk(&chunk, 22050).unwrap();
814 }
815 sink.finalize().unwrap();
816 assert_eq!(sink.get_samples().len(), 1_000_000);
817 assert_eq!(sink.sample_rate(), Some(22050));
818 }
819}