Skip to main content

piper_plus/
engine.rs

1//! ONNX 推論エンジン
2//!
3//! VITS モデルの ONNX Runtime 推論を行う。
4//! 入力テンソルの構築・条件付きテンソル追加・出力変換を担当。
5
6use std::borrow::Cow;
7use std::path::Path;
8use std::time::Instant;
9
10use ort::session::Session;
11use ort::value::Tensor;
12
13use crate::audio::audio_float_to_int16;
14use crate::config::VoiceConfig;
15use crate::error::PiperError;
16
17/// 合成パラメータ
18#[derive(Debug, Clone)]
19pub struct SynthesisRequest {
20    pub phoneme_ids: Vec<i64>,
21    pub prosody_features: Option<Vec<[i32; 3]>>,
22    pub speaker_id: Option<i64>,
23    pub language_id: Option<i64>,
24    pub noise_scale: f32,
25    pub length_scale: f32,
26    pub noise_w: f32,
27}
28
29impl Default for SynthesisRequest {
30    fn default() -> Self {
31        Self {
32            phoneme_ids: Vec::new(),
33            prosody_features: None,
34            speaker_id: None,
35            language_id: None,
36            noise_scale: 0.667,
37            length_scale: 1.0,
38            noise_w: 0.8,
39        }
40    }
41}
42
43/// 合成結果
44#[derive(Debug)]
45pub struct SynthesisResult {
46    pub audio: Vec<i16>,
47    pub sample_rate: u32,
48    pub infer_seconds: f64,
49    pub audio_seconds: f64,
50    /// Phoneme durations from the model (if available).
51    /// Shape: [phoneme_length], each value = number of frames.
52    pub durations: Option<Vec<f32>>,
53}
54
55impl SynthesisResult {
56    /// リアルタイムファクタ (推論時間 / 音声時間)。
57    /// 1.0 未満ならリアルタイムより高速。
58    pub fn real_time_factor(&self) -> f64 {
59        if self.audio_seconds > 0.0 {
60            self.infer_seconds / self.audio_seconds
61        } else {
62            0.0
63        }
64    }
65}
66
67/// モデルの ONNX 入出力ノードから検出した能力情報
68#[derive(Debug, Clone)]
69pub struct ModelCapabilities {
70    pub has_sid: bool,
71    pub has_lid: bool,
72    pub has_prosody: bool,
73    pub has_duration_output: bool,
74}
75
76/// ONNX 推論エンジン
77pub struct OnnxEngine {
78    session: Session,
79    capabilities: ModelCapabilities,
80    sample_rate: u32,
81}
82
83impl OnnxEngine {
84    /// ONNX モデルを読み込んでエンジンを初期化する。
85    ///
86    /// `device` は `"cpu"`, `"auto"`, `"cuda"`, `"cuda:0"`, `"coreml"`, `"directml"`, `"tensorrt"` のいずれか。
87    /// `"auto"` 指定時は CUDA を試行し、失敗すれば CPU にフォールバックする。
88    pub fn load(model_path: &Path, config: &VoiceConfig, device: &str) -> Result<Self, PiperError> {
89        // デバイス文字列をパースして GPU プロバイダを設定
90        // "auto" は parse_device_string 内でフォールバックするが、
91        // 明示的なデバイス指定 (e.g. "cuda:0") が不正な場合はエラーを返す。
92        let device_type = crate::gpu::parse_device_string(device)
93            .map_err(|e| PiperError::ModelLoad(format!("invalid device '{}': {}", device, e)))?;
94
95        let builder = Session::builder().map_err(|e| PiperError::ModelLoad(e.to_string()))?;
96
97        let (mut builder, actual_device) =
98            crate::gpu::configure_session_builder(builder, &device_type)
99                .map_err(|e| PiperError::ModelLoad(format!("device config: {e}")))?;
100
101        tracing::info!("Using device: {}", actual_device);
102
103        let session = builder
104            .commit_from_file(model_path)
105            .map_err(|e| PiperError::ModelLoad(e.to_string()))?;
106
107        // モデルの入出力ノード名から能力を自動検出
108        let input_names: Vec<String> = session
109            .inputs()
110            .iter()
111            .map(|i| i.name().to_string())
112            .collect();
113        let output_names: Vec<String> = session
114            .outputs()
115            .iter()
116            .map(|o| o.name().to_string())
117            .collect();
118
119        let has_input = |name: &str| input_names.iter().any(|n| n == name);
120        let has_output = |name: &str| output_names.iter().any(|n| n == name);
121
122        let capabilities = ModelCapabilities {
123            has_sid: has_input("sid"),
124            has_lid: has_input("lid"),
125            has_prosody: has_input("prosody_features"),
126            has_duration_output: has_output("durations"),
127        };
128
129        tracing::info!(
130            "Model loaded: inputs={:?}, outputs={:?}",
131            input_names,
132            output_names,
133        );
134        tracing::info!(
135            "Capabilities: sid={}, lid={}, prosody={}, durations={}",
136            capabilities.has_sid,
137            capabilities.has_lid,
138            capabilities.has_prosody,
139            capabilities.has_duration_output,
140        );
141
142        Ok(Self {
143            session,
144            capabilities,
145            sample_rate: config.audio.sample_rate,
146        })
147    }
148
149    /// モデルの能力情報を返す
150    pub fn capabilities(&self) -> &ModelCapabilities {
151        &self.capabilities
152    }
153
154    /// サンプルレートを返す
155    pub fn sample_rate(&self) -> u32 {
156        self.sample_rate
157    }
158
159    /// ONNX 推論を実行して音声を生成する。
160    ///
161    /// ONNX 入力テンソル順序:
162    /// 1. `input` (phoneme_ids): int64 \[1, phoneme_length\]
163    /// 2. `input_lengths`: int64 \[1\]
164    /// 3. `scales`: float32 \[3\] = \[noise_scale, length_scale, noise_w\]
165    /// 4. `sid` (条件付き): int64 \[1\] -- has_sid が true のとき
166    /// 5. `lid` (条件付き): int64 \[1\] -- has_lid が true のとき
167    /// 6. `prosody_features` (条件付き): int64 \[1, phoneme_length, 3\]
168    ///
169    /// ONNX 出力:
170    /// - `output`: float32 \[1, 1, audio_samples\]
171    /// - `durations` (オプション): float32 \[1, phoneme_length\]
172    pub fn synthesize(
173        &mut self,
174        request: &SynthesisRequest,
175    ) -> Result<SynthesisResult, PiperError> {
176        let phoneme_len = request.phoneme_ids.len();
177        if phoneme_len == 0 {
178            return Err(PiperError::Inference("empty phoneme_ids".to_string()));
179        }
180
181        // --- 入力テンソル構築 ---
182        // 条件付き入力があるため動的に ValueMap を構築する。
183        // テンソルは run() 完了まで生存する必要があるため、ここで全て確保する。
184
185        // 1. input: int64 [1, phoneme_len]
186        let input_tensor = Tensor::from_array((
187            [1_usize, phoneme_len],
188            request.phoneme_ids.to_vec().into_boxed_slice(),
189        ))
190        .map_err(|e| PiperError::Inference(format!("input tensor: {e}")))?;
191
192        // 2. input_lengths: int64 [1]
193        let lengths_tensor =
194            Tensor::from_array(([1_usize], vec![phoneme_len as i64].into_boxed_slice()))
195                .map_err(|e| PiperError::Inference(format!("input_lengths tensor: {e}")))?;
196
197        // 3. scales: float32 [3]
198        let scales_tensor = Tensor::from_array((
199            [3_usize],
200            vec![request.noise_scale, request.length_scale, request.noise_w].into_boxed_slice(),
201        ))
202        .map_err(|e| PiperError::Inference(format!("scales tensor: {e}")))?;
203
204        // 4. sid: int64 [1] (条件付き)
205        let sid_val = request.speaker_id.unwrap_or(0);
206        let sid_tensor = if self.capabilities.has_sid {
207            Some(
208                Tensor::from_array(([1_usize], vec![sid_val].into_boxed_slice()))
209                    .map_err(|e| PiperError::Inference(format!("sid tensor: {e}")))?,
210            )
211        } else {
212            None
213        };
214
215        // 5. lid: int64 [1] (条件付き)
216        let lid_val = request.language_id.unwrap_or(0);
217        let lid_tensor = if self.capabilities.has_lid {
218            Some(
219                Tensor::from_array(([1_usize], vec![lid_val].into_boxed_slice()))
220                    .map_err(|e| PiperError::Inference(format!("lid tensor: {e}")))?,
221            )
222        } else {
223            None
224        };
225
226        // 6. prosody_features: int64 [1, phoneme_len, 3] (条件付き)
227        let prosody_tensor = if self.capabilities.has_prosody {
228            let flat: Vec<i64> = if let Some(ref features) = request.prosody_features {
229                features
230                    .iter()
231                    .flat_map(|f| [f[0] as i64, f[1] as i64, f[2] as i64])
232                    .collect()
233            } else {
234                // prosody ノードは存在するがリクエストに特徴量がない場合はゼロ埋め
235                vec![0i64; phoneme_len * 3]
236            };
237            let pf_len = flat.len() / 3;
238            Some(
239                Tensor::from_array(([1_usize, pf_len, 3], flat.into_boxed_slice()))
240                    .map_err(|e| PiperError::Inference(format!("prosody tensor: {e}")))?,
241            )
242        } else {
243            None
244        };
245
246        // ValueMap を構築
247        let mut inputs: Vec<(Cow<str>, ort::session::SessionInputValue<'_>)> =
248            Vec::with_capacity(6);
249
250        inputs.push(("input".into(), (&input_tensor).into()));
251        inputs.push(("input_lengths".into(), (&lengths_tensor).into()));
252        inputs.push(("scales".into(), (&scales_tensor).into()));
253
254        if let Some(ref t) = sid_tensor {
255            inputs.push(("sid".into(), t.into()));
256        }
257        if let Some(ref t) = lid_tensor {
258            inputs.push(("lid".into(), t.into()));
259        }
260        if let Some(ref t) = prosody_tensor {
261            inputs.push(("prosody_features".into(), t.into()));
262        }
263
264        // --- 推論実行 ---
265        let start = Instant::now();
266
267        let outputs = self
268            .session
269            .run(inputs)
270            .map_err(|e| PiperError::Inference(e.to_string()))?;
271
272        let infer_seconds = start.elapsed().as_secs_f64();
273
274        // --- 出力テンソル処理 ---
275        // output: float32 [1, 1, audio_samples]
276        let (_shape, audio_slice) = outputs["output"]
277            .try_extract_tensor::<f32>()
278            .map_err(|e| PiperError::Inference(format!("extract output: {e}")))?;
279
280        // float32 -> int16 ピーク正規化
281        let audio_i16 = audio_float_to_int16(audio_slice);
282        let audio_seconds = audio_i16.len() as f64 / self.sample_rate as f64;
283
284        // --- duration テンソル抽出 (オプション) ---
285        let durations = if self.capabilities.has_duration_output {
286            match outputs.get("durations") {
287                Some(d) => match d.try_extract_tensor::<f32>() {
288                    Ok((_shape, data)) => {
289                        let vec = data.to_vec();
290                        tracing::debug!("Duration tensor extracted: {} values", vec.len());
291                        Some(vec)
292                    }
293                    Err(e) => {
294                        tracing::warn!(
295                            "Duration tensor extraction failed (shape/type mismatch): {}. \
296                             Expected f32 tensor with shape [1, phoneme_length].",
297                            e
298                        );
299                        None
300                    }
301                },
302                None => {
303                    tracing::warn!(
304                        "Model declares 'durations' output but tensor was not found in results"
305                    );
306                    None
307                }
308            }
309        } else {
310            None
311        };
312
313        Ok(SynthesisResult {
314            audio: audio_i16,
315            sample_rate: self.sample_rate,
316            infer_seconds,
317            audio_seconds,
318            durations,
319        })
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_synthesis_request_default() {
329        let req = SynthesisRequest::default();
330        assert!(req.phoneme_ids.is_empty());
331        assert!(req.prosody_features.is_none());
332        assert!(req.speaker_id.is_none());
333        assert!(req.language_id.is_none());
334        assert!((req.noise_scale - 0.667).abs() < 1e-6);
335        assert!((req.length_scale - 1.0).abs() < 1e-6);
336        assert!((req.noise_w - 0.8).abs() < 1e-6);
337    }
338
339    #[test]
340    fn test_synthesis_result_rtf() {
341        let result = SynthesisResult {
342            audio: vec![0i16; 22050],
343            sample_rate: 22050,
344            infer_seconds: 0.5,
345            audio_seconds: 1.0,
346            durations: None,
347        };
348        assert!((result.real_time_factor() - 0.5).abs() < 1e-6);
349    }
350
351    #[test]
352    fn test_synthesis_result_rtf_zero_audio() {
353        let result = SynthesisResult {
354            audio: Vec::new(),
355            sample_rate: 22050,
356            infer_seconds: 0.1,
357            audio_seconds: 0.0,
358            durations: None,
359        };
360        assert!((result.real_time_factor()).abs() < 1e-6);
361    }
362
363    #[test]
364    fn test_model_capabilities_debug() {
365        let caps = ModelCapabilities {
366            has_sid: true,
367            has_lid: false,
368            has_prosody: true,
369            has_duration_output: false,
370        };
371        let debug = format!("{:?}", caps);
372        assert!(debug.contains("has_sid: true"));
373        assert!(debug.contains("has_lid: false"));
374        assert!(debug.contains("has_prosody: true"));
375        assert!(debug.contains("has_duration_output: false"));
376    }
377
378    // -----------------------------------------------------------------------
379    // Additional TDD tests
380    // -----------------------------------------------------------------------
381
382    #[test]
383    fn test_synthesis_result_with_durations() {
384        let result = SynthesisResult {
385            audio: vec![0i16; 22050],
386            sample_rate: 22050,
387            infer_seconds: 0.3,
388            audio_seconds: 1.0,
389            durations: Some(vec![1.0, 2.0, 3.0]),
390        };
391        let durations = result.durations.as_ref().unwrap();
392        assert_eq!(durations.len(), 3);
393        assert!((durations[0] - 1.0).abs() < 1e-6);
394        assert!((durations[1] - 2.0).abs() < 1e-6);
395        assert!((durations[2] - 3.0).abs() < 1e-6);
396    }
397
398    #[test]
399    fn test_synthesis_result_rtf_infinity() {
400        // infer_seconds > 0 but audio_seconds = 0 => RTF should be 0.0 (guard)
401        let result = SynthesisResult {
402            audio: Vec::new(),
403            sample_rate: 22050,
404            infer_seconds: 1.5,
405            audio_seconds: 0.0,
406            durations: None,
407        };
408        assert!((result.real_time_factor() - 0.0).abs() < 1e-6);
409    }
410
411    #[test]
412    fn test_synthesis_request_custom_values() {
413        let req = SynthesisRequest {
414            phoneme_ids: vec![1, 2, 3, 4, 5],
415            prosody_features: Some(vec![
416                [1, 2, 3],
417                [4, 5, 6],
418                [7, 8, 9],
419                [10, 11, 12],
420                [13, 14, 15],
421            ]),
422            speaker_id: Some(42),
423            language_id: Some(3),
424            noise_scale: 0.333,
425            length_scale: 1.5,
426            noise_w: 0.5,
427        };
428        assert_eq!(req.phoneme_ids.len(), 5);
429        assert_eq!(req.speaker_id, Some(42));
430        assert_eq!(req.language_id, Some(3));
431        assert!((req.noise_scale - 0.333).abs() < 1e-6);
432        assert!((req.length_scale - 1.5).abs() < 1e-6);
433        assert!((req.noise_w - 0.5).abs() < 1e-6);
434        let pf = req.prosody_features.as_ref().unwrap();
435        assert_eq!(pf.len(), 5);
436        assert_eq!(pf[0], [1, 2, 3]);
437    }
438
439    #[test]
440    fn test_model_capabilities_all_true() {
441        let caps = ModelCapabilities {
442            has_sid: true,
443            has_lid: true,
444            has_prosody: true,
445            has_duration_output: true,
446        };
447        assert!(caps.has_sid);
448        assert!(caps.has_lid);
449        assert!(caps.has_prosody);
450        assert!(caps.has_duration_output);
451    }
452
453    #[test]
454    fn test_model_capabilities_all_false() {
455        let caps = ModelCapabilities {
456            has_sid: false,
457            has_lid: false,
458            has_prosody: false,
459            has_duration_output: false,
460        };
461        assert!(!caps.has_sid);
462        assert!(!caps.has_lid);
463        assert!(!caps.has_prosody);
464        assert!(!caps.has_duration_output);
465    }
466}