use crate::error::{Error, Result};
use crate::model::audio::kokoro::loader::load_voice_pack;
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Default)]
pub struct VoiceResolver {
pub asset_dir: Option<PathBuf>,
}
impl VoiceResolver {
pub fn new() -> Self {
Self::default()
}
pub fn with_asset_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.asset_dir = Some(dir.into());
self
}
pub fn resolve_path(&self, spec: &str) -> Result<PathBuf> {
if spec.is_empty() {
return Err(Error::ModelError {
reason: "voice SPEC is empty".into(),
});
}
let looks_like_path = spec.contains('/')
|| spec.contains('\\')
|| spec.ends_with(".safetensors")
|| spec.ends_with(".pt")
|| spec.ends_with(".pth");
if looks_like_path {
let p = PathBuf::from(spec);
if !p.exists() {
return Err(Error::ModelError {
reason: format!("voice file not found: {}", p.display()),
});
}
return Ok(p);
}
let dirs = self.candidate_dirs();
for dir in &dirs {
for ext in ["safetensors", "pt", "pth"] {
let candidate = dir.join(format!("{spec}.{ext}"));
if candidate.is_file() {
return Ok(candidate);
}
}
}
let dirs_pretty: Vec<String> = dirs.iter().map(|d| d.display().to_string()).collect();
Err(Error::ModelError {
reason: format!(
"voice id {spec:?} not found in any of: {dirs_pretty:?} (looked for \
.safetensors, .pt, .pth)"
),
})
}
pub fn load<R: Runtime<DType = DType>>(
&self,
spec: &str,
device: &R::Device,
) -> Result<Tensor<R>> {
let path = self.resolve_path(spec)?;
load_voice_pack::<R>(&path, device)
}
fn candidate_dirs(&self) -> Vec<PathBuf> {
let mut out = Vec::new();
if let Some(dir) = &self.asset_dir {
out.push(dir.clone());
}
if let Ok(env) = std::env::var("BLAZR_VOICE_DIR") {
let p = PathBuf::from(env);
if !out.contains(&p) {
out.push(p);
}
}
for rel in ["./assets/kokoro_voices", "../blazr/assets/kokoro_voices"] {
let p = PathBuf::from(rel);
if !out.contains(&p) {
out.push(p);
}
}
out
}
}
pub fn resolve_and_load<R: Runtime<DType = DType>>(
spec: &str,
asset_dir: Option<&Path>,
device: &R::Device,
) -> Result<Tensor<R>> {
let mut resolver = VoiceResolver::new();
if let Some(d) = asset_dir {
resolver = resolver.with_asset_dir(d);
}
resolver.load::<R>(spec, device)
}
pub fn select_voice_style<R: Runtime<DType = DType>>(
voice_pack: &Tensor<R>,
phoneme_count: usize,
) -> Result<Tensor<R>> {
let shape = voice_pack.shape();
let (rows, style_width) = match shape.len() {
3 => {
if shape[1] != 1 {
return Err(Error::ModelError {
reason: format!("voice pack middle dim must be 1, got shape {shape:?}"),
});
}
(shape[0], shape[2])
}
2 => (shape[0], shape[1]),
_ => {
return Err(Error::ModelError {
reason: format!("voice pack rank must be 2 or 3, got shape {shape:?}"),
});
}
};
if rows == 0 {
return Err(Error::ModelError {
reason: "voice pack is empty".into(),
});
}
let idx = phoneme_count.saturating_sub(1).min(rows - 1);
let flat = match shape.len() {
3 => voice_pack
.reshape(&[rows, style_width])
.map_err(|e| Error::ModelError {
reason: format!("reshape voice pack: {e}"),
})?,
_ => voice_pack.clone(),
};
flat.narrow(0, idx, 1).map_err(|e| Error::ModelError {
reason: format!("narrow voice pack: {e}"),
})
}
pub fn split_voice_style<R: Runtime<DType = DType>>(
style_row: &Tensor<R>,
style_dim: usize,
) -> Result<(Tensor<R>, Tensor<R>)> {
let shape = style_row.shape();
if shape.len() != 2 || shape[1] != 2 * style_dim {
return Err(Error::ModelError {
reason: format!(
"style row shape must be [B, {}], got {shape:?}",
2 * style_dim
),
});
}
let decoder = style_row
.narrow(1, 0, style_dim)
.map_err(|e| Error::ModelError {
reason: format!("narrow decoder style: {e}"),
})?
.contiguous()?;
let predictor = style_row
.narrow(1, style_dim, style_dim)
.map_err(|e| Error::ModelError {
reason: format!("narrow predictor style: {e}"),
})?
.contiguous()?;
Ok((decoder, predictor))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_empty_spec() {
let r = VoiceResolver::new();
assert!(r.resolve_path("").is_err());
}
#[test]
fn path_form_requires_existing_file() {
let r = VoiceResolver::new();
assert!(
r.resolve_path("/nonexistent-voice-xyz.safetensors")
.is_err()
);
}
#[test]
fn id_form_searches_asset_dir() {
let tmp = std::env::temp_dir().join("boostr_voice_resolver_test");
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
std::fs::write(tmp.join("af_alloy.safetensors"), b"").unwrap();
let r = VoiceResolver::new().with_asset_dir(&tmp);
let resolved = r.resolve_path("af_alloy").unwrap();
assert_eq!(resolved, tmp.join("af_alloy.safetensors"));
let _ = std::fs::remove_dir_all(&tmp);
}
#[test]
fn id_form_prefers_safetensors_over_pt() {
let tmp = std::env::temp_dir().join("boostr_voice_resolver_pref_test");
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
std::fs::write(tmp.join("af.safetensors"), b"").unwrap();
std::fs::write(tmp.join("af.pt"), b"").unwrap();
let r = VoiceResolver::new().with_asset_dir(&tmp);
let resolved = r.resolve_path("af").unwrap();
assert!(resolved.to_string_lossy().ends_with("af.safetensors"));
let _ = std::fs::remove_dir_all(&tmp);
}
#[test]
fn id_form_falls_back_to_pt() {
let tmp = std::env::temp_dir().join("boostr_voice_resolver_pt_fallback");
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
std::fs::write(tmp.join("af.pt"), b"").unwrap();
let r = VoiceResolver::new().with_asset_dir(&tmp);
let resolved = r.resolve_path("af").unwrap();
assert!(resolved.to_string_lossy().ends_with("af.pt"));
let _ = std::fs::remove_dir_all(&tmp);
}
#[test]
fn unknown_id_reports_searched_dirs() {
let r = VoiceResolver::new();
let err = r.resolve_path("af_nope").unwrap_err();
match err {
Error::ModelError { reason } => assert!(reason.contains("af_nope")),
_ => panic!("wrong error variant"),
}
}
#[test]
fn select_voice_style_clamps_to_last_row() {
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
let device = CpuDevice::new();
let data: Vec<f32> = (0..3).flat_map(|r| vec![(r + 1) as f32; 4]).collect();
let pack = Tensor::<CpuRuntime>::from_slice(&data, &[3, 1, 4], &device);
let picked = select_voice_style(&pack, 2).unwrap();
assert_eq!(picked.shape(), &[1, 4]);
let v: Vec<f32> = picked.to_vec();
assert_eq!(v, vec![2.0, 2.0, 2.0, 2.0]);
let last = select_voice_style(&pack, 100).unwrap();
let lv: Vec<f32> = last.to_vec();
assert_eq!(lv, vec![3.0, 3.0, 3.0, 3.0]);
}
#[test]
fn split_voice_style_halves_match() {
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
let device = CpuDevice::new();
let row = Tensor::<CpuRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[1, 8],
&device,
);
let (dec, pred) = split_voice_style(&row, 4).unwrap();
assert_eq!(dec.shape(), &[1, 4]);
assert_eq!(pred.shape(), &[1, 4]);
let d: Vec<f32> = dec.to_vec();
let p: Vec<f32> = pred.to_vec();
assert_eq!(d, vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(p, vec![5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn select_voice_style_rejects_bad_rank() {
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
let device = CpuDevice::new();
let bad = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 4], &[4], &device);
assert!(select_voice_style(&bad, 0).is_err());
}
#[test]
fn split_voice_style_rejects_wrong_width() {
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
let device = CpuDevice::new();
let row = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 6], &[1, 6], &device);
assert!(split_voice_style(&row, 4).is_err());
}
#[test]
fn path_with_slash_is_treated_as_path() {
let tmp = std::env::temp_dir().join("boostr_voice_resolver_path_form");
let _ = std::fs::remove_dir_all(&tmp);
std::fs::create_dir_all(&tmp).unwrap();
let f = tmp.join("custom.safetensors");
std::fs::write(&f, b"").unwrap();
let r = VoiceResolver::new();
let resolved = r.resolve_path(f.to_str().unwrap()).unwrap();
assert_eq!(resolved, f);
let _ = std::fs::remove_dir_all(&tmp);
}
}