use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelSource {
Gguf {
model_path: String,
},
HuggingFace {
repo_id: String,
filename: Option<String>,
mmproj_filename: Option<String>,
},
}
impl ModelSource {
pub fn gguf(model_path: impl Into<String>) -> Self {
Self::Gguf {
model_path: model_path.into(),
}
}
pub fn huggingface(repo_id: impl Into<String>) -> Self {
Self::HuggingFace {
repo_id: repo_id.into(),
filename: None,
mmproj_filename: None,
}
}
pub fn huggingface_with_filename(
repo_id: impl Into<String>,
filename: impl Into<String>,
) -> Self {
Self::HuggingFace {
repo_id: repo_id.into(),
filename: Some(filename.into()),
mmproj_filename: None,
}
}
pub fn huggingface_with_mmproj(
repo_id: impl Into<String>,
filename: impl Into<String>,
mmproj_filename: impl Into<String>,
) -> Self {
Self::HuggingFace {
repo_id: repo_id.into(),
filename: Some(filename.into()),
mmproj_filename: Some(mmproj_filename.into()),
}
}
pub fn model_path(&self) -> Option<&str> {
match self {
ModelSource::Gguf { model_path } => Some(model_path),
ModelSource::HuggingFace { .. } => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_source_path() {
let source = ModelSource::gguf("test.gguf");
assert_eq!(source.model_path(), Some("test.gguf"));
}
#[test]
fn test_model_source_hf() {
let source = ModelSource::huggingface("org/model");
assert!(source.model_path().is_none());
assert_eq!(
source,
ModelSource::HuggingFace {
repo_id: "org/model".to_string(),
filename: None,
mmproj_filename: None,
}
);
}
}