pub mod download;
pub mod manifest;
pub mod verify;
pub use download::{
DownloadError, download_with_checksum, download_with_checksum_and_signature, verify_sha256,
};
pub use manifest::{Manifest, ManifestError, ModelEntry, ProfileEntry, SCHEMA_V1};
use crate::types::Profile;
use std::path::{Path, PathBuf};
pub const DEFAULT_MANIFEST_TOML: &str = include_str!("manifest.toml");
#[allow(clippy::expect_used)]
pub fn default_manifest() -> Manifest {
Manifest::from_toml_str(DEFAULT_MANIFEST_TOML)
.expect("embedded manifest.toml must parse — this is a static-asset bug")
}
#[derive(Debug, thiserror::Error)]
pub enum RegistryError {
#[error("model '{model_id}' not found in manifest")]
ModelNotFound { model_id: String },
#[error("profile '{profile}' not found in manifest")]
ProfileNotFound { profile: String },
#[error("custom profile cannot be resolved by registry — caller must supply models")]
CustomProfileUnresolvable,
#[error("cache directory {path} is not writable")]
CacheNotWritable { path: PathBuf },
#[error("model '{model_id}' is not present in cache and offline mode is requested")]
OfflineMissing { model_id: String },
#[error("manifest error: {0}")]
Manifest(#[from] ManifestError),
#[error("download error: {0}")]
Download(#[from] DownloadError),
#[error("io error on {path}: {source}")]
Io {
path: PathBuf,
#[source]
source: std::io::Error,
},
}
#[derive(Debug, Clone)]
pub struct ProfileModels {
pub segmenter_path: PathBuf,
pub embedder_path: PathBuf,
}
#[derive(Debug, Clone)]
pub struct ModelRegistry {
manifest: Manifest,
cache_dir: PathBuf,
}
impl ModelRegistry {
#[allow(clippy::should_implement_trait)]
pub fn default() -> Result<Self, RegistryError> {
let cache = dirs::cache_dir()
.ok_or_else(|| RegistryError::CacheNotWritable {
path: PathBuf::from("(unresolved-cache-dir)"),
})?
.join("polyvoice")
.join("models");
Self::with_cache_dir(cache)
}
pub fn with_cache_dir(path: impl AsRef<Path>) -> Result<Self, RegistryError> {
let path = path.as_ref().to_path_buf();
std::fs::create_dir_all(&path).map_err(|e| RegistryError::Io {
path: path.clone(),
source: e,
})?;
Ok(Self {
manifest: default_manifest(),
cache_dir: path,
})
}
#[cfg(test)]
pub fn with_manifest_override(mut self, manifest: Manifest) -> Self {
self.manifest = manifest;
self
}
#[cfg(test)]
pub fn with_manifest(
manifest: Manifest,
cache_dir: impl AsRef<Path>,
) -> Result<Self, RegistryError> {
let path = cache_dir.as_ref().to_path_buf();
std::fs::create_dir_all(&path).map_err(|e| RegistryError::Io {
path: path.clone(),
source: e,
})?;
Ok(Self {
manifest,
cache_dir: path,
})
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
pub fn manifest(&self) -> &Manifest {
&self.manifest
}
pub fn ensure(&self, model_id: &str) -> Result<PathBuf, RegistryError> {
let entry = self
.manifest
.model(model_id)
.ok_or_else(|| RegistryError::ModelNotFound {
model_id: model_id.to_owned(),
})?;
let dest = self.cache_dir.join(&entry.filename);
download_with_checksum_and_signature(
&entry.url,
&entry.sha256,
entry.signature.as_deref(),
&dest,
)?;
Ok(dest)
}
#[doc(hidden)]
pub fn ensure_in_cache_only(&self, model_id: &str) -> Result<PathBuf, RegistryError> {
let entry = self
.manifest
.model(model_id)
.ok_or_else(|| RegistryError::ModelNotFound {
model_id: model_id.to_owned(),
})?;
let dest = self.cache_dir.join(&entry.filename);
if !dest.exists() {
return Err(RegistryError::OfflineMissing {
model_id: model_id.to_owned(),
});
}
Ok(dest)
}
pub fn ensure_for_profile(&self, profile: Profile) -> Result<ProfileModels, RegistryError> {
if profile == Profile::Custom {
return Err(RegistryError::CustomProfileUnresolvable);
}
let prof = self
.manifest
.profile(profile.manifest_id())
.ok_or_else(|| RegistryError::ProfileNotFound {
profile: profile.manifest_id().to_owned(),
})?;
let segmenter_path = self.ensure(&prof.segmenter)?;
let embedder_path = self.ensure(&prof.embedder)?;
Ok(ProfileModels {
segmenter_path,
embedder_path,
})
}
#[cfg(test)]
pub fn ensure_in_cache_only_for_profile(
&self,
profile: Profile,
) -> Result<ProfileModels, RegistryError> {
if profile == Profile::Custom {
return Err(RegistryError::CustomProfileUnresolvable);
}
let prof = self
.manifest
.profile(profile.manifest_id())
.ok_or_else(|| RegistryError::ProfileNotFound {
profile: profile.manifest_id().to_owned(),
})?;
let segmenter_path = self.ensure_in_cache_only(&prof.segmenter)?;
let embedder_path = self.ensure_in_cache_only(&prof.embedder)?;
Ok(ProfileModels {
segmenter_path,
embedder_path,
})
}
}
#[cfg(test)]
pub(crate) mod tests_helpers {
pub const TINY_MANIFEST: &str = r#"
schema = "polyvoice-models-v1"
[profiles.mobile]
segmenter = "hello_model"
embedder = "hello_model"
[profiles.balanced]
segmenter = "hello_model"
embedder = "hello_model"
[models.hello_model]
url = "file:///dev/null"
sha256 = "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
size = 5
filename = "hello.bin"
"#;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Profile;
use tempfile::TempDir;
#[test]
fn embedded_manifest_parses() {
let m = default_manifest();
assert_eq!(m.schema, SCHEMA_V1);
assert!(m.profiles.contains_key("mobile"));
assert!(m.profiles.contains_key("balanced"));
}
#[test]
fn embedded_manifest_lists_legacy_models() {
let m = default_manifest();
assert!(m.models.contains_key("silero_vad"));
assert!(m.models.contains_key("wespeaker_resnet34"));
}
#[test]
fn profiles_share_segmenter_diverge_on_embedder_in_m5() {
let m = default_manifest();
let mob = m.profile("mobile").unwrap();
let bal = m.profile("balanced").unwrap();
assert_eq!(mob.segmenter, bal.segmenter);
assert_ne!(mob.embedder, bal.embedder);
}
#[test]
fn registry_default_uses_user_cache() {
let r = ModelRegistry::default().expect("default cache dir resolvable");
let path = r.cache_dir().to_path_buf();
assert!(path.ends_with("polyvoice/models") || path.ends_with("polyvoice\\models"));
}
#[test]
fn registry_with_cache_dir_creates_dir() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("nested/models");
let r = ModelRegistry::with_cache_dir(&path).unwrap();
assert!(path.exists());
assert_eq!(r.cache_dir(), path.as_path());
}
#[test]
fn ensure_returns_err_for_unknown_model_id() {
let tmp = TempDir::new().unwrap();
let r = ModelRegistry::with_cache_dir(tmp.path()).unwrap();
let err = r
.ensure_in_cache_only("ghost")
.expect_err("must be missing");
assert!(matches!(err, RegistryError::ModelNotFound { .. }));
}
#[test]
fn ensure_in_cache_only_succeeds_when_file_present() {
let tmp = TempDir::new().unwrap();
let manifest =
Manifest::from_toml_str(crate::models::tests_helpers::TINY_MANIFEST).unwrap();
let r = ModelRegistry::with_cache_dir(tmp.path())
.unwrap()
.with_manifest_override(manifest);
let cached = tmp.path().join("hello.bin");
std::fs::write(&cached, b"hello").unwrap();
let path = r.ensure_in_cache_only("hello_model").unwrap();
assert_eq!(path, cached);
}
#[test]
fn ensure_for_profile_uses_manifest_lookup() {
let tmp = TempDir::new().unwrap();
let manifest =
Manifest::from_toml_str(crate::models::tests_helpers::TINY_MANIFEST).unwrap();
let r = ModelRegistry::with_cache_dir(tmp.path())
.unwrap()
.with_manifest_override(manifest);
std::fs::write(tmp.path().join("hello.bin"), b"hello").unwrap();
let bundle = r.ensure_in_cache_only_for_profile(Profile::Mobile).unwrap();
assert_eq!(bundle.segmenter_path, tmp.path().join("hello.bin"));
assert_eq!(bundle.embedder_path, tmp.path().join("hello.bin"));
}
}