1use super::storage::{ModelStorage, ModelMetadata};
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tracing::{debug, info, warn};
12
13use super::recommendation::{HardwareProfile, ModelRecommender};
14use reqwest::Client;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum ModelStatus {
19 Installed,
21 Downloading,
23 Available,
25 Error(String),
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ModelPricing {
33 pub prompt: String,
35 pub completion: String,
37}
38
39impl ModelPricing {
40 pub fn is_free(&self) -> bool {
41 (self.prompt == "0" || self.prompt.is_empty())
42 && (self.completion == "0" || self.completion.is_empty())
43 }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelInfo {
49 pub id: String,
50 pub name: String,
51 pub description: Option<String>,
52 pub author: Option<String>,
53 pub status: ModelStatus,
54 pub size_bytes: u64,
55 pub format: String,
56 pub download_source: Option<String>,
57 #[serde(default)]
59 pub filename: Option<String>,
60 pub installed_version: Option<String>,
61 pub last_updated: Option<chrono::DateTime<chrono::Utc>>,
62 pub tags: Vec<String>,
63 pub compatibility_score: Option<f32>, #[serde(default)]
66 pub parameters: Option<String>,
67 #[serde(default)]
69 pub context_length: Option<u64>,
70 #[serde(default)]
72 pub provider: Option<String>,
73 #[serde(default)]
75 pub total_shards: Option<u32>,
76 #[serde(default)]
78 pub shard_filenames: Vec<String>,
79 #[serde(default)]
81 pub downloads: u64,
82 #[serde(default)]
84 pub is_gated: bool,
85 #[serde(default)]
87 pub pricing: Option<ModelPricing>,
88}
89
90pub struct ModelRegistry {
92 storage: Arc<ModelStorage>,
93 models: HashMap<String, ModelInfo>,
94 known_sources: Vec<ModelSource>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ModelSource {
100 pub name: String,
101 pub url: String,
102 pub api_type: SourceType,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub enum SourceType {
108 HuggingFace,
109 OpenRouter,
110}
111
112#[derive(Debug, Deserialize)]
113struct OpenRouterModelsResponse {
114 data: Vec<OpenRouterModel>,
115}
116
117#[derive(Debug, Deserialize, Default)]
120struct OpenRouterPricing {
121 #[serde(default)]
123 prompt: String,
124 #[serde(default)]
126 completion: String,
127}
128
129impl OpenRouterPricing {
130 fn is_free(&self) -> bool {
132 (self.prompt == "0" || self.prompt.is_empty())
133 && (self.completion == "0" || self.completion.is_empty())
134 }
135}
136
137#[derive(Debug, Deserialize)]
138struct OpenRouterModel {
139 id: String,
140 name: Option<String>,
141 description: Option<String>,
142 context_length: Option<u64>,
143 #[serde(default)]
144 architecture: Option<OpenRouterArchitecture>,
145 #[serde(default)]
147 pricing: Option<OpenRouterPricing>,
148}
149
150#[derive(Debug, Deserialize, Default)]
151struct OpenRouterArchitecture {
152 #[serde(default)]
153 modality: Option<String>,
154 #[serde(default)]
155 tokenizer: Option<String>,
156 #[serde(default)]
158 instruct_type: Option<String>,
159}
160
161#[derive(Debug, Deserialize)]
163struct HuggingFaceModel {
164 id: String,
165 #[serde(rename = "modelId")]
166 model_id: Option<String>,
167 author: Option<String>,
168 downloads: Option<u64>,
169 #[serde(default)]
171 gated: Option<serde_json::Value>,
172 #[serde(default)]
173 tags: Vec<String>,
174 #[serde(default)]
175 siblings: Vec<HuggingFaceSibling>,
176}
177
178#[derive(Debug, Deserialize)]
179struct HuggingFaceSibling {
180 rfilename: String,
181 #[serde(default)]
182 size: Option<u64>,
183}
184
185impl ModelRegistry {
186 pub fn new(storage: Arc<ModelStorage>) -> Result<Self> {
187 let mut registry = Self {
188 storage,
189 models: HashMap::new(),
190 known_sources: vec![
191 ModelSource {
192 name: "Hugging Face".to_string(),
193 url: "https://huggingface.co".to_string(),
194 api_type: SourceType::HuggingFace,
195 },
196 ModelSource {
197 name: "OpenRouter".to_string(),
198 url: "https://openrouter.ai".to_string(),
199 api_type: SourceType::OpenRouter,
200 },
201 ],
202 };
203
204 registry.load_registry()?;
206
207 registry.populate_default_catalog();
209
210 Ok(registry)
211 }
212
213 pub async fn refresh_openrouter_catalog_from_api(
217 &mut self,
218 api_key: &str,
219 ) -> Result<()> {
220 let client = Client::new();
221 let resp = client
222 .get("https://openrouter.ai/api/v1/models")
223 .header("Authorization", format!("Bearer {}", api_key))
224 .send()
225 .await
226 .context("Failed to call OpenRouter /models API")?;
227
228 let resp = resp.error_for_status().context("OpenRouter /models returned error status")?;
229 let body: OpenRouterModelsResponse = resp
230 .json()
231 .await
232 .context("Failed to parse OpenRouter models response")?;
233
234 let mut openrouter_ids: HashSet<String> = HashSet::new();
236
237 for m in body.data.into_iter() {
238 let plain_id = m.id.clone();
239
240 let registry_id = format!("openrouter:{}", plain_id);
241 openrouter_ids.insert(registry_id.clone());
242
243 let is_free = m.pricing.as_ref().map_or(true, |p| p.is_free())
245 || plain_id.ends_with(":free");
246
247 let provider = plain_id
249 .split('/')
250 .next()
251 .unwrap_or("")
252 .to_lowercase();
253
254 let mut tags = vec![
255 "api".to_string(),
256 "online".to_string(),
257 "cloud".to_string(),
258 ];
259 if is_free {
260 tags.push("free".to_string());
261 } else {
262 tags.push("paid".to_string());
263 }
264 if !provider.is_empty() {
265 tags.push(format!("provider:{}", provider));
266 }
267
268 if let Some(ctx) = m.context_length {
269 if ctx >= 128_000 {
270 tags.push("context:xl".to_string());
271 } else if ctx >= 32_000 {
272 tags.push("context:large".to_string());
273 } else if ctx >= 8_000 {
274 tags.push("context:medium".to_string());
275 } else {
276 tags.push("context:small".to_string());
277 }
278 }
279
280 let parameters = {
282 let name_str = m.name.as_deref().unwrap_or(&plain_id);
283 let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
285 re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
286 };
287
288 let provider_display = if !provider.is_empty() {
290 let mut chars = provider.chars();
291 match chars.next() {
292 Some(c) => Some(c.to_uppercase().collect::<String>() + chars.as_str()),
293 None => None,
294 }
295 } else {
296 None
297 };
298
299 let pricing = m.pricing.as_ref().map(|p| ModelPricing {
300 prompt: p.prompt.clone(),
301 completion: p.completion.clone(),
302 });
303
304 let model_info = ModelInfo {
305 id: registry_id.clone(),
306 name: m.name.clone().unwrap_or_else(|| plain_id.clone()),
307 description: m.description.clone(),
308 author: provider_display.clone(),
309 status: ModelStatus::Available,
310 size_bytes: 0,
311 format: "api".to_string(),
312 download_source: Some("openrouter".to_string()),
313 filename: None,
314 installed_version: None,
315 last_updated: None,
316 tags,
317 compatibility_score: None,
318 parameters,
319 context_length: m.context_length,
320 provider: provider_display,
321 total_shards: None,
322 shard_filenames: vec![],
323 downloads: 0,
324 is_gated: false,
325 pricing,
326 };
327
328 self.models
330 .entry(registry_id.clone())
331 .and_modify(|existing| {
332 existing.name = model_info.name.clone();
333 existing.description = model_info.description.clone();
334 existing.status = model_info.status.clone();
335 existing.format = model_info.format.clone();
336 existing.download_source = model_info.download_source.clone();
337 existing.tags = model_info.tags.clone();
338 existing.pricing = model_info.pricing.clone();
339 })
340 .or_insert(model_info);
341 }
342
343 self.models.retain(|id, model| {
345 if model.download_source.as_deref() == Some("openrouter") {
346 openrouter_ids.contains(id)
347 } else {
348 true
349 }
350 });
351
352 info!("Refreshed OpenRouter catalog, now tracking {} models", openrouter_ids.len());
353
354 Ok(())
355 }
356
357 pub async fn refresh_huggingface_catalog_from_api(
360 &mut self,
361 limit: usize,
362 ) -> Result<()> {
363 let client = Client::new();
364
365 let url = format!(
367 "https://huggingface.co/api/models?filter=gguf&sort=downloads&direction=-1&limit={}&full=true",
368 limit
369 );
370
371 let resp = client
372 .get(&url)
373 .header("User-Agent", "Aud.io-Desktop/1.0")
374 .send()
375 .await
376 .context("Failed to call Hugging Face models API")?;
377
378 let resp = resp
379 .error_for_status()
380 .context("Hugging Face API returned error status")?;
381
382 let models: Vec<HuggingFaceModel> = resp
383 .json()
384 .await
385 .context("Failed to parse Hugging Face models response")?;
386
387 let mut hf_ids: HashSet<String> = HashSet::new();
388
389 for m in models.into_iter() {
390 let repo_id = m.model_id.as_ref().unwrap_or(&m.id).clone();
391
392 let is_gated = match &m.gated {
396 Some(serde_json::Value::Bool(false)) | None => false,
397 _ => true, };
399
400 let gguf_files: Vec<&HuggingFaceSibling> = m
402 .siblings
403 .iter()
404 .filter(|s| {
405 s.rfilename.ends_with(".gguf") || s.rfilename.ends_with(".ggml")
406 })
407 .collect();
408
409 if gguf_files.is_empty() {
410 continue;
411 }
412
413 let mut sharded_model_info = None;
415 for file in &gguf_files {
416 if let Some(total_shards) = self.detect_shard_pattern_internal(&file.rfilename) {
417 let all_shards = self.collect_shards_internal(&gguf_files, total_shards);
419
420 let total_size = all_shards.iter()
422 .map(|s| s.size.unwrap_or(0))
423 .sum();
424
425 let registry_id = repo_id.clone();
426 hf_ids.insert(registry_id.clone());
427
428 let format = if file.rfilename.ends_with(".gguf") {
430 "gguf"
431 } else {
432 "ggml"
433 }
434 .to_string();
435
436 let mut tags: Vec<String> = m
438 .tags
439 .iter()
440 .filter(|t| !t.is_empty() && *t != "gguf" && *t != "ggml")
441 .take(5)
442 .cloned()
443 .collect();
444 tags.push("offline".to_string());
445 tags.push(format.clone());
446 tags.push("sharded".to_string()); let name = repo_id
450 .split('/')
451 .last()
452 .unwrap_or(&repo_id)
453 .replace("-GGUF", "")
454 .replace("-gguf", "");
455
456 let parameters = {
458 let name_str = &name;
459 let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
460 re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
461 };
462
463 sharded_model_info = Some(ModelInfo {
464 id: registry_id.clone(),
465 name,
466 description: Some(format!("Sharded GGUF model from {} ({} parts)", repo_id, total_shards)),
467 author: m.author.clone(),
468 status: ModelStatus::Available,
469 size_bytes: total_size,
470 format,
471 download_source: Some("huggingface".to_string()),
472 filename: Some(file.rfilename.clone()), installed_version: None,
474 last_updated: None,
475 tags,
476 compatibility_score: None,
477 parameters,
478 context_length: None,
479 provider: None,
480 total_shards: Some(total_shards),
481 shard_filenames: all_shards.iter().map(|s| s.rfilename.clone()).collect(),
482 downloads: m.downloads.unwrap_or(0),
483 is_gated,
484 pricing: None,
485 });
486 break; }
488 }
489
490 let model_info = if let Some(sharded_info) = sharded_model_info {
492 sharded_info
493 } else {
494 let preferred_file = gguf_files
496 .iter()
497 .find(|f| f.rfilename.contains("Q4_K_M"))
498 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q5_K_M")))
499 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q6_K")))
500 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q8_0")))
501 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ3_XXS")))
502 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ3_S")))
503 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ4_NL")))
504 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ4_XS")))
505 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q3_K_S")))
506 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q3_K_M")))
507 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q3_K_L")))
508 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q5_K_S")))
509 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q5_K_L")))
510 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q2_K")))
511 .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q2_K_S")))
512 .or_else(|| gguf_files.first())
513 .copied();
514
515 let Some(file) = preferred_file else {
516 continue;
517 };
518
519 let registry_id = repo_id.clone();
520 hf_ids.insert(registry_id.clone());
521
522 let format = if file.rfilename.ends_with(".gguf") {
524 "gguf"
525 } else {
526 "ggml"
527 }
528 .to_string();
529
530 let mut tags: Vec<String> = m
532 .tags
533 .iter()
534 .filter(|t| !t.is_empty() && *t != "gguf" && *t != "ggml")
535 .take(5)
536 .cloned()
537 .collect();
538 tags.push("offline".to_string());
539 tags.push(format.clone());
540
541 let name = repo_id
543 .split('/')
544 .last()
545 .unwrap_or(&repo_id)
546 .replace("-GGUF", "")
547 .replace("-gguf", "");
548
549 let parameters = {
551 let name_str = &name;
552 let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
553 re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
554 };
555
556 ModelInfo {
557 id: registry_id.clone(),
558 name,
559 description: Some(format!("GGUF model from {}", repo_id)),
560 author: m.author.clone(),
561 status: ModelStatus::Available,
562 size_bytes: file.size.unwrap_or(0),
563 format,
564 download_source: Some("huggingface".to_string()),
565 filename: Some(file.rfilename.clone()),
566 installed_version: None,
567 last_updated: None,
568 tags,
569 compatibility_score: None,
570 parameters,
571 context_length: None,
572 provider: None,
573 total_shards: None,
574 shard_filenames: vec![],
575 downloads: m.downloads.unwrap_or(0),
576 is_gated,
577 pricing: None,
578 }
579 };
580
581 self.models
583 .entry(model_info.id.clone())
584 .and_modify(|existing| {
585 if existing.status != ModelStatus::Installed {
587 existing.name = model_info.name.clone();
588 existing.description = model_info.description.clone();
589 existing.author = model_info.author.clone();
590 existing.size_bytes = model_info.size_bytes;
591 existing.format = model_info.format.clone();
592 existing.download_source = model_info.download_source.clone();
593 existing.filename = model_info.filename.clone();
594 existing.tags = model_info.tags.clone();
595 existing.total_shards = model_info.total_shards;
596 existing.shard_filenames = model_info.shard_filenames.clone();
597 existing.is_gated = model_info.is_gated;
598 }
599 })
600 .or_insert(model_info);
601 }
602
603 info!(
604 "Refreshed Hugging Face catalog, now tracking {} GGUF/GGML models",
605 hf_ids.len()
606 );
607
608 Ok(())
609 }
610
611 pub fn update_compatibility_scores(
615 &mut self,
616 recommender: &ModelRecommender,
617 hardware: &HardwareProfile,
618 ) {
619 for model in self.models.values_mut() {
620 let is_offline_format = model.format.eq_ignore_ascii_case("gguf")
623 || model.format.eq_ignore_ascii_case("ggml");
624 let is_api_model = model.download_source.as_deref() == Some("openrouter");
625
626 if is_offline_format && !is_api_model {
627 let score = recommender.score_model_compatibility(model, hardware);
628 model.compatibility_score = Some(score);
629 }
630 }
631 }
632
633 fn load_registry(&mut self) -> Result<()> {
635 let registry_path = self.storage.location.registry_dir.join("registry.json");
636 if registry_path.exists() {
637 match std::fs::read_to_string(®istry_path) {
638 Ok(content) if !content.trim().is_empty() => {
639 match serde_json::from_str::<HashMap<String, ModelInfo>>(&content) {
640 Ok(saved_models) => {
641 self.models = saved_models;
642 info!("Loaded {} models from registry", self.models.len());
643 }
644 Err(e) => {
645 warn!("Registry file corrupted, starting fresh: {}", e);
646 }
647 }
648 }
649 Ok(_) => {
650 debug!("Registry file is empty, starting fresh");
651 }
652 Err(e) => {
653 warn!("Failed to read registry file: {}", e);
654 }
655 }
656 }
657 Ok(())
658 }
659
660 pub async fn scan_storage(&mut self) -> Result<()> {
662 let model_ids = self.storage.list_models()?;
663
664 for model_id in model_ids {
665 if let Some(metadata) = self.load_model_metadata(&model_id).await? {
666 let model_info = ModelInfo {
667 id: model_id.clone(),
668 name: metadata.name,
669 description: metadata.description,
670 author: metadata.author,
671 status: ModelStatus::Installed,
672 size_bytes: metadata.size_bytes,
673 format: metadata.format,
674 download_source: Some(metadata.download_source),
675 filename: None, installed_version: None, last_updated: Some(metadata.download_date),
678 tags: metadata.tags,
679 compatibility_score: None, parameters: None,
681 context_length: None,
682 provider: None,
683 total_shards: None,
684 shard_filenames: vec![],
685 downloads: 0,
686 is_gated: false,
687 pricing: None,
688 };
689
690 self.models.insert(model_id, model_info);
691 }
692 }
693
694 info!("Scanned storage and found {} models", self.models.len());
695 Ok(())
696 }
697
698 async fn load_model_metadata(&self, model_id: &str) -> Result<Option<ModelMetadata>> {
700 let metadata_path = self.storage.metadata_path(model_id);
701
702 if metadata_path.exists() {
703 let content = tokio::fs::read_to_string(&metadata_path).await?;
704 let metadata: ModelMetadata = serde_json::from_str(&content)?;
705 Ok(Some(metadata))
706 } else {
707 Ok(None)
708 }
709 }
710
711 pub async fn update_model_status_from_storage(&mut self, model_id: &str) -> Result<()> {
713 if let Some(model_info) = self.models.get_mut(model_id) {
714 let model_exists = self.storage.model_exists(model_id);
715
716 if model_exists {
717 model_info.status = ModelStatus::Installed;
718 } else {
719 if matches!(model_info.status, ModelStatus::Installed) {
721 model_info.status = ModelStatus::Available;
722 }
723 }
724 }
725
726 Ok(())
727 }
728
729 pub async fn update_all_model_statuses_from_storage(&mut self) -> Result<()> {
731 let model_ids: Vec<String> = self.models.keys().cloned().collect();
732
733 for model_id in model_ids {
734 self.update_model_status_from_storage(&model_id).await?;
735 }
736
737 Ok(())
738 }
739
740 pub fn get_installed_model_path(&self, model_id: &str) -> Option<std::path::PathBuf> {
742 let model_info = self.models.get(model_id)?;
743 if model_info.status != ModelStatus::Installed {
744 return None;
745 }
746
747 if let Some(filename) = &model_info.filename {
749 return Some(self.storage.model_path(model_id, filename));
750 }
751
752 let temp_path = self.storage.model_path(model_id, "dummy");
754 let model_dir = match temp_path.parent() {
755 Some(dir) => dir.to_path_buf(),
756 None => return None,
757 };
758 if !model_dir.exists() {
759 return None;
760 }
761
762 if let Ok(entries) = std::fs::read_dir(&model_dir) {
763 for entry in entries.flatten() {
764 if let Ok(file_type) = entry.file_type() {
765 if file_type.is_file() {
766 let path = entry.path();
767 let ext = path.extension().unwrap_or_default().to_string_lossy().to_lowercase();
768 if matches!(ext.as_str(), "gguf" | "bin" | "ggml" | "onnx" | "trt" | "engine" | "safetensors" | "mlmodel") {
769 return Some(path);
770 }
771 }
772 }
773 }
774 }
775
776 None
777 }
778
779 pub async fn get_model_metadata(&self, model_id: &str) -> Option<ModelMetadata> {
781 match self.load_model_metadata(model_id).await {
782 Ok(Some(metadata)) => Some(metadata),
783 _ => None,
784 }
785 }
786
787 pub fn add_model(&mut self, model_info: ModelInfo) {
789 self.models.insert(model_info.id.clone(), model_info);
790 }
791
792 pub fn get_model(&self, model_id: &str) -> Option<&ModelInfo> {
794 self.models.get(model_id)
795 }
796
797 pub fn get_model_mut(&mut self, model_id: &str) -> Option<&mut ModelInfo> {
799 self.models.get_mut(model_id)
800 }
801
802 pub fn list_models(&self) -> Vec<&ModelInfo> {
804 self.models.values().collect()
805 }
806
807 pub fn list_models_by_status(&self, status: ModelStatus) -> Vec<&ModelInfo> {
809 self.models.values()
810 .filter(|model| model.status == status)
811 .collect()
812 }
813
814 pub fn search_models(&self, query: &str) -> Vec<&ModelInfo> {
816 let query_lower = query.to_lowercase();
817 self.models.values()
818 .filter(|model| {
819 model.name.to_lowercase().contains(&query_lower) ||
820 model.description.as_ref().map_or(false, |desc| desc.to_lowercase().contains(&query_lower)) ||
821 model.tags.iter().any(|tag| tag.to_lowercase().contains(&query_lower))
822 })
823 .collect()
824 }
825
826 pub fn get_recommended_models(&self, max_results: usize) -> Vec<&ModelInfo> {
828 let mut models: Vec<_> = self.models.values().collect();
829 models.sort_by(|a, b| {
830 b.compatibility_score.unwrap_or(0.0)
831 .partial_cmp(&a.compatibility_score.unwrap_or(0.0))
832 .unwrap_or(std::cmp::Ordering::Equal)
833 });
834 models.truncate(max_results);
835 models
836 }
837
838 pub fn update_model_status(&mut self, model_id: &str, status: ModelStatus) {
840 if let Some(model) = self.models.get_mut(model_id) {
841 model.status = status;
842 }
843 }
844
845 pub fn remove_model(&mut self, model_id: &str) -> bool {
847 self.models.remove(model_id).is_some()
848 }
849
850 pub fn get_statistics(&self) -> RegistryStats {
852 let mut stats = RegistryStats::default();
853
854 for model in self.models.values() {
855 match model.status {
856 ModelStatus::Installed => stats.installed_count += 1,
857 ModelStatus::Downloading => stats.downloading_count += 1,
858 ModelStatus::Available => stats.available_count += 1,
859 ModelStatus::Error(_) => stats.error_count += 1,
860 }
861 stats.total_size_bytes += model.size_bytes;
862 }
863
864 stats
865 }
866
867 pub fn get_models_by_category(&self, category: &str) -> Vec<&ModelInfo> {
869 self.models.values()
870 .filter(|model| {
871 model.tags.iter().any(|tag|
872 tag.to_lowercase().contains(&category.to_lowercase())
873 )
874 })
875 .collect()
876 }
877
878 pub fn get_trending_models(&self, limit: usize) -> Vec<&ModelInfo> {
880 let mut models: Vec<&ModelInfo> = self.models.values()
881 .filter(|model| {
882 model.tags.iter().any(|tag|
884 tag == "popular" || tag == "trending" || tag == "featured"
885 )
886 })
887 .collect();
888
889 models.sort_by(|a, b| b.size_bytes.cmp(&a.size_bytes));
891 models.truncate(limit);
892 models
893 }
894
895 pub fn get_models_by_task(&self, task: &str) -> Vec<&ModelInfo> {
897 self.models.values()
898 .filter(|model| {
899 model.name.to_lowercase().contains(&task.to_lowercase()) ||
901 model.description.as_ref().map_or(false, |desc|
902 desc.to_lowercase().contains(&task.to_lowercase())) ||
903 model.tags.iter().any(|tag|
904 tag.to_lowercase().contains(&task.to_lowercase()))
905 })
906 .collect()
907 }
908
909 pub async fn save_registry(&self) -> Result<()> {
911 let registry_path = self.storage.location.registry_dir.join("registry.json");
912 let content = serde_json::to_string_pretty(&self.models)
913 .context("Failed to serialize registry")?;
914 tokio::fs::write(®istry_path, content).await
915 .context("Failed to write registry file")?;
916 debug!("Saved {} models to registry", self.models.len());
917 Ok(())
918 }
919
920 pub fn populate_default_catalog(&mut self) {
924 let catalog = Self::get_default_catalog();
925 let catalog_ids: std::collections::HashSet<String> = catalog.iter().map(|m| m.id.clone()).collect();
926
927 let stale_ids: Vec<String> = self.models.iter()
929 .filter(|(id, m)| {
930 m.download_source.as_deref() == Some("ollama")
932 })
934 .map(|(id, _)| id.clone())
935 .collect();
936 for id in &stale_ids {
937 self.models.remove(id);
938 }
939 if !stale_ids.is_empty() {
940 info!("Removed {} stale/obsolete models from registry", stale_ids.len());
941 }
942
943 let mut added = 0;
944 for model in catalog {
945 if !self.models.contains_key(&model.id) {
946 self.models.insert(model.id.clone(), model);
947 added += 1;
948 }
949 }
950 if added > 0 {
951 info!("Populated catalog with {} new available models", added);
952 }
953 }
954
955 fn get_default_catalog() -> Vec<ModelInfo> {
958 vec![]
959 }
960
961 pub async fn populate_default_openrouter_models(&mut self) {
964 if let Err(e) = self.fetch_public_openrouter_models().await {
965 warn!("Failed to fetch public OpenRouter models: {}", e);
966 }
968 }
969
970 async fn fetch_public_openrouter_models(&mut self) -> Result<()> {
973 let client = Client::new();
974 let resp = client
975 .get("https://openrouter.ai/api/v1/models")
976 .send()
977 .await
978 .context("Failed to call OpenRouter public /models API")?;
979
980 let resp = resp.error_for_status().context("OpenRouter public /models returned error status")?;
981 let body: OpenRouterModelsResponse = resp
982 .json()
983 .await
984 .context("Failed to parse OpenRouter public models response")?;
985
986 let mut added = 0;
987
988 for m in body.data.into_iter() {
989 let plain_id = m.id.clone();
990
991 if self.is_invalid_openrouter_model(&plain_id) {
993 debug!("Skipping invalid model: {}", plain_id);
994 continue;
995 }
996
997 let registry_id = format!("openrouter:{}", plain_id);
998
999 let provider = plain_id
1001 .split('/')
1002 .next()
1003 .unwrap_or("")
1004 .to_lowercase();
1005
1006 let mut tags = vec![
1007 "api".to_string(),
1008 "online".to_string(),
1009 "cloud".to_string(),
1010 ];
1011 if !provider.is_empty() {
1012 tags.push(format!("provider:{}", provider));
1013 }
1014
1015 if let Some(ctx) = m.context_length {
1016 if ctx >= 128_000 {
1017 tags.push("context:xl".to_string());
1018 } else if ctx >= 32_000 {
1019 tags.push("context:large".to_string());
1020 } else if ctx >= 8_000 {
1021 tags.push("context:medium".to_string());
1022 } else {
1023 tags.push("context:small".to_string());
1024 }
1025 }
1026
1027 let parameters = {
1029 let name_str = m.name.as_deref().unwrap_or(&plain_id);
1030 let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
1032 re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
1033 };
1034
1035 let provider_display = if !provider.is_empty() {
1037 let mut chars = provider.chars();
1038 match chars.next() {
1039 Some(c) => Some(c.to_uppercase().collect::<String>() + chars.as_str()),
1040 None => None,
1041 }
1042 } else {
1043 None
1044 };
1045
1046 let is_free = m.pricing.as_ref().map_or(true, |p| p.is_free())
1047 || plain_id.ends_with(":free");
1048 if is_free {
1049 tags.push("free".to_string());
1050 } else {
1051 tags.push("paid".to_string());
1052 }
1053
1054 let pricing = m.pricing.as_ref().map(|p| ModelPricing {
1055 prompt: p.prompt.clone(),
1056 completion: p.completion.clone(),
1057 });
1058
1059 let model_info = ModelInfo {
1060 id: registry_id.clone(),
1061 name: m.name.clone().unwrap_or_else(|| plain_id.clone()),
1062 description: m.description.clone(),
1063 author: provider_display.clone(),
1064 status: ModelStatus::Available,
1065 size_bytes: 0,
1066 format: "api".to_string(),
1067 download_source: Some("openrouter".to_string()),
1068 filename: None,
1069 installed_version: None,
1070 last_updated: Some(chrono::Utc::now()),
1071 tags,
1072 compatibility_score: None,
1073 parameters,
1074 context_length: m.context_length,
1075 provider: provider_display,
1076 total_shards: None,
1077 shard_filenames: vec![],
1078 downloads: 0,
1079 is_gated: false,
1080 pricing,
1081 };
1082
1083 self.models.insert(registry_id, model_info);
1085 added += 1;
1086 }
1087
1088 info!("Fetched {} public OpenRouter models from API", added);
1089 Ok(())
1090 }
1091
1092 fn detect_shard_pattern_internal(&self, filename: &str) -> Option<u32> {
1094 let re = regex::Regex::new(r".*-(\d{5})-of-(\d{5})\.[^.]+$").ok()?;
1096 if let Some(caps) = re.captures(filename) {
1097 if let Some(total_str) = caps.get(2) {
1098 if let Ok(total) = total_str.as_str().parse::<u32>() {
1099 return Some(total);
1100 }
1101 }
1102 }
1103 None
1104 }
1105
1106 fn collect_shards_internal<'a>(&self, gguf_files: &[&'a HuggingFaceSibling], total_shards: u32) -> Vec<&'a HuggingFaceSibling> {
1108 let mut shards = Vec::new();
1109
1110 if let Some(first_file) = gguf_files.iter().find(|f| self.detect_shard_pattern_internal(&f.rfilename).is_some()) {
1112 if let Some(caps) = regex::Regex::new(r"(.*-)(\d{5})(-of-\d{5}\.[^.]+)$")
1114 .ok()
1115 .and_then(|re| re.captures(&first_file.rfilename)) {
1116
1117 let prefix = caps[1].to_string(); let suffix = caps[3].to_string(); for i in 1..=total_shards {
1122 let expected_filename = format!("{}{:05}{}", prefix, i, suffix);
1123 if let Some(file) = gguf_files.iter().find(|f| f.rfilename == expected_filename) {
1124 shards.push(*file);
1125 }
1126 }
1127 }
1128 }
1129
1130 shards
1131 }
1132
1133 fn is_invalid_openrouter_model(&self, model_id: &str) -> bool {
1135 model_id == "google/gemini-pro" ||
1137 model_id == "google/palm-2-chat-bison" ||
1138 model_id.starts_with("google/palm") ||
1139 model_id.starts_with("google/gemini-pro") ||
1140 false
1142 }
1143}
1144
1145#[derive(Debug, Default)]
1147pub struct RegistryStats {
1148 pub installed_count: usize,
1149 pub downloading_count: usize,
1150 pub available_count: usize,
1151 pub error_count: usize,
1152 pub total_size_bytes: u64,
1153}
1154
1155impl RegistryStats {
1156 pub fn total_models(&self) -> usize {
1157 self.installed_count + self.downloading_count + self.available_count + self.error_count
1158 }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164 use tempfile::TempDir;
1165
1166 #[tokio::test]
1167 async fn test_registry_creation() -> Result<()> {
1168 let temp_dir = TempDir::new()?;
1169 let storage = Arc::new(ModelStorage {
1170 location: super::super::storage::StorageLocation {
1171 app_data_dir: temp_dir.path().to_path_buf(),
1172 models_dir: temp_dir.path().join("models"),
1173 registry_dir: temp_dir.path().join("registry"),
1174 },
1175 });
1176
1177 let registry = ModelRegistry::new(storage)?;
1178 assert_eq!(registry.models.len(), 0);
1179
1180 Ok(())
1181 }
1182
1183 #[tokio::test]
1184 async fn test_model_addition_and_lookup() -> Result<()> {
1185 let temp_dir = TempDir::new()?;
1186 let storage = Arc::new(ModelStorage {
1187 location: super::super::storage::StorageLocation {
1188 app_data_dir: temp_dir.path().to_path_buf(),
1189 models_dir: temp_dir.path().join("models"),
1190 registry_dir: temp_dir.path().join("registry"),
1191 },
1192 });
1193
1194 let mut registry = ModelRegistry::new(storage)?;
1195
1196 let model_info = ModelInfo {
1197 id: "test-model".to_string(),
1198 name: "Test Model".to_string(),
1199 description: Some("A test model".to_string()),
1200 author: Some("Test Author".to_string()),
1201 status: ModelStatus::Available,
1202 size_bytes: 1024,
1203 format: "gguf".to_string(),
1204 download_source: Some("huggingface".to_string()),
1205 filename: None,
1206 installed_version: None,
1207 last_updated: None,
1208 tags: vec!["test".to_string()],
1209 compatibility_score: Some(0.8),
1210 parameters: None,
1211 context_length: None,
1212 provider: None,
1213 total_shards: None,
1214 shard_filenames: vec![],
1215 downloads: 0,
1216 is_gated: false,
1217 pricing: None,
1218 };
1219
1220 registry.add_model(model_info);
1221 assert_eq!(registry.models.len(), 1);
1222
1223 let retrieved = registry.get_model("test-model");
1224 assert!(retrieved.is_some());
1225 assert_eq!(retrieved.unwrap().name, "Test Model");
1226
1227 Ok(())
1228 }
1229}