use serde::Deserialize;
use std::collections::HashMap;
pub const SCHEMA_V1: &str = "polyvoice-models-v1";
#[derive(Debug, Clone, Deserialize)]
pub struct Manifest {
pub schema: String,
#[serde(default)]
pub profiles: HashMap<String, ProfileEntry>,
#[serde(default)]
pub models: HashMap<String, ModelEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProfileEntry {
pub segmenter: String,
pub embedder: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelEntry {
pub url: String,
pub sha256: String,
#[serde(default)]
pub size: Option<u64>,
pub filename: String,
#[serde(default)]
pub calibration: Option<String>,
#[serde(default)]
pub signature: Option<String>,
}
#[derive(Debug, thiserror::Error)]
pub enum ManifestError {
#[error("toml parse error: {0}")]
Toml(#[from] toml::de::Error),
#[error("unsupported schema version: expected '{}', got '{0}'", SCHEMA_V1)]
UnsupportedSchema(String),
#[error("profile '{profile}' references unknown model '{model}'")]
DanglingModelRef { profile: String, model: String },
#[error("model '{model}' has invalid sha256 '{sha}': expected 64 lowercase hex chars")]
InvalidSha256 { model: String, sha: String },
}
impl Manifest {
pub fn from_toml_str(s: &str) -> Result<Self, ManifestError> {
let m: Manifest = toml::from_str(s)?;
if m.schema != SCHEMA_V1 {
return Err(ManifestError::UnsupportedSchema(m.schema));
}
let mut sorted_profile_ids: Vec<&String> = m.profiles.keys().collect();
sorted_profile_ids.sort();
for name in sorted_profile_ids {
let prof = &m.profiles[name];
if !m.models.contains_key(&prof.segmenter) {
return Err(ManifestError::DanglingModelRef {
profile: name.clone(),
model: prof.segmenter.clone(),
});
}
if !m.models.contains_key(&prof.embedder) {
return Err(ManifestError::DanglingModelRef {
profile: name.clone(),
model: prof.embedder.clone(),
});
}
}
let mut sorted_model_ids: Vec<&String> = m.models.keys().collect();
sorted_model_ids.sort();
for model_id in sorted_model_ids {
let entry = &m.models[model_id];
if !is_valid_sha256_hex(&entry.sha256) {
return Err(ManifestError::InvalidSha256 {
model: model_id.clone(),
sha: truncate_for_display(&entry.sha256),
});
}
}
Ok(m)
}
pub fn profile(&self, id: &str) -> Option<&ProfileEntry> {
self.profiles.get(id)
}
pub fn model(&self, id: &str) -> Option<&ModelEntry> {
self.models.get(id)
}
}
fn is_valid_sha256_hex(s: &str) -> bool {
s.len() == 64
&& s.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
}
fn truncate_for_display(s: &str) -> String {
if s.len() <= 80 {
s.to_owned()
} else {
format!("{}…[{} more chars]", &s[..72], s.len() - 72)
}
}
#[cfg(test)]
mod tests {
use super::*;
const SAMPLE: &str = r#"
schema = "polyvoice-models-v1"
[profiles.mobile]
segmenter = "silero_vad"
embedder = "wespeaker_resnet34"
[profiles.balanced]
segmenter = "silero_vad"
embedder = "wespeaker_resnet34"
[models.silero_vad]
url = "https://example.com/silero_vad.onnx"
sha256 = "abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234"
size = 2300000
filename = "silero_vad.onnx"
[models.wespeaker_resnet34]
url = "https://example.com/wespeaker.onnx"
sha256 = "11112222333344445555666677778888aaaabbbbccccddddeeeeffff00001111"
size = 25000000
filename = "wespeaker_resnet34.onnx"
"#;
#[test]
fn parse_known_good_manifest() {
let m = Manifest::from_toml_str(SAMPLE).expect("must parse");
assert_eq!(m.schema, "polyvoice-models-v1");
assert_eq!(m.profiles.len(), 2);
assert_eq!(m.models.len(), 2);
assert_eq!(m.profiles["mobile"].segmenter, "silero_vad");
assert_eq!(m.models["silero_vad"].size, Some(2300000));
assert_eq!(m.models["silero_vad"].filename, "silero_vad.onnx");
}
#[test]
fn rejects_unknown_schema_version() {
let bad = SAMPLE.replace("polyvoice-models-v1", "polyvoice-models-v999");
let err = Manifest::from_toml_str(&bad).expect_err("must fail");
assert!(format!("{err}").contains("v999") || format!("{err}").contains("schema"));
}
#[test]
fn profile_lookup_resolves_to_models() {
let m = Manifest::from_toml_str(SAMPLE).unwrap();
let prof = m.profile("mobile").expect("mobile profile present");
let seg = m.model(&prof.segmenter).expect("segmenter resolved");
assert_eq!(seg.filename, "silero_vad.onnx");
}
#[test]
fn missing_profile_returns_none() {
let m = Manifest::from_toml_str(SAMPLE).unwrap();
assert!(m.profile("nope").is_none());
}
#[test]
fn rejects_profile_with_missing_model_reference() {
let bad = r#"
schema = "polyvoice-models-v1"
[profiles.mobile]
segmenter = "ghost_model"
embedder = "ghost_model"
[models.silero_vad]
url = "https://example.com/x"
sha256 = "abc"
filename = "silero_vad.onnx"
"#;
let err = Manifest::from_toml_str(bad).expect_err("must fail");
assert!(format!("{err}").to_lowercase().contains("ghost_model"));
}
#[test]
fn rejects_invalid_sha256_length() {
let bad = SAMPLE.replace(
"abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234abcd1234",
"tooshort",
);
let err = Manifest::from_toml_str(&bad).expect_err("must fail");
assert!(format!("{err}").to_lowercase().contains("sha256"));
}
}