1use std::borrow::Cow;
7use std::path::Path;
8use std::time::Instant;
9
10use ort::session::Session;
11use ort::value::Tensor;
12
13use crate::audio::audio_float_to_int16;
14use crate::config::VoiceConfig;
15use crate::error::PiperError;
16
17#[derive(Debug, Clone)]
19pub struct SynthesisRequest {
20 pub phoneme_ids: Vec<i64>,
21 pub prosody_features: Option<Vec<[i32; 3]>>,
22 pub speaker_id: Option<i64>,
23 pub language_id: Option<i64>,
24 pub noise_scale: f32,
25 pub length_scale: f32,
26 pub noise_w: f32,
27}
28
29impl Default for SynthesisRequest {
30 fn default() -> Self {
31 Self {
32 phoneme_ids: Vec::new(),
33 prosody_features: None,
34 speaker_id: None,
35 language_id: None,
36 noise_scale: 0.667,
37 length_scale: 1.0,
38 noise_w: 0.8,
39 }
40 }
41}
42
43#[derive(Debug)]
45pub struct SynthesisResult {
46 pub audio: Vec<i16>,
47 pub sample_rate: u32,
48 pub infer_seconds: f64,
49 pub audio_seconds: f64,
50 pub durations: Option<Vec<f32>>,
53}
54
55impl SynthesisResult {
56 pub fn real_time_factor(&self) -> f64 {
59 if self.audio_seconds > 0.0 {
60 self.infer_seconds / self.audio_seconds
61 } else {
62 0.0
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct ModelCapabilities {
70 pub has_sid: bool,
71 pub has_lid: bool,
72 pub has_prosody: bool,
73 pub has_duration_output: bool,
74}
75
76pub struct OnnxEngine {
78 session: Session,
79 capabilities: ModelCapabilities,
80 sample_rate: u32,
81}
82
83impl OnnxEngine {
84 pub fn load(model_path: &Path, config: &VoiceConfig, device: &str) -> Result<Self, PiperError> {
89 let device_type = crate::gpu::parse_device_string(device)
93 .map_err(|e| PiperError::ModelLoad(format!("invalid device '{}': {}", device, e)))?;
94
95 let builder = Session::builder().map_err(|e| PiperError::ModelLoad(e.to_string()))?;
96
97 let (mut builder, actual_device) =
98 crate::gpu::configure_session_builder(builder, &device_type)
99 .map_err(|e| PiperError::ModelLoad(format!("device config: {e}")))?;
100
101 tracing::info!("Using device: {}", actual_device);
102
103 let session = builder
104 .commit_from_file(model_path)
105 .map_err(|e| PiperError::ModelLoad(e.to_string()))?;
106
107 let input_names: Vec<String> = session
109 .inputs()
110 .iter()
111 .map(|i| i.name().to_string())
112 .collect();
113 let output_names: Vec<String> = session
114 .outputs()
115 .iter()
116 .map(|o| o.name().to_string())
117 .collect();
118
119 let has_input = |name: &str| input_names.iter().any(|n| n == name);
120 let has_output = |name: &str| output_names.iter().any(|n| n == name);
121
122 let capabilities = ModelCapabilities {
123 has_sid: has_input("sid"),
124 has_lid: has_input("lid"),
125 has_prosody: has_input("prosody_features"),
126 has_duration_output: has_output("durations"),
127 };
128
129 tracing::info!(
130 "Model loaded: inputs={:?}, outputs={:?}",
131 input_names,
132 output_names,
133 );
134 tracing::info!(
135 "Capabilities: sid={}, lid={}, prosody={}, durations={}",
136 capabilities.has_sid,
137 capabilities.has_lid,
138 capabilities.has_prosody,
139 capabilities.has_duration_output,
140 );
141
142 Ok(Self {
143 session,
144 capabilities,
145 sample_rate: config.audio.sample_rate,
146 })
147 }
148
149 pub fn capabilities(&self) -> &ModelCapabilities {
151 &self.capabilities
152 }
153
154 pub fn sample_rate(&self) -> u32 {
156 self.sample_rate
157 }
158
159 pub fn synthesize(
173 &mut self,
174 request: &SynthesisRequest,
175 ) -> Result<SynthesisResult, PiperError> {
176 let phoneme_len = request.phoneme_ids.len();
177 if phoneme_len == 0 {
178 return Err(PiperError::Inference("empty phoneme_ids".to_string()));
179 }
180
181 let input_tensor = Tensor::from_array((
187 [1_usize, phoneme_len],
188 request.phoneme_ids.to_vec().into_boxed_slice(),
189 ))
190 .map_err(|e| PiperError::Inference(format!("input tensor: {e}")))?;
191
192 let lengths_tensor =
194 Tensor::from_array(([1_usize], vec![phoneme_len as i64].into_boxed_slice()))
195 .map_err(|e| PiperError::Inference(format!("input_lengths tensor: {e}")))?;
196
197 let scales_tensor = Tensor::from_array((
199 [3_usize],
200 vec![request.noise_scale, request.length_scale, request.noise_w].into_boxed_slice(),
201 ))
202 .map_err(|e| PiperError::Inference(format!("scales tensor: {e}")))?;
203
204 let sid_val = request.speaker_id.unwrap_or(0);
206 let sid_tensor = if self.capabilities.has_sid {
207 Some(
208 Tensor::from_array(([1_usize], vec![sid_val].into_boxed_slice()))
209 .map_err(|e| PiperError::Inference(format!("sid tensor: {e}")))?,
210 )
211 } else {
212 None
213 };
214
215 let lid_val = request.language_id.unwrap_or(0);
217 let lid_tensor = if self.capabilities.has_lid {
218 Some(
219 Tensor::from_array(([1_usize], vec![lid_val].into_boxed_slice()))
220 .map_err(|e| PiperError::Inference(format!("lid tensor: {e}")))?,
221 )
222 } else {
223 None
224 };
225
226 let prosody_tensor = if self.capabilities.has_prosody {
228 let flat: Vec<i64> = if let Some(ref features) = request.prosody_features {
229 features
230 .iter()
231 .flat_map(|f| [f[0] as i64, f[1] as i64, f[2] as i64])
232 .collect()
233 } else {
234 vec![0i64; phoneme_len * 3]
236 };
237 let pf_len = flat.len() / 3;
238 Some(
239 Tensor::from_array(([1_usize, pf_len, 3], flat.into_boxed_slice()))
240 .map_err(|e| PiperError::Inference(format!("prosody tensor: {e}")))?,
241 )
242 } else {
243 None
244 };
245
246 let mut inputs: Vec<(Cow<str>, ort::session::SessionInputValue<'_>)> =
248 Vec::with_capacity(6);
249
250 inputs.push(("input".into(), (&input_tensor).into()));
251 inputs.push(("input_lengths".into(), (&lengths_tensor).into()));
252 inputs.push(("scales".into(), (&scales_tensor).into()));
253
254 if let Some(ref t) = sid_tensor {
255 inputs.push(("sid".into(), t.into()));
256 }
257 if let Some(ref t) = lid_tensor {
258 inputs.push(("lid".into(), t.into()));
259 }
260 if let Some(ref t) = prosody_tensor {
261 inputs.push(("prosody_features".into(), t.into()));
262 }
263
264 let start = Instant::now();
266
267 let outputs = self
268 .session
269 .run(inputs)
270 .map_err(|e| PiperError::Inference(e.to_string()))?;
271
272 let infer_seconds = start.elapsed().as_secs_f64();
273
274 let (_shape, audio_slice) = outputs["output"]
277 .try_extract_tensor::<f32>()
278 .map_err(|e| PiperError::Inference(format!("extract output: {e}")))?;
279
280 let audio_i16 = audio_float_to_int16(audio_slice);
282 let audio_seconds = audio_i16.len() as f64 / self.sample_rate as f64;
283
284 let durations = if self.capabilities.has_duration_output {
286 match outputs.get("durations") {
287 Some(d) => match d.try_extract_tensor::<f32>() {
288 Ok((_shape, data)) => {
289 let vec = data.to_vec();
290 tracing::debug!("Duration tensor extracted: {} values", vec.len());
291 Some(vec)
292 }
293 Err(e) => {
294 tracing::warn!(
295 "Duration tensor extraction failed (shape/type mismatch): {}. \
296 Expected f32 tensor with shape [1, phoneme_length].",
297 e
298 );
299 None
300 }
301 },
302 None => {
303 tracing::warn!(
304 "Model declares 'durations' output but tensor was not found in results"
305 );
306 None
307 }
308 }
309 } else {
310 None
311 };
312
313 Ok(SynthesisResult {
314 audio: audio_i16,
315 sample_rate: self.sample_rate,
316 infer_seconds,
317 audio_seconds,
318 durations,
319 })
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_synthesis_request_default() {
329 let req = SynthesisRequest::default();
330 assert!(req.phoneme_ids.is_empty());
331 assert!(req.prosody_features.is_none());
332 assert!(req.speaker_id.is_none());
333 assert!(req.language_id.is_none());
334 assert!((req.noise_scale - 0.667).abs() < 1e-6);
335 assert!((req.length_scale - 1.0).abs() < 1e-6);
336 assert!((req.noise_w - 0.8).abs() < 1e-6);
337 }
338
339 #[test]
340 fn test_synthesis_result_rtf() {
341 let result = SynthesisResult {
342 audio: vec![0i16; 22050],
343 sample_rate: 22050,
344 infer_seconds: 0.5,
345 audio_seconds: 1.0,
346 durations: None,
347 };
348 assert!((result.real_time_factor() - 0.5).abs() < 1e-6);
349 }
350
351 #[test]
352 fn test_synthesis_result_rtf_zero_audio() {
353 let result = SynthesisResult {
354 audio: Vec::new(),
355 sample_rate: 22050,
356 infer_seconds: 0.1,
357 audio_seconds: 0.0,
358 durations: None,
359 };
360 assert!((result.real_time_factor()).abs() < 1e-6);
361 }
362
363 #[test]
364 fn test_model_capabilities_debug() {
365 let caps = ModelCapabilities {
366 has_sid: true,
367 has_lid: false,
368 has_prosody: true,
369 has_duration_output: false,
370 };
371 let debug = format!("{:?}", caps);
372 assert!(debug.contains("has_sid: true"));
373 assert!(debug.contains("has_lid: false"));
374 assert!(debug.contains("has_prosody: true"));
375 assert!(debug.contains("has_duration_output: false"));
376 }
377
378 #[test]
383 fn test_synthesis_result_with_durations() {
384 let result = SynthesisResult {
385 audio: vec![0i16; 22050],
386 sample_rate: 22050,
387 infer_seconds: 0.3,
388 audio_seconds: 1.0,
389 durations: Some(vec![1.0, 2.0, 3.0]),
390 };
391 let durations = result.durations.as_ref().unwrap();
392 assert_eq!(durations.len(), 3);
393 assert!((durations[0] - 1.0).abs() < 1e-6);
394 assert!((durations[1] - 2.0).abs() < 1e-6);
395 assert!((durations[2] - 3.0).abs() < 1e-6);
396 }
397
398 #[test]
399 fn test_synthesis_result_rtf_infinity() {
400 let result = SynthesisResult {
402 audio: Vec::new(),
403 sample_rate: 22050,
404 infer_seconds: 1.5,
405 audio_seconds: 0.0,
406 durations: None,
407 };
408 assert!((result.real_time_factor() - 0.0).abs() < 1e-6);
409 }
410
411 #[test]
412 fn test_synthesis_request_custom_values() {
413 let req = SynthesisRequest {
414 phoneme_ids: vec![1, 2, 3, 4, 5],
415 prosody_features: Some(vec![
416 [1, 2, 3],
417 [4, 5, 6],
418 [7, 8, 9],
419 [10, 11, 12],
420 [13, 14, 15],
421 ]),
422 speaker_id: Some(42),
423 language_id: Some(3),
424 noise_scale: 0.333,
425 length_scale: 1.5,
426 noise_w: 0.5,
427 };
428 assert_eq!(req.phoneme_ids.len(), 5);
429 assert_eq!(req.speaker_id, Some(42));
430 assert_eq!(req.language_id, Some(3));
431 assert!((req.noise_scale - 0.333).abs() < 1e-6);
432 assert!((req.length_scale - 1.5).abs() < 1e-6);
433 assert!((req.noise_w - 0.5).abs() < 1e-6);
434 let pf = req.prosody_features.as_ref().unwrap();
435 assert_eq!(pf.len(), 5);
436 assert_eq!(pf[0], [1, 2, 3]);
437 }
438
439 #[test]
440 fn test_model_capabilities_all_true() {
441 let caps = ModelCapabilities {
442 has_sid: true,
443 has_lid: true,
444 has_prosody: true,
445 has_duration_output: true,
446 };
447 assert!(caps.has_sid);
448 assert!(caps.has_lid);
449 assert!(caps.has_prosody);
450 assert!(caps.has_duration_output);
451 }
452
453 #[test]
454 fn test_model_capabilities_all_false() {
455 let caps = ModelCapabilities {
456 has_sid: false,
457 has_lid: false,
458 has_prosody: false,
459 has_duration_output: false,
460 };
461 assert!(!caps.has_sid);
462 assert!(!caps.has_lid);
463 assert!(!caps.has_prosody);
464 assert!(!caps.has_duration_output);
465 }
466}