use anyhow::{Context, Result};
use hf_hub::api::tokio::Api;
use indicatif::{ProgressBar, ProgressStyle};
use mecha10_core::model::{CustomLabelsConfig, ModelConfig, PreprocessingConfig};
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use tokio::fs;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCatalogEntry {
pub name: String,
pub description: String,
pub task: String,
pub repo: String,
pub filename: String,
#[serde(default)]
pub preprocessing_preset: Option<String>,
#[serde(default)]
pub classes: Vec<String>,
#[serde(default)]
pub quantize: Option<QuantizeConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizeConfig {
pub enabled: bool,
pub method: String,
}
#[derive(Debug, Deserialize)]
struct ModelCatalog {
models: Vec<ModelCatalogEntry>,
}
#[derive(Debug, Clone, Copy)]
pub enum PreprocessingPreset {
ImageNet,
Yolo,
Coco,
Zero255,
}
impl PreprocessingPreset {
pub fn from_name(name: &str) -> Result<Self> {
match name.to_lowercase().as_str() {
"imagenet" => Ok(Self::ImageNet),
"yolo" => Ok(Self::Yolo),
"coco" => Ok(Self::Coco),
"zero255" | "0-255" => Ok(Self::Zero255),
_ => anyhow::bail!("Unknown preprocessing preset: {}", name),
}
}
pub fn to_config(self) -> PreprocessingConfig {
match self {
Self::ImageNet | Self::Coco => PreprocessingConfig {
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
channel_order: "RGB".to_string(),
},
Self::Yolo => PreprocessingConfig {
mean: [0.0, 0.0, 0.0],
std: [255.0, 255.0, 255.0],
channel_order: "RGB".to_string(),
},
Self::Zero255 => PreprocessingConfig {
mean: [0.0, 0.0, 0.0],
std: [1.0, 1.0, 1.0],
channel_order: "RGB".to_string(),
},
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct HFPreprocessorConfig {
#[serde(default)]
image_mean: Option<Vec<f32>>,
#[serde(default)]
image_std: Option<Vec<f32>>,
#[serde(default)]
size: Option<HFSize>,
#[serde(default)]
crop_size: Option<HFSize>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum HFSize {
Dict { height: u32, width: u32 },
ShortestEdge { shortest_edge: u32 },
Single(u32), }
impl HFPreprocessorConfig {
fn to_preprocessing(&self) -> PreprocessingConfig {
PreprocessingConfig {
mean: [
self.image_mean.as_ref().and_then(|v| v.first()).copied().unwrap_or(0.0),
self.image_mean.as_ref().and_then(|v| v.get(1)).copied().unwrap_or(0.0),
self.image_mean.as_ref().and_then(|v| v.get(2)).copied().unwrap_or(0.0),
],
std: [
self.image_std.as_ref().and_then(|v| v.first()).copied().unwrap_or(1.0),
self.image_std.as_ref().and_then(|v| v.get(1)).copied().unwrap_or(1.0),
self.image_std.as_ref().and_then(|v| v.get(2)).copied().unwrap_or(1.0),
],
channel_order: "RGB".to_string(),
}
}
fn input_size(&self) -> Option<[u32; 2]> {
if let Some(crop_size) = &self.crop_size {
return match crop_size {
HFSize::Dict { height, width } => Some([*width, *height]),
HFSize::ShortestEdge { shortest_edge } => Some([*shortest_edge, *shortest_edge]),
HFSize::Single(s) => Some([*s, *s]),
};
}
match &self.size {
Some(HFSize::Dict { height, width }) => Some([*width, *height]),
Some(HFSize::ShortestEdge { shortest_edge }) => Some([*shortest_edge, *shortest_edge]),
Some(HFSize::Single(s)) => Some([*s, *s]),
None => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InstalledModel {
pub name: String,
pub path: PathBuf,
pub size: u64,
pub catalog_entry: Option<ModelCatalogEntry>,
}
pub struct ModelService {
api: Api,
catalog: Vec<ModelCatalogEntry>,
models_dir: PathBuf,
}
impl ModelService {
#[allow(dead_code)]
pub fn new() -> Result<Self> {
Self::with_models_dir(PathBuf::from("models"))
}
pub fn with_models_dir(models_dir: PathBuf) -> Result<Self> {
let api = Api::new().context("Failed to initialize HuggingFace API")?;
let catalog_toml = include_str!("../../model_catalog.toml");
let catalog: ModelCatalog = toml::from_str(catalog_toml).context("Failed to parse model_catalog.toml")?;
Ok(Self {
api,
catalog: catalog.models,
models_dir,
})
}
pub fn list_catalog(&self) -> Result<Vec<ModelCatalogEntry>> {
Ok(self.catalog.clone())
}
pub fn get_catalog_entry(&self, name: &str) -> Option<&ModelCatalogEntry> {
self.catalog.iter().find(|m| m.name == name)
}
pub async fn list_installed(&self) -> Result<Vec<InstalledModel>> {
if !self.models_dir.exists() {
return Ok(Vec::new());
}
let mut installed = Vec::new();
let mut entries = fs::read_dir(&self.models_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if !path.is_dir() {
continue;
}
let model_path = path.join("model.onnx");
if !model_path.exists() {
continue;
}
let metadata = fs::metadata(&model_path).await?;
let size = metadata.len();
let name = path
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string();
let catalog_entry = self.get_catalog_entry(&name).cloned();
installed.push(InstalledModel {
name,
path: model_path,
size,
catalog_entry,
});
}
Ok(installed)
}
pub async fn pull(&self, name: &str, progress: Option<&ProgressBar>) -> Result<PathBuf> {
let entry = self
.get_catalog_entry(name)
.context(format!("Model '{}' not found in catalog", name))?;
let model_dir = self.models_dir.join(name);
fs::create_dir_all(&model_dir).await?;
let model_path = self
.pull_from_repo(&entry.repo, &entry.filename, name, progress)
.await?;
if !entry.classes.is_empty() {
self.write_inline_labels(name, &entry.classes).await?;
} else if entry.task == "object-detection" {
self.pull_labels_from_repo(entry, name, progress).await?;
} else if entry.task == "image-classification" {
self.pull_labels_file(name, "imagenet-labels.txt", progress).await?;
}
self.generate_model_config(entry, &model_path, progress).await?;
if let Some(quantize_config) = &entry.quantize {
if quantize_config.enabled {
self.quantize_model(&model_path, quantize_config, progress).await?;
}
}
if let Some(pb) = progress {
pb.set_message(format!("✅ Model '{}' ready at {}", name, model_dir.display()));
}
Ok(model_path)
}
pub async fn pull_from_repo(
&self,
repo: &str,
filename: &str,
name: &str,
progress: Option<&ProgressBar>,
) -> Result<PathBuf> {
let model_dir = self.models_dir.join(name);
fs::create_dir_all(&model_dir).await?;
let output_path = model_dir.join("model.onnx");
if output_path.exists() {
if let Some(pb) = progress {
pb.set_message(format!("Model '{}' already cached", name));
}
return Ok(output_path);
}
if let Some(pb) = progress {
pb.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} {msg}")
.unwrap(),
);
pb.set_message(format!("Downloading {} from {}", name, repo));
}
let repo_api = self.api.model(repo.to_string());
let hf_cached_path = repo_api
.get(filename)
.await
.context(format!("Failed to download {} from {}", filename, repo))?;
fs::copy(&hf_cached_path, &output_path)
.await
.context("Failed to copy model to project directory")?;
if let Some(pb) = progress {
pb.set_message(format!("Downloaded {} successfully", name));
}
Ok(output_path)
}
async fn pull_labels_file(
&self,
model_name: &str,
filename: &str,
progress: Option<&ProgressBar>,
) -> Result<PathBuf> {
let model_dir = self.models_dir.join(model_name);
let output_path = model_dir.join("labels.txt");
if output_path.exists() {
if let Some(pb) = progress {
pb.set_message("Labels file already cached".to_string());
}
return Ok(output_path);
}
let url = match filename {
"imagenet-labels.txt" => "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt",
_ => {
if let Some(pb) = progress {
pb.set_message(format!("⚠️ Unknown labels file: {}, skipping", filename));
}
return Ok(output_path); }
};
if let Some(pb) = progress {
pb.set_message(format!("Downloading labels: {}", filename));
}
let client = reqwest::Client::new();
let response = client
.get(url)
.send()
.await
.context(format!("Failed to download labels from {}", url))?;
if !response.status().is_success() {
anyhow::bail!("Failed to download labels: HTTP {}", response.status());
}
let content = response.text().await.context("Failed to read labels content")?;
fs::write(&output_path, content)
.await
.context("Failed to write labels file")?;
if let Some(pb) = progress {
pb.set_message(format!("Downloaded labels: {}", filename));
}
Ok(output_path)
}
async fn write_inline_labels(&self, model_name: &str, classes: &[String]) -> Result<()> {
let model_dir = self.models_dir.join(model_name);
let labels_path = model_dir.join("labels.txt");
let content = classes.join("\n");
fs::write(&labels_path, content)
.await
.context("Failed to write inline labels to labels.txt")?;
Ok(())
}
async fn pull_labels_from_repo(
&self,
entry: &ModelCatalogEntry,
model_name: &str,
progress: Option<&ProgressBar>,
) -> Result<()> {
let model_dir = self.models_dir.join(model_name);
let labels_path = model_dir.join("labels.txt");
if labels_path.exists() {
if let Some(pb) = progress {
pb.set_message("Labels file already cached".to_string());
}
return Ok(());
}
let model_dir_in_repo = entry.filename.rsplit_once('/').map(|(dir, _)| dir).unwrap_or("");
if model_dir_in_repo.is_empty() {
return Ok(());
}
let labels_filename = format!("{}/labels.json", model_dir_in_repo);
if let Some(pb) = progress {
pb.set_message(format!("Downloading labels from {}", entry.repo));
}
let url = format!("https://huggingface.co/{}/raw/main/{}", entry.repo, labels_filename);
let client = reqwest::Client::new();
let response = match client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => resp,
_ => {
return Ok(());
}
};
let json_content = response.text().await.context("Failed to read labels.json response")?;
let labels: Vec<String> = serde_json::from_str(&json_content).context("Failed to parse labels.json")?;
let content = labels.join("\n");
fs::write(&labels_path, content)
.await
.context("Failed to write labels.txt")?;
if let Some(pb) = progress {
pb.set_message(format!("Downloaded {} class labels", labels.len()));
}
Ok(())
}
async fn fetch_hf_preprocessor_config(&self, repo: &str) -> Result<HFPreprocessorConfig> {
let url = format!("https://huggingface.co/{}/raw/main/preprocessor_config.json", repo);
let client = reqwest::Client::new();
let response = client
.get(&url)
.send()
.await
.context(format!("Failed to fetch from {}", url))?;
if !response.status().is_success() {
anyhow::bail!(
"HuggingFace preprocessor_config.json not found for {} (HTTP {})",
repo,
response.status()
);
}
let config: HFPreprocessorConfig = response
.json()
.await
.context("Failed to parse preprocessor_config.json")?;
Ok(config)
}
#[cfg(feature = "vision")]
fn extract_input_size_from_onnx(&self, model_path: &Path) -> Option<[u32; 2]> {
use ort::session::Session;
let session = Session::builder().ok()?.commit_from_file(model_path).ok()?;
let _input = session.inputs.first()?;
None
}
#[cfg(not(feature = "vision"))]
fn extract_input_size_from_onnx(&self, _model_path: &Path) -> Option<[u32; 2]> {
None
}
async fn auto_detect_preprocessing(
&self,
entry: &ModelCatalogEntry,
model_path: &Path,
progress: Option<&ProgressBar>,
) -> Result<(PreprocessingConfig, [u32; 2])> {
if let Some(pb) = progress {
pb.set_message(format!("🔍 Auto-detecting preprocessing for {}", entry.name));
}
if let Ok(hf_config) = self.fetch_hf_preprocessor_config(&entry.repo).await {
tracing::debug!(
"HF config: size={:?}, crop_size={:?}",
hf_config.size,
hf_config.crop_size
);
let preprocessing = hf_config.to_preprocessing();
let input_size = hf_config.input_size().unwrap_or([224, 224]);
tracing::debug!(
"Detected preprocessing: mean={:?}, std={:?}, input_size={:?}",
preprocessing.mean,
preprocessing.std,
input_size
);
if let Some(pb) = progress {
pb.set_message(format!(
"✅ Auto-detected from HuggingFace (input_size={:?})",
input_size
));
}
return Ok((preprocessing, input_size));
} else {
tracing::debug!("Failed to fetch HuggingFace preprocessor config, falling back to preset");
}
if let Some(preset_name) = &entry.preprocessing_preset {
if let Ok(preset) = PreprocessingPreset::from_name(preset_name) {
let preprocessing = preset.to_config();
let input_size = self.extract_input_size_from_onnx(model_path).unwrap_or([224, 224]);
if let Some(pb) = progress {
pb.set_message(format!(
"✅ Using preset '{}' (input_size={:?})",
preset_name, input_size
));
}
return Ok((preprocessing, input_size));
}
}
let input_size = self.extract_input_size_from_onnx(model_path).unwrap_or([224, 224]);
let preprocessing = PreprocessingConfig {
mean: [0.0, 0.0, 0.0],
std: [1.0, 1.0, 1.0],
channel_order: "RGB".to_string(),
};
if let Some(pb) = progress {
pb.set_message(format!(
"⚠️ Using fallback preprocessing (input_size={:?}). Consider editing config.json",
input_size
));
}
Ok((preprocessing, input_size))
}
async fn generate_model_config(
&self,
entry: &ModelCatalogEntry,
model_path: &Path,
progress: Option<&ProgressBar>,
) -> Result<()> {
let model_dir = self.models_dir.join(&entry.name);
let config_path = model_dir.join("config.json");
let (preprocessing, input_size) = self.auto_detect_preprocessing(entry, model_path, progress).await?;
let num_classes = if entry.task == "object-detection" {
entry.classes.len().max(1)
} else {
1000 };
let config = ModelConfig {
name: entry.name.clone(),
task: entry.task.clone(),
repo: entry.repo.clone(),
filename: entry.filename.clone(),
input_size,
preprocessing,
num_classes,
labels_file: "labels.txt".to_string(),
custom_labels: CustomLabelsConfig::default(),
};
let json = serde_json::to_string_pretty(&config).context("Failed to serialize model config")?;
fs::write(&config_path, json)
.await
.context("Failed to write model config.json")?;
if let Some(pb) = progress {
pb.set_message(format!("📝 Wrote config to {}", config_path.display()));
}
Ok(())
}
async fn quantize_model(
&self,
model_path: &Path,
config: &QuantizeConfig,
progress: Option<&ProgressBar>,
) -> Result<PathBuf> {
let int8_path = model_path.with_file_name("model-int8.onnx");
if int8_path.exists() {
if let Some(pb) = progress {
pb.set_message("INT8 model already cached");
}
return Ok(int8_path);
}
if let Some(pb) = progress {
pb.set_message("Quantizing model to INT8...");
}
match config.method.as_str() {
"dynamic_int8" => {
self.quantize_dynamic_int8(model_path, &int8_path).await?;
}
_ => {
anyhow::bail!("Unsupported quantization method: {}", config.method);
}
}
if let Some(pb) = progress {
pb.set_message("✅ INT8 model ready");
}
Ok(int8_path)
}
async fn quantize_dynamic_int8(&self, input: &Path, output: &Path) -> Result<()> {
let python = self.find_python()?;
let script = include_str!("../../scripts/quantize_int8.py");
let script_path = std::env::temp_dir().join("mecha10_quantize_int8.py");
fs::write(&script_path, script).await?;
let output_result = tokio::process::Command::new(&python)
.arg(&script_path)
.arg(input)
.arg(output)
.output()
.await?;
let _ = fs::remove_file(&script_path).await;
if !output_result.status.success() {
let stderr = String::from_utf8_lossy(&output_result.stderr);
anyhow::bail!(
"Quantization failed: {}\n\nTip: Install with 'pip install onnx onnxruntime'",
stderr
);
}
Ok(())
}
fn find_python(&self) -> Result<String> {
for candidate in &["python3", "python"] {
if which::which(candidate).is_ok() {
return Ok(candidate.to_string());
}
}
anyhow::bail!("Python 3 not found. Install with: brew install python3 (macOS) or apt install python3 (Linux)")
}
pub async fn remove(&self, name: &str) -> Result<()> {
let model_dir = self.models_dir.join(name);
if !model_dir.exists() {
anyhow::bail!("Model '{}' is not installed", name);
}
fs::remove_dir_all(&model_dir)
.await
.context(format!("Failed to remove model '{}'", name))?;
Ok(())
}
#[allow(dead_code)]
pub fn get_model_path(&self, name: &str) -> PathBuf {
self.models_dir.join(name).join("model.onnx")
}
#[allow(dead_code)]
pub async fn is_installed(&self, name: &str) -> bool {
let model_path = self.get_model_path(name);
model_path.exists()
}
pub async fn info(&self, name: &str) -> Result<ModelInfo> {
let catalog_entry = self.get_catalog_entry(name).cloned();
let installed = self.list_installed().await?;
let installed_info = installed.iter().find(|m| m.name == name).cloned();
Ok(ModelInfo {
name: name.to_string(),
catalog_entry,
installed_info,
})
}
#[allow(dead_code)]
pub async fn validate(&self, path: &Path) -> Result<bool> {
if !path.exists() {
return Ok(false);
}
if path.extension().and_then(|s| s.to_str()) != Some("onnx") {
return Ok(false);
}
let bytes = fs::read(path).await?;
Ok(bytes.len() > 4)
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ModelInfo {
pub name: String,
pub catalog_entry: Option<ModelCatalogEntry>,
pub installed_info: Option<InstalledModel>,
}
impl ModelInfo {
#[allow(dead_code)]
pub fn is_installed(&self) -> bool {
self.installed_info.is_some()
}
#[allow(dead_code)]
pub fn is_in_catalog(&self) -> bool {
self.catalog_entry.is_some()
}
}