use std::path::{Path, PathBuf};
const HF_ENDPOINT_ENV: &str = "HF_ENDPOINT";
const HUGGINGFACE_HUB_TOKEN_ENV: &str = "HUGGINGFACE_HUB_TOKEN";
const HF_TOKEN_ENV: &str = "HF_TOKEN";
const HUGGINGFACE_TOKEN_ENV: &str = "HUGGINGFACE_TOKEN";
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ModelSource {
kind: ModelSourceKind,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum ModelSourceKind {
File {
path: PathBuf,
},
HuggingFace {
repo_id: String,
filename: String,
revision: Option<String>,
},
HuggingFaceDir {
repo_id: String,
directory: String,
revision: Option<String>,
},
}
impl ModelSource {
pub fn from_file(path: impl Into<PathBuf>) -> Self {
Self {
kind: ModelSourceKind::File { path: path.into() },
}
}
pub fn from_hf(repo_id: impl Into<String>, filename: impl Into<String>) -> Self {
Self {
kind: ModelSourceKind::HuggingFace {
repo_id: repo_id.into(),
filename: filename.into(),
revision: None,
},
}
}
pub fn from_hf_dir(repo_id: impl Into<String>, directory: impl Into<String>) -> Self {
Self {
kind: ModelSourceKind::HuggingFaceDir {
repo_id: repo_id.into(),
directory: directory.into(),
revision: None,
},
}
}
pub fn with_revision(mut self, revision: impl Into<String>) -> Self {
match &mut self.kind {
ModelSourceKind::HuggingFace { revision: slot, .. }
| ModelSourceKind::HuggingFaceDir { revision: slot, .. } => {
*slot = Some(revision.into());
}
_ => {}
}
self
}
pub fn resolve(&self) -> Result<PathBuf, ModelSourceError> {
match &self.kind {
ModelSourceKind::File { path } => {
if path.is_file() {
Ok(path.clone())
} else {
Err(ModelSourceError::MissingLocalFile(path.clone()))
}
}
ModelSourceKind::HuggingFace {
repo_id,
filename,
revision,
} => resolve_hf(repo_id, filename, revision.as_deref()),
ModelSourceKind::HuggingFaceDir {
repo_id,
directory,
revision,
} => resolve_hf_dir(repo_id, directory, revision.as_deref()),
}
}
pub fn local_path(&self) -> Option<&Path> {
match &self.kind {
ModelSourceKind::File { path } => Some(path.as_path()),
_ => None,
}
}
pub fn repo_id(&self) -> Option<&str> {
match &self.kind {
ModelSourceKind::HuggingFace { repo_id, .. }
| ModelSourceKind::HuggingFaceDir { repo_id, .. } => Some(repo_id.as_str()),
_ => None,
}
}
pub fn filename(&self) -> Option<&str> {
match &self.kind {
ModelSourceKind::HuggingFace { filename, .. } => Some(filename.as_str()),
_ => None,
}
}
pub fn directory(&self) -> Option<&str> {
match &self.kind {
ModelSourceKind::HuggingFaceDir { directory, .. } => Some(directory.as_str()),
_ => None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ModelSourceError {
#[error("Model file not found: {0}")]
MissingLocalFile(PathBuf),
#[error("HuggingFace support is not enabled; enable the `model-hf` feature")]
HuggingFaceDisabled,
#[error("HuggingFace download failed: {0}")]
HuggingFaceDownload(String),
#[error("HuggingFace repo id is required")]
MissingRepoId,
#[error("HuggingFace filename is required")]
MissingFilename,
#[error("HuggingFace directory is required")]
MissingDirectory,
}
#[cfg(feature = "model-hf")]
fn resolve_hf(
repo_id: &str,
filename: &str,
revision: Option<&str>,
) -> Result<PathBuf, ModelSourceError> {
use hf_hub::api::sync::ApiBuilder;
use hf_hub::{Cache, Repo, RepoType};
if repo_id.is_empty() {
return Err(ModelSourceError::MissingRepoId);
}
if filename.is_empty() {
return Err(ModelSourceError::MissingFilename);
}
let cache = Cache::from_env();
let mut api_builder = ApiBuilder::from_cache(cache);
if let Ok(endpoint) = std::env::var(HF_ENDPOINT_ENV) {
api_builder = api_builder.with_endpoint(endpoint);
}
if let Some(token) = hf_token() {
api_builder = api_builder.with_token(Some(token));
}
let api = api_builder
.build()
.map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
let revision = revision.unwrap_or("main");
let repo = Repo::with_revision(repo_id.to_string(), RepoType::Model, revision.to_string());
let api_repo = api.repo(repo);
let path = api_repo
.get(filename)
.map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
Ok(path)
}
#[cfg(feature = "model-hf")]
fn resolve_hf_dir(
repo_id: &str,
directory: &str,
revision: Option<&str>,
) -> Result<PathBuf, ModelSourceError> {
use hf_hub::api::sync::ApiBuilder;
use hf_hub::{Cache, Repo, RepoType};
if repo_id.is_empty() {
return Err(ModelSourceError::MissingRepoId);
}
if directory.is_empty() {
return Err(ModelSourceError::MissingDirectory);
}
let cache = Cache::from_env();
let mut api_builder = ApiBuilder::from_cache(cache);
if let Ok(endpoint) = std::env::var(HF_ENDPOINT_ENV) {
api_builder = api_builder.with_endpoint(endpoint);
}
if let Some(token) = hf_token() {
api_builder = api_builder.with_token(Some(token));
}
let api = api_builder
.build()
.map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
let revision = revision.unwrap_or("main");
let repo = Repo::with_revision(repo_id.to_string(), RepoType::Model, revision.to_string());
let api_repo = api.repo(repo);
let info = api_repo
.info()
.map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
let prefix = if directory.ends_with('/') {
directory.to_string()
} else {
format!("{directory}/")
};
let mut local_dir: Option<PathBuf> = None;
let mut found = false;
for sibling in info.siblings {
let filename = sibling.rfilename;
if !filename.starts_with(&prefix) {
continue;
}
found = true;
let path = api_repo
.get(&filename)
.map_err(|err| ModelSourceError::HuggingFaceDownload(err.to_string()))?;
if local_dir.is_none() {
let local = derive_directory(&path, &prefix, &filename);
local_dir = Some(local);
}
}
if !found {
return Err(ModelSourceError::MissingDirectory);
}
local_dir.ok_or(ModelSourceError::MissingDirectory)
}
#[cfg(not(feature = "model-hf"))]
fn resolve_hf_dir(
_repo_id: &str,
_directory: &str,
_revision: Option<&str>,
) -> Result<PathBuf, ModelSourceError> {
Err(ModelSourceError::HuggingFaceDisabled)
}
#[cfg(feature = "model-hf")]
fn derive_directory(path: &Path, directory: &str, rfilename: &str) -> PathBuf {
let prefix_path = Path::new(directory);
let prefix_count = prefix_path.components().count();
let file_components = Path::new(rfilename).components().count();
let pops = file_components.saturating_sub(prefix_count);
let mut local = path.to_path_buf();
for _ in 0..pops {
local.pop();
}
local
}
#[cfg(not(feature = "model-hf"))]
fn resolve_hf(
_repo_id: &str,
_filename: &str,
_revision: Option<&str>,
) -> Result<PathBuf, ModelSourceError> {
Err(ModelSourceError::HuggingFaceDisabled)
}
#[cfg(feature = "model-hf")]
fn hf_token() -> Option<String> {
std::env::var(HUGGINGFACE_HUB_TOKEN_ENV)
.ok()
.or_else(|| std::env::var(HF_TOKEN_ENV).ok())
.or_else(|| std::env::var(HUGGINGFACE_TOKEN_ENV).ok())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn from_file_tracks_path() {
let source = ModelSource::from_file("model.onnx");
assert_eq!(source.local_path(), Some(Path::new("model.onnx")));
assert!(source.repo_id().is_none());
}
#[test]
fn resolve_missing_file_returns_error() {
let source = ModelSource::from_file("missing.onnx");
let err = source.resolve().unwrap_err();
match err {
ModelSourceError::MissingLocalFile(path) => {
assert_eq!(path, PathBuf::from("missing.onnx"));
}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn resolve_existing_file() {
let mut file = tempfile::NamedTempFile::new().unwrap();
writeln!(file, "test").unwrap();
let path = file.path().to_path_buf();
let source = ModelSource::from_file(&path);
let resolved = source.resolve().unwrap();
assert_eq!(resolved, path);
}
#[test]
fn from_hf_tracks_repo_and_filename() {
let source = ModelSource::from_hf("org/model", "model.onnx");
assert_eq!(source.repo_id(), Some("org/model"));
assert_eq!(source.filename(), Some("model.onnx"));
}
#[test]
fn from_hf_dir_tracks_repo_and_directory() {
let source = ModelSource::from_hf_dir("org/model", "weights");
assert_eq!(source.repo_id(), Some("org/model"));
assert_eq!(source.directory(), Some("weights"));
assert!(source.filename().is_none());
}
#[test]
#[cfg(not(feature = "model-hf"))]
fn resolve_hf_requires_feature() {
let source = ModelSource::from_hf("org/model", "model.onnx");
let err = source.resolve().unwrap_err();
match err {
ModelSourceError::HuggingFaceDisabled => {}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
#[cfg(not(feature = "model-hf"))]
fn resolve_hf_dir_requires_feature() {
let source = ModelSource::from_hf_dir("org/model", "weights");
let err = source.resolve().unwrap_err();
match err {
ModelSourceError::HuggingFaceDisabled => {}
other => panic!("unexpected error: {other:?}"),
}
}
}