1use std::io::{Seek, Write};
6use std::path::Path;
7
8use crate::error::PiperError;
9
10pub trait AudioSink {
19 fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError>;
21
22 fn finalize(&mut self) -> Result<(), PiperError>;
24}
25
26#[derive(Debug, Clone)]
32pub struct StreamingResult {
33 pub total_audio_seconds: f64,
35 pub total_infer_seconds: f64,
37 pub chunk_count: usize,
39}
40
41pub struct BufferSink {
47 samples: Vec<i16>,
48 sample_rate: Option<u32>,
49}
50
51impl BufferSink {
52 pub fn new() -> Self {
54 Self {
55 samples: Vec::new(),
56 sample_rate: None,
57 }
58 }
59
60 pub fn get_samples(&self) -> &[i16] {
62 &self.samples
63 }
64
65 pub fn sample_rate(&self) -> Option<u32> {
67 self.sample_rate
68 }
69}
70
71impl Default for BufferSink {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl AudioSink for BufferSink {
78 fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
79 self.sample_rate = Some(sample_rate);
80 self.samples.extend_from_slice(samples);
81 Ok(())
82 }
83
84 fn finalize(&mut self) -> Result<(), PiperError> {
85 Ok(())
86 }
87}
88
89pub struct WavFileSink {
99 file: std::fs::File,
100 sample_rate: u32,
101 total_samples: usize,
102 header_written: bool,
103}
104
105impl WavFileSink {
106 pub fn new(path: &Path) -> Result<Self, PiperError> {
111 let file = std::fs::File::create(path)?;
112 Ok(Self {
113 file,
114 sample_rate: 0,
115 total_samples: 0,
116 header_written: false,
117 })
118 }
119
120 fn write_header(&mut self, sample_rate: u32) -> Result<(), PiperError> {
122 let placeholder_data_size: u32 = 0;
123 let placeholder_file_size: u32 = 36; self.file.write_all(b"RIFF")?;
127 self.file.write_all(&placeholder_file_size.to_le_bytes())?;
128 self.file.write_all(b"WAVE")?;
129
130 self.file.write_all(b"fmt ")?;
132 self.file.write_all(&16u32.to_le_bytes())?; self.file.write_all(&1u16.to_le_bytes())?; self.file.write_all(&1u16.to_le_bytes())?; self.file.write_all(&sample_rate.to_le_bytes())?;
136 self.file.write_all(&(sample_rate * 2).to_le_bytes())?; self.file.write_all(&2u16.to_le_bytes())?; self.file.write_all(&16u16.to_le_bytes())?; self.file.write_all(b"data")?;
142 self.file.write_all(&placeholder_data_size.to_le_bytes())?;
143
144 self.sample_rate = sample_rate;
145 self.header_written = true;
146 Ok(())
147 }
148
149 fn update_sizes(&mut self) -> Result<(), PiperError> {
151 let data_size_u64 = (self.total_samples as u64) * 2;
152 if data_size_u64 > u32::MAX as u64 {
153 return Err(PiperError::Streaming(
154 "WAV file exceeds 4GB limit".to_string(),
155 ));
156 }
157 let data_size = data_size_u64 as u32;
158 let file_size = data_size + 36;
159
160 self.file.seek(std::io::SeekFrom::Start(4))?;
162 self.file.write_all(&file_size.to_le_bytes())?;
163
164 self.file.seek(std::io::SeekFrom::Start(40))?;
166 self.file.write_all(&data_size.to_le_bytes())?;
167
168 self.file.flush()?;
170 Ok(())
171 }
172}
173
174impl Drop for WavFileSink {
175 fn drop(&mut self) {
176 let _ = self.finalize();
179 }
180}
181
182impl AudioSink for WavFileSink {
183 fn write_chunk(&mut self, samples: &[i16], sample_rate: u32) -> Result<(), PiperError> {
184 if !self.header_written {
185 self.write_header(sample_rate)?;
186 }
187
188 if self.sample_rate != sample_rate {
190 return Err(PiperError::Streaming(format!(
191 "sample rate mismatch: expected {}, got {}",
192 self.sample_rate, sample_rate
193 )));
194 }
195
196 let mut buf = Vec::with_capacity(samples.len() * 2);
198 for &sample in samples {
199 buf.extend_from_slice(&sample.to_le_bytes());
200 }
201 self.file.write_all(&buf)?;
202 self.total_samples += samples.len();
203 Ok(())
204 }
205
206 fn finalize(&mut self) -> Result<(), PiperError> {
207 if self.header_written {
208 self.update_sizes()?;
209 }
210 Ok(())
211 }
212}
213
214pub fn crossfade(prev_tail: &[i16], next_head: &[i16], overlap_samples: usize) -> Vec<i16> {
228 let actual_overlap = overlap_samples.min(prev_tail.len()).min(next_head.len());
229
230 if actual_overlap == 0 {
231 return Vec::new();
232 }
233
234 let mut blended = Vec::with_capacity(actual_overlap);
235 for i in 0..actual_overlap {
236 let alpha = if actual_overlap <= 1 {
238 1.0
239 } else {
240 (i as f64) / ((actual_overlap - 1) as f64)
241 };
242 let prev_sample = prev_tail[prev_tail.len() - actual_overlap + i] as f64;
243 let next_sample = next_head[i] as f64;
244 let mixed = prev_sample * (1.0 - alpha) + next_sample * alpha;
245 blended.push(mixed.clamp(-32768.0, 32767.0) as i16);
246 }
247 blended
248}
249
250pub fn split_sentences(text: &str) -> Vec<String> {
263 if text.is_empty() {
264 return Vec::new();
265 }
266
267 let mut sentences = Vec::new();
268 let mut current = String::new();
269
270 let mut chars = text.chars().peekable();
271
272 while let Some(ch) = chars.next() {
273 current.push(ch);
274
275 if is_sentence_terminator(ch) {
277 while let Some(&next_ch) = chars.peek() {
280 if is_closing_punctuation(next_ch) {
281 current.push(chars.next().unwrap());
282 } else {
283 break;
284 }
285 }
286
287 let trimmed = current.trim().to_string();
289 if !trimmed.is_empty() {
290 sentences.push(trimmed);
291 }
292 current.clear();
293
294 while let Some(&next_ch) = chars.peek() {
296 if next_ch.is_whitespace() {
297 chars.next();
298 } else {
299 break;
300 }
301 }
302 }
303 }
304
305 let trimmed = current.trim().to_string();
307 if !trimmed.is_empty() {
308 sentences.push(trimmed);
309 }
310
311 sentences
312}
313
314fn is_sentence_terminator(ch: char) -> bool {
316 matches!(
317 ch,
318 '.' | '!' | '?' | '\u{3002}' | '\u{FF01}' | '\u{FF1F}' )
322}
323
324fn is_closing_punctuation(ch: char) -> bool {
327 matches!(
328 ch,
329 ')' | ']'
330 | '}'
331 | '"'
332 | '\''
333 | '\u{300D}' | '\u{300F}' | '\u{FF09}' | '\u{FF3D}' | '\u{3011}' | '\u{FF63}' )
340}
341
342#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
355 fn test_buffer_sink_collects_samples() {
356 let mut sink = BufferSink::new();
357 sink.write_chunk(&[1, 2, 3], 22050).unwrap();
358 sink.write_chunk(&[4, 5], 22050).unwrap();
359 sink.finalize().unwrap();
360 assert_eq!(sink.get_samples(), &[1, 2, 3, 4, 5]);
361 }
362
363 #[test]
364 fn test_buffer_sink_empty() {
365 let mut sink = BufferSink::new();
366 sink.finalize().unwrap();
367 assert!(sink.get_samples().is_empty());
368 assert_eq!(sink.sample_rate(), None);
369 }
370
371 #[test]
372 fn test_buffer_sink_sample_rate() {
373 let mut sink = BufferSink::new();
374 assert_eq!(sink.sample_rate(), None);
375 sink.write_chunk(&[100], 44100).unwrap();
376 assert_eq!(sink.sample_rate(), Some(44100));
377 }
378
379 #[test]
380 fn test_buffer_sink_default() {
381 let sink = BufferSink::default();
382 assert!(sink.get_samples().is_empty());
383 }
384
385 #[test]
390 fn test_wav_file_sink_writes_valid_wav() {
391 let dir = tempfile::tempdir().unwrap();
392 let wav_path = dir.path().join("test.wav");
393
394 {
395 let mut sink = WavFileSink::new(&wav_path).unwrap();
396 let samples: Vec<i16> = (0..100).collect();
397 sink.write_chunk(&samples, 22050).unwrap();
398 sink.finalize().unwrap();
399 }
400
401 let reader = hound::WavReader::open(&wav_path).unwrap();
403 let spec = reader.spec();
404 assert_eq!(spec.channels, 1);
405 assert_eq!(spec.sample_rate, 22050);
406 assert_eq!(spec.bits_per_sample, 16);
407 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
408 let expected: Vec<i16> = (0..100).collect();
409 assert_eq!(read_samples, expected);
410 }
411
412 #[test]
413 fn test_wav_file_sink_multiple_chunks() {
414 let dir = tempfile::tempdir().unwrap();
415 let wav_path = dir.path().join("multi.wav");
416
417 {
418 let mut sink = WavFileSink::new(&wav_path).unwrap();
419 sink.write_chunk(&[10, 20, 30], 16000).unwrap();
420 sink.write_chunk(&[40, 50], 16000).unwrap();
421 sink.write_chunk(&[60], 16000).unwrap();
422 sink.finalize().unwrap();
423 }
424
425 let reader = hound::WavReader::open(&wav_path).unwrap();
426 assert_eq!(reader.spec().sample_rate, 16000);
427 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
428 assert_eq!(read_samples, vec![10, 20, 30, 40, 50, 60]);
429 }
430
431 #[test]
432 fn test_wav_file_sink_finalize_without_write() {
433 let dir = tempfile::tempdir().unwrap();
434 let wav_path = dir.path().join("empty.wav");
435
436 let mut sink = WavFileSink::new(&wav_path).unwrap();
437 sink.finalize().unwrap();
439 }
440
441 #[test]
446 fn test_crossfade_basic() {
447 let prev = vec![1000i16; 10];
449 let next = vec![0i16; 10];
450 let result = crossfade(&prev, &next, 4);
451 assert_eq!(result.len(), 4);
452 assert_eq!(result[0], 1000);
454 assert_eq!(result[3], 0);
456 }
457
458 #[test]
459 fn test_crossfade_equal_blend() {
460 let prev = vec![100i16; 4];
461 let next = vec![200i16; 4];
462 let result = crossfade(&prev, &next, 4);
463 assert_eq!(result.len(), 4);
464 assert_eq!(result[0], 100);
466 assert_eq!(result[2], 166);
468 }
469
470 #[test]
471 fn test_crossfade_zero_overlap() {
472 let prev = vec![100i16; 5];
473 let next = vec![200i16; 5];
474 let result = crossfade(&prev, &next, 0);
475 assert!(result.is_empty());
476 }
477
478 #[test]
479 fn test_crossfade_overlap_exceeds_prev() {
480 let prev = vec![500i16; 3];
481 let next = vec![0i16; 10];
482 let result = crossfade(&prev, &next, 100);
483 assert_eq!(result.len(), 3);
485 }
486
487 #[test]
488 fn test_crossfade_overlap_exceeds_next() {
489 let prev = vec![500i16; 10];
490 let next = vec![0i16; 2];
491 let result = crossfade(&prev, &next, 100);
492 assert_eq!(result.len(), 2);
494 }
495
496 #[test]
497 fn test_crossfade_empty_slices() {
498 let result = crossfade(&[], &[], 10);
499 assert!(result.is_empty());
500 }
501
502 #[test]
503 fn test_crossfade_one_sample() {
504 let prev = vec![1000i16];
505 let next = vec![0i16];
506 let result = crossfade(&prev, &next, 1);
507 assert_eq!(result.len(), 1);
508 assert_eq!(result[0], 0);
510 }
511
512 #[test]
517 fn test_split_sentences_japanese() {
518 let text = "こんにちは。今日は良い天気ですね。明日も晴れるでしょう。";
519 let result = split_sentences(text);
520 assert_eq!(result.len(), 3);
521 assert_eq!(result[0], "こんにちは。");
522 assert_eq!(result[1], "今日は良い天気ですね。");
523 assert_eq!(result[2], "明日も晴れるでしょう。");
524 }
525
526 #[test]
527 fn test_split_sentences_english() {
528 let text = "Hello world. How are you? I am fine!";
529 let result = split_sentences(text);
530 assert_eq!(result.len(), 3);
531 assert_eq!(result[0], "Hello world.");
532 assert_eq!(result[1], "How are you?");
533 assert_eq!(result[2], "I am fine!");
534 }
535
536 #[test]
537 fn test_split_sentences_mixed_punctuation() {
538 let text = "日本語のテスト。English test! 混合テスト?";
539 let result = split_sentences(text);
540 assert_eq!(result.len(), 3);
541 assert_eq!(result[0], "日本語のテスト。");
542 assert_eq!(result[1], "English test!");
543 assert_eq!(result[2], "混合テスト?");
544 }
545
546 #[test]
547 fn test_split_sentences_fullwidth_punctuation() {
548 let text = "すごい!本当ですか?はい。";
549 let result = split_sentences(text);
550 assert_eq!(result.len(), 3);
551 assert_eq!(result[0], "すごい!");
552 assert_eq!(result[1], "本当ですか?");
553 assert_eq!(result[2], "はい。");
554 }
555
556 #[test]
557 fn test_split_sentences_empty() {
558 let result = split_sentences("");
559 assert!(result.is_empty());
560 }
561
562 #[test]
563 fn test_split_sentences_no_terminator() {
564 let text = "This has no ending punctuation";
565 let result = split_sentences(text);
566 assert_eq!(result.len(), 1);
567 assert_eq!(result[0], "This has no ending punctuation");
568 }
569
570 #[test]
571 fn test_split_sentences_whitespace_only() {
572 let result = split_sentences(" ");
573 assert!(result.is_empty());
574 }
575
576 #[test]
577 fn test_split_sentences_with_closing_brackets() {
578 let text = "「こんにちは。」次の文。";
579 let result = split_sentences(text);
580 assert_eq!(result.len(), 2);
581 assert_eq!(result[0], "「こんにちは。」");
582 assert_eq!(result[1], "次の文。");
583 }
584
585 #[test]
586 fn test_split_sentences_single_sentence() {
587 let text = "一つだけ。";
588 let result = split_sentences(text);
589 assert_eq!(result.len(), 1);
590 assert_eq!(result[0], "一つだけ。");
591 }
592
593 #[test]
598 fn test_streaming_result_construction() {
599 let result = StreamingResult {
600 total_audio_seconds: 5.0,
601 total_infer_seconds: 1.5,
602 chunk_count: 3,
603 };
604 assert!((result.total_audio_seconds - 5.0).abs() < 1e-9);
605 assert!((result.total_infer_seconds - 1.5).abs() < 1e-9);
606 assert_eq!(result.chunk_count, 3);
607 }
608
609 #[test]
610 fn test_streaming_result_clone() {
611 let result = StreamingResult {
612 total_audio_seconds: 2.0,
613 total_infer_seconds: 0.8,
614 chunk_count: 1,
615 };
616 let cloned = result.clone();
617 assert_eq!(cloned.chunk_count, result.chunk_count);
618 assert!((cloned.total_audio_seconds - result.total_audio_seconds).abs() < 1e-9);
619 }
620
621 #[test]
622 fn test_streaming_result_debug() {
623 let result = StreamingResult {
624 total_audio_seconds: 3.14,
625 total_infer_seconds: 1.0,
626 chunk_count: 2,
627 };
628 let debug = format!("{:?}", result);
629 assert!(debug.contains("total_audio_seconds"));
630 assert!(debug.contains("chunk_count"));
631 }
632
633 #[test]
638 fn test_audio_sink_object_safety() {
639 fn accept_sink(sink: &mut dyn AudioSink) -> Result<(), PiperError> {
641 sink.write_chunk(&[1, 2, 3], 22050)?;
642 sink.finalize()
643 }
644 let mut buffer = BufferSink::new();
645 accept_sink(&mut buffer).unwrap();
646 assert_eq!(buffer.get_samples(), &[1, 2, 3]);
647 }
648
649 #[test]
654 fn test_wav_file_sink_drop_finalizes() {
655 let dir = tempfile::tempdir().unwrap();
657 let wav_path = dir.path().join("drop_test.wav");
658
659 {
660 let mut sink = WavFileSink::new(&wav_path).unwrap();
661 let samples: Vec<i16> = vec![100, 200, 300, -100, -200];
662 sink.write_chunk(&samples, 22050).unwrap();
663 }
665
666 let reader = hound::WavReader::open(&wav_path).unwrap();
668 let spec = reader.spec();
669 assert_eq!(spec.channels, 1);
670 assert_eq!(spec.sample_rate, 22050);
671 assert_eq!(spec.bits_per_sample, 16);
672 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
673 assert_eq!(read_samples, vec![100, 200, 300, -100, -200]);
674 }
675
676 #[test]
677 fn test_wav_file_sink_sample_rate_mismatch_rejected() {
678 let dir = tempfile::tempdir().unwrap();
680 let wav_path = dir.path().join("rate_mismatch.wav");
681
682 let mut sink = WavFileSink::new(&wav_path).unwrap();
683 sink.write_chunk(&[10, 20], 16000).unwrap();
684 let err = sink.write_chunk(&[30, 40], 44100).unwrap_err();
685 let msg = err.to_string();
686 assert!(
687 msg.contains("sample rate mismatch"),
688 "expected sample rate mismatch error, got: {}",
689 msg
690 );
691 }
692
693 #[test]
694 fn test_wav_file_sink_same_sample_rate_ok() {
695 let dir = tempfile::tempdir().unwrap();
697 let wav_path = dir.path().join("same_rate.wav");
698
699 {
700 let mut sink = WavFileSink::new(&wav_path).unwrap();
701 sink.write_chunk(&[10, 20], 16000).unwrap();
702 sink.write_chunk(&[30, 40], 16000).unwrap();
703 sink.finalize().unwrap();
704 }
705
706 let reader = hound::WavReader::open(&wav_path).unwrap();
707 assert_eq!(reader.spec().sample_rate, 16000);
708 let read_samples: Vec<i16> = reader.into_samples::<i16>().map(|s| s.unwrap()).collect();
709 assert_eq!(read_samples, vec![10, 20, 30, 40]);
710 }
711
712 #[test]
713 fn test_wav_file_sink_overflow_rejected() {
714 let dir = tempfile::tempdir().unwrap();
718 let wav_path = dir.path().join("overflow.wav");
719
720 let mut sink = WavFileSink::new(&wav_path).unwrap();
721 sink.write_chunk(&[1], 22050).unwrap();
722 sink.total_samples = (u32::MAX as usize) / 2 + 2;
724 let err = sink.finalize().unwrap_err();
725 let msg = err.to_string();
726 assert!(
727 msg.contains("4GB"),
728 "expected 4GB limit error, got: {}",
729 msg
730 );
731 }
732
733 #[test]
738 fn test_crossfade_negative_samples() {
739 let prev = vec![-10000i16, -5000];
741 let next = vec![5000i16, 10000];
742 let result = crossfade(&prev, &next, 2);
743 assert_eq!(result.len(), 2);
744 assert_eq!(result[0], -10000);
746 assert_eq!(result[1], 10000);
748 }
749
750 #[test]
751 fn test_crossfade_max_i16_values() {
752 let prev = vec![i16::MAX, i16::MAX];
755 let next = vec![i16::MIN, i16::MIN];
756 let result = crossfade(&prev, &next, 2);
757 assert_eq!(result.len(), 2);
758 assert_eq!(result[0], i16::MAX);
760 assert_eq!(result[1], i16::MIN);
762 }
763
764 #[test]
769 fn test_split_sentences_consecutive_terminators() {
770 let result = split_sentences("Really?! Yes.");
775 assert_eq!(result.len(), 3);
776 assert_eq!(result[0], "Really?");
777 assert_eq!(result[1], "!");
778 assert_eq!(result[2], "Yes.");
779 }
780
781 #[test]
782 fn test_split_sentences_single_char_sentence() {
783 let result = split_sentences("A. B.");
785 assert_eq!(result.len(), 2);
786 assert_eq!(result[0], "A.");
787 assert_eq!(result[1], "B.");
788 }
789
790 #[test]
791 fn test_split_sentences_newline_separator() {
792 let result = split_sentences("Hello.\nWorld.");
794 assert_eq!(result.len(), 2);
795 assert_eq!(result[0], "Hello.");
796 assert_eq!(result[1], "World.");
797 }
798
799 #[test]
804 fn test_buffer_sink_large_chunks() {
805 let mut sink = BufferSink::new();
807 let chunk: Vec<i16> = (0..10_000).map(|i| (i % 1000) as i16).collect();
808 for _ in 0..100 {
809 sink.write_chunk(&chunk, 22050).unwrap();
810 }
811 sink.finalize().unwrap();
812 assert_eq!(sink.get_samples().len(), 1_000_000);
813 assert_eq!(sink.sample_rate(), Some(22050));
814 }
815}