use std::borrow::Cow;
use std::path::Path;
use std::time::Instant;
use ort::session::Session;
use ort::session::builder::GraphOptimizationLevel;
use ort::value::Tensor;
use crate::audio::audio_float_to_int16;
use crate::config::VoiceConfig;
use crate::error::PiperError;
const MAX_INTRA_THREADS: usize = 4;
pub const DEFAULT_WARMUP_RUNS: usize = 2;
const WARMUP_PHONEME_LENGTH: usize = 100;
#[derive(Debug, Clone)]
pub struct SynthesisRequest {
pub phoneme_ids: Vec<i64>,
pub prosody_features: Option<Vec<[i32; 3]>>,
pub speaker_id: Option<i64>,
pub language_id: Option<i64>,
pub noise_scale: f32,
pub length_scale: f32,
pub noise_w: f32,
}
impl Default for SynthesisRequest {
fn default() -> Self {
Self {
phoneme_ids: Vec::new(),
prosody_features: None,
speaker_id: None,
language_id: None,
noise_scale: 0.667,
length_scale: 1.0,
noise_w: 0.8,
}
}
}
#[derive(Debug)]
pub struct SynthesisResult {
pub audio: Vec<i16>,
pub sample_rate: u32,
pub infer_seconds: f64,
pub audio_seconds: f64,
pub durations: Option<Vec<f32>>,
}
impl SynthesisResult {
pub fn real_time_factor(&self) -> f64 {
if self.audio_seconds > 0.0 {
self.infer_seconds / self.audio_seconds
} else {
0.0
}
}
}
#[derive(Debug, Clone)]
pub struct ModelCapabilities {
pub has_sid: bool,
pub has_lid: bool,
pub has_prosody: bool,
pub has_duration_output: bool,
}
pub struct OnnxEngine {
session: Session,
capabilities: ModelCapabilities,
sample_rate: u32,
}
impl OnnxEngine {
pub fn load(model_path: &Path, config: &VoiceConfig, device: &str) -> Result<Self, PiperError> {
let device_type = crate::gpu::parse_device_string(device)
.map_err(|e| PiperError::ModelLoad(format!("invalid device '{}': {}", device, e)))?;
let num_intra_threads = std::thread::available_parallelism()
.map(|n| (n.get() / 2).max(1))
.unwrap_or(1)
.min(MAX_INTRA_THREADS);
let device_label = device_type.to_string().replace(':', "");
let cache_ext = format!("{}.opt.onnx", device_label);
let optimized_path = model_path.with_extension(&cache_ext);
let sentinel_path = {
let mut s = optimized_path.as_os_str().to_owned();
s.push(".ok");
std::path::PathBuf::from(s)
};
let cache_hit = optimized_path.exists() && sentinel_path.exists();
if !cache_hit && optimized_path.exists() && !sentinel_path.exists() {
tracing::warn!(
"Removing incomplete cache {:?} (missing sentinel)",
optimized_path
);
let _ = std::fs::remove_file(&optimized_path);
}
if cache_hit {
tracing::info!("Loading pre-optimized model from {:?}", optimized_path);
match Self::build_session(&optimized_path, num_intra_threads, &device_type, true, None)
{
Ok((session, actual_device)) => {
tracing::info!("Using device: {}", actual_device);
return Self::finish_load(session, config);
}
Err(e) => {
tracing::warn!(
"Failed to load cached model {:?}, rebuilding: {}",
optimized_path,
e
);
let _ = std::fs::remove_file(&optimized_path);
let _ = std::fs::remove_file(&sentinel_path);
}
}
}
let (session, actual_device) = Self::build_session(
model_path,
num_intra_threads,
&device_type,
false,
Some(&optimized_path),
)?;
tracing::info!("Using device: {}", actual_device);
if optimized_path.exists() {
if let Err(e) = std::fs::write(&sentinel_path, b"ok") {
tracing::warn!("Failed to write sentinel {:?}: {}", sentinel_path, e);
} else {
tracing::info!("Cache sentinel written: {:?}", sentinel_path);
}
}
Self::finish_load(session, config)
}
fn build_session(
model_path: &Path,
num_intra_threads: usize,
device_type: &crate::gpu::DeviceType,
cached: bool,
cache_save_path: Option<&std::path::Path>,
) -> Result<(Session, crate::gpu::DeviceType), PiperError> {
let mut builder = Session::builder()
.map_err(|e| PiperError::ModelLoad(e.to_string()))?
.with_intra_threads(num_intra_threads)
.map_err(|e| PiperError::ModelLoad(format!("intra_threads: {e}")))?
.with_inter_threads(1)
.map_err(|e| PiperError::ModelLoad(format!("inter_threads: {e}")))?
.with_parallel_execution(false)
.map_err(|e| PiperError::ModelLoad(format!("execution_mode: {e}")))?
.with_memory_pattern(true)
.map_err(|e| PiperError::ModelLoad(format!("memory_pattern: {e}")))?
.with_dynamic_block_base(4)
.map_err(|e| PiperError::ModelLoad(format!("dynamic_block_base: {e}")))?;
if cached {
builder = builder
.with_optimization_level(GraphOptimizationLevel::Disable)
.map_err(|e| PiperError::ModelLoad(format!("optimization_level: {e}")))?;
} else if let Some(save_path) = cache_save_path {
match builder.with_optimized_model_path(save_path) {
Ok(b) => {
builder = b;
tracing::info!("ORT will save optimized model to {:?}", save_path);
}
Err(e) => {
let msg = e.to_string();
builder = e.recover();
tracing::warn!(
"Could not set optimized model path {:?}: {} (continuing without cache)",
save_path,
msg
);
}
}
}
let (mut builder, actual_device) =
crate::gpu::configure_session_builder(builder, device_type)
.map_err(|e| PiperError::ModelLoad(format!("device config: {e}")))?;
let session = builder
.commit_from_file(model_path)
.map_err(|e| PiperError::ModelLoad(e.to_string()))?;
Ok((session, actual_device))
}
fn finish_load(session: Session, config: &VoiceConfig) -> Result<Self, PiperError> {
let input_names: Vec<String> = session
.inputs()
.iter()
.map(|i| i.name().to_string())
.collect();
let output_names: Vec<String> = session
.outputs()
.iter()
.map(|o| o.name().to_string())
.collect();
let has_input = |name: &str| input_names.iter().any(|n| n == name);
let has_output = |name: &str| output_names.iter().any(|n| n == name);
let capabilities = ModelCapabilities {
has_sid: has_input("sid"),
has_lid: has_input("lid"),
has_prosody: has_input("prosody_features"),
has_duration_output: has_output("durations"),
};
tracing::info!(
"Model loaded: inputs={:?}, outputs={:?}",
input_names,
output_names,
);
tracing::info!(
"Capabilities: sid={}, lid={}, prosody={}, durations={}",
capabilities.has_sid,
capabilities.has_lid,
capabilities.has_prosody,
capabilities.has_duration_output,
);
Ok(Self {
session,
capabilities,
sample_rate: config.audio.sample_rate,
})
}
pub fn capabilities(&self) -> &ModelCapabilities {
&self.capabilities
}
pub fn sample_rate(&self) -> u32 {
self.sample_rate
}
pub fn synthesize(
&mut self,
request: &SynthesisRequest,
) -> Result<SynthesisResult, PiperError> {
let phoneme_len = request.phoneme_ids.len();
if phoneme_len == 0 {
return Err(PiperError::Inference("empty phoneme_ids".to_string()));
}
let input_tensor = Tensor::from_array((
[1_usize, phoneme_len],
request.phoneme_ids.to_vec().into_boxed_slice(),
))
.map_err(|e| PiperError::Inference(format!("input tensor: {e}")))?;
let lengths_tensor =
Tensor::from_array(([1_usize], vec![phoneme_len as i64].into_boxed_slice()))
.map_err(|e| PiperError::Inference(format!("input_lengths tensor: {e}")))?;
let scales_tensor = Tensor::from_array((
[3_usize],
vec![request.noise_scale, request.length_scale, request.noise_w].into_boxed_slice(),
))
.map_err(|e| PiperError::Inference(format!("scales tensor: {e}")))?;
let sid_val = request.speaker_id.unwrap_or(0);
let sid_tensor = if self.capabilities.has_sid {
Some(
Tensor::from_array(([1_usize], vec![sid_val].into_boxed_slice()))
.map_err(|e| PiperError::Inference(format!("sid tensor: {e}")))?,
)
} else {
None
};
let lid_val = request.language_id.unwrap_or(0);
let lid_tensor = if self.capabilities.has_lid {
Some(
Tensor::from_array(([1_usize], vec![lid_val].into_boxed_slice()))
.map_err(|e| PiperError::Inference(format!("lid tensor: {e}")))?,
)
} else {
None
};
let prosody_tensor = if self.capabilities.has_prosody {
let flat: Vec<i64> = if let Some(ref features) = request.prosody_features {
features
.iter()
.flat_map(|f| [f[0] as i64, f[1] as i64, f[2] as i64])
.collect()
} else {
vec![0i64; phoneme_len * 3]
};
let pf_len = flat.len() / 3;
Some(
Tensor::from_array(([1_usize, pf_len, 3], flat.into_boxed_slice()))
.map_err(|e| PiperError::Inference(format!("prosody tensor: {e}")))?,
)
} else {
None
};
let mut inputs: Vec<(Cow<str>, ort::session::SessionInputValue<'_>)> =
Vec::with_capacity(6);
inputs.push(("input".into(), (&input_tensor).into()));
inputs.push(("input_lengths".into(), (&lengths_tensor).into()));
inputs.push(("scales".into(), (&scales_tensor).into()));
if let Some(ref t) = sid_tensor {
inputs.push(("sid".into(), t.into()));
}
if let Some(ref t) = lid_tensor {
inputs.push(("lid".into(), t.into()));
}
if let Some(ref t) = prosody_tensor {
inputs.push(("prosody_features".into(), t.into()));
}
let start = Instant::now();
let outputs = self
.session
.run(inputs)
.map_err(|e| PiperError::Inference(e.to_string()))?;
let infer_seconds = start.elapsed().as_secs_f64();
let (_shape, audio_slice) = outputs["output"]
.try_extract_tensor::<f32>()
.map_err(|e| PiperError::Inference(format!("extract output: {e}")))?;
let audio_i16 = audio_float_to_int16(audio_slice);
let audio_seconds = audio_i16.len() as f64 / self.sample_rate as f64;
let durations = if self.capabilities.has_duration_output {
match outputs.get("durations") {
Some(d) => match d.try_extract_tensor::<f32>() {
Ok((_shape, data)) => {
let vec = data.to_vec();
tracing::debug!("Duration tensor extracted: {} values", vec.len());
Some(vec)
}
Err(e) => {
tracing::warn!(
"Duration tensor extraction failed (shape/type mismatch): {}. \
Expected f32 tensor with shape [1, phoneme_length].",
e
);
None
}
},
None => {
tracing::warn!(
"Model declares 'durations' output but tensor was not found in results"
);
None
}
}
} else {
None
};
Ok(SynthesisResult {
audio: audio_i16,
sample_rate: self.sample_rate,
infer_seconds,
audio_seconds,
durations,
})
}
pub fn warmup(&mut self, runs: usize) -> Result<(), PiperError> {
let mut dummy_ids = vec![8i64; WARMUP_PHONEME_LENGTH]; dummy_ids[0] = 1; dummy_ids[WARMUP_PHONEME_LENGTH - 1] = 2; let dummy_request = SynthesisRequest {
phoneme_ids: dummy_ids,
..SynthesisRequest::default()
};
for i in 0..runs {
let start = std::time::Instant::now();
let _ = self.synthesize(&dummy_request)?;
tracing::debug!("warmup run {}/{}: {:?}", i + 1, runs, start.elapsed());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_intra_threads_capped_at_max() {
let available = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(2);
let num_intra_threads = available.min(MAX_INTRA_THREADS);
assert!(num_intra_threads >= 1);
assert!(num_intra_threads <= MAX_INTRA_THREADS);
}
#[test]
fn test_thread_count_low_cpu() {
assert_eq!(2_usize.min(MAX_INTRA_THREADS), 2);
}
#[test]
fn test_thread_count_high_cpu() {
assert_eq!(32_usize.min(MAX_INTRA_THREADS), MAX_INTRA_THREADS);
}
#[test]
fn test_synthesis_request_default() {
let req = SynthesisRequest::default();
assert!(req.phoneme_ids.is_empty());
assert!(req.prosody_features.is_none());
assert!(req.speaker_id.is_none());
assert!(req.language_id.is_none());
assert!((req.noise_scale - 0.667).abs() < 1e-6);
assert!((req.length_scale - 1.0).abs() < 1e-6);
assert!((req.noise_w - 0.8).abs() < 1e-6);
}
#[test]
fn test_synthesis_result_rtf() {
let result = SynthesisResult {
audio: vec![0i16; 22050],
sample_rate: 22050,
infer_seconds: 0.5,
audio_seconds: 1.0,
durations: None,
};
assert!((result.real_time_factor() - 0.5).abs() < 1e-6);
}
#[test]
fn test_synthesis_result_rtf_zero_audio() {
let result = SynthesisResult {
audio: Vec::new(),
sample_rate: 22050,
infer_seconds: 0.1,
audio_seconds: 0.0,
durations: None,
};
assert!((result.real_time_factor()).abs() < 1e-6);
}
#[test]
fn test_model_capabilities_debug() {
let caps = ModelCapabilities {
has_sid: true,
has_lid: false,
has_prosody: true,
has_duration_output: false,
};
let debug = format!("{:?}", caps);
assert!(debug.contains("has_sid: true"));
assert!(debug.contains("has_lid: false"));
assert!(debug.contains("has_prosody: true"));
assert!(debug.contains("has_duration_output: false"));
}
#[test]
fn test_synthesis_result_with_durations() {
let result = SynthesisResult {
audio: vec![0i16; 22050],
sample_rate: 22050,
infer_seconds: 0.3,
audio_seconds: 1.0,
durations: Some(vec![1.0, 2.0, 3.0]),
};
let durations = result.durations.as_ref().unwrap();
assert_eq!(durations.len(), 3);
assert!((durations[0] - 1.0).abs() < 1e-6);
assert!((durations[1] - 2.0).abs() < 1e-6);
assert!((durations[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_synthesis_result_rtf_infinity() {
let result = SynthesisResult {
audio: Vec::new(),
sample_rate: 22050,
infer_seconds: 1.5,
audio_seconds: 0.0,
durations: None,
};
assert!((result.real_time_factor() - 0.0).abs() < 1e-6);
}
#[test]
fn test_synthesis_request_custom_values() {
let req = SynthesisRequest {
phoneme_ids: vec![1, 2, 3, 4, 5],
prosody_features: Some(vec![
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12],
[13, 14, 15],
]),
speaker_id: Some(42),
language_id: Some(3),
noise_scale: 0.333,
length_scale: 1.5,
noise_w: 0.5,
};
assert_eq!(req.phoneme_ids.len(), 5);
assert_eq!(req.speaker_id, Some(42));
assert_eq!(req.language_id, Some(3));
assert!((req.noise_scale - 0.333).abs() < 1e-6);
assert!((req.length_scale - 1.5).abs() < 1e-6);
assert!((req.noise_w - 0.5).abs() < 1e-6);
let pf = req.prosody_features.as_ref().unwrap();
assert_eq!(pf.len(), 5);
assert_eq!(pf[0], [1, 2, 3]);
}
#[test]
fn test_model_capabilities_all_true() {
let caps = ModelCapabilities {
has_sid: true,
has_lid: true,
has_prosody: true,
has_duration_output: true,
};
assert!(caps.has_sid);
assert!(caps.has_lid);
assert!(caps.has_prosody);
assert!(caps.has_duration_output);
}
#[test]
fn test_model_capabilities_all_false() {
let caps = ModelCapabilities {
has_sid: false,
has_lid: false,
has_prosody: false,
has_duration_output: false,
};
assert!(!caps.has_sid);
assert!(!caps.has_lid);
assert!(!caps.has_prosody);
assert!(!caps.has_duration_output);
}
#[test]
fn test_warmup_request_is_valid() {
let mut dummy_ids = vec![8i64; WARMUP_PHONEME_LENGTH]; dummy_ids[0] = 1; dummy_ids[WARMUP_PHONEME_LENGTH - 1] = 2; let req = SynthesisRequest {
phoneme_ids: dummy_ids,
..SynthesisRequest::default()
};
assert!(!req.phoneme_ids.is_empty());
assert_eq!(req.phoneme_ids.len(), WARMUP_PHONEME_LENGTH);
assert_eq!(req.phoneme_ids[0], 1); assert_eq!(req.phoneme_ids[WARMUP_PHONEME_LENGTH - 1], 2); assert_eq!(req.phoneme_ids[1], 8); }
fn build_cache_path(model_path: &Path, device_label: &str) -> PathBuf {
let cache_ext = format!("{}.opt.onnx", device_label);
model_path.with_extension(&cache_ext)
}
fn build_sentinel_path(optimized_path: &Path) -> PathBuf {
let mut s = optimized_path.as_os_str().to_owned();
s.push(".ok");
PathBuf::from(s)
}
#[test]
fn test_optimized_model_path_construction_cpu() {
let model_path = PathBuf::from("/data/models/test.onnx");
let opt_path = build_cache_path(&model_path, "cpu");
assert_eq!(opt_path.to_str().unwrap(), "/data/models/test.cpu.opt.onnx");
}
#[test]
fn test_optimized_model_path_construction_cuda() {
let model_path = PathBuf::from("/data/models/test.onnx");
let device_label = "cuda:0".replace(':', "");
let opt_path = build_cache_path(&model_path, &device_label);
assert_eq!(
opt_path.to_str().unwrap(),
"/data/models/test.cuda0.opt.onnx"
);
}
#[test]
fn test_optimized_model_path_from_nested_dir() {
let model_path = PathBuf::from("/home/user/models/tsukuyomi/model.onnx");
let opt_path = build_cache_path(&model_path, "cpu");
assert_eq!(
opt_path.to_str().unwrap(),
"/home/user/models/tsukuyomi/model.cpu.opt.onnx"
);
}
#[test]
fn test_optimized_model_path_preserves_parent() {
let model_path = PathBuf::from("/data/models/test.onnx");
let opt_path = build_cache_path(&model_path, "cpu");
assert_eq!(opt_path.parent(), model_path.parent());
}
#[test]
fn test_sentinel_path_construction() {
let model_path = PathBuf::from("/data/models/test.onnx");
let opt_path = build_cache_path(&model_path, "cpu");
let sentinel = build_sentinel_path(&opt_path);
assert_eq!(
sentinel.to_str().unwrap(),
"/data/models/test.cpu.opt.onnx.ok"
);
}
#[test]
fn test_use_cached_requires_both_files() {
let opt_exists = true;
let sentinel_exists = true;
let use_cached = opt_exists && sentinel_exists;
assert!(use_cached);
}
#[test]
fn test_no_cache_when_sentinel_missing() {
let opt_exists = true;
let sentinel_exists = false;
let use_cached = opt_exists && sentinel_exists;
assert!(!use_cached);
}
#[test]
fn test_no_cache_when_opt_missing() {
let opt_exists = false;
let sentinel_exists = false;
let use_cached = opt_exists && sentinel_exists;
assert!(!use_cached);
}
#[test]
fn test_device_label_colon_removal() {
let label = "cuda:0".replace(':', "");
assert_eq!(label, "cuda0");
assert!(!label.contains(':'));
assert!(!label.contains('.'));
}
#[test]
fn test_session_builder_with_all_options() {
let builder = Session::builder()
.expect("session builder")
.with_intra_threads(1)
.expect("intra_threads")
.with_inter_threads(1)
.expect("inter_threads")
.with_parallel_execution(false)
.expect("parallel_execution")
.with_memory_pattern(true)
.expect("memory_pattern")
.with_dynamic_block_base(4)
.expect("dynamic_block_base");
let _ = builder;
}
#[test]
fn test_device_label_cpu_no_colon() {
let label = "cpu".replace(':', ".");
assert_eq!(label, "cpu");
}
#[test]
fn test_device_label_directml() {
let label = "directml:1".replace(':', "");
assert_eq!(label, "directml1");
let model_path = PathBuf::from("/data/models/test.onnx");
let opt_path = build_cache_path(&model_path, &label);
assert_eq!(
opt_path.to_str().unwrap(),
"/data/models/test.directml1.opt.onnx"
);
}
#[test]
fn test_sentinel_file_io_roundtrip() {
let dir = std::env::temp_dir().join("piper_test_sentinel");
let _ = std::fs::create_dir_all(&dir);
let sentinel = dir.join("test.cpu.opt.onnx.ok");
std::fs::write(&sentinel, b"ok").unwrap();
assert!(sentinel.exists());
let content = std::fs::read(&sentinel).unwrap();
assert_eq!(content, b"ok");
let _ = std::fs::remove_file(&sentinel);
let _ = std::fs::remove_dir(&dir);
}
}