1use anyhow::{Context, Result};
34use hf_hub::api::tokio::Api;
35use indicatif::{ProgressBar, ProgressStyle};
36use mecha10_core::model::{CustomLabelsConfig, ModelConfig, PreprocessingConfig};
37use serde::{Deserialize, Serialize};
38use std::path::{Path, PathBuf};
39use tokio::fs;
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ModelCatalogEntry {
44 pub name: String,
46 pub description: String,
48 pub task: String,
50 pub repo: String,
52 pub filename: String,
54
55 #[serde(default)]
57 pub preprocessing_preset: Option<String>,
58
59 #[serde(default)]
61 pub classes: Vec<String>,
62
63 #[serde(default)]
65 pub quantize: Option<QuantizeConfig>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct QuantizeConfig {
71 pub enabled: bool,
73 pub method: String,
75}
76
77#[derive(Debug, Deserialize)]
79struct ModelCatalog {
80 models: Vec<ModelCatalogEntry>,
81}
82
83#[derive(Debug, Clone, Copy)]
85pub enum PreprocessingPreset {
86 ImageNet,
88 Yolo,
90 Coco,
92 Zero255,
94}
95
96impl PreprocessingPreset {
97 pub fn from_name(name: &str) -> Result<Self> {
99 match name.to_lowercase().as_str() {
100 "imagenet" => Ok(Self::ImageNet),
101 "yolo" => Ok(Self::Yolo),
102 "coco" => Ok(Self::Coco),
103 "zero255" | "0-255" => Ok(Self::Zero255),
104 _ => anyhow::bail!("Unknown preprocessing preset: {}", name),
105 }
106 }
107
108 pub fn to_config(self) -> PreprocessingConfig {
110 match self {
111 Self::ImageNet | Self::Coco => PreprocessingConfig {
112 mean: [0.485, 0.456, 0.406],
113 std: [0.229, 0.224, 0.225],
114 channel_order: "RGB".to_string(),
115 },
116 Self::Yolo => PreprocessingConfig {
117 mean: [0.0, 0.0, 0.0],
118 std: [255.0, 255.0, 255.0],
119 channel_order: "RGB".to_string(),
120 },
121 Self::Zero255 => PreprocessingConfig {
122 mean: [0.0, 0.0, 0.0],
123 std: [1.0, 1.0, 1.0],
124 channel_order: "RGB".to_string(),
125 },
126 }
127 }
128}
129
130#[derive(Debug, Clone, Deserialize)]
132struct HFPreprocessorConfig {
133 #[serde(default)]
134 image_mean: Option<Vec<f32>>,
135 #[serde(default)]
136 image_std: Option<Vec<f32>>,
137 #[serde(default)]
138 size: Option<HFSize>,
139 #[serde(default)]
141 crop_size: Option<HFSize>,
142}
143
144#[derive(Debug, Clone, Deserialize)]
145#[serde(untagged)]
146enum HFSize {
147 Dict { height: u32, width: u32 },
148 ShortestEdge { shortest_edge: u32 },
149 Single(u32), }
151
152impl HFPreprocessorConfig {
153 fn to_preprocessing(&self) -> PreprocessingConfig {
154 PreprocessingConfig {
155 mean: [
156 self.image_mean.as_ref().and_then(|v| v.first()).copied().unwrap_or(0.0),
157 self.image_mean.as_ref().and_then(|v| v.get(1)).copied().unwrap_or(0.0),
158 self.image_mean.as_ref().and_then(|v| v.get(2)).copied().unwrap_or(0.0),
159 ],
160 std: [
161 self.image_std.as_ref().and_then(|v| v.first()).copied().unwrap_or(1.0),
162 self.image_std.as_ref().and_then(|v| v.get(1)).copied().unwrap_or(1.0),
163 self.image_std.as_ref().and_then(|v| v.get(2)).copied().unwrap_or(1.0),
164 ],
165 channel_order: "RGB".to_string(),
166 }
167 }
168
169 fn input_size(&self) -> Option<[u32; 2]> {
170 if let Some(crop_size) = &self.crop_size {
173 return match crop_size {
174 HFSize::Dict { height, width } => Some([*width, *height]),
175 HFSize::ShortestEdge { shortest_edge } => Some([*shortest_edge, *shortest_edge]),
176 HFSize::Single(s) => Some([*s, *s]),
177 };
178 }
179
180 match &self.size {
182 Some(HFSize::Dict { height, width }) => Some([*width, *height]),
183 Some(HFSize::ShortestEdge { shortest_edge }) => Some([*shortest_edge, *shortest_edge]),
184 Some(HFSize::Single(s)) => Some([*s, *s]),
185 None => None,
186 }
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct InstalledModel {
193 pub name: String,
195 pub path: PathBuf,
197 pub size: u64,
199 pub catalog_entry: Option<ModelCatalogEntry>,
201}
202
203pub struct ModelService {
205 api: Api,
207 catalog: Vec<ModelCatalogEntry>,
209 models_dir: PathBuf,
211}
212
213impl ModelService {
214 #[allow(dead_code)]
216 pub fn new() -> Result<Self> {
217 Self::with_models_dir(PathBuf::from("models"))
218 }
219
220 pub fn with_models_dir(models_dir: PathBuf) -> Result<Self> {
222 let api = Api::new().context("Failed to initialize HuggingFace API")?;
223
224 let catalog_toml = include_str!("../../model_catalog.toml");
226 let catalog: ModelCatalog = toml::from_str(catalog_toml).context("Failed to parse model_catalog.toml")?;
227
228 Ok(Self {
229 api,
230 catalog: catalog.models,
231 models_dir,
232 })
233 }
234
235 pub fn list_catalog(&self) -> Result<Vec<ModelCatalogEntry>> {
237 Ok(self.catalog.clone())
238 }
239
240 pub fn get_catalog_entry(&self, name: &str) -> Option<&ModelCatalogEntry> {
242 self.catalog.iter().find(|m| m.name == name)
243 }
244
245 pub async fn list_installed(&self) -> Result<Vec<InstalledModel>> {
247 if !self.models_dir.exists() {
249 return Ok(Vec::new());
250 }
251
252 let mut installed = Vec::new();
253 let mut entries = fs::read_dir(&self.models_dir).await?;
254
255 while let Some(entry) = entries.next_entry().await? {
256 let path = entry.path();
257
258 if !path.is_dir() {
260 continue;
261 }
262
263 let model_path = path.join("model.onnx");
265 if !model_path.exists() {
266 continue;
267 }
268
269 let metadata = fs::metadata(&model_path).await?;
270 let size = metadata.len();
271
272 let name = path
274 .file_name()
275 .and_then(|s| s.to_str())
276 .unwrap_or("unknown")
277 .to_string();
278
279 let catalog_entry = self.get_catalog_entry(&name).cloned();
281
282 installed.push(InstalledModel {
283 name,
284 path: model_path,
285 size,
286 catalog_entry,
287 });
288 }
289
290 Ok(installed)
291 }
292
293 pub async fn pull(&self, name: &str, progress: Option<&ProgressBar>) -> Result<PathBuf> {
295 let entry = self
297 .get_catalog_entry(name)
298 .context(format!("Model '{}' not found in catalog", name))?;
299
300 let model_dir = self.models_dir.join(name);
302 fs::create_dir_all(&model_dir).await?;
303
304 let model_path = self
306 .pull_from_repo(&entry.repo, &entry.filename, name, progress)
307 .await?;
308
309 if !entry.classes.is_empty() {
311 self.write_inline_labels(name, &entry.classes).await?;
313 } else if entry.task == "object-detection" {
314 self.pull_labels_from_repo(entry, name, progress).await?;
316 } else if entry.task == "image-classification" {
317 self.pull_labels_file(name, "imagenet-labels.txt", progress).await?;
319 }
320
321 self.generate_model_config(entry, &model_path, progress).await?;
323
324 if let Some(quantize_config) = &entry.quantize {
326 if quantize_config.enabled {
327 self.quantize_model(&model_path, quantize_config, progress).await?;
328 }
329 }
330
331 if let Some(pb) = progress {
332 pb.set_message(format!("✅ Model '{}' ready at {}", name, model_dir.display()));
333 }
334
335 Ok(model_path)
336 }
337
338 pub async fn pull_from_repo(
340 &self,
341 repo: &str,
342 filename: &str,
343 name: &str,
344 progress: Option<&ProgressBar>,
345 ) -> Result<PathBuf> {
346 let model_dir = self.models_dir.join(name);
348 fs::create_dir_all(&model_dir).await?;
349
350 let output_path = model_dir.join("model.onnx");
352
353 if output_path.exists() {
355 if let Some(pb) = progress {
356 pb.set_message(format!("Model '{}' already cached", name));
357 }
358 return Ok(output_path);
359 }
360
361 if let Some(pb) = progress {
363 pb.set_style(
364 ProgressStyle::default_spinner()
365 .template("{spinner:.green} {msg}")
366 .unwrap(),
367 );
368 pb.set_message(format!("Downloading {} from {}", name, repo));
369 }
370
371 let repo_api = self.api.model(repo.to_string());
373 let hf_cached_path = repo_api
374 .get(filename)
375 .await
376 .context(format!("Failed to download {} from {}", filename, repo))?;
377
378 fs::copy(&hf_cached_path, &output_path)
380 .await
381 .context("Failed to copy model to project directory")?;
382
383 if let Some(pb) = progress {
384 pb.set_message(format!("Downloaded {} successfully", name));
385 }
386
387 Ok(output_path)
388 }
389
390 async fn pull_labels_file(
392 &self,
393 model_name: &str,
394 filename: &str,
395 progress: Option<&ProgressBar>,
396 ) -> Result<PathBuf> {
397 let model_dir = self.models_dir.join(model_name);
399 let output_path = model_dir.join("labels.txt");
400
401 if output_path.exists() {
403 if let Some(pb) = progress {
404 pb.set_message("Labels file already cached".to_string());
405 }
406 return Ok(output_path);
407 }
408
409 let url = match filename {
411 "imagenet-labels.txt" => "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt",
412 _ => {
413 if let Some(pb) = progress {
415 pb.set_message(format!("⚠️ Unknown labels file: {}, skipping", filename));
416 }
417 return Ok(output_path); }
419 };
420
421 if let Some(pb) = progress {
422 pb.set_message(format!("Downloading labels: {}", filename));
423 }
424
425 let client = reqwest::Client::new();
427 let response = client
428 .get(url)
429 .send()
430 .await
431 .context(format!("Failed to download labels from {}", url))?;
432
433 if !response.status().is_success() {
434 anyhow::bail!("Failed to download labels: HTTP {}", response.status());
435 }
436
437 let content = response.text().await.context("Failed to read labels content")?;
438
439 fs::write(&output_path, content)
441 .await
442 .context("Failed to write labels file")?;
443
444 if let Some(pb) = progress {
445 pb.set_message(format!("Downloaded labels: {}", filename));
446 }
447
448 Ok(output_path)
449 }
450
451 async fn write_inline_labels(&self, model_name: &str, classes: &[String]) -> Result<()> {
453 let model_dir = self.models_dir.join(model_name);
454 let labels_path = model_dir.join("labels.txt");
455
456 let content = classes.join("\n");
457 fs::write(&labels_path, content)
458 .await
459 .context("Failed to write inline labels to labels.txt")?;
460
461 Ok(())
462 }
463
464 async fn pull_labels_from_repo(
466 &self,
467 entry: &ModelCatalogEntry,
468 model_name: &str,
469 progress: Option<&ProgressBar>,
470 ) -> Result<()> {
471 let model_dir = self.models_dir.join(model_name);
472 let labels_path = model_dir.join("labels.txt");
473
474 if labels_path.exists() {
476 if let Some(pb) = progress {
477 pb.set_message("Labels file already cached".to_string());
478 }
479 return Ok(());
480 }
481
482 let model_dir_in_repo = entry.filename.rsplit_once('/').map(|(dir, _)| dir).unwrap_or("");
484
485 if model_dir_in_repo.is_empty() {
486 return Ok(());
488 }
489
490 let labels_filename = format!("{}/labels.json", model_dir_in_repo);
491
492 if let Some(pb) = progress {
493 pb.set_message(format!("Downloading labels from {}", entry.repo));
494 }
495
496 let url = format!("https://huggingface.co/{}/raw/main/{}", entry.repo, labels_filename);
498
499 let client = reqwest::Client::new();
500 let response = match client.get(&url).send().await {
501 Ok(resp) if resp.status().is_success() => resp,
502 _ => {
503 return Ok(());
505 }
506 };
507
508 let json_content = response.text().await.context("Failed to read labels.json response")?;
509
510 let labels: Vec<String> = serde_json::from_str(&json_content).context("Failed to parse labels.json")?;
511
512 let content = labels.join("\n");
514 fs::write(&labels_path, content)
515 .await
516 .context("Failed to write labels.txt")?;
517
518 if let Some(pb) = progress {
519 pb.set_message(format!("Downloaded {} class labels", labels.len()));
520 }
521
522 Ok(())
523 }
524
525 async fn fetch_hf_preprocessor_config(&self, repo: &str) -> Result<HFPreprocessorConfig> {
527 let url = format!("https://huggingface.co/{}/raw/main/preprocessor_config.json", repo);
528
529 let client = reqwest::Client::new();
530 let response = client
531 .get(&url)
532 .send()
533 .await
534 .context(format!("Failed to fetch from {}", url))?;
535
536 if !response.status().is_success() {
537 anyhow::bail!(
538 "HuggingFace preprocessor_config.json not found for {} (HTTP {})",
539 repo,
540 response.status()
541 );
542 }
543
544 let config: HFPreprocessorConfig = response
545 .json()
546 .await
547 .context("Failed to parse preprocessor_config.json")?;
548
549 Ok(config)
550 }
551
552 fn extract_input_size_from_onnx(&self, model_path: &Path) -> Option<[u32; 2]> {
557 use ort::session::Session;
558
559 let session = Session::builder().ok()?.commit_from_file(model_path).ok()?;
561
562 let _input = session.inputs.first()?;
564
565 None
571 }
572
573 async fn auto_detect_preprocessing(
578 &self,
579 entry: &ModelCatalogEntry,
580 model_path: &Path,
581 progress: Option<&ProgressBar>,
582 ) -> Result<(PreprocessingConfig, [u32; 2])> {
583 if let Some(pb) = progress {
585 pb.set_message(format!("🔍 Auto-detecting preprocessing for {}", entry.name));
586 }
587
588 if let Ok(hf_config) = self.fetch_hf_preprocessor_config(&entry.repo).await {
589 tracing::debug!(
590 "HF config: size={:?}, crop_size={:?}",
591 hf_config.size,
592 hf_config.crop_size
593 );
594
595 let preprocessing = hf_config.to_preprocessing();
596 let input_size = hf_config.input_size().unwrap_or([224, 224]);
597
598 tracing::debug!(
599 "Detected preprocessing: mean={:?}, std={:?}, input_size={:?}",
600 preprocessing.mean,
601 preprocessing.std,
602 input_size
603 );
604
605 if let Some(pb) = progress {
606 pb.set_message(format!(
607 "✅ Auto-detected from HuggingFace (input_size={:?})",
608 input_size
609 ));
610 }
611
612 return Ok((preprocessing, input_size));
613 } else {
614 tracing::debug!("Failed to fetch HuggingFace preprocessor config, falling back to preset");
615 }
616
617 if let Some(preset_name) = &entry.preprocessing_preset {
619 if let Ok(preset) = PreprocessingPreset::from_name(preset_name) {
620 let preprocessing = preset.to_config();
621
622 let input_size = self.extract_input_size_from_onnx(model_path).unwrap_or([224, 224]);
624
625 if let Some(pb) = progress {
626 pb.set_message(format!(
627 "✅ Using preset '{}' (input_size={:?})",
628 preset_name, input_size
629 ));
630 }
631
632 return Ok((preprocessing, input_size));
633 }
634 }
635
636 let input_size = self.extract_input_size_from_onnx(model_path).unwrap_or([224, 224]);
638
639 let preprocessing = PreprocessingConfig {
640 mean: [0.0, 0.0, 0.0],
641 std: [1.0, 1.0, 1.0],
642 channel_order: "RGB".to_string(),
643 };
644
645 if let Some(pb) = progress {
646 pb.set_message(format!(
647 "⚠️ Using fallback preprocessing (input_size={:?}). Consider editing config.json",
648 input_size
649 ));
650 }
651
652 Ok((preprocessing, input_size))
653 }
654
655 async fn generate_model_config(
657 &self,
658 entry: &ModelCatalogEntry,
659 model_path: &Path,
660 progress: Option<&ProgressBar>,
661 ) -> Result<()> {
662 let model_dir = self.models_dir.join(&entry.name);
663 let config_path = model_dir.join("config.json");
664
665 let (preprocessing, input_size) = self.auto_detect_preprocessing(entry, model_path, progress).await?;
667
668 let num_classes = if entry.task == "object-detection" {
670 entry.classes.len().max(1)
671 } else {
672 1000 };
674
675 let config = ModelConfig {
677 name: entry.name.clone(),
678 task: entry.task.clone(),
679 repo: entry.repo.clone(),
680 filename: entry.filename.clone(),
681 input_size,
682 preprocessing,
683 num_classes,
684 labels_file: "labels.txt".to_string(),
685 custom_labels: CustomLabelsConfig::default(),
686 };
687
688 let json = serde_json::to_string_pretty(&config).context("Failed to serialize model config")?;
690
691 fs::write(&config_path, json)
692 .await
693 .context("Failed to write model config.json")?;
694
695 if let Some(pb) = progress {
696 pb.set_message(format!("📝 Wrote config to {}", config_path.display()));
697 }
698
699 Ok(())
700 }
701
702 async fn quantize_model(
704 &self,
705 model_path: &Path,
706 config: &QuantizeConfig,
707 progress: Option<&ProgressBar>,
708 ) -> Result<PathBuf> {
709 let int8_path = model_path.with_file_name("model-int8.onnx");
710
711 if int8_path.exists() {
713 if let Some(pb) = progress {
714 pb.set_message("INT8 model already cached");
715 }
716 return Ok(int8_path);
717 }
718
719 if let Some(pb) = progress {
720 pb.set_message("Quantizing model to INT8...");
721 }
722
723 match config.method.as_str() {
724 "dynamic_int8" => {
725 self.quantize_dynamic_int8(model_path, &int8_path).await?;
726 }
727 _ => {
728 anyhow::bail!("Unsupported quantization method: {}", config.method);
729 }
730 }
731
732 if let Some(pb) = progress {
733 pb.set_message("✅ INT8 model ready");
734 }
735
736 Ok(int8_path)
737 }
738
739 async fn quantize_dynamic_int8(&self, input: &Path, output: &Path) -> Result<()> {
741 let python = self.find_python()?;
743
744 let script = include_str!("../../scripts/quantize_int8.py");
746 let script_path = std::env::temp_dir().join("mecha10_quantize_int8.py");
747 fs::write(&script_path, script).await?;
748
749 let output_result = tokio::process::Command::new(&python)
751 .arg(&script_path)
752 .arg(input)
753 .arg(output)
754 .output()
755 .await?;
756
757 let _ = fs::remove_file(&script_path).await;
759
760 if !output_result.status.success() {
761 let stderr = String::from_utf8_lossy(&output_result.stderr);
762 anyhow::bail!(
763 "Quantization failed: {}\n\nTip: Install with 'pip install onnx onnxruntime'",
764 stderr
765 );
766 }
767
768 Ok(())
769 }
770
771 fn find_python(&self) -> Result<String> {
773 for candidate in &["python3", "python"] {
774 if which::which(candidate).is_ok() {
775 return Ok(candidate.to_string());
776 }
777 }
778 anyhow::bail!("Python 3 not found. Install with: brew install python3 (macOS) or apt install python3 (Linux)")
779 }
780
781 pub async fn remove(&self, name: &str) -> Result<()> {
783 let model_dir = self.models_dir.join(name);
784
785 if !model_dir.exists() {
786 anyhow::bail!("Model '{}' is not installed", name);
787 }
788
789 fs::remove_dir_all(&model_dir)
790 .await
791 .context(format!("Failed to remove model '{}'", name))?;
792
793 Ok(())
794 }
795
796 #[allow(dead_code)]
801 pub fn get_model_path(&self, name: &str) -> PathBuf {
802 self.models_dir.join(name).join("model.onnx")
803 }
804
805 #[allow(dead_code)]
810 pub async fn is_installed(&self, name: &str) -> bool {
811 let model_path = self.get_model_path(name);
812 model_path.exists()
813 }
814
815 pub async fn info(&self, name: &str) -> Result<ModelInfo> {
817 let catalog_entry = self.get_catalog_entry(name).cloned();
818 let installed = self.list_installed().await?;
819 let installed_info = installed.iter().find(|m| m.name == name).cloned();
820
821 Ok(ModelInfo {
822 name: name.to_string(),
823 catalog_entry,
824 installed_info,
825 })
826 }
827
828 #[allow(dead_code)]
830 pub async fn validate(&self, path: &Path) -> Result<bool> {
831 if !path.exists() {
833 return Ok(false);
834 }
835
836 if path.extension().and_then(|s| s.to_str()) != Some("onnx") {
837 return Ok(false);
838 }
839
840 let bytes = fs::read(path).await?;
843
844 Ok(bytes.len() > 4)
847 }
848}
849
850#[derive(Debug, Clone, Serialize)]
852pub struct ModelInfo {
853 pub name: String,
854 pub catalog_entry: Option<ModelCatalogEntry>,
855 pub installed_info: Option<InstalledModel>,
856}
857
858impl ModelInfo {
859 #[allow(dead_code)]
861 pub fn is_installed(&self) -> bool {
862 self.installed_info.is_some()
863 }
864
865 #[allow(dead_code)]
867 pub fn is_in_catalog(&self) -> bool {
868 self.catalog_entry.is_some()
869 }
870}