1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
//! Unified weight-source adapter for the Kokoro loader.
//!
//! The tier-3 loader (`load_kokoro_v2`) needs to read tensors by name from
//! either:
//!
//! * a safetensors directory (if the user pre-converted the upstream
//! checkpoint), or
//! * the upstream `kokoro-v1_0.pth` directly (via [`TorchStateDict`]).
//!
//! Both source types expose the same two operations the tier-1 helpers need:
//! "load this tensor by name" and "does this tensor exist". This enum wraps
//! them so the helpers don't have to be generic / trait-objectified.
use crate::error::Result;
use crate::format::TorchStateDict;
use crate::format::safetensors_loader::SafeTensorsLoader;
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use std::path::Path;
pub enum KokoroWeightSource {
SafeTensors(SafeTensorsLoader),
Pickle(TorchStateDict),
}
impl KokoroWeightSource {
/// Auto-detect from a model directory. Prefers a `.safetensors` file if
/// one exists (faster loads, zero pickle surface); falls back to the
/// upstream `.pth` otherwise.
pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
let dir = model_dir.as_ref();
// Safetensors first: either `model.safetensors` or any `.safetensors`
// that `SafeTensorsLoader::open` picks up.
let has_st = std::fs::read_dir(dir)
.map(|entries| {
entries
.flatten()
.any(|e| e.path().extension().is_some_and(|x| x == "safetensors"))
})
.unwrap_or(false);
if has_st {
let st = SafeTensorsLoader::open(dir)?;
return Ok(Self::SafeTensors(st));
}
// Otherwise look for a single `.pth` / `.pt` file in the directory.
for name in ["kokoro-v1_0.pth", "kokoro.pth", "model.pth", "model.pt"] {
let candidate = dir.join(name);
if candidate.is_file() {
return Ok(Self::Pickle(TorchStateDict::open(&candidate)?));
}
}
// Last resort: any `.pth` / `.pt` file.
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let p = entry.path();
if p.extension().is_some_and(|x| x == "pth" || x == "pt") {
return Ok(Self::Pickle(TorchStateDict::open(&p)?));
}
}
}
Err(crate::error::Error::ModelError {
reason: format!("no .safetensors or .pth weights found in {}", dir.display()),
})
}
/// Load a tensor by flattened name.
pub fn load_tensor<R: Runtime<DType = DType>>(
&mut self,
name: &str,
device: &R::Device,
) -> Result<Tensor<R>> {
match self {
Self::SafeTensors(s) => s.load_tensor::<R>(name, device),
Self::Pickle(p) => p.load_tensor::<R>(name, device),
}
}
/// Whether a tensor with this name is present.
pub fn has_tensor(&self, name: &str) -> bool {
match self {
Self::SafeTensors(s) => s.tensor_info(name).is_ok(),
Self::Pickle(p) => p.has(name),
}
}
}