use std::path::PathBuf;
#[derive(Debug, Clone)]
pub enum ModelSource {
Registry {
id: String,
platform: Option<String>,
},
#[deprecated(since = "0.0.17", note = "Use ModelSource::Registry instead")]
LegacyRegistry {
url: String,
model_id: String,
version: String,
platform: Option<String>,
},
Bundle {
path: PathBuf,
},
Directory {
path: PathBuf,
},
HuggingFace {
repo: String,
revision: Option<String>,
variant: Option<String>,
},
}
impl ModelSource {
pub fn registry(id: impl Into<String>) -> Self {
ModelSource::Registry {
id: id.into(),
platform: None,
}
}
pub fn registry_with_platform(id: impl Into<String>, platform: impl Into<String>) -> Self {
ModelSource::Registry {
id: id.into(),
platform: Some(platform.into()),
}
}
#[deprecated(since = "0.0.17", note = "Use ModelSource::registry() instead")]
#[allow(deprecated)]
pub fn legacy_registry(
url: impl Into<String>,
model_id: impl Into<String>,
version: impl Into<String>,
) -> Self {
ModelSource::LegacyRegistry {
url: url.into(),
model_id: model_id.into(),
version: version.into(),
platform: None,
}
}
#[deprecated(
since = "0.0.17",
note = "Use ModelSource::registry_with_platform() instead"
)]
#[allow(deprecated)]
pub fn legacy_registry_with_platform(
url: impl Into<String>,
model_id: impl Into<String>,
version: impl Into<String>,
platform: impl Into<String>,
) -> Self {
ModelSource::LegacyRegistry {
url: url.into(),
model_id: model_id.into(),
version: version.into(),
platform: Some(platform.into()),
}
}
pub fn bundle(path: impl Into<PathBuf>) -> Self {
ModelSource::Bundle { path: path.into() }
}
pub fn directory(path: impl Into<PathBuf>) -> Self {
ModelSource::Directory { path: path.into() }
}
pub fn huggingface(repo: impl Into<String>) -> Self {
ModelSource::HuggingFace {
repo: repo.into(),
revision: None,
variant: None,
}
}
pub fn huggingface_with_revision(repo: impl Into<String>, revision: impl Into<String>) -> Self {
ModelSource::HuggingFace {
repo: repo.into(),
revision: Some(revision.into()),
variant: None,
}
}
pub fn huggingface_with_variant(repo: impl Into<String>, variant: impl Into<String>) -> Self {
ModelSource::HuggingFace {
repo: repo.into(),
revision: None,
variant: Some(variant.into()),
}
}
pub fn parse_huggingface(input: &str) -> Self {
if let Some((repo, variant)) = input.rsplit_once(':') {
if repo.contains('/') && !repo.contains("://") {
ModelSource::HuggingFace {
repo: repo.to_string(),
revision: None,
variant: Some(variant.to_string()),
}
} else {
ModelSource::huggingface(input)
}
} else {
ModelSource::huggingface(input)
}
}
#[allow(deprecated)]
pub fn source_type(&self) -> &'static str {
match self {
ModelSource::Registry { .. } => "registry",
ModelSource::LegacyRegistry { .. } => "legacy_registry",
ModelSource::Bundle { .. } => "bundle",
ModelSource::Directory { .. } => "directory",
ModelSource::HuggingFace { .. } => "huggingface",
}
}
#[allow(deprecated)]
pub fn model_id(&self) -> Option<&str> {
match self {
ModelSource::Registry { id, .. } => Some(id),
ModelSource::LegacyRegistry { model_id, .. } => Some(model_id),
ModelSource::HuggingFace { repo, .. } => Some(repo),
_ => None,
}
}
#[allow(deprecated)]
pub fn version(&self) -> Option<&str> {
match self {
ModelSource::LegacyRegistry { version, .. } => Some(version),
ModelSource::HuggingFace { revision, .. } => revision.as_deref(),
_ => None,
}
}
pub fn variant(&self) -> Option<&str> {
match self {
ModelSource::HuggingFace { variant, .. } => variant.as_deref(),
_ => None,
}
}
}
pub fn detect_platform() -> String {
crate::platform::current_platform().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_source() {
let source = ModelSource::registry("kokoro-82m");
assert_eq!(source.source_type(), "registry");
assert_eq!(source.model_id(), Some("kokoro-82m"));
assert_eq!(source.version(), None); }
#[test]
fn test_registry_source_with_platform() {
let source = ModelSource::registry_with_platform("whisper-tiny", "macos-arm64");
assert_eq!(source.source_type(), "registry");
assert_eq!(source.model_id(), Some("whisper-tiny"));
}
#[test]
#[allow(deprecated)]
fn test_legacy_registry_source() {
let source = ModelSource::legacy_registry("http://localhost:8080", "whisper", "1.0");
assert_eq!(source.source_type(), "legacy_registry");
assert_eq!(source.model_id(), Some("whisper"));
assert_eq!(source.version(), Some("1.0"));
}
#[test]
fn test_bundle_source() {
let source = ModelSource::bundle("models/test.xyb");
assert_eq!(source.source_type(), "bundle");
assert_eq!(source.model_id(), None);
}
#[test]
fn test_directory_source() {
let source = ModelSource::directory("/tmp/test-model");
assert_eq!(source.source_type(), "directory");
}
#[test]
fn test_huggingface_source() {
let source = ModelSource::huggingface("xybrid-ai/kokoro-82m");
assert_eq!(source.source_type(), "huggingface");
assert_eq!(source.model_id(), Some("xybrid-ai/kokoro-82m"));
assert_eq!(source.version(), None);
assert_eq!(source.variant(), None);
}
#[test]
fn test_huggingface_source_with_revision() {
let source = ModelSource::huggingface_with_revision("xybrid-ai/kokoro-82m", "v1.0");
assert_eq!(source.source_type(), "huggingface");
assert_eq!(source.model_id(), Some("xybrid-ai/kokoro-82m"));
assert_eq!(source.version(), Some("v1.0"));
}
#[test]
fn test_huggingface_source_with_variant() {
let source = ModelSource::huggingface_with_variant("LiquidAI/LFM2.5-350M-GGUF", "Q8_0");
assert_eq!(source.model_id(), Some("LiquidAI/LFM2.5-350M-GGUF"));
assert_eq!(source.variant(), Some("Q8_0"));
}
#[test]
fn test_parse_huggingface_with_variant() {
let source = ModelSource::parse_huggingface("LiquidAI/LFM2.5-350M-GGUF:Q8_0");
assert_eq!(source.model_id(), Some("LiquidAI/LFM2.5-350M-GGUF"));
assert_eq!(source.variant(), Some("Q8_0"));
}
#[test]
fn test_parse_huggingface_without_variant() {
let source = ModelSource::parse_huggingface("xybrid-ai/kokoro-82m");
assert_eq!(source.model_id(), Some("xybrid-ai/kokoro-82m"));
assert_eq!(source.variant(), None);
}
#[test]
fn test_detect_platform() {
let platform = detect_platform();
assert!(!platform.is_empty());
}
}