Skip to main content

piper_plus/
wasm.rs

1//! WASM-compatible synthesis API
2//!
3//! Provides an API that works without filesystem access, suitable for
4//! WebAssembly (wasm32) targets. Model and config data are passed as byte slices
5//! instead of file paths.
6//!
7//! This module is available on all platforms but is designed primarily for WASM.
8//! On native platforms, prefer `PiperVoice::load()` for convenience.
9
10use std::borrow::Cow;
11use std::time::Instant;
12
13use ort::session::Session;
14use ort::value::Tensor;
15
16use crate::audio::audio_float_to_int16;
17use crate::config::VoiceConfig;
18use crate::error::PiperError;
19
20/// WASM-friendly synthesis result (no file I/O)
21#[derive(Debug, Clone)]
22pub struct WasmSynthesisResult {
23    /// Raw PCM audio samples (16-bit signed, mono)
24    pub audio_samples: Vec<i16>,
25    /// Audio sample rate (e.g., 22050)
26    pub sample_rate: u32,
27    /// Inference time in seconds
28    pub infer_seconds: f64,
29    /// Audio duration in seconds
30    pub audio_seconds: f64,
31}
32
33impl WasmSynthesisResult {
34    /// Real-time factor (infer_seconds / audio_seconds).
35    /// Below 1.0 means faster than real-time.
36    pub fn real_time_factor(&self) -> f64 {
37        if self.audio_seconds > 0.0 {
38            self.infer_seconds / self.audio_seconds
39        } else {
40            0.0
41        }
42    }
43}
44
45/// Model capabilities detected from ONNX input/output node names.
46#[derive(Debug, Clone)]
47pub struct WasmModelCapabilities {
48    pub has_sid: bool,
49    pub has_lid: bool,
50    pub has_prosody: bool,
51    pub has_duration_output: bool,
52}
53
54/// WASM-compatible voice synthesizer.
55/// Loads model from bytes rather than file paths.
56#[derive(Debug)]
57pub struct WasmVoice {
58    config: VoiceConfig,
59    session: Session,
60    capabilities: WasmModelCapabilities,
61}
62
63impl WasmVoice {
64    /// Load from in-memory model and config data.
65    ///
66    /// # Arguments
67    /// * `model_bytes` - ONNX model file contents
68    /// * `config_json` - config.json file contents as string
69    pub fn load_from_bytes(model_bytes: &[u8], config_json: &str) -> Result<Self, PiperError> {
70        let config: VoiceConfig = parse_config(config_json)?;
71
72        let session = Session::builder()
73            .map_err(|e| PiperError::ModelLoad(e.to_string()))?
74            .commit_from_memory(model_bytes)
75            .map_err(|e| PiperError::ModelLoad(e.to_string()))?;
76
77        // Detect capabilities from ONNX input/output node names
78        let input_names: Vec<String> = session
79            .inputs()
80            .iter()
81            .map(|i| i.name().to_string())
82            .collect();
83        let output_names: Vec<String> = session
84            .outputs()
85            .iter()
86            .map(|o| o.name().to_string())
87            .collect();
88
89        let has_input = |name: &str| input_names.iter().any(|n| n == name);
90        let has_output = |name: &str| output_names.iter().any(|n| n == name);
91
92        let capabilities = WasmModelCapabilities {
93            has_sid: has_input("sid"),
94            has_lid: has_input("lid"),
95            has_prosody: has_input("prosody_features"),
96            has_duration_output: has_output("durations"),
97        };
98
99        tracing::info!(
100            "WasmVoice loaded: inputs={:?}, outputs={:?}",
101            input_names,
102            output_names,
103        );
104        tracing::info!(
105            "Capabilities: sid={}, lid={}, prosody={}, durations={}",
106            capabilities.has_sid,
107            capabilities.has_lid,
108            capabilities.has_prosody,
109            capabilities.has_duration_output,
110        );
111
112        Ok(Self {
113            config,
114            session,
115            capabilities,
116        })
117    }
118
119    /// Synthesize from pre-computed phoneme IDs (no G2P needed).
120    /// This is the primary API for WASM since G2P may not be available.
121    pub fn synthesize_ids(
122        &mut self,
123        phoneme_ids: &[i64],
124        speaker_id: Option<i64>,
125        language_id: Option<i64>,
126        noise_scale: f32,
127        length_scale: f32,
128        noise_w: f32,
129    ) -> Result<WasmSynthesisResult, PiperError> {
130        let phoneme_len = phoneme_ids.len();
131        if phoneme_len == 0 {
132            return Err(PiperError::Inference("empty phoneme_ids".to_string()));
133        }
134
135        // --- Build input tensors ---
136
137        // 1. input: int64 [1, phoneme_len]
138        let input_tensor = Tensor::from_array((
139            [1_usize, phoneme_len],
140            phoneme_ids.to_vec().into_boxed_slice(),
141        ))
142        .map_err(|e| PiperError::Inference(format!("input tensor: {e}")))?;
143
144        // 2. input_lengths: int64 [1]
145        let lengths_tensor =
146            Tensor::from_array(([1_usize], vec![phoneme_len as i64].into_boxed_slice()))
147                .map_err(|e| PiperError::Inference(format!("input_lengths tensor: {e}")))?;
148
149        // 3. scales: float32 [3]
150        let scales_tensor = Tensor::from_array((
151            [3_usize],
152            vec![noise_scale, length_scale, noise_w].into_boxed_slice(),
153        ))
154        .map_err(|e| PiperError::Inference(format!("scales tensor: {e}")))?;
155
156        // 4. sid: int64 [1] (conditional)
157        let sid_val = speaker_id.unwrap_or(0);
158        let sid_tensor = if self.capabilities.has_sid {
159            Some(
160                Tensor::from_array(([1_usize], vec![sid_val].into_boxed_slice()))
161                    .map_err(|e| PiperError::Inference(format!("sid tensor: {e}")))?,
162            )
163        } else {
164            None
165        };
166
167        // 5. lid: int64 [1] (conditional)
168        let lid_val = language_id.unwrap_or(0);
169        let lid_tensor = if self.capabilities.has_lid {
170            Some(
171                Tensor::from_array(([1_usize], vec![lid_val].into_boxed_slice()))
172                    .map_err(|e| PiperError::Inference(format!("lid tensor: {e}")))?,
173            )
174        } else {
175            None
176        };
177
178        // 6. prosody_features: int64 [1, phoneme_len, 3] (conditional, zero-filled)
179        let prosody_tensor = if self.capabilities.has_prosody {
180            let flat = vec![0i64; phoneme_len * 3];
181            Some(
182                Tensor::from_array(([1_usize, phoneme_len, 3], flat.into_boxed_slice()))
183                    .map_err(|e| PiperError::Inference(format!("prosody tensor: {e}")))?,
184            )
185        } else {
186            None
187        };
188
189        // Build input map
190        let mut inputs: Vec<(Cow<str>, ort::session::SessionInputValue<'_>)> =
191            Vec::with_capacity(6);
192
193        inputs.push(("input".into(), (&input_tensor).into()));
194        inputs.push(("input_lengths".into(), (&lengths_tensor).into()));
195        inputs.push(("scales".into(), (&scales_tensor).into()));
196
197        if let Some(ref t) = sid_tensor {
198            inputs.push(("sid".into(), t.into()));
199        }
200        if let Some(ref t) = lid_tensor {
201            inputs.push(("lid".into(), t.into()));
202        }
203        if let Some(ref t) = prosody_tensor {
204            inputs.push(("prosody_features".into(), t.into()));
205        }
206
207        // --- Run inference ---
208        let start = Instant::now();
209
210        let outputs = self
211            .session
212            .run(inputs)
213            .map_err(|e| PiperError::Inference(e.to_string()))?;
214
215        let infer_seconds = start.elapsed().as_secs_f64();
216
217        // --- Extract output ---
218        // output: float32 [1, 1, audio_samples]
219        let (_shape, audio_slice) = outputs["output"]
220            .try_extract_tensor::<f32>()
221            .map_err(|e| PiperError::Inference(format!("extract output: {e}")))?;
222
223        let audio_f32: Vec<f32> = audio_slice.to_vec();
224
225        // float32 -> int16 peak normalization
226        let audio_i16 = audio_float_to_int16(&audio_f32);
227        let sample_rate = self.config.audio.sample_rate;
228        let audio_seconds = audio_i16.len() as f64 / sample_rate as f64;
229
230        Ok(WasmSynthesisResult {
231            audio_samples: audio_i16,
232            sample_rate,
233            infer_seconds,
234            audio_seconds,
235        })
236    }
237
238    /// Get the loaded config
239    pub fn config(&self) -> &VoiceConfig {
240        &self.config
241    }
242
243    /// Whether the model accepts a speaker ID input
244    pub fn has_speaker_id(&self) -> bool {
245        self.capabilities.has_sid
246    }
247
248    /// Whether the model accepts a language ID input
249    pub fn has_language_id(&self) -> bool {
250        self.capabilities.has_lid
251    }
252
253    /// Whether the model accepts prosody features input
254    pub fn has_prosody(&self) -> bool {
255        self.capabilities.has_prosody
256    }
257
258    /// Get model capabilities
259    pub fn capabilities(&self) -> &WasmModelCapabilities {
260        &self.capabilities
261    }
262}
263
264/// Convert i16 PCM samples to f32 normalized audio (-1.0 to 1.0)
265pub fn samples_i16_to_f32(samples: &[i16]) -> Vec<f32> {
266    samples.iter().map(|&s| s as f32 / 32768.0).collect()
267}
268
269/// Convert f32 audio to WAV bytes (in-memory, no filesystem)
270///
271/// Writes a complete WAV file (RIFF header + fmt chunk + data chunk) into a `Vec<u8>`.
272/// The format is 16-bit signed PCM, mono, at the given sample rate.
273/// Useful for creating a downloadable Blob in WASM environments.
274pub fn samples_to_wav_bytes(samples: &[i16], sample_rate: u32) -> Vec<u8> {
275    let data_size = (samples.len() * 2) as u32;
276    let file_size = data_size + 36;
277
278    // Total WAV size: 44-byte header + data
279    let total_size = 44 + samples.len() * 2;
280    let mut buf = Vec::with_capacity(total_size);
281
282    // RIFF header (12 bytes)
283    buf.extend_from_slice(b"RIFF");
284    buf.extend_from_slice(&file_size.to_le_bytes());
285    buf.extend_from_slice(b"WAVE");
286
287    // fmt chunk (24 bytes)
288    buf.extend_from_slice(b"fmt ");
289    buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
290    buf.extend_from_slice(&1u16.to_le_bytes()); // PCM format
291    buf.extend_from_slice(&1u16.to_le_bytes()); // mono
292    buf.extend_from_slice(&sample_rate.to_le_bytes()); // sample rate
293    buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); // byte rate (sample_rate * channels * bytes_per_sample)
294    buf.extend_from_slice(&2u16.to_le_bytes()); // block align (channels * bytes_per_sample)
295    buf.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
296
297    // data chunk (8 bytes header + sample data)
298    buf.extend_from_slice(b"data");
299    buf.extend_from_slice(&data_size.to_le_bytes());
300    buf.extend_from_slice(
301        &samples
302            .iter()
303            .flat_map(|s| s.to_le_bytes())
304            .collect::<Vec<u8>>(),
305    );
306
307    buf
308}
309
310/// Parse config JSON string into VoiceConfig
311pub fn parse_config(config_json: &str) -> Result<VoiceConfig, PiperError> {
312    let config: VoiceConfig = serde_json::from_str(config_json)?;
313    Ok(config)
314}
315
316// ---------------------------------------------------------------------------
317// Tests
318// ---------------------------------------------------------------------------
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    // -----------------------------------------------------------------------
325    // 1. parse_config: valid JSON
326    // -----------------------------------------------------------------------
327    #[test]
328    fn test_parse_config_valid_minimal() {
329        let json = r#"{"phoneme_id_map": {"a": [1]}, "audio": {"sample_rate": 22050}}"#;
330        let config = parse_config(json).unwrap();
331        assert_eq!(config.audio.sample_rate, 22050);
332        assert_eq!(config.num_speakers, 1);
333        assert_eq!(config.num_languages, 1);
334        assert!(!config.is_multilingual());
335    }
336
337    #[test]
338    fn test_parse_config_valid_multilingual() {
339        let json = r#"{
340            "num_speakers": 571,
341            "num_languages": 6,
342            "phoneme_type": "multilingual",
343            "phoneme_id_map": {"^": [1], "_": [0]},
344            "language_id_map": {"ja": 0, "en": 1, "zh": 2, "es": 3, "fr": 4, "pt": 5},
345            "audio": {"sample_rate": 22050}
346        }"#;
347        let config = parse_config(json).unwrap();
348        assert_eq!(config.num_speakers, 571);
349        assert_eq!(config.num_languages, 6);
350        assert!(config.is_multilingual());
351        assert!(config.needs_lid());
352        assert_eq!(config.language_id_map.len(), 6);
353        assert_eq!(config.language_id_map.get("ja"), Some(&0));
354        assert_eq!(config.language_id_map.get("pt"), Some(&5));
355    }
356
357    #[test]
358    fn test_parse_config_valid_defaults() {
359        // Empty JSON object should use all defaults
360        let json = r#"{}"#;
361        let config = parse_config(json).unwrap();
362        assert_eq!(config.audio.sample_rate, 22050);
363        assert_eq!(config.num_speakers, 1);
364        assert_eq!(config.num_languages, 1);
365        assert!(config.phoneme_id_map.is_empty());
366    }
367
368    // -----------------------------------------------------------------------
369    // 2. parse_config: invalid JSON
370    // -----------------------------------------------------------------------
371    #[test]
372    fn test_parse_config_invalid_json() {
373        let json = r#"{ not valid json }"#;
374        let result = parse_config(json);
375        assert!(result.is_err());
376        match result.unwrap_err() {
377            PiperError::JsonParse(_) => {} // expected
378            other => panic!("expected JsonParse, got: {other:?}"),
379        }
380    }
381
382    #[test]
383    fn test_parse_config_empty_string() {
384        let result = parse_config("");
385        assert!(result.is_err());
386        match result.unwrap_err() {
387            PiperError::JsonParse(_) => {} // expected
388            other => panic!("expected JsonParse, got: {other:?}"),
389        }
390    }
391
392    #[test]
393    fn test_parse_config_wrong_type() {
394        // num_speakers as string instead of number
395        let json = r#"{"num_speakers": "not_a_number"}"#;
396        let result = parse_config(json);
397        assert!(result.is_err());
398        match result.unwrap_err() {
399            PiperError::JsonParse(_) => {} // expected
400            other => panic!("expected JsonParse, got: {other:?}"),
401        }
402    }
403
404    // -----------------------------------------------------------------------
405    // 3. samples_i16_to_f32: conversion accuracy
406    // -----------------------------------------------------------------------
407    #[test]
408    fn test_samples_i16_to_f32_basic() {
409        let samples: Vec<i16> = vec![0, 32767, -32768, 16384, -16384];
410        let result = samples_i16_to_f32(&samples);
411        assert_eq!(result.len(), 5);
412        // 0 -> 0.0
413        assert!((result[0] - 0.0).abs() < 1e-6);
414        // 32767 -> 32767/32768 ~ 0.999969
415        assert!((result[1] - 32767.0 / 32768.0).abs() < 1e-4);
416        // -32768 -> -32768/32768 = -1.0
417        assert!((result[2] - (-1.0)).abs() < 1e-6);
418        // 16384 -> 0.5
419        assert!((result[3] - 0.5).abs() < 1e-4);
420        // -16384 -> -0.5
421        assert!((result[4] - (-0.5)).abs() < 1e-4);
422    }
423
424    #[test]
425    fn test_samples_i16_to_f32_empty() {
426        let result = samples_i16_to_f32(&[]);
427        assert!(result.is_empty());
428    }
429
430    #[test]
431    fn test_samples_i16_to_f32_silence() {
432        let samples = vec![0i16; 100];
433        let result = samples_i16_to_f32(&samples);
434        assert_eq!(result.len(), 100);
435        assert!(result.iter().all(|&x| x == 0.0));
436    }
437
438    // -----------------------------------------------------------------------
439    // 4. samples_to_wav_bytes: format validation
440    // -----------------------------------------------------------------------
441    #[test]
442    fn test_wav_bytes_riff_header() {
443        let samples = vec![0i16; 10];
444        let wav = samples_to_wav_bytes(&samples, 22050);
445
446        // Check RIFF magic
447        assert_eq!(&wav[0..4], b"RIFF");
448
449        // Check file size field (total - 8)
450        let file_size = u32::from_le_bytes([wav[4], wav[5], wav[6], wav[7]]);
451        assert_eq!(file_size, (wav.len() - 8) as u32);
452
453        // Check WAVE magic
454        assert_eq!(&wav[8..12], b"WAVE");
455    }
456
457    #[test]
458    fn test_wav_bytes_fmt_chunk() {
459        let samples = vec![100i16, -100, 200, -200];
460        let wav = samples_to_wav_bytes(&samples, 44100);
461
462        // fmt chunk starts at offset 12
463        assert_eq!(&wav[12..16], b"fmt ");
464
465        // fmt chunk size = 16
466        let fmt_size = u32::from_le_bytes([wav[16], wav[17], wav[18], wav[19]]);
467        assert_eq!(fmt_size, 16);
468
469        // Audio format: PCM = 1
470        let audio_format = u16::from_le_bytes([wav[20], wav[21]]);
471        assert_eq!(audio_format, 1);
472
473        // Channels: mono = 1
474        let channels = u16::from_le_bytes([wav[22], wav[23]]);
475        assert_eq!(channels, 1);
476
477        // Sample rate
478        let sample_rate = u32::from_le_bytes([wav[24], wav[25], wav[26], wav[27]]);
479        assert_eq!(sample_rate, 44100);
480
481        // Byte rate = sample_rate * channels * bytes_per_sample
482        let byte_rate = u32::from_le_bytes([wav[28], wav[29], wav[30], wav[31]]);
483        assert_eq!(byte_rate, 44100 * 2);
484
485        // Block align = channels * bytes_per_sample
486        let block_align = u16::from_le_bytes([wav[32], wav[33]]);
487        assert_eq!(block_align, 2);
488
489        // Bits per sample
490        let bits_per_sample = u16::from_le_bytes([wav[34], wav[35]]);
491        assert_eq!(bits_per_sample, 16);
492    }
493
494    #[test]
495    fn test_wav_bytes_data_chunk() {
496        let samples: Vec<i16> = vec![1000, -2000, 3000];
497        let wav = samples_to_wav_bytes(&samples, 22050);
498
499        // data chunk starts at offset 36
500        assert_eq!(&wav[36..40], b"data");
501
502        // data size = samples.len() * 2
503        let data_size = u32::from_le_bytes([wav[40], wav[41], wav[42], wav[43]]);
504        assert_eq!(data_size, 6); // 3 samples * 2 bytes each
505
506        // Verify sample data (little-endian i16)
507        let s0 = i16::from_le_bytes([wav[44], wav[45]]);
508        let s1 = i16::from_le_bytes([wav[46], wav[47]]);
509        let s2 = i16::from_le_bytes([wav[48], wav[49]]);
510        assert_eq!(s0, 1000);
511        assert_eq!(s1, -2000);
512        assert_eq!(s2, 3000);
513    }
514
515    #[test]
516    fn test_wav_bytes_total_length() {
517        let samples = vec![0i16; 100];
518        let wav = samples_to_wav_bytes(&samples, 22050);
519        // Total = 44 header bytes + 100 samples * 2 bytes = 244
520        assert_eq!(wav.len(), 244);
521    }
522
523    #[test]
524    fn test_wav_bytes_empty_samples() {
525        let wav = samples_to_wav_bytes(&[], 22050);
526        // Total = 44 header bytes + 0 data bytes
527        assert_eq!(wav.len(), 44);
528
529        // RIFF header still valid
530        assert_eq!(&wav[0..4], b"RIFF");
531        assert_eq!(&wav[8..12], b"WAVE");
532
533        // data size should be 0
534        let data_size = u32::from_le_bytes([wav[40], wav[41], wav[42], wav[43]]);
535        assert_eq!(data_size, 0);
536    }
537
538    // -----------------------------------------------------------------------
539    // 5. WasmSynthesisResult construction and methods
540    // -----------------------------------------------------------------------
541    #[test]
542    fn test_wasm_synthesis_result_construction() {
543        let result = WasmSynthesisResult {
544            audio_samples: vec![100i16, -200, 300],
545            sample_rate: 22050,
546            infer_seconds: 0.05,
547            audio_seconds: 0.5,
548        };
549        assert_eq!(result.audio_samples.len(), 3);
550        assert_eq!(result.sample_rate, 22050);
551        assert!((result.infer_seconds - 0.05).abs() < 1e-9);
552        assert!((result.audio_seconds - 0.5).abs() < 1e-9);
553    }
554
555    #[test]
556    fn test_wasm_synthesis_result_rtf() {
557        let result = WasmSynthesisResult {
558            audio_samples: vec![0i16; 22050],
559            sample_rate: 22050,
560            infer_seconds: 0.5,
561            audio_seconds: 1.0,
562        };
563        assert!((result.real_time_factor() - 0.5).abs() < 1e-6);
564    }
565
566    #[test]
567    fn test_wasm_synthesis_result_rtf_zero_audio() {
568        let result = WasmSynthesisResult {
569            audio_samples: Vec::new(),
570            sample_rate: 22050,
571            infer_seconds: 0.1,
572            audio_seconds: 0.0,
573        };
574        assert!((result.real_time_factor()).abs() < 1e-6);
575    }
576
577    #[test]
578    fn test_wasm_synthesis_result_clone() {
579        let result = WasmSynthesisResult {
580            audio_samples: vec![1, 2, 3],
581            sample_rate: 44100,
582            infer_seconds: 0.01,
583            audio_seconds: 0.1,
584        };
585        let cloned = result.clone();
586        assert_eq!(cloned.audio_samples, result.audio_samples);
587        assert_eq!(cloned.sample_rate, result.sample_rate);
588    }
589
590    // -----------------------------------------------------------------------
591    // 6. WasmModelCapabilities
592    // -----------------------------------------------------------------------
593    #[test]
594    fn test_wasm_model_capabilities() {
595        let caps = WasmModelCapabilities {
596            has_sid: true,
597            has_lid: true,
598            has_prosody: false,
599            has_duration_output: false,
600        };
601        assert!(caps.has_sid);
602        assert!(caps.has_lid);
603        assert!(!caps.has_prosody);
604        assert!(!caps.has_duration_output);
605
606        // Clone works
607        let cloned = caps.clone();
608        assert_eq!(cloned.has_sid, caps.has_sid);
609        assert_eq!(cloned.has_lid, caps.has_lid);
610    }
611
612    // -----------------------------------------------------------------------
613    // 7. WAV roundtrip: i16 -> wav bytes -> verify sample data
614    // -----------------------------------------------------------------------
615    #[test]
616    fn test_wav_roundtrip_samples() {
617        let original: Vec<i16> = vec![i16::MIN, -1000, 0, 1000, i16::MAX];
618        let wav = samples_to_wav_bytes(&original, 16000);
619
620        // Extract samples back from WAV bytes (data starts at offset 44)
621        let mut recovered = Vec::new();
622        for i in 0..original.len() {
623            let offset = 44 + i * 2;
624            let sample = i16::from_le_bytes([wav[offset], wav[offset + 1]]);
625            recovered.push(sample);
626        }
627        assert_eq!(recovered, original);
628    }
629
630    // -----------------------------------------------------------------------
631    // 8. samples_i16_to_f32 range boundaries
632    // -----------------------------------------------------------------------
633    #[test]
634    fn test_samples_i16_to_f32_range() {
635        let samples = vec![i16::MAX, i16::MIN, 0];
636        let result = samples_i16_to_f32(&samples);
637
638        // i16::MAX (32767) / 32768.0 should be just under 1.0
639        assert!(result[0] > 0.999 && result[0] < 1.0);
640        // i16::MIN (-32768) / 32768.0 should be exactly -1.0
641        assert!((result[1] - (-1.0)).abs() < 1e-6);
642        // 0 / 32768.0 should be exactly 0.0
643        assert!((result[2] - 0.0).abs() < 1e-6);
644    }
645
646    // -----------------------------------------------------------------------
647    // 9. WAV bytes with different sample rates
648    // -----------------------------------------------------------------------
649    #[test]
650    fn test_wav_bytes_various_sample_rates() {
651        for &rate in &[8000u32, 16000, 22050, 44100, 48000] {
652            let wav = samples_to_wav_bytes(&[0i16; 10], rate);
653            let sr = u32::from_le_bytes([wav[24], wav[25], wav[26], wav[27]]);
654            assert_eq!(sr, rate, "sample rate mismatch for {rate}");
655            let br = u32::from_le_bytes([wav[28], wav[29], wav[30], wav[31]]);
656            assert_eq!(br, rate * 2, "byte rate mismatch for {rate}");
657        }
658    }
659
660    // -----------------------------------------------------------------------
661    // 10. WasmVoice::load_from_bytes with invalid model bytes
662    // -----------------------------------------------------------------------
663    #[test]
664    fn test_load_from_bytes_invalid_model() {
665        let config = r#"{
666            "audio": {"sample_rate": 22050},
667            "num_speakers": 1,
668            "num_symbols": 10,
669            "phoneme_type": "openjtalk",
670            "phoneme_id_map": {},
671            "num_languages": 1,
672            "language_id_map": {},
673            "speaker_id_map": {}
674        }"#;
675        let result = WasmVoice::load_from_bytes(b"not a model", config);
676        assert!(result.is_err());
677        match result.err().unwrap() {
678            PiperError::ModelLoad(msg) => {
679                assert!(!msg.is_empty(), "error message should be non-empty");
680            }
681            other => panic!("expected ModelLoad, got: {other:?}"),
682        }
683    }
684
685    // -----------------------------------------------------------------------
686    // 11. WasmVoice::load_from_bytes with invalid config JSON
687    // -----------------------------------------------------------------------
688    #[test]
689    fn test_load_from_bytes_invalid_config() {
690        let result = WasmVoice::load_from_bytes(b"fake", "not json");
691        assert!(result.is_err());
692        match result.err().unwrap() {
693            PiperError::JsonParse(_) => {} // config parse fails before model load
694            other => panic!("expected JsonParse, got: {other:?}"),
695        }
696    }
697
698    #[test]
699    fn test_load_from_bytes_empty_config() {
700        // Empty string is not valid JSON
701        let result = WasmVoice::load_from_bytes(b"fake model data", "");
702        assert!(result.is_err());
703        match result.err().unwrap() {
704            PiperError::JsonParse(_) => {}
705            other => panic!("expected JsonParse, got: {other:?}"),
706        }
707    }
708
709    // -----------------------------------------------------------------------
710    // 12. WasmSynthesisResult edge cases
711    // -----------------------------------------------------------------------
712    #[test]
713    fn test_wasm_synthesis_result_large_audio() {
714        // Simulate a large audio output (~60 seconds at 22050 Hz)
715        let num_samples = 22050 * 60;
716        let result = WasmSynthesisResult {
717            audio_samples: vec![0i16; num_samples],
718            sample_rate: 22050,
719            infer_seconds: 2.5,
720            audio_seconds: num_samples as f64 / 22050.0,
721        };
722        assert_eq!(result.audio_samples.len(), num_samples);
723        assert!((result.audio_seconds - 60.0).abs() < 1e-6);
724        // RTF < 1 means faster than real-time
725        assert!(result.real_time_factor() < 1.0);
726    }
727
728    #[test]
729    fn test_wasm_synthesis_result_negative_infer_seconds() {
730        // Negative infer_seconds is unusual but should not panic
731        let result = WasmSynthesisResult {
732            audio_samples: vec![1, 2, 3],
733            sample_rate: 22050,
734            infer_seconds: -0.5,
735            audio_seconds: 1.0,
736        };
737        // RTF will be negative, which is meaningless but should not crash
738        let rtf = result.real_time_factor();
739        assert!(rtf < 0.0);
740    }
741
742    // -----------------------------------------------------------------------
743    // 13. samples_i16_to_f32 boundary values
744    // -----------------------------------------------------------------------
745    #[test]
746    fn test_samples_i16_to_f32_boundaries() {
747        let samples = vec![i16::MIN, i16::MAX, 0];
748        let f32s = samples_i16_to_f32(&samples);
749        // i16::MIN (-32768) / 32768.0 = exactly -1.0
750        assert!(f32s[0] <= -1.0 + 0.001);
751        // i16::MAX (32767) / 32768.0 ~ 0.99997
752        assert!(f32s[1] >= 1.0 - 0.001);
753        // 0 / 32768.0 = 0.0
754        assert!((f32s[2]).abs() < 0.001);
755    }
756
757    #[test]
758    fn test_samples_i16_to_f32_all_within_range() {
759        // Every possible i16 value should produce f32 in [-1.0, 1.0)
760        let samples: Vec<i16> = vec![i16::MIN, i16::MIN + 1, -1, 0, 1, i16::MAX - 1, i16::MAX];
761        let f32s = samples_i16_to_f32(&samples);
762        for &v in &f32s {
763            assert!(v >= -1.0, "value {v} below -1.0");
764            assert!(v < 1.0, "value {v} >= 1.0 (i16::MAX / 32768 should be < 1)");
765        }
766    }
767
768    // -----------------------------------------------------------------------
769    // 14. samples_to_wav_bytes with large data (no overflow)
770    // -----------------------------------------------------------------------
771    #[test]
772    fn test_wav_bytes_large_sample_count() {
773        // 10 seconds of audio at 22050 Hz = 220,500 samples
774        let num_samples = 220_500;
775        let samples = vec![0i16; num_samples];
776        let wav = samples_to_wav_bytes(&samples, 22050);
777
778        // Total should be 44 header + num_samples * 2 bytes
779        let expected_len = 44 + num_samples * 2;
780        assert_eq!(wav.len(), expected_len);
781
782        // RIFF file size = total - 8
783        let file_size = u32::from_le_bytes([wav[4], wav[5], wav[6], wav[7]]);
784        assert_eq!(file_size, (expected_len - 8) as u32);
785
786        // data chunk size = num_samples * 2
787        let data_size = u32::from_le_bytes([wav[40], wav[41], wav[42], wav[43]]);
788        assert_eq!(data_size, (num_samples * 2) as u32);
789    }
790
791    // -----------------------------------------------------------------------
792    // 15. parse_config with extra/unknown fields (should be ignored)
793    // -----------------------------------------------------------------------
794    #[test]
795    fn test_parse_config_extra_fields_ignored() {
796        let json = r#"{
797            "audio": {"sample_rate": 44100},
798            "num_speakers": 5,
799            "some_unknown_field": "should be ignored",
800            "another_unknown": 42,
801            "nested_unknown": {"a": 1, "b": [2, 3]}
802        }"#;
803        let config = parse_config(json).unwrap();
804        assert_eq!(config.audio.sample_rate, 44100);
805        assert_eq!(config.num_speakers, 5);
806        // The parse succeeded despite unknown fields
807    }
808
809    // -----------------------------------------------------------------------
810    // 16. parse_config with speaker_id_map
811    // -----------------------------------------------------------------------
812    #[test]
813    fn test_parse_config_speaker_id_map() {
814        let json = r#"{
815            "num_speakers": 3,
816            "speaker_id_map": {"alice": 0, "bob": 1, "charlie": 2},
817            "phoneme_id_map": {"a": [1], "b": [2]}
818        }"#;
819        let config = parse_config(json).unwrap();
820        assert_eq!(config.num_speakers, 3);
821        assert_eq!(config.speaker_id_map.len(), 3);
822        assert_eq!(config.speaker_id_map.get("alice"), Some(&0));
823        assert_eq!(config.speaker_id_map.get("charlie"), Some(&2));
824    }
825
826    // -----------------------------------------------------------------------
827    // 17. WasmVoice::load_from_bytes with empty model bytes
828    // -----------------------------------------------------------------------
829    #[test]
830    fn test_load_from_bytes_empty_model() {
831        let config = r#"{"audio": {"sample_rate": 22050}}"#;
832        let result = WasmVoice::load_from_bytes(b"", config);
833        assert!(result.is_err());
834        match result.err().unwrap() {
835            PiperError::ModelLoad(_) => {} // ONNX runtime cannot load empty bytes
836            other => panic!("expected ModelLoad, got: {other:?}"),
837        }
838    }
839
840    // -----------------------------------------------------------------------
841    // 18. samples_to_wav_bytes roundtrip with extreme values
842    // -----------------------------------------------------------------------
843    #[test]
844    fn test_wav_bytes_extreme_sample_values() {
845        let samples: Vec<i16> = vec![i16::MIN, i16::MAX, i16::MIN, i16::MAX];
846        let wav = samples_to_wav_bytes(&samples, 22050);
847
848        // Verify each extreme value survives the WAV encoding
849        for i in 0..samples.len() {
850            let offset = 44 + i * 2;
851            let recovered = i16::from_le_bytes([wav[offset], wav[offset + 1]]);
852            assert_eq!(
853                recovered, samples[i],
854                "sample {i}: expected {}, got {recovered}",
855                samples[i]
856            );
857        }
858    }
859
860    // -----------------------------------------------------------------------
861    // 19. WasmSynthesisResult real_time_factor edge: both zero
862    // -----------------------------------------------------------------------
863    #[test]
864    fn test_wasm_synthesis_result_rtf_both_zero() {
865        let result = WasmSynthesisResult {
866            audio_samples: Vec::new(),
867            sample_rate: 22050,
868            infer_seconds: 0.0,
869            audio_seconds: 0.0,
870        };
871        // audio_seconds == 0 -> returns 0.0 (guarded division)
872        assert!((result.real_time_factor()).abs() < 1e-6);
873    }
874}