Skip to main content

offline_intelligence/api/
model_api.rs

1//! Model Management API Endpoints
2//!
3//! Provides RESTful API endpoints for:
4//! - Listing available and installed models
5//! - Searching for models
6//! - Downloading/installing models
7//! - Removing/uninstalling models
8//! - Getting download progress
9//! - Hardware recommendations
10
11use axum::{
12    extract::{Query, State},
13    http::StatusCode,
14    response::IntoResponse,
15    Json,
16};
17use serde::{Deserialize, Serialize};
18use tracing::{error, info, warn};
19use std::env;
20
21use crate::{
22    model_management::{
23        downloader::DownloadSource,
24        registry::{ModelInfo, ModelStatus},
25        recommendation::{ModelRecommender, UseCase, QualityPreference, SpeedPreference, CostSensitivity},
26        ModelManager,
27    },
28    model_runtime::RuntimeManager,
29    shared_state::UnifiedAppState,
30    memory_db::ApiKeyType,
31};
32
33/// Request to install/download a model
34#[derive(Debug, Deserialize)]
35pub struct InstallModelRequest {
36    pub model_id: String,
37    pub model_name: String,
38    pub source: ModelSourceSpecifier,
39    pub description: Option<String>,
40    pub size_bytes: u64,
41    pub format: String,
42    /// Optional HuggingFace token for gated/private models
43    pub hf_token: Option<String>,
44}
45
46/// Specify where to download a model from
47#[derive(Debug, Deserialize)]
48#[serde(tag = "type")]
49pub enum ModelSourceSpecifier {
50    HuggingFace { repo_id: String, filename: String },
51    OpenRouter { model_id: String },
52}
53
54/// Response for model installation
55#[derive(Debug, Serialize)]
56pub struct InstallModelResponse {
57    pub download_id: String,
58    pub message: String,
59}
60
61/// Response for the currently active/loaded model
62#[derive(Debug, Serialize)]
63pub struct ActiveModelResponse {
64    pub model_path: String,
65    pub model_name: String,
66    pub format: String,
67    pub context_size: u32,
68    pub gpu_layers: u32,
69    pub backend_url: String,
70    pub status: String,
71}
72
73/// Request to search for models
74#[derive(Debug, Deserialize)]
75pub struct SearchModelsRequest {
76    pub query: String,
77    pub limit: Option<usize>,
78}
79
80/// Response containing search results
81#[derive(Debug, Serialize)]
82pub struct SearchModelsResponse {
83    pub models: Vec<ModelInfo>,
84    pub total_found: usize,
85}
86
87/// Request to refresh the dynamic model catalog
88#[derive(Debug, Deserialize)]
89pub struct RefreshModelsRequest {
90    /// Which source to refresh: "openrouter", "huggingface", or "all" (default)
91    pub source: Option<String>,
92    /// Optional OpenRouter API key supplied by the frontend. If not provided,
93    /// the backend will fall back to the OPENROUTER_API_KEY environment
94    /// variable when refreshing OpenRouter models.
95    pub openrouter_api_key: Option<String>,
96    /// Optional HuggingFace token for gated/private models
97    pub hf_token: Option<String>,
98}
99
100/// Response after refreshing the model catalog
101#[derive(Debug, Serialize)]
102pub struct RefreshModelsResponse {
103    pub updated_sources: Vec<String>,
104    pub total_models: usize,
105}
106
107/// Request to update user preferences
108#[derive(Debug, Deserialize)]
109pub struct UpdatePreferencesRequest {
110    pub primary_use_case: Option<String>,
111    pub quality_preference: Option<String>,
112    pub speed_preference: Option<String>,
113    pub cost_sensitivity: Option<String>,
114}
115
116/// Response with hardware recommendations
117#[derive(Debug, Serialize)]
118pub struct HardwareRecommendationsResponse {
119    pub recommendations: Vec<String>,
120    pub message: String,
121}
122
123/// Request to switch to a different model
124#[derive(Debug, Deserialize)]
125pub struct SwitchModelRequest {
126    pub model_id: String,
127}
128
129/// Response after switching model
130#[derive(Debug, Serialize)]
131pub struct SwitchModelResponse {
132    pub message: String,
133    pub model_id: String,
134    pub model_path: String,
135}
136
137/// Helper function to clone models from registry
138async fn get_cloned_models(model_manager: &ModelManager) -> Vec<ModelInfo> {
139    let registry = model_manager.registry.read().await;
140    registry.list_models().iter().map(|m| (*m).clone()).collect()
141}
142
143/// Helper function to check if a verified API key exists in the database
144async fn has_verified_key(state: &UnifiedAppState, key_type: ApiKeyType) -> bool {
145    match state.shared_state.database_pool.api_keys.get_key_plaintext(&key_type) {
146        Ok(Some(_)) => true,
147        _ => false,
148    }
149}
150
151/// Get list of all models (installed and available)
152/// Filters out OpenRouter and HuggingFace models if no verified key exists
153pub async fn list_models(
154    State(state): State<UnifiedAppState>,
155) -> Result<impl IntoResponse, StatusCode> {
156    let model_manager = state.shared_state.model_manager.as_ref()
157        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
158
159    // Return ALL models regardless of API key presence
160    // Users should see available models before adding API keys
161    let models = get_cloned_models(model_manager).await;
162
163    Ok(Json(models))
164}
165
166/// Get models filtered by mode (online/offline)
167pub async fn list_models_by_mode(
168    State(state): State<UnifiedAppState>,
169    Query(params): Query<std::collections::HashMap<String, String>>,
170) -> Result<impl IntoResponse, StatusCode> {
171    let model_manager = state.shared_state.model_manager.as_ref()
172        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
173
174    let mode = params.get("mode")
175        .map(|s| s.as_str())
176        .unwrap_or("offline");
177
178    let all_models = get_cloned_models(model_manager).await;
179
180    // Check for verified keys (for display purposes - show all models regardless)
181    let has_openrouter_key = has_verified_key(&state, ApiKeyType::OpenRouter).await;
182    let has_huggingface_key = has_verified_key(&state, ApiKeyType::HuggingFace).await;
183
184    match mode {
185        "online" => {
186            // Show filtered models (free, context > 16k from premium providers)
187            // API key is only required when actually using the model for chat, not for viewing
188            let premium_providers = vec![
189                "openai",       // GPT-4, GPT-4o, GPT-4 Turbo
190                "google",       // Gemini
191                "anthropic",    // Claude
192                "deepseek",     // DeepSeek
193                "moonshot",    // Kimi
194                "qwen",        // Qwen
195                "zhipuai",     // GLM/智谱AI
196                "minimax",     // Minimax
197                "microsoft",   // Phi, Copilot
198                "meta-llama",  // Llama
199                "mistral",     // Mistral
200                "cohere",      // Command R
201                "sarvam",     // Sarvam AI
202                "x-ai",       // Grok
203                "nvidia",     // Nvidia
204                "amazon",     // Claude on Bedrock
205                "replicate",  // Various models
206            ];
207
208            // Context length thresholds (in tokens)
209            const MIN_PREMIUM_CTX: u64 = 32000;  // 32k minimum for premium
210            const MIN_STANDARD_CTX: u64 = 16000; // 16k minimum for standard
211
212            // Filter OpenRouter models: premium providers + good context length
213            let or_models: Vec<ModelInfo> = all_models
214                .into_iter()
215                .filter(|m| {
216                    if m.download_source.as_deref() != Some("openrouter") {
217                        return false;
218                    }
219                    let id_lower = m.id.to_lowercase();
220                    
221                    // Check if it's from a premium provider
222                    let is_premium = premium_providers.iter().any(|p| id_lower.contains(p));
223                    
224                    // Check context length
225                    let ctx_len = m.context_length.unwrap_or(0);
226                    let has_good_ctx = ctx_len >= MIN_STANDARD_CTX;
227                    
228                    // Include if premium provider OR has good context length
229                    is_premium || has_good_ctx
230                })
231                .collect();
232
233            // Sort: Premium providers first (by context length), then by provider name
234            let mut sorted_models: Vec<ModelInfo> = or_models;
235            sorted_models.sort_by(|a, b| {
236                let a_lower = a.id.to_lowercase();
237                let b_lower = b.id.to_lowercase();
238                
239                // Check premium status
240                let a_is_premium = premium_providers.iter().any(|p| a_lower.contains(p));
241                let b_is_premium = premium_providers.iter().any(|p| b_lower.contains(p));
242                
243                // Get context lengths (default to 0 if not set)
244                let a_ctx = a.context_length.unwrap_or(0);
245                let b_ctx = b.context_length.unwrap_or(0);
246                
247                match (a_is_premium, b_is_premium) {
248                    (true, false) => std::cmp::Ordering::Less,
249                    (false, true) => std::cmp::Ordering::Greater,
250                    _ => {
251                        // Both are premium - sort by context length (higher first)
252                        b_ctx.cmp(&a_ctx)
253                    }
254                }
255            });
256
257            // Deduplicate by keeping first occurrence of each base model
258            let mut seen = std::collections::HashSet::new();
259            sorted_models.retain(|m| {
260                // Extract base model name (remove openrouter: prefix and version suffixes)
261                let base_name = m.id
262                    .replace("openrouter:", "")
263                    .split(':')
264                    .next()
265                    .unwrap_or(&m.id)
266                    .to_string();
267                seen.insert(base_name).then(|| {
268                    // Keep this model
269                    true
270                }).unwrap_or(false)
271            });
272
273            // Limit to top 50 models
274            sorted_models.truncate(50);
275
276            Ok(Json(sorted_models))
277        }
278        "offline" => {
279            // Show ALL HuggingFace models (not just installed ones)
280            // API key is only required when downloading gated/private models, not for viewing
281            // Prioritize big tech authors at top
282            let big_tech_authors = vec![
283                "google",
284                "meta",
285                "microsoft",
286                "openai",
287                "anthropic",
288                "deepseek-ai",
289                "bigscience",
290                "EleutherAI",
291                "tiiuae",
292                "mistralai",
293                "01-ai",
294                "Qwen",
295                "THUDM",
296                "baai",
297            ];
298
299            // All non-gated HF models
300            let mut hf_models: Vec<ModelInfo> = all_models
301                .into_iter()
302                .filter(|m| {
303                    m.download_source.as_deref() == Some("huggingface") &&
304                    !matches!(m.status, ModelStatus::Error(_))
305                })
306                .collect();
307
308            // Sort: big tech authors first
309            hf_models.sort_by(|a, b| {
310                let a_lower = a.author.as_deref().unwrap_or("").to_lowercase();
311                let b_lower = b.author.as_deref().unwrap_or("").to_lowercase();
312                let a_is_big = big_tech_authors.iter().any(|p| a_lower.contains(p));
313                let b_is_big = big_tech_authors.iter().any(|p| b_lower.contains(p));
314                
315                match (a_is_big, b_is_big) {
316                    (true, false) => std::cmp::Ordering::Less,
317                    (false, true) => std::cmp::Ordering::Greater,
318                    _ => b.downloads.cmp(&a.downloads), // Then by downloads
319                }
320            });
321
322            // Limit to top 100 models
323            hf_models.truncate(100);
324
325            Ok(Json(hf_models))
326        }
327        _ => Err(StatusCode::BAD_REQUEST),
328    }
329}
330
331/// Get the currently active/loaded model info from the running server config
332pub async fn get_active_model(
333    State(state): State<UnifiedAppState>,
334) -> Json<ActiveModelResponse> {
335    let config = &state.shared_state.config;
336
337    // Prefer the runtime's live config (populated when a model is auto-loaded or
338    // activated via the UI) over the static startup config, which may be empty.
339    //
340    // The lock guard (RwLockReadGuard) is NOT Send, so it must be fully dropped
341    // before any .await call.  We clone the Arc inside a synchronous block, then
342    // call the async method outside that block.
343    let runtime_arc = state.shared_state.runtime_manager
344        .read()
345        .ok()
346        .and_then(|g| g.clone()); // guard dropped at end of this expression
347
348    let runtime_model_path: Option<String> = if let Some(rm) = runtime_arc {
349        rm.get_current_config().await
350            .map(|c| c.model_path.to_string_lossy().to_string())
351            .filter(|p| !p.is_empty())
352    } else {
353        None
354    };
355
356    let model_path = runtime_model_path
357        .as_deref()
358        .unwrap_or(&config.model_path);
359
360    let model_name = std::path::Path::new(model_path)
361        .file_stem()
362        .and_then(|s| s.to_str())
363        .unwrap_or("unknown")
364        .to_string();
365
366    let format = std::path::Path::new(model_path)
367        .extension()
368        .and_then(|s| s.to_str())
369        .unwrap_or("unknown")
370        .to_uppercase();
371
372    let file_exists = std::path::Path::new(model_path).exists();
373    let status = if file_exists { "loaded" } else { "not_found" }.to_string();
374
375    Json(ActiveModelResponse {
376        model_path: model_path.to_string(),
377        model_name,
378        format,
379        context_size: config.ctx_size,
380        gpu_layers: config.gpu_layers,
381        backend_url: config.backend_url.clone(),
382        status,
383    })
384}
385
386/// Search for models by name, description, or tags
387pub async fn search_models(
388    State(state): State<UnifiedAppState>,
389    Query(params): Query<SearchModelsRequest>,
390) -> Result<impl IntoResponse, StatusCode> {
391    let model_manager = state.shared_state.model_manager.as_ref()
392        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
393
394    let all_models = get_cloned_models(model_manager).await;
395    let query_lower = params.query.to_lowercase();
396    
397    let mut filtered_models: Vec<ModelInfo> = all_models
398        .into_iter()
399        .filter(|model| {
400            model.name.to_lowercase().contains(&query_lower) ||
401            model.description.as_ref().map_or(false, |desc| desc.to_lowercase().contains(&query_lower)) ||
402            model.tags.iter().any(|tag| tag.to_lowercase().contains(&query_lower))
403        })
404        .collect();
405
406    let total_found = filtered_models.len();
407    let limit = params.limit.unwrap_or(20).min(total_found);
408    
409    // Truncate to limit if needed
410    filtered_models.truncate(limit);
411
412    Ok(Json(SearchModelsResponse {
413        models: filtered_models,
414        total_found,
415    }))
416}
417
418/// Refresh the dynamic model catalog from remote sources (OpenRouter, Hugging Face).
419pub async fn refresh_models(
420    State(state): State<UnifiedAppState>,
421    Json(payload): Json<RefreshModelsRequest>,
422) -> Result<impl IntoResponse, StatusCode> {
423    let model_manager = state
424        .shared_state
425        .model_manager
426        .as_ref()
427        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
428
429    let source = payload.source.as_deref().unwrap_or("all");
430    let mut updated_sources = Vec::new();
431
432    // Refresh OpenRouter catalog if requested
433    if source == "openrouter" || source == "all" {
434        // Priority: 1) Frontend-passed key, 2) Database stored key, 3) Env var, 4) Config
435        let api_key = if let Some(key) = &payload.openrouter_api_key {
436            if !key.is_empty() { Some(key.clone()) } else { state.get_openrouter_api_key().await }
437        } else {
438            state.get_openrouter_api_key().await
439        };
440
441        if let Some(key) = api_key {
442            info!("Refreshing OpenRouter catalog with stored API key");
443            let mut registry = model_manager.registry.write().await;
444            if let Err(e) = registry.refresh_openrouter_catalog_from_api(&key).await {
445                error!("Failed to refresh OpenRouter catalog: {}", e);
446            } else {
447                updated_sources.push("openrouter".to_string());
448            }
449            if let Err(e) = registry.save_registry().await {
450                error!("Failed to save model registry after OpenRouter refresh: {}", e);
451            }
452        } else {
453            info!("No OpenRouter API key available - loading default OpenRouter catalog");
454            // Load default popular OpenRouter models so users can see what's available
455            let mut registry = model_manager.registry.write().await;
456            registry.populate_default_openrouter_models().await;
457            if let Err(e) = registry.save_registry().await {
458                error!("Failed to save model registry after populating default OpenRouter models: {}", e);
459            } else {
460                updated_sources.push("openrouter".to_string());
461            }
462        }
463    }
464
465    // Refresh Hugging Face GGUF/GGML catalog if requested
466    // Token is mainly used for downloading gated models, not for catalog refresh
467    if source == "huggingface" || source == "all" {
468        let mut registry = model_manager.registry.write().await;
469        // Fetch top 100 GGUF models by downloads
470        if let Err(e) = registry.refresh_huggingface_catalog_from_api(100).await {
471            error!("Failed to refresh HuggingFace catalog: {}", e);
472            // Continue - don't fail the entire refresh
473        } else {
474            updated_sources.push("huggingface".to_string());
475        }
476        if let Err(e) = registry.save_registry().await {
477            error!("Failed to save model registry after HuggingFace refresh: {}", e);
478        }
479    }
480
481    // Recompute compatibility scores for newly fetched offline models
482    if !updated_sources.is_empty() {
483        let cfg = &state.shared_state.config;
484        let hardware = crate::model_management::ModelRecommender::detect_hardware_profile(cfg);
485        let mut registry = model_manager.registry.write().await;
486        registry.update_compatibility_scores(&*model_manager.recommender, &hardware);
487        if let Err(e) = registry.save_registry().await {
488            error!("Failed to save registry after compatibility scoring: {}", e);
489        }
490    }
491
492    let total_models = {
493        let registry = model_manager.registry.read().await;
494        registry.list_models().len()
495    };
496
497    Ok(Json(RefreshModelsResponse {
498        updated_sources,
499        total_models,
500    }))
501}
502
503/// Install/download a model
504pub async fn install_model(
505    State(state): State<UnifiedAppState>,
506    Json(payload): Json<InstallModelRequest>,
507) -> Result<impl IntoResponse, StatusCode> {
508    let model_manager = state.shared_state.model_manager.as_ref()
509        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
510
511    info!("Installing model: {} ({})", payload.model_name, payload.model_id);
512
513    // Create model info
514    let model_info = ModelInfo {
515        id: payload.model_id.clone(),
516        name: payload.model_name.clone(),
517        description: payload.description,
518        author: None,
519        status: ModelStatus::Available,
520        size_bytes: payload.size_bytes,
521        format: payload.format,
522        download_source: None,
523        filename: None, // Will be determined by download source
524        installed_version: None,
525        last_updated: None,
526        tags: vec![], // Tags extracted from model metadata or source
527        compatibility_score: None,
528        parameters: None,
529        context_length: None,
530        provider: None,
531        total_shards: None,
532        shard_filenames: vec![],
533        downloads: 0,
534        is_gated: false,
535        pricing: None,
536    };
537
538    // Convert source specifier to download source
539    let download_source = match payload.source {
540        ModelSourceSpecifier::HuggingFace { repo_id, filename } => {
541            DownloadSource::HuggingFace { repo_id, filename }
542        }
543        ModelSourceSpecifier::OpenRouter { model_id } => {
544            DownloadSource::OpenRouter { model_id }
545        }
546    };
547
548    // Clone for use in async block
549    let download_source_clone = download_source.clone();
550
551    // Pre-create the download tracking entry so the frontend can poll immediately
552    let pre_download_id = model_manager.downloader.progress_tracker()
553        .start_download(
554            payload.model_id.clone(),
555            payload.model_name.clone(),
556            Some(payload.size_bytes),
557        )
558        .await;
559
560    let return_download_id = pre_download_id.clone();
561
562    // Update registry status to Downloading
563    {
564        let mut reg = model_manager.registry.write().await;
565        reg.update_model_status(&payload.model_id, ModelStatus::Downloading);
566    }
567
568    // Start download in background
569    let registry = model_manager.registry.clone();
570    let downloader = model_manager.downloader.clone();
571    let existing_download_id = pre_download_id.clone();
572
573    tokio::spawn(async move {
574        // Pass the pre-created download ID so progress is tracked correctly
575        // Also pass the HF token from the request for authentication
576        match downloader.download_model(model_info.clone(), download_source_clone.clone(), Some(existing_download_id), payload.hf_token).await {
577            Ok(_download_id) => {
578                // Extract the filename from the download source
579                let filename = match &download_source_clone {
580                    DownloadSource::HuggingFace { filename, .. } => Some(filename.clone()),
581                    DownloadSource::OpenRouter { .. } => None, // OpenRouter models are API-based, no file
582                };
583
584                // Update registry status AND filename, then persist
585                let mut reg = registry.write().await;
586                reg.update_model_status(&model_info.id, ModelStatus::Installed);
587
588                // CRITICAL: Update the filename in the registry so we know which file to load
589                if let Some(fname) = filename {
590                    if let Some(model) = reg.get_model_mut(&model_info.id) {
591                        model.filename = Some(fname);
592                        info!("Updated registry with filename for model: {}", model_info.id);
593                    }
594                }
595
596                if let Err(e) = reg.save_registry().await {
597                    error!("Failed to persist registry: {}", e);
598                }
599                drop(reg);
600
601                if let Err(e) = downloader.save_model_metadata(&model_info, &download_source_clone).await {
602                    error!("Failed to save model metadata: {}", e);
603                }
604                info!("Model installation completed: {}", model_info.name);
605            }
606            Err(e) => {
607                error!("Model installation failed: {} - {}", model_info.name, e);
608                let mut reg = registry.write().await;
609                reg.update_model_status(&model_info.id, ModelStatus::Error(e.to_string()));
610            }
611        }
612    });
613
614    Ok((
615        StatusCode::ACCEPTED,
616        Json(InstallModelResponse {
617            download_id: return_download_id,
618            message: format!("Started downloading model: {}", payload.model_name),
619        })
620    ))
621}
622
623/// Get download progress for a specific download
624pub async fn get_download_progress(
625    State(state): State<UnifiedAppState>,
626    Query(params): Query<std::collections::HashMap<String, String>>,
627) -> Result<impl IntoResponse, StatusCode> {
628    let model_manager = state.shared_state.model_manager.as_ref()
629        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
630
631    let download_id = params.get("download_id")
632        .ok_or(StatusCode::BAD_REQUEST)?;
633
634    let progress = model_manager.downloader.progress_tracker()
635        .get_progress(download_id)
636        .await
637        .ok_or(StatusCode::NOT_FOUND)?;
638
639    Ok(Json(progress))
640}
641
642/// Get all downloads (active and completed)
643pub async fn get_active_downloads(
644    State(state): State<UnifiedAppState>,
645) -> Result<impl IntoResponse, StatusCode> {
646    let model_manager = state.shared_state.model_manager.as_ref()
647        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
648
649    let downloads = model_manager.downloader.progress_tracker()
650        .get_all_downloads()
651        .await;
652
653    Ok(Json(downloads))
654}
655
656/// Cancel an ongoing download
657pub async fn cancel_download(
658    State(state): State<UnifiedAppState>,
659    Query(params): Query<std::collections::HashMap<String, String>>,
660) -> Result<impl IntoResponse, StatusCode> {
661    let model_manager = state.shared_state.model_manager.as_ref()
662        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
663
664    let download_id = params.get("download_id")
665        .ok_or(StatusCode::BAD_REQUEST)?;
666
667    let success = model_manager.downloader.cancel_download(download_id).await;
668    
669    if success {
670        Ok(Json(serde_json::json!({
671            "message": "Download cancelled successfully"
672        })))
673    } else {
674        Err(StatusCode::BAD_REQUEST)
675    }
676}
677
678/// Pause an ongoing download
679pub async fn pause_download(
680    State(state): State<UnifiedAppState>,
681    Query(params): Query<std::collections::HashMap<String, String>>,
682) -> Result<impl IntoResponse, StatusCode> {
683    let model_manager = state.shared_state.model_manager.as_ref()
684        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
685
686    let download_id = params.get("download_id")
687        .ok_or(StatusCode::BAD_REQUEST)?;
688
689    let tracker = model_manager.downloader.progress_tracker();
690    if let Some(progress) = tracker.get_progress(download_id).await {
691        // Only allow pausing if the download is currently downloading
692        if progress.status == crate::model_management::progress::DownloadStatus::Downloading {
693            tracker.update_progress(
694                download_id,
695                progress.bytes_downloaded,
696                crate::model_management::progress::DownloadStatus::Paused,
697                None,
698            ).await;
699            Ok(Json(serde_json::json!({ "message": "Download paused" })))
700        } else {
701            // Return an error if trying to pause a download that isn't downloading
702            Err(StatusCode::BAD_REQUEST)
703        }
704    } else {
705        Err(StatusCode::NOT_FOUND)
706    }
707}
708
709/// Resume a paused download
710pub async fn resume_download(
711    State(state): State<UnifiedAppState>,
712    Query(params): Query<std::collections::HashMap<String, String>>,
713) -> Result<impl IntoResponse, StatusCode> {
714    let model_manager = state.shared_state.model_manager.as_ref()
715        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
716
717    let download_id = params.get("download_id")
718        .ok_or(StatusCode::BAD_REQUEST)?;
719
720    let tracker = model_manager.downloader.progress_tracker();
721    if let Some(progress) = tracker.get_progress(download_id).await {
722        // Only allow resuming if the download is currently paused
723        if progress.status == crate::model_management::progress::DownloadStatus::Paused {
724            tracker.update_progress(
725                download_id,
726                progress.bytes_downloaded,
727                crate::model_management::progress::DownloadStatus::Downloading,
728                None,
729            ).await;
730            Ok(Json(serde_json::json!({ "message": "Download resumed" })))
731        } else {
732            // Return an error if trying to resume a download that isn't paused
733            Err(StatusCode::BAD_REQUEST)
734        }
735    } else {
736        Err(StatusCode::NOT_FOUND)
737    }
738}
739
740/// Remove/uninstall a model
741pub async fn remove_model(
742    State(state): State<UnifiedAppState>,
743    Query(params): Query<std::collections::HashMap<String, String>>,
744) -> Result<impl IntoResponse, StatusCode> {
745    let model_manager = state.shared_state.model_manager.as_ref()
746        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
747
748    let model_id = params.get("model_id")
749        .ok_or(StatusCode::BAD_REQUEST)?;
750
751    info!("Removing model: {}", model_id);
752
753    // Remove from storage
754    if let Err(e) = model_manager.storage.remove_model(model_id) {
755        error!("Failed to remove model from storage: {}", e);
756        return Err(StatusCode::INTERNAL_SERVER_ERROR);
757    }
758
759    // Remove from registry and persist
760    let mut registry = model_manager.registry.write().await;
761    registry.remove_model(model_id);
762    if let Err(e) = registry.save_registry().await {
763        error!("Failed to persist registry after removal: {}", e);
764    }
765
766    Ok(Json(serde_json::json!({
767        "message": format!("Model {} removed successfully", model_id)
768    })))
769}
770
771/// Get hardware recommendations
772pub async fn get_hardware_recommendations(
773    State(state): State<UnifiedAppState>,
774) -> Result<impl IntoResponse, StatusCode> {
775    let model_manager = state.shared_state.model_manager.as_ref()
776        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
777
778    let hardware = ModelRecommender::detect_hardware_profile(&state.shared_state.config);
779    let message = model_manager.recommender.get_hardware_recommendation_message(&hardware);
780    
781    let recommendations = message.lines().map(|s| s.to_string()).collect::<Vec<String>>();
782
783    Ok(Json(HardwareRecommendationsResponse {
784        recommendations,
785        message,
786    }))
787}
788
789/// Update user preferences for model recommendations
790pub async fn update_preferences(
791    State(state): State<UnifiedAppState>,
792    Json(payload): Json<UpdatePreferencesRequest>,
793) -> Result<impl IntoResponse, StatusCode> {
794    let model_manager = state.shared_state.model_manager.as_ref()
795        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
796
797    let mut preferences = model_manager.recommender.get_preferences().clone();
798
799    if let Some(use_case) = payload.primary_use_case {
800        preferences.primary_use_case = match use_case.as_str() {
801            "chat_assistant" => UseCase::ChatAssistant,
802            "code_generation" => UseCase::CodeGeneration,
803            "creative_writing" => UseCase::CreativeWriting,
804            "research_analysis" => UseCase::ResearchAnalysis,
805            "translation" => UseCase::Translation,
806            _ => UseCase::GeneralPurpose,
807        };
808    }
809
810    if let Some(quality) = payload.quality_preference {
811        preferences.quality_preference = match quality.as_str() {
812            "high_quality" => QualityPreference::HighQuality,
813            "fast_response" => QualityPreference::FastResponse,
814            _ => QualityPreference::Balanced,
815        };
816    }
817
818    if let Some(speed) = payload.speed_preference {
819        preferences.speed_preference = match speed.as_str() {
820            "fastest" => SpeedPreference::Fastest,
821            "highest_quality" => SpeedPreference::HighestQuality,
822            _ => SpeedPreference::Balanced,
823        };
824    }
825
826    if let Some(cost) = payload.cost_sensitivity {
827        preferences.cost_sensitivity = match cost.as_str() {
828            "budget" => CostSensitivity::Budget,
829            "premium" => CostSensitivity::Premium,
830            _ => CostSensitivity::Moderate,
831        };
832    }
833
834    // We can't mutate the recommender through Arc, so we'll need to restructure this
835    // For now, let's just acknowledge the preferences were set
836    info!("User preferences updated: {:?}", preferences);
837
838    Ok(Json(serde_json::json!({
839        "message": "Preferences updated successfully"
840    })))
841}
842
843/// Get recommended models based on current hardware and preferences
844pub async fn get_recommended_models(
845    State(state): State<UnifiedAppState>,
846    Query(params): Query<std::collections::HashMap<String, String>>,
847) -> Result<impl IntoResponse, StatusCode> {
848    let model_manager = state.shared_state.model_manager.as_ref()
849        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
850
851    let max_results = params.get("limit")
852        .and_then(|s| s.parse().ok())
853        .unwrap_or(10);
854
855    let hardware = ModelRecommender::detect_hardware_profile(&state.shared_state.config);
856    let all_models = get_cloned_models(model_manager).await;
857    
858    let recommendations = model_manager.recommender.get_recommendations(
859        all_models.iter().collect(),
860        &hardware,
861        max_results
862    );
863
864    // Get full model info for recommended models
865    let recommended_models: Vec<ModelInfo> = recommendations
866        .into_iter()
867        .filter_map(|(model_id, _)| {
868            all_models.iter().find(|m| m.id == model_id).cloned()
869        })
870        .collect();
871
872    Ok(Json(recommended_models))
873}
874
875/// Hardware information response
876#[derive(Debug, Serialize)]
877pub struct HardwareInfoResponse {
878    pub total_ram_gb: f32,
879    pub available_ram_gb: f32,
880    pub cpu_cores: u32,
881    pub gpu_available: bool,
882    pub gpu_vram_gb: Option<f32>,
883    pub storage_used_bytes: u64,
884    pub storage_available_bytes: u64,
885}
886
887/// Get current hardware info and storage usage
888pub async fn get_hardware_info(
889    State(state): State<UnifiedAppState>,
890) -> Result<impl IntoResponse, StatusCode> {
891    let hardware = ModelRecommender::detect_hardware_profile(&state.shared_state.config);
892
893    let (storage_used, storage_available) = if let Some(mm) = state.shared_state.model_manager.as_ref() {
894        let used = mm.storage.get_storage_usage().unwrap_or(0);
895        let available = mm.storage.get_available_space().unwrap_or(0);
896        (used, available)
897    } else {
898        (0, 0)
899    };
900
901    Ok(Json(HardwareInfoResponse {
902        total_ram_gb: hardware.total_ram_gb,
903        available_ram_gb: hardware.available_ram_gb,
904        cpu_cores: hardware.cpu_cores,
905        gpu_available: hardware.gpu_available,
906        gpu_vram_gb: hardware.gpu_vram_gb,
907        storage_used_bytes: storage_used,
908        storage_available_bytes: storage_available,
909    }))
910}
911
912/// Live system metrics response
913#[derive(Serialize)]
914pub struct SystemMetricsResponse {
915    pub cpu_usage_percent: f32,
916    pub per_core_usage: Vec<f32>,
917    pub cpu_model_name: String,
918    pub cpu_frequency_mhz: u64,
919    pub gpu_available: bool,
920    pub gpu_name: String,
921    pub gpu_usage_percent: f32,
922    pub gpu_vram_total_gb: f32,
923    pub gpu_vram_used_gb: f32,
924    pub gpu_temperature_c: f32,
925    pub memory_total_gb: f32,
926    pub memory_used_gb: f32,
927    pub memory_available_gb: f32,
928    pub gpu_layers_offloaded: u32,
929    pub inference_device: String, // "GPU", "CPU", "CPU+GPU"
930}
931
932/// Switch to a different model by ID
933pub async fn switch_model(
934    State(state): State<UnifiedAppState>,
935    Json(payload): Json<SwitchModelRequest>,
936) -> Result<impl IntoResponse, StatusCode> {
937    let model_manager = state.shared_state.model_manager.as_ref()
938        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
939    
940    // Look up the model in the registry by ID
941    let model_info = {
942        let registry = model_manager.registry.read().await;
943        registry.get_model(&payload.model_id)
944            .ok_or(StatusCode::NOT_FOUND)?
945            .clone()
946    };
947    
948    // Get the complete model metadata including runtime binaries
949    let model_metadata = {
950        let registry = model_manager.registry.read().await;
951        registry.get_model_metadata(&payload.model_id).await
952    };
953    
954    // Verify that the model is installed and the file exists
955    if model_info.status != ModelStatus::Installed {
956        return Err(StatusCode::BAD_REQUEST);
957    }
958    
959    // Get the model's path using the storage module
960    let model_path = if let Some(ref filename) = model_info.filename {
961        // Use the stored filename if available
962        let path = model_manager.storage.model_path(&payload.model_id, filename);
963        info!("🔍 Resolving model path from registry filename: {}", path.display());
964        path
965    } else {
966        // If no filename is stored, look for model files in the model directory
967        warn!("⚠️  Model {} has no filename in registry, scanning directory...", payload.model_id);
968
969        // We can get the directory by using a dummy filename with model_path, then taking the parent
970        let model_dir = model_manager.storage.model_path(&payload.model_id, "dummy").parent()
971            .map(|p| p.to_path_buf())
972            .ok_or_else(|| {
973                error!("❌ Failed to get parent directory for model: {}", payload.model_id);
974                StatusCode::NOT_FOUND
975            })?;
976
977        info!("📂 Scanning model directory: {}", model_dir.display());
978
979        if !model_dir.exists() {
980            error!("❌ Model directory does not exist: {}", model_dir.display());
981            return Err(StatusCode::NOT_FOUND);
982        }
983
984        // Look for model files in the directory
985        let mut found_path = None;
986        if let Ok(entries) = std::fs::read_dir(&model_dir) {
987            for entry in entries.flatten() {
988                if let Ok(file_type) = entry.file_type() {
989                    if file_type.is_file() {
990                        let path = entry.path();
991                        let ext = path.extension().unwrap_or_default().to_string_lossy().to_lowercase();
992                        if matches!(ext.as_str(), "gguf" | "bin" | "ggml" | "onnx" | "trt" | "engine" | "safetensors" | "mlmodel") {
993                            // Found a valid model file
994                            info!("✅ Found model file: {}", path.display());
995                            found_path = Some(path);
996                            break; // Take the first valid file found
997                        }
998                    }
999                }
1000            }
1001        }
1002
1003        match found_path {
1004            Some(path) => path,
1005            None => {
1006                error!("❌ No valid model file found in directory: {}", model_dir.display());
1007                return Err(StatusCode::NOT_FOUND);
1008            }
1009        }
1010    };
1011
1012    if !model_path.exists() {
1013        error!("❌ Model file does not exist: {}", model_path.display());
1014        error!("   Please check if the model was downloaded correctly to AppData");
1015        return Err(StatusCode::NOT_FOUND);
1016    }
1017
1018    info!("✅ Model file verified at: {}", model_path.display());
1019    
1020    // Convert model_path to string for later use since it will be moved
1021    let model_path_str = model_path.to_string_lossy().to_string();
1022    
1023    // Get the runtime manager from shared state
1024    let runtime_manager = {
1025        let guard = state.shared_state.runtime_manager.read()
1026            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1027        guard.clone().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
1028    };
1029    
1030    // Prepare runtime config for the new model
1031    // Try to get the model's format from the registry info, default to GGUF if not recognized
1032    let model_format = match model_info.format.as_str().to_lowercase().as_str() {
1033        "gguf" => crate::model_runtime::ModelFormat::GGUF,
1034        "ggml" => crate::model_runtime::ModelFormat::GGML,
1035        "onnx" => crate::model_runtime::ModelFormat::ONNX,
1036        "tensorrt" => crate::model_runtime::ModelFormat::TensorRT,
1037        "safetensors" => crate::model_runtime::ModelFormat::Safetensors,
1038        "coreml" => crate::model_runtime::ModelFormat::CoreML,
1039        _ => crate::model_runtime::ModelFormat::GGUF, // Default fallback
1040    };
1041    
1042    // Determine the appropriate runtime binary based on platform and model metadata
1043    // Priority: 1) Model metadata, 2) Installed engine from registry, 3) Config fallback
1044    let runtime_binary = if let Some(ref metadata) = model_metadata {
1045        // If we have model metadata with runtime binaries, try to use the platform-appropriate one
1046        use crate::model_runtime::platform_detector::HardwareCapabilities;
1047        let hw_caps = HardwareCapabilities::default();
1048        let platform_name = match hw_caps.platform {
1049            crate::model_runtime::platform_detector::Platform::Windows => "windows",
1050            crate::model_runtime::platform_detector::Platform::Linux => "linux",
1051            crate::model_runtime::platform_detector::Platform::MacOS => "macos",
1052        };
1053
1054        // First try to get platform-specific binary from metadata
1055        if let Some(bin_path) = metadata.runtime_binaries.get(platform_name) {
1056            Some(bin_path.clone())
1057        } else {
1058            // Fallback to installed engine from engine registry, then to config llama_bin
1059            if let Some(ref engine_manager) = state.shared_state.engine_manager {
1060                let registry = engine_manager.registry.read().await;
1061                registry.get_default_engine_binary_path()
1062                    .or_else(|| if !state.shared_state.config.llama_bin.is_empty() {
1063                        Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1064                    } else { None })
1065            } else {
1066                Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1067            }
1068        }
1069    } else {
1070        // No metadata available, use installed engine from registry, then config llama_bin
1071        if let Some(ref engine_manager) = state.shared_state.engine_manager {
1072            let registry = engine_manager.registry.read().await;
1073            registry.get_default_engine_binary_path()
1074                .or_else(|| if !state.shared_state.config.llama_bin.is_empty() {
1075                    Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1076                } else { None })
1077        } else {
1078            Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1079        }
1080    };
1081    
1082    let runtime_config = crate::model_runtime::RuntimeConfig {
1083        model_path: model_path.clone(),
1084        format: model_format, // Use the detected format from model info
1085        host: state.shared_state.config.llama_host.clone(),
1086        port: state.shared_state.config.llama_port,
1087        context_size: state.shared_state.config.ctx_size,
1088        batch_size: state.shared_state.config.batch_size,
1089        threads: state.shared_state.config.threads,
1090        gpu_layers: state.shared_state.config.gpu_layers,
1091        runtime_binary: runtime_binary.clone(), // Use platform-appropriate binary
1092        extra_config: serde_json::json!({}),
1093    };
1094
1095    info!("🚀 Initializing runtime with config:");
1096    info!("   Model Path: {}", runtime_config.model_path.display());
1097    info!("   Runtime Binary: {}", runtime_binary.as_ref().map(|p| p.display().to_string()).unwrap_or_else(|| "None".to_string()));
1098    info!("   Format: {:?}", runtime_config.format);
1099    info!("   Host: {}:{}", runtime_config.host, runtime_config.port);
1100    info!("   Context Size: {}", runtime_config.context_size);
1101    info!("   GPU Layers: {}", runtime_config.gpu_layers);
1102
1103    // Skip re-initialization if this exact model is already loaded and healthy.
1104    // This avoids a shutdown→restart race where the model becomes briefly unavailable.
1105    if let Some(current_config) = runtime_manager.get_current_config().await {
1106        if current_config.model_path == runtime_config.model_path
1107            && runtime_manager.is_ready().await
1108        {
1109            info!("✅ Model {} is already loaded and ready — skipping re-initialization", model_info.name);
1110            if let Some(data_dir) = dirs::data_dir() {
1111                let last_model_path = data_dir.join("Aud.io").join("last_model.txt");
1112                let _ = std::fs::write(last_model_path, &payload.model_id);
1113            }
1114            return Ok(Json(SwitchModelResponse {
1115                message: format!("Model {} is already loaded and ready for inference", model_info.name),
1116                model_id: payload.model_id.clone(),
1117                model_path: model_path_str,
1118            }));
1119        }
1120    }
1121
1122    // Use initialize_auto to automatically detect the model format
1123    match runtime_manager.initialize_auto(runtime_config).await {
1124        Ok(base_url) => {
1125            info!("Runtime initialized at {}, performing health check...", base_url);
1126
1127            // CRITICAL: Verify runtime is actually ready before returning success
1128            match runtime_manager.health_check().await {
1129                Ok(_) => {
1130                    info!("✅ Model {} activated successfully and health check passed", model_info.name);
1131
1132                    // Save last used model for auto-load on next startup
1133                    if let Some(data_dir) = dirs::data_dir() {
1134                        let last_model_path = data_dir.join("Aud.io").join("last_model.txt");
1135                        let _ = std::fs::write(last_model_path, &payload.model_id);
1136                    }
1137
1138                    Ok(Json(SwitchModelResponse {
1139                        message: format!("Model {} loaded and ready for inference", model_info.name),
1140                        model_id: payload.model_id.clone(),
1141                        model_path: model_path_str,
1142                    }))
1143                }
1144                Err(e) => {
1145                    error!("❌ Model activation health check failed: {}", e);
1146                    error!("   The model may be too large or incompatible with your hardware");
1147                    Err(StatusCode::INTERNAL_SERVER_ERROR)
1148                }
1149            }
1150        }
1151        Err(e) => {
1152            let error_msg = e.to_string();
1153
1154            // Check if the error is due to a missing binary
1155            if error_msg.contains("binary not found")
1156                || error_msg.contains("not found at")
1157                || error_msg.contains("No such file")
1158            {
1159                error!("Engine binary not found - attempting automatic download and retry");
1160
1161                // Attempt to download engine automatically
1162                if let Some(ref engine_manager) = state.shared_state.engine_manager {
1163                    match engine_manager.ensure_engine_available().await {
1164                        Ok(true) => {
1165                            info!("Engine downloaded successfully, retrying model switch...");
1166
1167                            // Retry the initialization with updated engine path
1168                            // Return special response to signal retry required
1169                            // Note: We return Ok() with the SwitchModelResponse type
1170                            // The frontend will check the message for retry_required
1171                            return Ok(Json(SwitchModelResponse {
1172                                message: "Engine was downloaded. Please retry switching models.".to_string(),
1173                                model_id: payload.model_id.clone(),
1174                                model_path: "retry_required".to_string(),
1175                            }));
1176                        }
1177                        Ok(false) => {
1178                            error!("Failed to auto-download engine - download returned false");
1179                            // Return detailed error message
1180                            return Ok(Json(SwitchModelResponse {
1181                                message: "Engine download failed. Please check your internet connection and try again.".to_string(),
1182                                model_id: payload.model_id.clone(),
1183                                model_path: "engine_download_failed".to_string(),
1184                            }));
1185                        }
1186                        Err(e) => {
1187                            error!("Failed to auto-download engine: {}", e);
1188                            // Return detailed error message
1189                            return Ok(Json(SwitchModelResponse {
1190                                message: format!("Engine download error: {}", e),
1191                                model_id: payload.model_id.clone(),
1192                                model_path: "engine_download_error".to_string(),
1193                            }));
1194                        }
1195                    }
1196                } else {
1197                    error!("No engine manager available");
1198                    return Ok(Json(SwitchModelResponse {
1199                        message: "Engine manager not initialized. Please restart the application.".to_string(),
1200                        model_id: payload.model_id.clone(),
1201                        model_path: "no_engine_manager".to_string(),
1202                    }));
1203                }
1204            }
1205
1206            error!("Failed to switch model: {}", error_msg);
1207            Err(StatusCode::INTERNAL_SERVER_ERROR)
1208        }
1209    }
1210}
1211
1212/// Get live system metrics including CPU/GPU usage
1213pub async fn get_system_metrics(
1214    State(state): State<UnifiedAppState>,
1215) -> Result<impl IntoResponse, StatusCode> {
1216    use sysinfo::System;
1217
1218    // CPU metrics - need two refreshes with delay for accurate usage
1219    let mut system = System::new_all();
1220    system.refresh_cpu();
1221    // Small delay for accurate CPU measurement
1222    tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1223    system.refresh_cpu();
1224    system.refresh_memory();
1225
1226    let cpu_usage = system.global_cpu_info().cpu_usage();
1227    let per_core: Vec<f32> = system.cpus().iter().map(|cpu| cpu.cpu_usage()).collect();
1228    let cpu_model = system.cpus().first().map(|c| c.brand().to_string()).unwrap_or_else(|| "Unknown CPU".into());
1229    let cpu_freq = system.cpus().first().map(|c| c.frequency()).unwrap_or(0);
1230
1231    let total_mem = system.total_memory() as f32 / (1024.0 * 1024.0 * 1024.0);
1232    let used_mem = (system.total_memory() - system.available_memory()) as f32 / (1024.0 * 1024.0 * 1024.0);
1233    let available_mem = system.available_memory() as f32 / (1024.0 * 1024.0 * 1024.0);
1234
1235    // GPU metrics
1236    let (gpu_available, gpu_name, gpu_usage, gpu_vram_total, gpu_vram_used, gpu_temp) = {
1237        #[cfg(feature = "nvidia")]
1238        {
1239            match nvml_wrapper::Nvml::init() {
1240                Ok(nvml) => {
1241                    match nvml.device_by_index(0) {
1242                        Ok(device) => {
1243                            let name = device.name().unwrap_or_else(|_| "GPU".into());
1244                            let utilization = device.utilization_rates().map(|u| u.gpu as f32).unwrap_or(0.0);
1245                            let mem_info = device.memory_info();
1246                            let vram_total = mem_info.as_ref().map(|m| m.total as f32 / (1024.0 * 1024.0 * 1024.0)).unwrap_or(0.0);
1247                            let vram_used = mem_info.as_ref().map(|m| m.used as f32 / (1024.0 * 1024.0 * 1024.0)).unwrap_or(0.0);
1248                            let temp = device.temperature(nvml_wrapper::enum_wrappers::device::TemperatureSensor::Gpu).unwrap_or(0) as f32;
1249                            tracing::debug!("GPU detected: {}, usage: {}%, VRAM: {}/{} GB", name, utilization, vram_used, vram_total);
1250                            (true, name, utilization, vram_total, vram_used, temp)
1251                        }
1252                        Err(e) => {
1253                            tracing::warn!("NVML initialized but failed to get device: {}", e);
1254                            (false, String::from("Not detected"), 0.0_f32, 0.0_f32, 0.0_f32, 0.0_f32)
1255                        }
1256                    }
1257                }
1258                Err(e) => {
1259                    tracing::debug!("NVML not available: {}", e);
1260                    (false, String::from("Not detected"), 0.0_f32, 0.0_f32, 0.0_f32, 0.0_f32)
1261                }
1262            }
1263        }
1264        #[cfg(not(feature = "nvidia"))]
1265        {
1266            // Without NVML, report no GPU metrics - GPU detection happens at config level
1267            (false, String::from("Not detected"), 0.0_f32, 0.0_f32, 0.0_f32, 0.0_f32)
1268        }
1269    };
1270
1271    // Determine inference device from config
1272    let gpu_layers = state.shared_state.config.gpu_layers;
1273    let inference_device = if !gpu_available {
1274        "CPU".to_string()
1275    } else if gpu_layers == 0 {
1276        "CPU".to_string()
1277    } else if gpu_layers >= 50 {
1278        "GPU".to_string()
1279    } else {
1280        "CPU+GPU".to_string()
1281    };
1282
1283    Ok(Json(SystemMetricsResponse {
1284        cpu_usage_percent: cpu_usage,
1285        per_core_usage: per_core,
1286        cpu_model_name: cpu_model,
1287        cpu_frequency_mhz: cpu_freq,
1288        gpu_available,
1289        gpu_name,
1290        gpu_usage_percent: gpu_usage,
1291        gpu_vram_total_gb: gpu_vram_total,
1292        gpu_vram_used_gb: gpu_vram_used,
1293        gpu_temperature_c: gpu_temp,
1294        memory_total_gb: total_mem,
1295        memory_used_gb: used_mem,
1296        memory_available_gb: available_mem,
1297        gpu_layers_offloaded: gpu_layers,
1298        inference_device,
1299    }))
1300}
1301
1302/// Storage metadata response
1303#[derive(Debug, Serialize)]
1304pub struct StorageMetadataResponse {
1305    /// System paths
1306    pub paths: StoragePaths,
1307    /// Downloaded models with metadata
1308    pub models: Vec<DownloadedModelInfo>,
1309    /// Storage usage statistics
1310    pub storage_stats: StorageStats,
1311    /// Database information
1312    pub database_info: DatabaseInfo,
1313    /// Installed engines information
1314    pub engines: Vec<InstalledEngineInfo>,
1315}
1316
1317/// Storage paths on the system
1318#[derive(Debug, Serialize)]
1319pub struct StoragePaths {
1320    pub app_data_dir: String,
1321    pub models_dir: String,
1322    pub registry_dir: String,
1323    pub database_path: String,
1324    pub engines_dir: String,
1325}
1326
1327/// Information about an installed engine
1328#[derive(Debug, Serialize)]
1329pub struct InstalledEngineInfo {
1330    pub id: String,
1331    pub name: String,
1332    pub version: String,
1333    pub platform: String,
1334    pub acceleration: String,
1335    pub file_size: u64,
1336    pub size_human: String,
1337    pub install_path: String,
1338    pub binary_name: String,
1339    pub is_default: bool,
1340}
1341
1342/// Information about a downloaded model
1343#[derive(Debug, Serialize)]
1344pub struct DownloadedModelInfo {
1345    pub id: String,
1346    pub name: String,
1347    pub format: String,
1348    pub size_bytes: u64,
1349    pub size_human: String,
1350    pub download_date: String,
1351    pub download_source: String,
1352    pub file_path: String,
1353    pub metadata_path: Option<String>,
1354}
1355
1356/// Storage usage statistics
1357#[derive(Debug, Serialize)]
1358pub struct StorageStats {
1359    pub models_total_bytes: u64,
1360    pub models_total_human: String,
1361    pub available_space_bytes: u64,
1362    pub available_space_human: String,
1363    pub model_count: usize,
1364}
1365
1366/// Database information
1367#[derive(Debug, Serialize)]
1368pub struct DatabaseInfo {
1369    pub path: String,
1370    pub size_bytes: u64,
1371    pub size_human: String,
1372}
1373
1374/// Get comprehensive local storage metadata
1375pub async fn get_storage_metadata(
1376    State(state): State<UnifiedAppState>,
1377) -> Result<impl IntoResponse, StatusCode> {
1378    use crate::model_management::storage::ModelMetadata;
1379    
1380    let model_manager = state.shared_state.model_manager.as_ref()
1381        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
1382    
1383    // Get app data directory using dirs crate
1384    let app_data_dir = dirs::data_dir()
1385        .map(|d| {
1386            if cfg!(target_os = "windows") || cfg!(target_os = "macos") {
1387                d.join("Aud.io")
1388            } else {
1389                d.join("aud.io")
1390            }
1391        })
1392        .unwrap_or_else(|| std::path::PathBuf::from("./aud.io-data"));
1393    
1394    // Get engines directory
1395    let engines_dir = app_data_dir.join("engines");
1396
1397    // Get storage paths
1398    let paths = StoragePaths {
1399        app_data_dir: app_data_dir.to_string_lossy().to_string(),
1400        models_dir: model_manager.storage.location.models_dir.to_string_lossy().to_string(),
1401        registry_dir: model_manager.storage.location.registry_dir.to_string_lossy().to_string(),
1402        database_path: app_data_dir.join("memory.db").to_string_lossy().to_string(),
1403        engines_dir: engines_dir.to_string_lossy().to_string(),
1404    };
1405    
1406    // Get downloaded models with metadata
1407    let mut models = Vec::new();
1408    let installed_models: Vec<crate::model_management::registry::ModelInfo> = {
1409        let registry = model_manager.registry.read().await;
1410        registry.list_models().into_iter()
1411            .filter(|m| matches!(m.status, crate::model_management::registry::ModelStatus::Installed))
1412            .cloned()
1413            .collect()
1414    };
1415    
1416    for model in installed_models {
1417        // Try to load metadata
1418        let metadata_path = model_manager.storage.metadata_path(&model.id);
1419        let download_date = if metadata_path.exists() {
1420            std::fs::metadata(&metadata_path)
1421                .and_then(|m| m.modified())
1422                .ok()
1423                .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
1424                .map(|d| chrono::DateTime::from_timestamp(d.as_secs() as i64, 0))
1425                .flatten()
1426                .map(|d| d.format("%Y-%m-%d %H:%M:%S UTC").to_string())
1427                .unwrap_or_else(|| "Unknown".to_string())
1428        } else {
1429            "Unknown".to_string()
1430        };
1431        
1432        let download_source = if metadata_path.exists() {
1433            std::fs::read_to_string(&metadata_path)
1434                .ok()
1435                .and_then(|content| serde_json::from_str::<ModelMetadata>(&content).ok())
1436                .map(|m| m.download_source)
1437                .unwrap_or_else(|| "unknown".to_string())
1438        } else {
1439            "unknown".to_string()
1440        };
1441        
1442        // Get actual file size from disk
1443        let model_dir = model_manager.storage.location.models_dir.join(
1444            model.id.replace(':', "_").replace('/', "_").replace('\\', "_")
1445        );
1446        let mut actual_size = model.size_bytes;
1447        if model_dir.exists() {
1448            actual_size = walkdir::WalkDir::new(&model_dir)
1449                .into_iter()
1450                .filter_map(|e| e.ok())
1451                .filter(|e| e.file_type().is_file())
1452                .filter_map(|e| e.metadata().ok())
1453                .map(|m| m.len())
1454                .sum();
1455        }
1456        
1457        models.push(DownloadedModelInfo {
1458            id: model.id.clone(),
1459            name: model.name.clone(),
1460            format: model.format.clone(),
1461            size_bytes: actual_size,
1462            size_human: format_bytes(actual_size),
1463            download_date,
1464            download_source,
1465            file_path: model_dir.to_string_lossy().to_string(),
1466            metadata_path: if metadata_path.exists() {
1467                Some(metadata_path.to_string_lossy().to_string())
1468            } else {
1469                None
1470            },
1471        });
1472    }
1473    
1474    // Get storage stats
1475    let models_total_bytes = model_manager.storage.get_storage_usage().unwrap_or(0);
1476    let available_space_bytes = model_manager.storage.get_available_space().unwrap_or(0);
1477    
1478    let storage_stats = StorageStats {
1479        models_total_bytes,
1480        models_total_human: format_bytes(models_total_bytes),
1481        available_space_bytes,
1482        available_space_human: format_bytes(available_space_bytes),
1483        model_count: models.len(),
1484    };
1485    
1486    // Get database info
1487    let db_path = app_data_dir.join("memory.db");
1488    let db_size = std::fs::metadata(&db_path).map(|m| m.len()).unwrap_or(0);
1489    
1490    let database_info = DatabaseInfo {
1491        path: db_path.to_string_lossy().to_string(),
1492        size_bytes: db_size,
1493        size_human: format_bytes(db_size),
1494    };
1495
1496    // Get installed engines info
1497    let mut engines = Vec::new();
1498    if let Some(ref engine_manager) = state.shared_state.engine_manager {
1499        let registry = engine_manager.registry.read().await;
1500        let default_engine_id = registry.default_engine.clone();
1501
1502        for (engine_id, engine_info) in &registry.installed_engines {
1503            if let Some(install_path) = &engine_info.install_path {
1504                engines.push(InstalledEngineInfo {
1505                    id: engine_info.id.clone(),
1506                    name: engine_info.name.clone(),
1507                    version: engine_info.version.clone(),
1508                    platform: format!("{:?}", engine_info.platform),
1509                    acceleration: format!("{:?}", engine_info.acceleration),
1510                    file_size: engine_info.file_size,
1511                    size_human: format_bytes(engine_info.file_size),
1512                    install_path: install_path.to_string_lossy().to_string(),
1513                    binary_name: engine_info.binary_name.clone(),
1514                    is_default: default_engine_id.as_ref() == Some(engine_id),
1515                });
1516            }
1517        }
1518    }
1519
1520    Ok(Json(StorageMetadataResponse {
1521        paths,
1522        models,
1523        storage_stats,
1524        database_info,
1525        engines,
1526    }))
1527}
1528
1529/// Format bytes to human-readable string
1530fn format_bytes(bytes: u64) -> String {
1531    const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
1532    let mut size = bytes as f64;
1533    let mut unit_index = 0;
1534
1535    while size >= 1024.0 && unit_index < UNITS.len() - 1 {
1536        size /= 1024.0;
1537        unit_index += 1;
1538    }
1539
1540    format!("{:.2} {}", size, UNITS[unit_index])
1541}
1542
1543// ─────────────────────────────────────────────────────────────────────────────
1544// Phase A: HuggingFace Gated Model Access Check
1545// ─────────────────────────────────────────────────────────────────────────────
1546
1547/// Query parameters for the HF access-check endpoint
1548#[derive(Debug, Deserialize)]
1549pub struct HfAccessParams {
1550    pub repo_id: String,
1551    pub filename: String,
1552    pub hf_token: Option<String>,
1553}
1554
1555/// Response body for the HF access-check endpoint
1556#[derive(Debug, Serialize)]
1557pub struct HfAccessResponse {
1558    /// One of: "accessible", "not_approved", "unauthorized", "not_found", "error"
1559    pub status: String,
1560    /// `true` when the user may start a download immediately
1561    pub can_download: bool,
1562    /// Human-readable explanation
1563    pub message: String,
1564}
1565
1566/// `GET /models/hf/access?repo_id=…&filename=…&hf_token=…`
1567///
1568/// Performs a HEAD request against the HuggingFace CDN to determine whether
1569/// the supplied token grants download access to a gated repository.
1570pub async fn check_hf_access(
1571    Query(params): Query<HfAccessParams>,
1572) -> Result<impl IntoResponse, StatusCode> {
1573    use crate::model_management::{check_hf_gated_access, HfAccessStatus};
1574
1575    let status = check_hf_gated_access(
1576        &params.repo_id,
1577        &params.filename,
1578        params.hf_token.as_deref(),
1579    )
1580    .await;
1581
1582    let (status_str, can_download, message) = match &status {
1583        HfAccessStatus::Accessible => (
1584            "accessible",
1585            true,
1586            "Access granted — download can proceed.".to_string(),
1587        ),
1588        HfAccessStatus::NotApproved => (
1589            "not_approved",
1590            false,
1591            "Your token is valid but you have not been approved to access this \
1592             model yet. Visit the model page on HuggingFace to request access."
1593                .to_string(),
1594        ),
1595        HfAccessStatus::Unauthorized => (
1596            "unauthorized",
1597            false,
1598            "No HuggingFace token provided or the token is invalid. \
1599             Please add your HF token in Settings."
1600                .to_string(),
1601        ),
1602        HfAccessStatus::NotFound => (
1603            "not_found",
1604            false,
1605            "The model or file was not found on HuggingFace.".to_string(),
1606        ),
1607        HfAccessStatus::Error(e) => (
1608            "error",
1609            false,
1610            format!("Network or server error: {}", e),
1611        ),
1612    };
1613
1614    Ok(Json(HfAccessResponse {
1615        status: status_str.to_string(),
1616        can_download,
1617        message,
1618    }))
1619}
1620
1621// ─────────────────────────────────────────────────────────────────────────────
1622// Phase B: OpenRouter Full Catalog (paginated + filtered)
1623// ─────────────────────────────────────────────────────────────────────────────
1624
1625/// Query parameters for the OpenRouter catalog endpoint
1626#[derive(Debug, Deserialize)]
1627pub struct OpenRouterCatalogParams {
1628    /// 1-based page number (default: 1)
1629    pub page: Option<usize>,
1630    /// Results per page, clamped to [1, 200] (default: 50)
1631    pub per_page: Option<usize>,
1632    /// Free-text search across name, id, description, and provider
1633    pub search: Option<String>,
1634    /// Filter by provider prefix, e.g. "openai", "meta", "google"
1635    pub provider: Option<String>,
1636    /// When `true`, only return models with zero-cost pricing
1637    pub free_only: Option<bool>,
1638    /// Minimum context-length in tokens, e.g. 32000
1639    pub min_context: Option<u64>,
1640}
1641
1642/// Paginated response for the OpenRouter catalog
1643#[derive(Debug, Serialize)]
1644pub struct OpenRouterCatalogResponse {
1645    pub models: Vec<ModelInfo>,
1646    pub total: usize,
1647    pub page: usize,
1648    pub per_page: usize,
1649    pub total_pages: usize,
1650}
1651
1652/// `GET /models/openrouter/catalog`
1653///
1654/// Returns the full set of OpenRouter models stored in the registry,
1655/// with optional server-side filtering and pagination.
1656pub async fn openrouter_catalog(
1657    State(state): State<UnifiedAppState>,
1658    Query(params): Query<OpenRouterCatalogParams>,
1659) -> Result<impl IntoResponse, StatusCode> {
1660    let model_manager = state
1661        .shared_state
1662        .model_manager
1663        .as_ref()
1664        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
1665
1666    let page = params.page.unwrap_or(1).max(1);
1667    let per_page = params.per_page.unwrap_or(50).clamp(1, 200);
1668
1669    // Collect all OpenRouter models from the registry
1670    let mut models: Vec<ModelInfo> = {
1671        let registry = model_manager.registry.read().await;
1672        registry
1673            .list_models()
1674            .into_iter()
1675            .filter(|m| m.download_source.as_deref() == Some("openrouter"))
1676            .cloned()
1677            .collect()
1678    };
1679
1680    // ── Filters ────────────────────────────────────────────────────────────
1681    if let Some(ref search) = params.search {
1682        let q = search.to_lowercase();
1683        models.retain(|m| {
1684            m.name.to_lowercase().contains(&q)
1685                || m.id.to_lowercase().contains(&q)
1686                || m.description
1687                    .as_ref()
1688                    .map_or(false, |d| d.to_lowercase().contains(&q))
1689                || m.provider
1690                    .as_ref()
1691                    .map_or(false, |p| p.to_lowercase().contains(&q))
1692        });
1693    }
1694
1695    if let Some(ref provider) = params.provider {
1696        let prov = provider.to_lowercase();
1697        models.retain(|m| {
1698            m.provider
1699                .as_ref()
1700                .map_or(false, |p| p.to_lowercase().contains(&prov))
1701                || m.id.to_lowercase().starts_with(&prov)
1702        });
1703    }
1704
1705    if params.free_only.unwrap_or(false) {
1706        models.retain(|m| {
1707            m.pricing.as_ref().map_or(false, |p| p.is_free())
1708                || m.tags.iter().any(|t| t == "free")
1709        });
1710    }
1711
1712    if let Some(min_ctx) = params.min_context {
1713        models.retain(|m| m.context_length.map_or(false, |ctx| ctx >= min_ctx));
1714    }
1715
1716    // Sort by name for stable ordering
1717    models.sort_by(|a, b| a.name.cmp(&b.name));
1718
1719    let total = models.len();
1720    let total_pages = total.div_ceil(per_page);
1721    let start = (page - 1) * per_page;
1722    let page_models: Vec<ModelInfo> = models.into_iter().skip(start).take(per_page).collect();
1723
1724    Ok(Json(OpenRouterCatalogResponse {
1725        models: page_models,
1726        total,
1727        page,
1728        per_page,
1729        total_pages,
1730    }))
1731}
1732
1733// ─────────────────────────────────────────────────────────────────────────────
1734// Phase B: OpenRouter Account Quota
1735// ─────────────────────────────────────────────────────────────────────────────
1736
1737/// Response body for the OpenRouter quota endpoint
1738#[derive(Debug, Serialize)]
1739pub struct OpenRouterQuotaResponse {
1740    /// Total credits spent so far in USD
1741    pub usage_usd: f64,
1742    /// Hard spending cap in USD (`null` = unlimited)
1743    pub limit_usd: Option<f64>,
1744    /// `true` when the account is on the free tier
1745    pub is_free_tier: bool,
1746    /// Remaining credits in USD (`null` when limit is unlimited)
1747    pub remaining_usd: Option<f64>,
1748}
1749
1750/// `GET /models/openrouter/quota`
1751///
1752/// Queries `https://openrouter.ai/api/v1/auth/key` with the stored API key
1753/// and returns the current usage / limit information.
1754pub async fn openrouter_quota(
1755    State(state): State<UnifiedAppState>,
1756) -> Result<impl IntoResponse, StatusCode> {
1757    // Retrieve stored OpenRouter key
1758    let api_key = state
1759        .shared_state
1760        .database_pool
1761        .api_keys
1762        .get_key_plaintext(&ApiKeyType::OpenRouter)
1763        .ok()
1764        .flatten()
1765        .ok_or(StatusCode::UNAUTHORIZED)?;
1766
1767    let client = reqwest::Client::builder()
1768        .timeout(std::time::Duration::from_secs(10))
1769        .build()
1770        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1771
1772    let resp = client
1773        .get("https://openrouter.ai/api/v1/auth/key")
1774        .header("Authorization", format!("Bearer {}", api_key))
1775        .send()
1776        .await
1777        .map_err(|e| {
1778            warn!("OpenRouter quota request failed: {}", e);
1779            StatusCode::BAD_GATEWAY
1780        })?;
1781
1782    if !resp.status().is_success() {
1783        warn!("OpenRouter /auth/key returned {}", resp.status());
1784        return Err(StatusCode::BAD_GATEWAY);
1785    }
1786
1787    #[derive(serde::Deserialize)]
1788    struct OrKeyData {
1789        usage: Option<f64>,
1790        limit: Option<f64>,
1791        is_free_tier: Option<bool>,
1792    }
1793    #[derive(serde::Deserialize)]
1794    struct OrKeyResp {
1795        data: OrKeyData,
1796    }
1797
1798    let body: OrKeyResp = resp.json().await.map_err(|e| {
1799        warn!("Failed to parse OpenRouter quota response: {}", e);
1800        StatusCode::BAD_GATEWAY
1801    })?;
1802
1803    let usage = body.data.usage.unwrap_or(0.0);
1804    let limit = body.data.limit;
1805    let is_free_tier = body.data.is_free_tier.unwrap_or(false);
1806    let remaining = limit.map(|l| (l - usage).max(0.0));
1807
1808    Ok(Json(OpenRouterQuotaResponse {
1809        usage_usd: usage,
1810        limit_usd: limit,
1811        is_free_tier,
1812        remaining_usd: remaining,
1813    }))
1814}