Skip to main content

offline_intelligence/model_management/
registry.rs

1//! Model Registry
2//!
3//! Manages model metadata, tracks installed models, and provides
4//! querying capabilities for available models.
5
6use super::storage::{ModelStorage, ModelMetadata};
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tracing::{debug, info, warn};
12
13use super::recommendation::{HardwareProfile, ModelRecommender};
14use reqwest::Client;
15
16/// Status of a model in the registry
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum ModelStatus {
19    /// Model is available locally
20    Installed,
21    /// Model is being downloaded
22    Downloading,
23    /// Model is available for download
24    Available,
25    /// Model had an error during download/installation
26    Error(String),
27}
28
29/// Public pricing information for an API model (e.g., OpenRouter).
30/// Both fields are decimal strings; "0" means the model is free.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ModelPricing {
33    /// Cost per prompt token — "0" means free
34    pub prompt: String,
35    /// Cost per completion token — "0" means free
36    pub completion: String,
37}
38
39impl ModelPricing {
40    pub fn is_free(&self) -> bool {
41        (self.prompt == "0" || self.prompt.is_empty())
42            && (self.completion == "0" || self.completion.is_empty())
43    }
44}
45
46/// Information about a model in the registry
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ModelInfo {
49    pub id: String,
50    pub name: String,
51    pub description: Option<String>,
52    pub author: Option<String>,
53    pub status: ModelStatus,
54    pub size_bytes: u64,
55    pub format: String,
56    pub download_source: Option<String>,
57    /// Specific filename to download (for HuggingFace models with non-standard naming)
58    #[serde(default)]
59    pub filename: Option<String>,
60    pub installed_version: Option<String>,
61    pub last_updated: Option<chrono::DateTime<chrono::Utc>>,
62    pub tags: Vec<String>,
63    pub compatibility_score: Option<f32>, // 0.0 to 1.0 based on hardware match
64    /// Parameter count string (e.g., "7B", "70B", "671B")
65    #[serde(default)]
66    pub parameters: Option<String>,
67    /// Context length in tokens
68    #[serde(default)]
69    pub context_length: Option<u64>,
70    /// Provider name (for OpenRouter models)
71    #[serde(default)]
72    pub provider: Option<String>,
73    /// Total number of shards for sharded models (None for single-file models)
74    #[serde(default)]
75    pub total_shards: Option<u32>,
76    /// List of all shard filenames for sharded models
77    #[serde(default)]
78    pub shard_filenames: Vec<String>,
79    /// Download count (for HuggingFace models)
80    #[serde(default)]
81    pub downloads: u64,
82    /// Whether this HuggingFace model requires access approval from the repo owner
83    #[serde(default)]
84    pub is_gated: bool,
85    /// Pricing info for OpenRouter API models (None for offline/HF models)
86    #[serde(default)]
87    pub pricing: Option<ModelPricing>,
88}
89
90/// Model registry manager
91pub struct ModelRegistry {
92    storage: Arc<ModelStorage>,
93    models: HashMap<String, ModelInfo>,
94    known_sources: Vec<ModelSource>,
95}
96
97/// Source where models can be found
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct ModelSource {
100    pub name: String,
101    pub url: String,
102    pub api_type: SourceType,
103}
104
105/// Type of model source API
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub enum SourceType {
108    HuggingFace,
109    OpenRouter,
110}
111
112#[derive(Debug, Deserialize)]
113struct OpenRouterModelsResponse {
114    data: Vec<OpenRouterModel>,
115}
116
117/// Pricing block returned by the OpenRouter /models API.
118/// Free models have both fields as the string "0".
119#[derive(Debug, Deserialize, Default)]
120struct OpenRouterPricing {
121    /// Cost per prompt token as a string, e.g. "0" for free.
122    #[serde(default)]
123    prompt: String,
124    /// Cost per completion token as a string, e.g. "0" for free.
125    #[serde(default)]
126    completion: String,
127}
128
129impl OpenRouterPricing {
130    /// Returns true when both prompt and completion are zero-cost.
131    fn is_free(&self) -> bool {
132        (self.prompt == "0" || self.prompt.is_empty())
133            && (self.completion == "0" || self.completion.is_empty())
134    }
135}
136
137#[derive(Debug, Deserialize)]
138struct OpenRouterModel {
139    id: String,
140    name: Option<String>,
141    description: Option<String>,
142    context_length: Option<u64>,
143    #[serde(default)]
144    architecture: Option<OpenRouterArchitecture>,
145    /// Pricing info — present in the live API but optional for backwards compat.
146    #[serde(default)]
147    pricing: Option<OpenRouterPricing>,
148}
149
150#[derive(Debug, Deserialize, Default)]
151struct OpenRouterArchitecture {
152    #[serde(default)]
153    modality: Option<String>,
154    #[serde(default)]
155    tokenizer: Option<String>,
156    /// Instruction type like "none" or "vicuna" etc.
157    #[serde(default)]
158    instruct_type: Option<String>,
159}
160
161/// Hugging Face API model response
162#[derive(Debug, Deserialize)]
163struct HuggingFaceModel {
164    id: String,
165    #[serde(rename = "modelId")]
166    model_id: Option<String>,
167    author: Option<String>,
168    downloads: Option<u64>,
169    /// Gated status: false, "auto", or "manual". Gated models require user approval.
170    #[serde(default)]
171    gated: Option<serde_json::Value>,
172    #[serde(default)]
173    tags: Vec<String>,
174    #[serde(default)]
175    siblings: Vec<HuggingFaceSibling>,
176}
177
178#[derive(Debug, Deserialize)]
179struct HuggingFaceSibling {
180    rfilename: String,
181    #[serde(default)]
182    size: Option<u64>,
183}
184
185impl ModelRegistry {
186    pub fn new(storage: Arc<ModelStorage>) -> Result<Self> {
187        let mut registry = Self {
188            storage,
189            models: HashMap::new(),
190            known_sources: vec![
191                ModelSource {
192                    name: "Hugging Face".to_string(),
193                    url: "https://huggingface.co".to_string(),
194                    api_type: SourceType::HuggingFace,
195                },
196                ModelSource {
197                    name: "OpenRouter".to_string(),
198                    url: "https://openrouter.ai".to_string(),
199                    api_type: SourceType::OpenRouter,
200                },
201            ],
202        };
203
204        // Load existing registry data
205        registry.load_registry()?;
206
207        // Populate default catalog (only adds models not already present)
208        registry.populate_default_catalog();
209
210        Ok(registry)
211    }
212
213    /// Refresh the OpenRouter model catalog from the live OpenRouter API.
214    /// This replaces any existing OpenRouter entries with the up-to-date list
215    /// of models visible to the provided API key.
216    pub async fn refresh_openrouter_catalog_from_api(
217        &mut self,
218        api_key: &str,
219    ) -> Result<()> {
220        let client = Client::new();
221        let resp = client
222            .get("https://openrouter.ai/api/v1/models")
223            .header("Authorization", format!("Bearer {}", api_key))
224            .send()
225            .await
226            .context("Failed to call OpenRouter /models API")?;
227
228        let resp = resp.error_for_status().context("OpenRouter /models returned error status")?;
229        let body: OpenRouterModelsResponse = resp
230            .json()
231            .await
232            .context("Failed to parse OpenRouter models response")?;
233
234        // Track which OpenRouter model IDs are present in the latest response
235        let mut openrouter_ids: HashSet<String> = HashSet::new();
236
237        for m in body.data.into_iter() {
238            let plain_id = m.id.clone();
239
240            let registry_id = format!("openrouter:{}", plain_id);
241            openrouter_ids.insert(registry_id.clone());
242
243            // Determine free vs paid for tagging purposes
244            let is_free = m.pricing.as_ref().map_or(true, |p| p.is_free())
245                || plain_id.ends_with(":free");
246
247            // Derive provider tag from the model id prefix (e.g. "openai/gpt-4o")
248            let provider = plain_id
249                .split('/')
250                .next()
251                .unwrap_or("")
252                .to_lowercase();
253
254            let mut tags = vec![
255                "api".to_string(),
256                "online".to_string(),
257                "cloud".to_string(),
258            ];
259            if is_free {
260                tags.push("free".to_string());
261            } else {
262                tags.push("paid".to_string());
263            }
264            if !provider.is_empty() {
265                tags.push(format!("provider:{}", provider));
266            }
267
268            if let Some(ctx) = m.context_length {
269                if ctx >= 128_000 {
270                    tags.push("context:xl".to_string());
271                } else if ctx >= 32_000 {
272                    tags.push("context:large".to_string());
273                } else if ctx >= 8_000 {
274                    tags.push("context:medium".to_string());
275                } else {
276                    tags.push("context:small".to_string());
277                }
278            }
279
280            // Extract parameter count from model name (e.g., "70B", "7B", "671B")
281            let parameters = {
282                let name_str = m.name.as_deref().unwrap_or(&plain_id);
283                // Match patterns like "70B", "7B", "671B", "1.5B", "8x7B"
284                let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
285                re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
286            };
287
288            // Capitalize provider name for display
289            let provider_display = if !provider.is_empty() {
290                let mut chars = provider.chars();
291                match chars.next() {
292                    Some(c) => Some(c.to_uppercase().collect::<String>() + chars.as_str()),
293                    None => None,
294                }
295            } else {
296                None
297            };
298
299            let pricing = m.pricing.as_ref().map(|p| ModelPricing {
300                prompt: p.prompt.clone(),
301                completion: p.completion.clone(),
302            });
303
304            let model_info = ModelInfo {
305                id: registry_id.clone(),
306                name: m.name.clone().unwrap_or_else(|| plain_id.clone()),
307                description: m.description.clone(),
308                author: provider_display.clone(),
309                status: ModelStatus::Available,
310                size_bytes: 0,
311                format: "api".to_string(),
312                download_source: Some("openrouter".to_string()),
313                filename: None,
314                installed_version: None,
315                last_updated: None,
316                tags,
317                compatibility_score: None,
318                parameters,
319                context_length: m.context_length,
320                provider: provider_display,
321                total_shards: None,
322                shard_filenames: vec![],
323                downloads: 0,
324                is_gated: false,
325                pricing,
326            };
327
328            // Insert or update existing OpenRouter entry
329            self.models
330                .entry(registry_id.clone())
331                .and_modify(|existing| {
332                    existing.name = model_info.name.clone();
333                    existing.description = model_info.description.clone();
334                    existing.status = model_info.status.clone();
335                    existing.format = model_info.format.clone();
336                    existing.download_source = model_info.download_source.clone();
337                    existing.tags = model_info.tags.clone();
338                    existing.pricing = model_info.pricing.clone();
339                })
340                .or_insert(model_info);
341        }
342
343        // Remove any stale OpenRouter models that are no longer returned
344        self.models.retain(|id, model| {
345            if model.download_source.as_deref() == Some("openrouter") {
346                openrouter_ids.contains(id)
347            } else {
348                true
349            }
350        });
351
352        info!("Refreshed OpenRouter catalog, now tracking {} models", openrouter_ids.len());
353
354        Ok(())
355    }
356
357    /// Refresh the Hugging Face GGUF/GGML model catalog from the HF Hub API.
358    /// Fetches top models by downloads and extracts available quantized files.
359    pub async fn refresh_huggingface_catalog_from_api(
360        &mut self,
361        limit: usize,
362    ) -> Result<()> {
363        let client = Client::new();
364
365        // Fetch GGUF models sorted by downloads
366        let url = format!(
367            "https://huggingface.co/api/models?filter=gguf&sort=downloads&direction=-1&limit={}&full=true",
368            limit
369        );
370
371        let resp = client
372            .get(&url)
373            .header("User-Agent", "Aud.io-Desktop/1.0")
374            .send()
375            .await
376            .context("Failed to call Hugging Face models API")?;
377
378        let resp = resp
379            .error_for_status()
380            .context("Hugging Face API returned error status")?;
381
382        let models: Vec<HuggingFaceModel> = resp
383            .json()
384            .await
385            .context("Failed to parse Hugging Face models response")?;
386
387        let mut hf_ids: HashSet<String> = HashSet::new();
388
389        for m in models.into_iter() {
390            let repo_id = m.model_id.as_ref().unwrap_or(&m.id).clone();
391
392            // Flag gated models — they require HuggingFace access approval.
393            // We include them in the catalog so the UI can show a
394            // "Request Access" button instead of a direct download button.
395            let is_gated = match &m.gated {
396                Some(serde_json::Value::Bool(false)) | None => false,
397                _ => true, // "auto", "manual", true all count as gated
398            };
399
400            // Find GGUF files in siblings
401            let gguf_files: Vec<&HuggingFaceSibling> = m
402                .siblings
403                .iter()
404                .filter(|s| {
405                    s.rfilename.ends_with(".gguf") || s.rfilename.ends_with(".ggml")
406                })
407                .collect();
408
409            if gguf_files.is_empty() {
410                continue;
411            }
412
413            // Check if this is a sharded model by looking for shard patterns
414            let mut sharded_model_info = None;
415            for file in &gguf_files {
416                if let Some(total_shards) = self.detect_shard_pattern_internal(&file.rfilename) {
417                    // This is a sharded model, collect all shards
418                    let all_shards = self.collect_shards_internal(&gguf_files, total_shards);
419                    
420                    // Calculate total size
421                    let total_size = all_shards.iter()
422                        .map(|s| s.size.unwrap_or(0))
423                        .sum();
424                    
425                    let registry_id = repo_id.clone();
426                    hf_ids.insert(registry_id.clone());
427
428                    // Determine format from first shard
429                    let format = if file.rfilename.ends_with(".gguf") {
430                        "gguf"
431                    } else {
432                        "ggml"
433                    }
434                    .to_string();
435
436                    // Build tags from HF tags + our own
437                    let mut tags: Vec<String> = m
438                        .tags
439                        .iter()
440                        .filter(|t| !t.is_empty() && *t != "gguf" && *t != "ggml")
441                        .take(5)
442                        .cloned()
443                        .collect();
444                    tags.push("offline".to_string());
445                    tags.push(format.clone());
446                    tags.push("sharded".to_string()); // Add sharded tag
447
448                    // Derive a friendly name from repo_id
449                    let name = repo_id
450                        .split('/')
451                        .last()
452                        .unwrap_or(&repo_id)
453                        .replace("-GGUF", "")
454                        .replace("-gguf", "");
455
456                    // Extract parameter count from name
457                    let parameters = {
458                        let name_str = &name;
459                        let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
460                        re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
461                    };
462
463                    sharded_model_info = Some(ModelInfo {
464                        id: registry_id.clone(),
465                        name,
466                        description: Some(format!("Sharded GGUF model from {} ({} parts)", repo_id, total_shards)),
467                        author: m.author.clone(),
468                        status: ModelStatus::Available,
469                        size_bytes: total_size,
470                        format,
471                        download_source: Some("huggingface".to_string()),
472                        filename: Some(file.rfilename.clone()), // Store the first shard as the primary filename
473                        installed_version: None,
474                        last_updated: None,
475                        tags,
476                        compatibility_score: None,
477                        parameters,
478                        context_length: None,
479                        provider: None,
480                        total_shards: Some(total_shards),
481                        shard_filenames: all_shards.iter().map(|s| s.rfilename.clone()).collect(),
482                        downloads: m.downloads.unwrap_or(0),
483                        is_gated,
484                        pricing: None,
485                    });
486                    break; // Found sharded model, no need to check other files
487                }
488            }
489
490            // If we found a sharded model, use that info; otherwise use the preferred single file
491            let model_info = if let Some(sharded_info) = sharded_model_info {
492                sharded_info
493            } else {
494                // Prefer Q4_K_M, Q5_K_M, Q6_K, Q8_0, IQ_X_X quantizations (good balance of size/quality)
495                let preferred_file = gguf_files
496                    .iter()
497                    .find(|f| f.rfilename.contains("Q4_K_M"))
498                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q5_K_M")))
499                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q6_K")))
500                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q8_0")))
501                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ3_XXS")))
502                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ3_S")))
503                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ4_NL")))
504                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("IQ4_XS")))
505                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q3_K_S")))
506                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q3_K_M")))
507                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q3_K_L")))
508                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q5_K_S")))
509                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q5_K_L")))
510                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q2_K")))
511                    .or_else(|| gguf_files.iter().find(|f| f.rfilename.contains("Q2_K_S")))
512                    .or_else(|| gguf_files.first())
513                    .copied();
514
515                let Some(file) = preferred_file else {
516                    continue;
517                };
518
519                let registry_id = repo_id.clone();
520                hf_ids.insert(registry_id.clone());
521
522                // Determine format from filename
523                let format = if file.rfilename.ends_with(".gguf") {
524                    "gguf"
525                } else {
526                    "ggml"
527                }
528                .to_string();
529
530                // Build tags from HF tags + our own
531                let mut tags: Vec<String> = m
532                    .tags
533                    .iter()
534                    .filter(|t| !t.is_empty() && *t != "gguf" && *t != "ggml")
535                    .take(5)
536                    .cloned()
537                    .collect();
538                tags.push("offline".to_string());
539                tags.push(format.clone());
540
541                // Derive a friendly name from repo_id
542                let name = repo_id
543                    .split('/')
544                    .last()
545                    .unwrap_or(&repo_id)
546                    .replace("-GGUF", "")
547                    .replace("-gguf", "");
548
549                // Extract parameter count from name
550                let parameters = {
551                    let name_str = &name;
552                    let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
553                    re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
554                };
555
556                ModelInfo {
557                    id: registry_id.clone(),
558                    name,
559                    description: Some(format!("GGUF model from {}", repo_id)),
560                    author: m.author.clone(),
561                    status: ModelStatus::Available,
562                    size_bytes: file.size.unwrap_or(0),
563                    format,
564                    download_source: Some("huggingface".to_string()),
565                    filename: Some(file.rfilename.clone()),
566                    installed_version: None,
567                    last_updated: None,
568                    tags,
569                    compatibility_score: None,
570                    parameters,
571                    context_length: None,
572                    provider: None,
573                    total_shards: None,
574                    shard_filenames: vec![],
575                    downloads: m.downloads.unwrap_or(0),
576                    is_gated,
577                    pricing: None,
578                }
579            };
580
581            // Insert or update
582            self.models
583                .entry(model_info.id.clone())
584                .and_modify(|existing| {
585                    // Don't overwrite installed models
586                    if existing.status != ModelStatus::Installed {
587                        existing.name = model_info.name.clone();
588                        existing.description = model_info.description.clone();
589                        existing.author = model_info.author.clone();
590                        existing.size_bytes = model_info.size_bytes;
591                        existing.format = model_info.format.clone();
592                        existing.download_source = model_info.download_source.clone();
593                        existing.filename = model_info.filename.clone();
594                        existing.tags = model_info.tags.clone();
595                        existing.total_shards = model_info.total_shards;
596                        existing.shard_filenames = model_info.shard_filenames.clone();
597                        existing.is_gated = model_info.is_gated;
598                    }
599                })
600                .or_insert(model_info);
601        }
602
603        info!(
604            "Refreshed Hugging Face catalog, now tracking {} GGUF/GGML models",
605            hf_ids.len()
606        );
607
608        Ok(())
609    }
610
611    /// Recompute compatibility scores for all known local models based on
612    /// the current hardware profile and user preferences. This is used by
613    /// the model manager to support "Best Match" sorting in the UI.
614    pub fn update_compatibility_scores(
615        &mut self,
616        recommender: &ModelRecommender,
617        hardware: &HardwareProfile,
618    ) {
619        for model in self.models.values_mut() {
620            // Only score offline models (GGUF/GGML) – API models and other
621            // formats are not constrained by local hardware.
622            let is_offline_format = model.format.eq_ignore_ascii_case("gguf")
623                || model.format.eq_ignore_ascii_case("ggml");
624            let is_api_model = model.download_source.as_deref() == Some("openrouter");
625
626            if is_offline_format && !is_api_model {
627                let score = recommender.score_model_compatibility(model, hardware);
628                model.compatibility_score = Some(score);
629            }
630        }
631    }
632
633    /// Load registry data from persistent storage
634    fn load_registry(&mut self) -> Result<()> {
635        let registry_path = self.storage.location.registry_dir.join("registry.json");
636        if registry_path.exists() {
637            match std::fs::read_to_string(&registry_path) {
638                Ok(content) if !content.trim().is_empty() => {
639                    match serde_json::from_str::<HashMap<String, ModelInfo>>(&content) {
640                        Ok(saved_models) => {
641                            self.models = saved_models;
642                            info!("Loaded {} models from registry", self.models.len());
643                        }
644                        Err(e) => {
645                            warn!("Registry file corrupted, starting fresh: {}", e);
646                        }
647                    }
648                }
649                Ok(_) => {
650                    debug!("Registry file is empty, starting fresh");
651                }
652                Err(e) => {
653                    warn!("Failed to read registry file: {}", e);
654                }
655            }
656        }
657        Ok(())
658    }
659
660    /// Scan local storage for existing models and populate registry
661    pub async fn scan_storage(&mut self) -> Result<()> {
662        let model_ids = self.storage.list_models()?;
663        
664        for model_id in model_ids {
665            if let Some(metadata) = self.load_model_metadata(&model_id).await? {
666                let model_info = ModelInfo {
667                    id: model_id.clone(),
668                    name: metadata.name,
669                    description: metadata.description,
670                    author: metadata.author,
671                    status: ModelStatus::Installed,
672                    size_bytes: metadata.size_bytes,
673                    format: metadata.format,
674                    download_source: Some(metadata.download_source),
675                    filename: None, // Already downloaded, filename not needed
676                    installed_version: None, // Version extracted from model metadata
677                    last_updated: Some(metadata.download_date),
678                    tags: metadata.tags,
679                    compatibility_score: None, // Will be calculated on demand
680                    parameters: None,
681                    context_length: None,
682                    provider: None,
683                    total_shards: None,
684                    shard_filenames: vec![],
685                    downloads: 0,
686                    is_gated: false,
687                    pricing: None,
688                };
689                
690                self.models.insert(model_id, model_info);
691            }
692        }
693        
694        info!("Scanned storage and found {} models", self.models.len());
695        Ok(())
696    }
697
698    /// Load metadata for a specific model
699    async fn load_model_metadata(&self, model_id: &str) -> Result<Option<ModelMetadata>> {
700        let metadata_path = self.storage.metadata_path(model_id);
701        
702        if metadata_path.exists() {
703            let content = tokio::fs::read_to_string(&metadata_path).await?;
704            let metadata: ModelMetadata = serde_json::from_str(&content)?;
705            Ok(Some(metadata))
706        } else {
707            Ok(None)
708        }
709    }
710
711    /// Update model status based on file existence
712    pub async fn update_model_status_from_storage(&mut self, model_id: &str) -> Result<()> {
713        if let Some(model_info) = self.models.get_mut(model_id) {
714            let model_exists = self.storage.model_exists(model_id);
715            
716            if model_exists {
717                model_info.status = ModelStatus::Installed;
718            } else {
719                // If it was installed but no longer exists, mark as available (downloadable)
720                if matches!(model_info.status, ModelStatus::Installed) {
721                    model_info.status = ModelStatus::Available;
722                }
723            }
724        }
725        
726        Ok(())
727    }
728
729    /// Update all model statuses based on file existence in storage
730    pub async fn update_all_model_statuses_from_storage(&mut self) -> Result<()> {
731        let model_ids: Vec<String> = self.models.keys().cloned().collect();
732        
733        for model_id in model_ids {
734            self.update_model_status_from_storage(&model_id).await?;
735        }
736        
737        Ok(())
738    }
739
740    /// Get the path of an installed model by ID
741    pub fn get_installed_model_path(&self, model_id: &str) -> Option<std::path::PathBuf> {
742        let model_info = self.models.get(model_id)?;
743        if model_info.status != ModelStatus::Installed {
744            return None;
745        }
746        
747        // Try to get the filename from model_info, otherwise look for any model file in the directory
748        if let Some(filename) = &model_info.filename {
749            return Some(self.storage.model_path(model_id, filename));
750        }
751        
752        // Look for any model file in the directory
753        let temp_path = self.storage.model_path(model_id, "dummy");
754        let model_dir = match temp_path.parent() {
755            Some(dir) => dir.to_path_buf(),
756            None => return None,
757        };
758        if !model_dir.exists() {
759            return None;
760        }
761        
762        if let Ok(entries) = std::fs::read_dir(&model_dir) {
763            for entry in entries.flatten() {
764                if let Ok(file_type) = entry.file_type() {
765                    if file_type.is_file() {
766                        let path = entry.path();
767                        let ext = path.extension().unwrap_or_default().to_string_lossy().to_lowercase();
768                        if matches!(ext.as_str(), "gguf" | "bin" | "ggml" | "onnx" | "trt" | "engine" | "safetensors" | "mlmodel") {
769                            return Some(path);
770                        }
771                    }
772                }
773            }
774        }
775        
776        None
777    }
778    
779    /// Get the complete model metadata including runtime binaries information
780    pub async fn get_model_metadata(&self, model_id: &str) -> Option<ModelMetadata> {
781        match self.load_model_metadata(model_id).await {
782            Ok(Some(metadata)) => Some(metadata),
783            _ => None,
784        }
785    }
786
787    /// Add a model to the registry
788    pub fn add_model(&mut self, model_info: ModelInfo) {
789        self.models.insert(model_info.id.clone(), model_info);
790    }
791
792    /// Get model information by ID
793    pub fn get_model(&self, model_id: &str) -> Option<&ModelInfo> {
794        self.models.get(model_id)
795    }
796
797    /// Get mutable reference to model information
798    pub fn get_model_mut(&mut self, model_id: &str) -> Option<&mut ModelInfo> {
799        self.models.get_mut(model_id)
800    }
801
802    /// List all models in registry
803    pub fn list_models(&self) -> Vec<&ModelInfo> {
804        self.models.values().collect()
805    }
806
807    /// List models by status
808    pub fn list_models_by_status(&self, status: ModelStatus) -> Vec<&ModelInfo> {
809        self.models.values()
810            .filter(|model| model.status == status)
811            .collect()
812    }
813
814    /// Search models by name or tags
815    pub fn search_models(&self, query: &str) -> Vec<&ModelInfo> {
816        let query_lower = query.to_lowercase();
817        self.models.values()
818            .filter(|model| {
819                model.name.to_lowercase().contains(&query_lower) ||
820                model.description.as_ref().map_or(false, |desc| desc.to_lowercase().contains(&query_lower)) ||
821                model.tags.iter().any(|tag| tag.to_lowercase().contains(&query_lower))
822            })
823            .collect()
824    }
825
826    /// Get models sorted by compatibility score for current hardware
827    pub fn get_recommended_models(&self, max_results: usize) -> Vec<&ModelInfo> {
828        let mut models: Vec<_> = self.models.values().collect();
829        models.sort_by(|a, b| {
830            b.compatibility_score.unwrap_or(0.0)
831                .partial_cmp(&a.compatibility_score.unwrap_or(0.0))
832                .unwrap_or(std::cmp::Ordering::Equal)
833        });
834        models.truncate(max_results);
835        models
836    }
837
838    /// Update model status
839    pub fn update_model_status(&mut self, model_id: &str, status: ModelStatus) {
840        if let Some(model) = self.models.get_mut(model_id) {
841            model.status = status;
842        }
843    }
844
845    /// Remove a model from registry
846    pub fn remove_model(&mut self, model_id: &str) -> bool {
847        self.models.remove(model_id).is_some()
848    }
849
850    /// Get registry statistics
851    pub fn get_statistics(&self) -> RegistryStats {
852        let mut stats = RegistryStats::default();
853        
854        for model in self.models.values() {
855            match model.status {
856                ModelStatus::Installed => stats.installed_count += 1,
857                ModelStatus::Downloading => stats.downloading_count += 1,
858                ModelStatus::Available => stats.available_count += 1,
859                ModelStatus::Error(_) => stats.error_count += 1,
860            }
861            stats.total_size_bytes += model.size_bytes;
862        }
863        
864        stats
865    }
866
867    /// Get models by category/tags
868    pub fn get_models_by_category(&self, category: &str) -> Vec<&ModelInfo> {
869        self.models.values()
870            .filter(|model| {
871                model.tags.iter().any(|tag| 
872                    tag.to_lowercase().contains(&category.to_lowercase())
873                )
874            })
875            .collect()
876    }
877
878    /// Get trending models (recently added or popular tags)
879    pub fn get_trending_models(&self, limit: usize) -> Vec<&ModelInfo> {
880        let mut models: Vec<&ModelInfo> = self.models.values()
881            .filter(|model| {
882                // Filter for popular models (based on certain tags)
883                model.tags.iter().any(|tag| 
884                    tag == "popular" || tag == "trending" || tag == "featured"
885                )
886            })
887            .collect();
888        
889        // Sort by some criteria (e.g., size as proxy for popularity, or by name)
890        models.sort_by(|a, b| b.size_bytes.cmp(&a.size_bytes));
891        models.truncate(limit);
892        models
893    }
894
895    /// Get models by task type (e.g., "chat", "coding", "text-generation")
896    pub fn get_models_by_task(&self, task: &str) -> Vec<&ModelInfo> {
897        self.models.values()
898            .filter(|model| {
899                // Look in name, description and tags for the task
900                model.name.to_lowercase().contains(&task.to_lowercase()) ||
901                model.description.as_ref().map_or(false, |desc| 
902                    desc.to_lowercase().contains(&task.to_lowercase())) ||
903                model.tags.iter().any(|tag| 
904                    tag.to_lowercase().contains(&task.to_lowercase()))
905            })
906            .collect()
907    }
908
909    /// Save registry to persistent storage
910    pub async fn save_registry(&self) -> Result<()> {
911        let registry_path = self.storage.location.registry_dir.join("registry.json");
912        let content = serde_json::to_string_pretty(&self.models)
913            .context("Failed to serialize registry")?;
914        tokio::fs::write(&registry_path, content).await
915            .context("Failed to write registry file")?;
916        debug!("Saved {} models to registry", self.models.len());
917        Ok(())
918    }
919
920    /// Populate the registry with well-known models from all sources.
921    /// Only adds models that are not already in the registry.
922    /// Also removes stale models that are no longer available.
923    pub fn populate_default_catalog(&mut self) {
924        let catalog = Self::get_default_catalog();
925        let catalog_ids: std::collections::HashSet<String> = catalog.iter().map(|m| m.id.clone()).collect();
926
927        // Remove stale models: Ollama models (functionality removed) and any other obsolete models
928        let stale_ids: Vec<String> = self.models.iter()
929            .filter(|(id, m)| {
930                // Remove all Ollama models (functionality removed)
931                m.download_source.as_deref() == Some("ollama")
932                // Note: We no longer remove OpenRouter models based on static catalog since they come from live API
933            })
934            .map(|(id, _)| id.clone())
935            .collect();
936        for id in &stale_ids {
937            self.models.remove(id);
938        }
939        if !stale_ids.is_empty() {
940            info!("Removed {} stale/obsolete models from registry", stale_ids.len());
941        }
942
943        let mut added = 0;
944        for model in catalog {
945            if !self.models.contains_key(&model.id) {
946                self.models.insert(model.id.clone(), model);
947                added += 1;
948            }
949        }
950        if added > 0 {
951            info!("Populated catalog with {} new available models", added);
952        }
953    }
954
955    /// Returns the built-in catalog of well-known models
956    /// Currently returns an empty vector as models are loaded dynamically from APIs
957    fn get_default_catalog() -> Vec<ModelInfo> {
958        vec![]
959    }
960
961    /// Populate registry with OpenRouter models fetched from the public API.
962    /// This is called when no API key is available to show users what models are available.
963    pub async fn populate_default_openrouter_models(&mut self) {
964        if let Err(e) = self.fetch_public_openrouter_models().await {
965            warn!("Failed to fetch public OpenRouter models: {}", e);
966            // Optionally, could fall back to a minimal static list if API fails
967        }
968    }
969
970    /// Fetch public models from OpenRouter API without authentication.
971    /// This method fetches the publicly available models that anyone can see.
972    async fn fetch_public_openrouter_models(&mut self) -> Result<()> {
973        let client = Client::new();
974        let resp = client
975            .get("https://openrouter.ai/api/v1/models")
976            .send()
977            .await
978            .context("Failed to call OpenRouter public /models API")?;
979
980        let resp = resp.error_for_status().context("OpenRouter public /models returned error status")?;
981        let body: OpenRouterModelsResponse = resp
982            .json()
983            .await
984            .context("Failed to parse OpenRouter public models response")?;
985
986        let mut added = 0;
987
988        for m in body.data.into_iter() {
989            let plain_id = m.id.clone();
990            
991            // Validate and skip known deprecated/invalid models
992            if self.is_invalid_openrouter_model(&plain_id) {
993                debug!("Skipping invalid model: {}", plain_id);
994                continue;
995            }
996            
997            let registry_id = format!("openrouter:{}", plain_id);
998
999            // Derive provider tag from the model id prefix (e.g. "openai/gpt-4o")
1000            let provider = plain_id
1001                .split('/')
1002                .next()
1003                .unwrap_or("")
1004                .to_lowercase();
1005
1006            let mut tags = vec![
1007                "api".to_string(),
1008                "online".to_string(),
1009                "cloud".to_string(),
1010            ];
1011            if !provider.is_empty() {
1012                tags.push(format!("provider:{}", provider));
1013            }
1014
1015            if let Some(ctx) = m.context_length {
1016                if ctx >= 128_000 {
1017                    tags.push("context:xl".to_string());
1018                } else if ctx >= 32_000 {
1019                    tags.push("context:large".to_string());
1020                } else if ctx >= 8_000 {
1021                    tags.push("context:medium".to_string());
1022                } else {
1023                    tags.push("context:small".to_string());
1024                }
1025            }
1026
1027            // Extract parameter count from model name (e.g., "70B", "7B", "671B")
1028            let parameters = {
1029                let name_str = m.name.as_deref().unwrap_or(&plain_id);
1030                // Match patterns like "70B", "7B", "671B", "1.5B", "8x7B"
1031                let re = regex::Regex::new(r"(\d+(?:\.\d+)?(?:x\d+)?[BMK])").ok();
1032                re.and_then(|r| r.find(name_str).map(|m| m.as_str().to_string()))
1033            };
1034
1035            // Capitalize provider name for display
1036            let provider_display = if !provider.is_empty() {
1037                let mut chars = provider.chars();
1038                match chars.next() {
1039                    Some(c) => Some(c.to_uppercase().collect::<String>() + chars.as_str()),
1040                    None => None,
1041                }
1042            } else {
1043                None
1044            };
1045
1046            let is_free = m.pricing.as_ref().map_or(true, |p| p.is_free())
1047                || plain_id.ends_with(":free");
1048            if is_free {
1049                tags.push("free".to_string());
1050            } else {
1051                tags.push("paid".to_string());
1052            }
1053
1054            let pricing = m.pricing.as_ref().map(|p| ModelPricing {
1055                prompt: p.prompt.clone(),
1056                completion: p.completion.clone(),
1057            });
1058
1059            let model_info = ModelInfo {
1060                id: registry_id.clone(),
1061                name: m.name.clone().unwrap_or_else(|| plain_id.clone()),
1062                description: m.description.clone(),
1063                author: provider_display.clone(),
1064                status: ModelStatus::Available,
1065                size_bytes: 0,
1066                format: "api".to_string(),
1067                download_source: Some("openrouter".to_string()),
1068                filename: None,
1069                installed_version: None,
1070                last_updated: Some(chrono::Utc::now()),
1071                tags,
1072                compatibility_score: None,
1073                parameters,
1074                context_length: m.context_length,
1075                provider: provider_display,
1076                total_shards: None,
1077                shard_filenames: vec![],
1078                downloads: 0,
1079                is_gated: false,
1080                pricing,
1081            };
1082
1083            // Insert or update - this will replace any existing entry
1084            self.models.insert(registry_id, model_info);
1085            added += 1;
1086        }
1087
1088        info!("Fetched {} public OpenRouter models from API", added);
1089        Ok(())
1090    }
1091
1092    /// Detect if the filename follows a shard pattern (e.g., model-00001-of-00003.gguf)
1093    fn detect_shard_pattern_internal(&self, filename: &str) -> Option<u32> {
1094        // Pattern: some-name-00001-of-00003.ext
1095        let re = regex::Regex::new(r".*-(\d{5})-of-(\d{5})\.[^.]+$").ok()?;
1096        if let Some(caps) = re.captures(filename) {
1097            if let Some(total_str) = caps.get(2) {
1098                if let Ok(total) = total_str.as_str().parse::<u32>() {
1099                    return Some(total);
1100                }
1101            }
1102        }
1103        None
1104    }
1105
1106    /// Collect all shards for a given total_shards number
1107    fn collect_shards_internal<'a>(&self, gguf_files: &[&'a HuggingFaceSibling], total_shards: u32) -> Vec<&'a HuggingFaceSibling> {
1108        let mut shards = Vec::new();
1109        
1110        // Find the pattern from one of the shard files
1111        if let Some(first_file) = gguf_files.iter().find(|f| self.detect_shard_pattern_internal(&f.rfilename).is_some()) {
1112            // Extract the pattern from the first file to find other shards
1113            if let Some(caps) = regex::Regex::new(r"(.*-)(\d{5})(-of-\d{5}\.[^.]+)$")
1114                .ok()
1115                .and_then(|re| re.captures(&first_file.rfilename)) {
1116                
1117                let prefix = caps[1].to_string();  // Owned string to avoid lifetime issues
1118                let suffix = caps[3].to_string();  // Owned string to avoid lifetime issues
1119                
1120                // Collect all expected shard files
1121                for i in 1..=total_shards {
1122                    let expected_filename = format!("{}{:05}{}", prefix, i, suffix);
1123                    if let Some(file) = gguf_files.iter().find(|f| f.rfilename == expected_filename) {
1124                        shards.push(*file);
1125                    }
1126                }
1127            }
1128        }
1129        
1130        shards
1131    }
1132
1133    /// Check if an OpenRouter model ID is invalid/deprecated
1134    fn is_invalid_openrouter_model(&self, model_id: &str) -> bool {
1135        // Known deprecated/invalid models
1136        model_id == "google/gemini-pro" || 
1137        model_id == "google/palm-2-chat-bison" ||
1138        model_id.starts_with("google/palm") ||
1139        model_id.starts_with("google/gemini-pro") ||
1140        // Additional checks can be added here as needed
1141        false
1142    }
1143}
1144
1145/// Registry statistics
1146#[derive(Debug, Default)]
1147pub struct RegistryStats {
1148    pub installed_count: usize,
1149    pub downloading_count: usize,
1150    pub available_count: usize,
1151    pub error_count: usize,
1152    pub total_size_bytes: u64,
1153}
1154
1155impl RegistryStats {
1156    pub fn total_models(&self) -> usize {
1157        self.installed_count + self.downloading_count + self.available_count + self.error_count
1158    }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163    use super::*;
1164    use tempfile::TempDir;
1165
1166    #[tokio::test]
1167    async fn test_registry_creation() -> Result<()> {
1168        let temp_dir = TempDir::new()?;
1169        let storage = Arc::new(ModelStorage {
1170            location: super::super::storage::StorageLocation {
1171                app_data_dir: temp_dir.path().to_path_buf(),
1172                models_dir: temp_dir.path().join("models"),
1173                registry_dir: temp_dir.path().join("registry"),
1174            },
1175        });
1176        
1177        let registry = ModelRegistry::new(storage)?;
1178        assert_eq!(registry.models.len(), 0);
1179        
1180        Ok(())
1181    }
1182
1183    #[tokio::test]
1184    async fn test_model_addition_and_lookup() -> Result<()> {
1185        let temp_dir = TempDir::new()?;
1186        let storage = Arc::new(ModelStorage {
1187            location: super::super::storage::StorageLocation {
1188                app_data_dir: temp_dir.path().to_path_buf(),
1189                models_dir: temp_dir.path().join("models"),
1190                registry_dir: temp_dir.path().join("registry"),
1191            },
1192        });
1193        
1194        let mut registry = ModelRegistry::new(storage)?;
1195        
1196        let model_info = ModelInfo {
1197            id: "test-model".to_string(),
1198            name: "Test Model".to_string(),
1199            description: Some("A test model".to_string()),
1200            author: Some("Test Author".to_string()),
1201            status: ModelStatus::Available,
1202            size_bytes: 1024,
1203            format: "gguf".to_string(),
1204            download_source: Some("huggingface".to_string()),
1205            filename: None,
1206            installed_version: None,
1207            last_updated: None,
1208            tags: vec!["test".to_string()],
1209            compatibility_score: Some(0.8),
1210            parameters: None,
1211            context_length: None,
1212            provider: None,
1213            total_shards: None,
1214            shard_filenames: vec![],
1215            downloads: 0,
1216            is_gated: false,
1217            pricing: None,
1218        };
1219        
1220        registry.add_model(model_info);
1221        assert_eq!(registry.models.len(), 1);
1222        
1223        let retrieved = registry.get_model("test-model");
1224        assert!(retrieved.is_some());
1225        assert_eq!(retrieved.unwrap().name, "Test Model");
1226        
1227        Ok(())
1228    }
1229}