use crate::error::{Error, Result};
use crate::format::safetensors_loader::SafeTensorsLoader;
use crate::nn::{BiLstm, Conv1d, Lstm, fuse_weight_norm};
use numr::dtype::DType;
use numr::ops::{BinaryOps, PaddingMode, ReduceOps, TensorOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use std::path::Path;
pub fn load_plain_conv1d<R: Runtime<DType = DType>>(
st: &mut super::weight_source::KokoroWeightSource,
prefix: &str,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
device: &R::Device,
) -> Result<Conv1d<R>> {
let weight = st.load_tensor::<R>(&format!("{prefix}.weight"), device)?;
let bias = st.load_tensor::<R>(&format!("{prefix}.bias"), device).ok();
Ok(Conv1d::new(
weight, bias, stride, padding, dilation, groups, false,
))
}
#[allow(clippy::too_many_arguments)]
pub fn load_weight_normed_conv1d<R, C>(
client: &C,
st: &mut super::weight_source::KokoroWeightSource,
prefix: &str,
stride: usize,
padding: PaddingMode,
dilation: usize,
groups: usize,
device: &R::Device,
) -> Result<Conv1d<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + ReduceOps<R> + UnaryOps<R> + BinaryOps<R> + TensorOps<R>,
{
let (g, v) = load_weight_norm_pair::<R>(st, prefix, device)?;
let weight = fuse_weight_norm(client, &v, &g, 0)?;
let bias = st.load_tensor::<R>(&format!("{prefix}.bias"), device).ok();
Ok(Conv1d::new(
weight, bias, stride, padding, dilation, groups, false,
))
}
pub fn load_weight_norm_pair<R: Runtime<DType = DType>>(
st: &mut super::weight_source::KokoroWeightSource,
prefix: &str,
device: &R::Device,
) -> Result<(Tensor<R>, Tensor<R>)> {
let modern_g = format!("{prefix}.parametrizations.weight.original0");
let modern_v = format!("{prefix}.parametrizations.weight.original1");
if let (Ok(g), Ok(v)) = (
st.load_tensor::<R>(&modern_g, device),
st.load_tensor::<R>(&modern_v, device),
) {
return Ok((g, v));
}
let legacy_g = format!("{prefix}.weight_g");
let legacy_v = format!("{prefix}.weight_v");
let g = st.load_tensor::<R>(&legacy_g, device)?;
let v = st.load_tensor::<R>(&legacy_v, device)?;
Ok((g, v))
}
pub fn load_lstm_direction<R: Runtime<DType = DType>>(
st: &mut super::weight_source::KokoroWeightSource,
prefix: &str,
suffix: &str,
device: &R::Device,
) -> Result<Lstm<R>> {
let weight_ih = st.load_tensor::<R>(&format!("{prefix}.weight_ih_l0{suffix}"), device)?;
let weight_hh = st.load_tensor::<R>(&format!("{prefix}.weight_hh_l0{suffix}"), device)?;
let bias_ih = st.load_tensor::<R>(&format!("{prefix}.bias_ih_l0{suffix}"), device)?;
let bias_hh = st.load_tensor::<R>(&format!("{prefix}.bias_hh_l0{suffix}"), device)?;
Lstm::new(weight_ih, weight_hh, bias_ih, bias_hh)
}
pub fn load_bilstm<R: Runtime<DType = DType>>(
st: &mut super::weight_source::KokoroWeightSource,
prefix: &str,
device: &R::Device,
) -> Result<BiLstm<R>> {
let forward = load_lstm_direction(st, prefix, "", device)?;
let backward = load_lstm_direction(st, prefix, "_reverse", device)?;
BiLstm::new(forward, backward)
}
pub fn load_linear_tensors<R: Runtime<DType = DType>>(
st: &mut super::weight_source::KokoroWeightSource,
prefix: &str,
device: &R::Device,
) -> Result<(Tensor<R>, Option<Tensor<R>>)> {
let w = st.load_tensor::<R>(&format!("{prefix}.weight"), device)?;
let bias = st.load_tensor::<R>(&format!("{prefix}.bias"), device).ok();
Ok((w, bias))
}
pub fn load_voice_pack<R: Runtime<DType = DType>>(
path: impl AsRef<Path>,
device: &R::Device,
) -> Result<Tensor<R>> {
let path = path.as_ref();
let ext = path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_ascii_lowercase();
let tensor = match ext.as_str() {
"safetensors" => {
let mut st = SafeTensorsLoader::open(path)?;
st.load_tensor::<R>("style", device)?
}
"pt" | "pth" => crate::format::load_voice_pt(path, device)?,
other => {
return Err(Error::ModelError {
reason: format!(
"unsupported voice file extension: .{other} (expected .safetensors, .pt, or .pth)"
),
});
}
};
let shape = tensor.shape();
let is_valid = matches!(shape.len(), 2 | 3)
&& *shape.last().unwrap() >= 2
&& shape.last().unwrap() % 2 == 0;
if !is_valid {
return Err(Error::ModelError {
reason: format!(
"voice pack has unexpected shape {shape:?}; expected [T, 1, 2*D] or [T, 2*D]"
),
});
}
Ok(tensor)
}
#[deprecated(note = "Use load_voice_pack + select_voice_style + split_voice_style")]
pub fn load_voice_style<R: Runtime<DType = DType>>(
path: impl AsRef<Path>,
device: &R::Device,
) -> Result<Tensor<R>> {
load_voice_pack::<R>(path, device)
}