use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub enum ModelSource {
HuggingFace {
repo_id: String,
revision: Option<String>,
},
Local(PathBuf),
}
impl ModelSource {
pub fn parse(s: &str) -> Self {
let s = s.trim();
if s.starts_with('/') || s.starts_with("./") || s.starts_with("~") {
return ModelSource::Local(PathBuf::from(s));
}
let slash_count = s.matches('/').count();
if slash_count == 1 {
return ModelSource::HuggingFace {
repo_id: s.to_string(),
revision: None,
};
}
ModelSource::HuggingFace {
repo_id: format!("abyo-software/{s}"),
revision: None,
}
}
pub fn local_path(&self) -> Option<&Path> {
match self {
ModelSource::Local(p) => Some(p.as_path()),
ModelSource::HuggingFace { .. } => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_local_paths() {
assert!(matches!(
ModelSource::parse("/tmp/model"),
ModelSource::Local(_)
));
assert!(matches!(
ModelSource::parse("./checkpoints/foo"),
ModelSource::Local(_)
));
}
#[test]
fn parse_hf_repo() {
match ModelSource::parse("meta-llama/Llama-3.1-8B-Instruct") {
ModelSource::HuggingFace { repo_id, .. } => {
assert_eq!(repo_id, "meta-llama/Llama-3.1-8B-Instruct");
}
_ => panic!("expected HuggingFace"),
}
}
#[test]
fn parse_alias_falls_back_to_org() {
match ModelSource::parse("llama-3.1-8b-instruct") {
ModelSource::HuggingFace { repo_id, .. } => {
assert_eq!(repo_id, "abyo-software/llama-3.1-8b-instruct");
}
_ => panic!("expected HuggingFace alias"),
}
}
}