1use axum::{
12 extract::{Query, State},
13 http::StatusCode,
14 response::IntoResponse,
15 Json,
16};
17use serde::{Deserialize, Serialize};
18use tracing::{error, info, warn};
19use std::env;
20
21use crate::{
22 model_management::{
23 downloader::DownloadSource,
24 registry::{ModelInfo, ModelStatus},
25 recommendation::{ModelRecommender, UseCase, QualityPreference, SpeedPreference, CostSensitivity},
26 ModelManager,
27 },
28 model_runtime::RuntimeManager,
29 shared_state::UnifiedAppState,
30 memory_db::ApiKeyType,
31};
32
33#[derive(Debug, Deserialize)]
35pub struct InstallModelRequest {
36 pub model_id: String,
37 pub model_name: String,
38 pub source: ModelSourceSpecifier,
39 pub description: Option<String>,
40 pub size_bytes: u64,
41 pub format: String,
42 pub hf_token: Option<String>,
44}
45
46#[derive(Debug, Deserialize)]
48#[serde(tag = "type")]
49pub enum ModelSourceSpecifier {
50 HuggingFace { repo_id: String, filename: String },
51 OpenRouter { model_id: String },
52}
53
54#[derive(Debug, Serialize)]
56pub struct InstallModelResponse {
57 pub download_id: String,
58 pub message: String,
59}
60
61#[derive(Debug, Serialize)]
63pub struct ActiveModelResponse {
64 pub model_path: String,
65 pub model_name: String,
66 pub format: String,
67 pub context_size: u32,
68 pub gpu_layers: u32,
69 pub backend_url: String,
70 pub status: String,
71}
72
73#[derive(Debug, Deserialize)]
75pub struct SearchModelsRequest {
76 pub query: String,
77 pub limit: Option<usize>,
78}
79
80#[derive(Debug, Serialize)]
82pub struct SearchModelsResponse {
83 pub models: Vec<ModelInfo>,
84 pub total_found: usize,
85}
86
87#[derive(Debug, Deserialize)]
89pub struct RefreshModelsRequest {
90 pub source: Option<String>,
92 pub openrouter_api_key: Option<String>,
96 pub hf_token: Option<String>,
98}
99
100#[derive(Debug, Serialize)]
102pub struct RefreshModelsResponse {
103 pub updated_sources: Vec<String>,
104 pub total_models: usize,
105}
106
107#[derive(Debug, Deserialize)]
109pub struct UpdatePreferencesRequest {
110 pub primary_use_case: Option<String>,
111 pub quality_preference: Option<String>,
112 pub speed_preference: Option<String>,
113 pub cost_sensitivity: Option<String>,
114}
115
116#[derive(Debug, Serialize)]
118pub struct HardwareRecommendationsResponse {
119 pub recommendations: Vec<String>,
120 pub message: String,
121}
122
123#[derive(Debug, Deserialize)]
125pub struct SwitchModelRequest {
126 pub model_id: String,
127}
128
129#[derive(Debug, Serialize)]
131pub struct SwitchModelResponse {
132 pub message: String,
133 pub model_id: String,
134 pub model_path: String,
135}
136
137async fn get_cloned_models(model_manager: &ModelManager) -> Vec<ModelInfo> {
139 let registry = model_manager.registry.read().await;
140 registry.list_models().iter().map(|m| (*m).clone()).collect()
141}
142
143async fn has_verified_key(state: &UnifiedAppState, key_type: ApiKeyType) -> bool {
145 match state.shared_state.database_pool.api_keys.get_key_plaintext(&key_type) {
146 Ok(Some(_)) => true,
147 _ => false,
148 }
149}
150
151pub async fn list_models(
154 State(state): State<UnifiedAppState>,
155) -> Result<impl IntoResponse, StatusCode> {
156 let model_manager = state.shared_state.model_manager.as_ref()
157 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
158
159 let models = get_cloned_models(model_manager).await;
162
163 Ok(Json(models))
164}
165
166pub async fn list_models_by_mode(
168 State(state): State<UnifiedAppState>,
169 Query(params): Query<std::collections::HashMap<String, String>>,
170) -> Result<impl IntoResponse, StatusCode> {
171 let model_manager = state.shared_state.model_manager.as_ref()
172 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
173
174 let mode = params.get("mode")
175 .map(|s| s.as_str())
176 .unwrap_or("offline");
177
178 let all_models = get_cloned_models(model_manager).await;
179
180 let has_openrouter_key = has_verified_key(&state, ApiKeyType::OpenRouter).await;
182 let has_huggingface_key = has_verified_key(&state, ApiKeyType::HuggingFace).await;
183
184 match mode {
185 "online" => {
186 let premium_providers = vec![
189 "openai", "google", "anthropic", "deepseek", "moonshot", "qwen", "zhipuai", "minimax", "microsoft", "meta-llama", "mistral", "cohere", "sarvam", "x-ai", "nvidia", "amazon", "replicate", ];
207
208 const MIN_PREMIUM_CTX: u64 = 32000; const MIN_STANDARD_CTX: u64 = 16000; let or_models: Vec<ModelInfo> = all_models
214 .into_iter()
215 .filter(|m| {
216 if m.download_source.as_deref() != Some("openrouter") {
217 return false;
218 }
219 let id_lower = m.id.to_lowercase();
220
221 let is_premium = premium_providers.iter().any(|p| id_lower.contains(p));
223
224 let ctx_len = m.context_length.unwrap_or(0);
226 let has_good_ctx = ctx_len >= MIN_STANDARD_CTX;
227
228 is_premium || has_good_ctx
230 })
231 .collect();
232
233 let mut sorted_models: Vec<ModelInfo> = or_models;
235 sorted_models.sort_by(|a, b| {
236 let a_lower = a.id.to_lowercase();
237 let b_lower = b.id.to_lowercase();
238
239 let a_is_premium = premium_providers.iter().any(|p| a_lower.contains(p));
241 let b_is_premium = premium_providers.iter().any(|p| b_lower.contains(p));
242
243 let a_ctx = a.context_length.unwrap_or(0);
245 let b_ctx = b.context_length.unwrap_or(0);
246
247 match (a_is_premium, b_is_premium) {
248 (true, false) => std::cmp::Ordering::Less,
249 (false, true) => std::cmp::Ordering::Greater,
250 _ => {
251 b_ctx.cmp(&a_ctx)
253 }
254 }
255 });
256
257 let mut seen = std::collections::HashSet::new();
259 sorted_models.retain(|m| {
260 let base_name = m.id
262 .replace("openrouter:", "")
263 .split(':')
264 .next()
265 .unwrap_or(&m.id)
266 .to_string();
267 seen.insert(base_name).then(|| {
268 true
270 }).unwrap_or(false)
271 });
272
273 sorted_models.truncate(50);
275
276 Ok(Json(sorted_models))
277 }
278 "offline" => {
279 let big_tech_authors = vec![
283 "google",
284 "meta",
285 "microsoft",
286 "openai",
287 "anthropic",
288 "deepseek-ai",
289 "bigscience",
290 "EleutherAI",
291 "tiiuae",
292 "mistralai",
293 "01-ai",
294 "Qwen",
295 "THUDM",
296 "baai",
297 ];
298
299 let mut hf_models: Vec<ModelInfo> = all_models
301 .into_iter()
302 .filter(|m| {
303 m.download_source.as_deref() == Some("huggingface") &&
304 !matches!(m.status, ModelStatus::Error(_))
305 })
306 .collect();
307
308 hf_models.sort_by(|a, b| {
310 let a_lower = a.author.as_deref().unwrap_or("").to_lowercase();
311 let b_lower = b.author.as_deref().unwrap_or("").to_lowercase();
312 let a_is_big = big_tech_authors.iter().any(|p| a_lower.contains(p));
313 let b_is_big = big_tech_authors.iter().any(|p| b_lower.contains(p));
314
315 match (a_is_big, b_is_big) {
316 (true, false) => std::cmp::Ordering::Less,
317 (false, true) => std::cmp::Ordering::Greater,
318 _ => b.downloads.cmp(&a.downloads), }
320 });
321
322 hf_models.truncate(100);
324
325 Ok(Json(hf_models))
326 }
327 _ => Err(StatusCode::BAD_REQUEST),
328 }
329}
330
331pub async fn get_active_model(
333 State(state): State<UnifiedAppState>,
334) -> Json<ActiveModelResponse> {
335 let config = &state.shared_state.config;
336
337 let runtime_arc = state.shared_state.runtime_manager
344 .read()
345 .ok()
346 .and_then(|g| g.clone()); let runtime_model_path: Option<String> = if let Some(rm) = runtime_arc {
349 rm.get_current_config().await
350 .map(|c| c.model_path.to_string_lossy().to_string())
351 .filter(|p| !p.is_empty())
352 } else {
353 None
354 };
355
356 let model_path = runtime_model_path
357 .as_deref()
358 .unwrap_or(&config.model_path);
359
360 let model_name = std::path::Path::new(model_path)
361 .file_stem()
362 .and_then(|s| s.to_str())
363 .unwrap_or("unknown")
364 .to_string();
365
366 let format = std::path::Path::new(model_path)
367 .extension()
368 .and_then(|s| s.to_str())
369 .unwrap_or("unknown")
370 .to_uppercase();
371
372 let file_exists = std::path::Path::new(model_path).exists();
373 let status = if file_exists { "loaded" } else { "not_found" }.to_string();
374
375 Json(ActiveModelResponse {
376 model_path: model_path.to_string(),
377 model_name,
378 format,
379 context_size: config.ctx_size,
380 gpu_layers: config.gpu_layers,
381 backend_url: config.backend_url.clone(),
382 status,
383 })
384}
385
386pub async fn search_models(
388 State(state): State<UnifiedAppState>,
389 Query(params): Query<SearchModelsRequest>,
390) -> Result<impl IntoResponse, StatusCode> {
391 let model_manager = state.shared_state.model_manager.as_ref()
392 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
393
394 let all_models = get_cloned_models(model_manager).await;
395 let query_lower = params.query.to_lowercase();
396
397 let mut filtered_models: Vec<ModelInfo> = all_models
398 .into_iter()
399 .filter(|model| {
400 model.name.to_lowercase().contains(&query_lower) ||
401 model.description.as_ref().map_or(false, |desc| desc.to_lowercase().contains(&query_lower)) ||
402 model.tags.iter().any(|tag| tag.to_lowercase().contains(&query_lower))
403 })
404 .collect();
405
406 let total_found = filtered_models.len();
407 let limit = params.limit.unwrap_or(20).min(total_found);
408
409 filtered_models.truncate(limit);
411
412 Ok(Json(SearchModelsResponse {
413 models: filtered_models,
414 total_found,
415 }))
416}
417
418pub async fn refresh_models(
420 State(state): State<UnifiedAppState>,
421 Json(payload): Json<RefreshModelsRequest>,
422) -> Result<impl IntoResponse, StatusCode> {
423 let model_manager = state
424 .shared_state
425 .model_manager
426 .as_ref()
427 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
428
429 let source = payload.source.as_deref().unwrap_or("all");
430 let mut updated_sources = Vec::new();
431
432 if source == "openrouter" || source == "all" {
434 let api_key = if let Some(key) = &payload.openrouter_api_key {
436 if !key.is_empty() { Some(key.clone()) } else { state.get_openrouter_api_key().await }
437 } else {
438 state.get_openrouter_api_key().await
439 };
440
441 if let Some(key) = api_key {
442 info!("Refreshing OpenRouter catalog with stored API key");
443 let mut registry = model_manager.registry.write().await;
444 if let Err(e) = registry.refresh_openrouter_catalog_from_api(&key).await {
445 error!("Failed to refresh OpenRouter catalog: {}", e);
446 } else {
447 updated_sources.push("openrouter".to_string());
448 }
449 if let Err(e) = registry.save_registry().await {
450 error!("Failed to save model registry after OpenRouter refresh: {}", e);
451 }
452 } else {
453 info!("No OpenRouter API key available - loading default OpenRouter catalog");
454 let mut registry = model_manager.registry.write().await;
456 registry.populate_default_openrouter_models().await;
457 if let Err(e) = registry.save_registry().await {
458 error!("Failed to save model registry after populating default OpenRouter models: {}", e);
459 } else {
460 updated_sources.push("openrouter".to_string());
461 }
462 }
463 }
464
465 if source == "huggingface" || source == "all" {
468 let mut registry = model_manager.registry.write().await;
469 if let Err(e) = registry.refresh_huggingface_catalog_from_api(100).await {
471 error!("Failed to refresh HuggingFace catalog: {}", e);
472 } else {
474 updated_sources.push("huggingface".to_string());
475 }
476 if let Err(e) = registry.save_registry().await {
477 error!("Failed to save model registry after HuggingFace refresh: {}", e);
478 }
479 }
480
481 if !updated_sources.is_empty() {
483 let cfg = &state.shared_state.config;
484 let hardware = crate::model_management::ModelRecommender::detect_hardware_profile(cfg);
485 let mut registry = model_manager.registry.write().await;
486 registry.update_compatibility_scores(&*model_manager.recommender, &hardware);
487 if let Err(e) = registry.save_registry().await {
488 error!("Failed to save registry after compatibility scoring: {}", e);
489 }
490 }
491
492 let total_models = {
493 let registry = model_manager.registry.read().await;
494 registry.list_models().len()
495 };
496
497 Ok(Json(RefreshModelsResponse {
498 updated_sources,
499 total_models,
500 }))
501}
502
503pub async fn install_model(
505 State(state): State<UnifiedAppState>,
506 Json(payload): Json<InstallModelRequest>,
507) -> Result<impl IntoResponse, StatusCode> {
508 let model_manager = state.shared_state.model_manager.as_ref()
509 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
510
511 info!("Installing model: {} ({})", payload.model_name, payload.model_id);
512
513 let model_info = ModelInfo {
515 id: payload.model_id.clone(),
516 name: payload.model_name.clone(),
517 description: payload.description,
518 author: None,
519 status: ModelStatus::Available,
520 size_bytes: payload.size_bytes,
521 format: payload.format,
522 download_source: None,
523 filename: None, installed_version: None,
525 last_updated: None,
526 tags: vec![], compatibility_score: None,
528 parameters: None,
529 context_length: None,
530 provider: None,
531 total_shards: None,
532 shard_filenames: vec![],
533 downloads: 0,
534 is_gated: false,
535 pricing: None,
536 };
537
538 let download_source = match payload.source {
540 ModelSourceSpecifier::HuggingFace { repo_id, filename } => {
541 DownloadSource::HuggingFace { repo_id, filename }
542 }
543 ModelSourceSpecifier::OpenRouter { model_id } => {
544 DownloadSource::OpenRouter { model_id }
545 }
546 };
547
548 let download_source_clone = download_source.clone();
550
551 let pre_download_id = model_manager.downloader.progress_tracker()
553 .start_download(
554 payload.model_id.clone(),
555 payload.model_name.clone(),
556 Some(payload.size_bytes),
557 )
558 .await;
559
560 let return_download_id = pre_download_id.clone();
561
562 {
564 let mut reg = model_manager.registry.write().await;
565 reg.update_model_status(&payload.model_id, ModelStatus::Downloading);
566 }
567
568 let registry = model_manager.registry.clone();
570 let downloader = model_manager.downloader.clone();
571 let existing_download_id = pre_download_id.clone();
572
573 tokio::spawn(async move {
574 match downloader.download_model(model_info.clone(), download_source_clone.clone(), Some(existing_download_id), payload.hf_token).await {
577 Ok(_download_id) => {
578 let filename = match &download_source_clone {
580 DownloadSource::HuggingFace { filename, .. } => Some(filename.clone()),
581 DownloadSource::OpenRouter { .. } => None, };
583
584 let mut reg = registry.write().await;
586 reg.update_model_status(&model_info.id, ModelStatus::Installed);
587
588 if let Some(fname) = filename {
590 if let Some(model) = reg.get_model_mut(&model_info.id) {
591 model.filename = Some(fname);
592 info!("Updated registry with filename for model: {}", model_info.id);
593 }
594 }
595
596 if let Err(e) = reg.save_registry().await {
597 error!("Failed to persist registry: {}", e);
598 }
599 drop(reg);
600
601 if let Err(e) = downloader.save_model_metadata(&model_info, &download_source_clone).await {
602 error!("Failed to save model metadata: {}", e);
603 }
604 info!("Model installation completed: {}", model_info.name);
605 }
606 Err(e) => {
607 error!("Model installation failed: {} - {}", model_info.name, e);
608 let mut reg = registry.write().await;
609 reg.update_model_status(&model_info.id, ModelStatus::Error(e.to_string()));
610 }
611 }
612 });
613
614 Ok((
615 StatusCode::ACCEPTED,
616 Json(InstallModelResponse {
617 download_id: return_download_id,
618 message: format!("Started downloading model: {}", payload.model_name),
619 })
620 ))
621}
622
623pub async fn get_download_progress(
625 State(state): State<UnifiedAppState>,
626 Query(params): Query<std::collections::HashMap<String, String>>,
627) -> Result<impl IntoResponse, StatusCode> {
628 let model_manager = state.shared_state.model_manager.as_ref()
629 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
630
631 let download_id = params.get("download_id")
632 .ok_or(StatusCode::BAD_REQUEST)?;
633
634 let progress = model_manager.downloader.progress_tracker()
635 .get_progress(download_id)
636 .await
637 .ok_or(StatusCode::NOT_FOUND)?;
638
639 Ok(Json(progress))
640}
641
642pub async fn get_active_downloads(
644 State(state): State<UnifiedAppState>,
645) -> Result<impl IntoResponse, StatusCode> {
646 let model_manager = state.shared_state.model_manager.as_ref()
647 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
648
649 let downloads = model_manager.downloader.progress_tracker()
650 .get_all_downloads()
651 .await;
652
653 Ok(Json(downloads))
654}
655
656pub async fn cancel_download(
658 State(state): State<UnifiedAppState>,
659 Query(params): Query<std::collections::HashMap<String, String>>,
660) -> Result<impl IntoResponse, StatusCode> {
661 let model_manager = state.shared_state.model_manager.as_ref()
662 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
663
664 let download_id = params.get("download_id")
665 .ok_or(StatusCode::BAD_REQUEST)?;
666
667 let success = model_manager.downloader.cancel_download(download_id).await;
668
669 if success {
670 Ok(Json(serde_json::json!({
671 "message": "Download cancelled successfully"
672 })))
673 } else {
674 Err(StatusCode::BAD_REQUEST)
675 }
676}
677
678pub async fn pause_download(
680 State(state): State<UnifiedAppState>,
681 Query(params): Query<std::collections::HashMap<String, String>>,
682) -> Result<impl IntoResponse, StatusCode> {
683 let model_manager = state.shared_state.model_manager.as_ref()
684 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
685
686 let download_id = params.get("download_id")
687 .ok_or(StatusCode::BAD_REQUEST)?;
688
689 let tracker = model_manager.downloader.progress_tracker();
690 if let Some(progress) = tracker.get_progress(download_id).await {
691 if progress.status == crate::model_management::progress::DownloadStatus::Downloading {
693 tracker.update_progress(
694 download_id,
695 progress.bytes_downloaded,
696 crate::model_management::progress::DownloadStatus::Paused,
697 None,
698 ).await;
699 Ok(Json(serde_json::json!({ "message": "Download paused" })))
700 } else {
701 Err(StatusCode::BAD_REQUEST)
703 }
704 } else {
705 Err(StatusCode::NOT_FOUND)
706 }
707}
708
709pub async fn resume_download(
711 State(state): State<UnifiedAppState>,
712 Query(params): Query<std::collections::HashMap<String, String>>,
713) -> Result<impl IntoResponse, StatusCode> {
714 let model_manager = state.shared_state.model_manager.as_ref()
715 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
716
717 let download_id = params.get("download_id")
718 .ok_or(StatusCode::BAD_REQUEST)?;
719
720 let tracker = model_manager.downloader.progress_tracker();
721 if let Some(progress) = tracker.get_progress(download_id).await {
722 if progress.status == crate::model_management::progress::DownloadStatus::Paused {
724 tracker.update_progress(
725 download_id,
726 progress.bytes_downloaded,
727 crate::model_management::progress::DownloadStatus::Downloading,
728 None,
729 ).await;
730 Ok(Json(serde_json::json!({ "message": "Download resumed" })))
731 } else {
732 Err(StatusCode::BAD_REQUEST)
734 }
735 } else {
736 Err(StatusCode::NOT_FOUND)
737 }
738}
739
740pub async fn remove_model(
742 State(state): State<UnifiedAppState>,
743 Query(params): Query<std::collections::HashMap<String, String>>,
744) -> Result<impl IntoResponse, StatusCode> {
745 let model_manager = state.shared_state.model_manager.as_ref()
746 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
747
748 let model_id = params.get("model_id")
749 .ok_or(StatusCode::BAD_REQUEST)?;
750
751 info!("Removing model: {}", model_id);
752
753 if let Err(e) = model_manager.storage.remove_model(model_id) {
755 error!("Failed to remove model from storage: {}", e);
756 return Err(StatusCode::INTERNAL_SERVER_ERROR);
757 }
758
759 let mut registry = model_manager.registry.write().await;
761 registry.remove_model(model_id);
762 if let Err(e) = registry.save_registry().await {
763 error!("Failed to persist registry after removal: {}", e);
764 }
765
766 Ok(Json(serde_json::json!({
767 "message": format!("Model {} removed successfully", model_id)
768 })))
769}
770
771pub async fn get_hardware_recommendations(
773 State(state): State<UnifiedAppState>,
774) -> Result<impl IntoResponse, StatusCode> {
775 let model_manager = state.shared_state.model_manager.as_ref()
776 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
777
778 let hardware = ModelRecommender::detect_hardware_profile(&state.shared_state.config);
779 let message = model_manager.recommender.get_hardware_recommendation_message(&hardware);
780
781 let recommendations = message.lines().map(|s| s.to_string()).collect::<Vec<String>>();
782
783 Ok(Json(HardwareRecommendationsResponse {
784 recommendations,
785 message,
786 }))
787}
788
789pub async fn update_preferences(
791 State(state): State<UnifiedAppState>,
792 Json(payload): Json<UpdatePreferencesRequest>,
793) -> Result<impl IntoResponse, StatusCode> {
794 let model_manager = state.shared_state.model_manager.as_ref()
795 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
796
797 let mut preferences = model_manager.recommender.get_preferences().clone();
798
799 if let Some(use_case) = payload.primary_use_case {
800 preferences.primary_use_case = match use_case.as_str() {
801 "chat_assistant" => UseCase::ChatAssistant,
802 "code_generation" => UseCase::CodeGeneration,
803 "creative_writing" => UseCase::CreativeWriting,
804 "research_analysis" => UseCase::ResearchAnalysis,
805 "translation" => UseCase::Translation,
806 _ => UseCase::GeneralPurpose,
807 };
808 }
809
810 if let Some(quality) = payload.quality_preference {
811 preferences.quality_preference = match quality.as_str() {
812 "high_quality" => QualityPreference::HighQuality,
813 "fast_response" => QualityPreference::FastResponse,
814 _ => QualityPreference::Balanced,
815 };
816 }
817
818 if let Some(speed) = payload.speed_preference {
819 preferences.speed_preference = match speed.as_str() {
820 "fastest" => SpeedPreference::Fastest,
821 "highest_quality" => SpeedPreference::HighestQuality,
822 _ => SpeedPreference::Balanced,
823 };
824 }
825
826 if let Some(cost) = payload.cost_sensitivity {
827 preferences.cost_sensitivity = match cost.as_str() {
828 "budget" => CostSensitivity::Budget,
829 "premium" => CostSensitivity::Premium,
830 _ => CostSensitivity::Moderate,
831 };
832 }
833
834 info!("User preferences updated: {:?}", preferences);
837
838 Ok(Json(serde_json::json!({
839 "message": "Preferences updated successfully"
840 })))
841}
842
843pub async fn get_recommended_models(
845 State(state): State<UnifiedAppState>,
846 Query(params): Query<std::collections::HashMap<String, String>>,
847) -> Result<impl IntoResponse, StatusCode> {
848 let model_manager = state.shared_state.model_manager.as_ref()
849 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
850
851 let max_results = params.get("limit")
852 .and_then(|s| s.parse().ok())
853 .unwrap_or(10);
854
855 let hardware = ModelRecommender::detect_hardware_profile(&state.shared_state.config);
856 let all_models = get_cloned_models(model_manager).await;
857
858 let recommendations = model_manager.recommender.get_recommendations(
859 all_models.iter().collect(),
860 &hardware,
861 max_results
862 );
863
864 let recommended_models: Vec<ModelInfo> = recommendations
866 .into_iter()
867 .filter_map(|(model_id, _)| {
868 all_models.iter().find(|m| m.id == model_id).cloned()
869 })
870 .collect();
871
872 Ok(Json(recommended_models))
873}
874
875#[derive(Debug, Serialize)]
877pub struct HardwareInfoResponse {
878 pub total_ram_gb: f32,
879 pub available_ram_gb: f32,
880 pub cpu_cores: u32,
881 pub gpu_available: bool,
882 pub gpu_vram_gb: Option<f32>,
883 pub storage_used_bytes: u64,
884 pub storage_available_bytes: u64,
885}
886
887pub async fn get_hardware_info(
889 State(state): State<UnifiedAppState>,
890) -> Result<impl IntoResponse, StatusCode> {
891 let hardware = ModelRecommender::detect_hardware_profile(&state.shared_state.config);
892
893 let (storage_used, storage_available) = if let Some(mm) = state.shared_state.model_manager.as_ref() {
894 let used = mm.storage.get_storage_usage().unwrap_or(0);
895 let available = mm.storage.get_available_space().unwrap_or(0);
896 (used, available)
897 } else {
898 (0, 0)
899 };
900
901 Ok(Json(HardwareInfoResponse {
902 total_ram_gb: hardware.total_ram_gb,
903 available_ram_gb: hardware.available_ram_gb,
904 cpu_cores: hardware.cpu_cores,
905 gpu_available: hardware.gpu_available,
906 gpu_vram_gb: hardware.gpu_vram_gb,
907 storage_used_bytes: storage_used,
908 storage_available_bytes: storage_available,
909 }))
910}
911
912#[derive(Serialize)]
914pub struct SystemMetricsResponse {
915 pub cpu_usage_percent: f32,
916 pub per_core_usage: Vec<f32>,
917 pub cpu_model_name: String,
918 pub cpu_frequency_mhz: u64,
919 pub gpu_available: bool,
920 pub gpu_name: String,
921 pub gpu_usage_percent: f32,
922 pub gpu_vram_total_gb: f32,
923 pub gpu_vram_used_gb: f32,
924 pub gpu_temperature_c: f32,
925 pub memory_total_gb: f32,
926 pub memory_used_gb: f32,
927 pub memory_available_gb: f32,
928 pub gpu_layers_offloaded: u32,
929 pub inference_device: String, }
931
932pub async fn switch_model(
934 State(state): State<UnifiedAppState>,
935 Json(payload): Json<SwitchModelRequest>,
936) -> Result<impl IntoResponse, StatusCode> {
937 let model_manager = state.shared_state.model_manager.as_ref()
938 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
939
940 let model_info = {
942 let registry = model_manager.registry.read().await;
943 registry.get_model(&payload.model_id)
944 .ok_or(StatusCode::NOT_FOUND)?
945 .clone()
946 };
947
948 let model_metadata = {
950 let registry = model_manager.registry.read().await;
951 registry.get_model_metadata(&payload.model_id).await
952 };
953
954 if model_info.status != ModelStatus::Installed {
956 return Err(StatusCode::BAD_REQUEST);
957 }
958
959 let model_path = if let Some(ref filename) = model_info.filename {
961 let path = model_manager.storage.model_path(&payload.model_id, filename);
963 info!("🔍 Resolving model path from registry filename: {}", path.display());
964 path
965 } else {
966 warn!("⚠️ Model {} has no filename in registry, scanning directory...", payload.model_id);
968
969 let model_dir = model_manager.storage.model_path(&payload.model_id, "dummy").parent()
971 .map(|p| p.to_path_buf())
972 .ok_or_else(|| {
973 error!("❌ Failed to get parent directory for model: {}", payload.model_id);
974 StatusCode::NOT_FOUND
975 })?;
976
977 info!("📂 Scanning model directory: {}", model_dir.display());
978
979 if !model_dir.exists() {
980 error!("❌ Model directory does not exist: {}", model_dir.display());
981 return Err(StatusCode::NOT_FOUND);
982 }
983
984 let mut found_path = None;
986 if let Ok(entries) = std::fs::read_dir(&model_dir) {
987 for entry in entries.flatten() {
988 if let Ok(file_type) = entry.file_type() {
989 if file_type.is_file() {
990 let path = entry.path();
991 let ext = path.extension().unwrap_or_default().to_string_lossy().to_lowercase();
992 if matches!(ext.as_str(), "gguf" | "bin" | "ggml" | "onnx" | "trt" | "engine" | "safetensors" | "mlmodel") {
993 info!("✅ Found model file: {}", path.display());
995 found_path = Some(path);
996 break; }
998 }
999 }
1000 }
1001 }
1002
1003 match found_path {
1004 Some(path) => path,
1005 None => {
1006 error!("❌ No valid model file found in directory: {}", model_dir.display());
1007 return Err(StatusCode::NOT_FOUND);
1008 }
1009 }
1010 };
1011
1012 if !model_path.exists() {
1013 error!("❌ Model file does not exist: {}", model_path.display());
1014 error!(" Please check if the model was downloaded correctly to AppData");
1015 return Err(StatusCode::NOT_FOUND);
1016 }
1017
1018 info!("✅ Model file verified at: {}", model_path.display());
1019
1020 let model_path_str = model_path.to_string_lossy().to_string();
1022
1023 let runtime_manager = {
1025 let guard = state.shared_state.runtime_manager.read()
1026 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1027 guard.clone().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?
1028 };
1029
1030 let model_format = match model_info.format.as_str().to_lowercase().as_str() {
1033 "gguf" => crate::model_runtime::ModelFormat::GGUF,
1034 "ggml" => crate::model_runtime::ModelFormat::GGML,
1035 "onnx" => crate::model_runtime::ModelFormat::ONNX,
1036 "tensorrt" => crate::model_runtime::ModelFormat::TensorRT,
1037 "safetensors" => crate::model_runtime::ModelFormat::Safetensors,
1038 "coreml" => crate::model_runtime::ModelFormat::CoreML,
1039 _ => crate::model_runtime::ModelFormat::GGUF, };
1041
1042 let runtime_binary = if let Some(ref metadata) = model_metadata {
1045 use crate::model_runtime::platform_detector::HardwareCapabilities;
1047 let hw_caps = HardwareCapabilities::default();
1048 let platform_name = match hw_caps.platform {
1049 crate::model_runtime::platform_detector::Platform::Windows => "windows",
1050 crate::model_runtime::platform_detector::Platform::Linux => "linux",
1051 crate::model_runtime::platform_detector::Platform::MacOS => "macos",
1052 };
1053
1054 if let Some(bin_path) = metadata.runtime_binaries.get(platform_name) {
1056 Some(bin_path.clone())
1057 } else {
1058 if let Some(ref engine_manager) = state.shared_state.engine_manager {
1060 let registry = engine_manager.registry.read().await;
1061 registry.get_default_engine_binary_path()
1062 .or_else(|| if !state.shared_state.config.llama_bin.is_empty() {
1063 Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1064 } else { None })
1065 } else {
1066 Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1067 }
1068 }
1069 } else {
1070 if let Some(ref engine_manager) = state.shared_state.engine_manager {
1072 let registry = engine_manager.registry.read().await;
1073 registry.get_default_engine_binary_path()
1074 .or_else(|| if !state.shared_state.config.llama_bin.is_empty() {
1075 Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1076 } else { None })
1077 } else {
1078 Some(std::path::PathBuf::from(&state.shared_state.config.llama_bin))
1079 }
1080 };
1081
1082 let runtime_config = crate::model_runtime::RuntimeConfig {
1083 model_path: model_path.clone(),
1084 format: model_format, host: state.shared_state.config.llama_host.clone(),
1086 port: state.shared_state.config.llama_port,
1087 context_size: state.shared_state.config.ctx_size,
1088 batch_size: state.shared_state.config.batch_size,
1089 threads: state.shared_state.config.threads,
1090 gpu_layers: state.shared_state.config.gpu_layers,
1091 parallel_slots: state.shared_state.config.parallel_slots,
1092 ubatch_size: state.shared_state.config.ubatch_size,
1093 runtime_binary: runtime_binary.clone(), draft_model_path: {
1095 let p = &state.shared_state.config.draft_model_path;
1096 if p == "none" || p.is_empty() { None } else { Some(std::path::PathBuf::from(p)) }
1097 },
1098 speculative_draft_max: state.shared_state.config.speculative_draft_max,
1099 speculative_draft_p_min: state.shared_state.config.speculative_draft_p_min,
1100 extra_config: serde_json::json!({}),
1101 };
1102
1103 info!("🚀 Initializing runtime with config:");
1104 info!(" Model Path: {}", runtime_config.model_path.display());
1105 info!(" Runtime Binary: {}", runtime_binary.as_ref().map(|p| p.display().to_string()).unwrap_or_else(|| "None".to_string()));
1106 info!(" Format: {:?}", runtime_config.format);
1107 info!(" Host: {}:{}", runtime_config.host, runtime_config.port);
1108 info!(" Context Size: {}", runtime_config.context_size);
1109 info!(" GPU Layers: {}", runtime_config.gpu_layers);
1110
1111 if let Some(current_config) = runtime_manager.get_current_config().await {
1114 if current_config.model_path == runtime_config.model_path
1115 && runtime_manager.is_ready().await
1116 {
1117 info!("✅ Model {} is already loaded and ready — skipping re-initialization", model_info.name);
1118 if let Some(data_dir) = dirs::data_dir() {
1119 let last_model_path = data_dir.join("Aud.io").join("last_model.txt");
1120 let _ = std::fs::write(last_model_path, &payload.model_id);
1121 }
1122 return Ok(Json(SwitchModelResponse {
1123 message: format!("Model {} is already loaded and ready for inference", model_info.name),
1124 model_id: payload.model_id.clone(),
1125 model_path: model_path_str,
1126 }));
1127 }
1128 }
1129
1130 match runtime_manager.initialize_auto(runtime_config).await {
1132 Ok(base_url) => {
1133 info!("Runtime initialized at {}, performing health check...", base_url);
1134
1135 match runtime_manager.health_check().await {
1137 Ok(_) => {
1138 info!("✅ Model {} activated successfully and health check passed", model_info.name);
1139
1140 if let Some(data_dir) = dirs::data_dir() {
1142 let last_model_path = data_dir.join("Aud.io").join("last_model.txt");
1143 let _ = std::fs::write(last_model_path, &payload.model_id);
1144 }
1145
1146 Ok(Json(SwitchModelResponse {
1147 message: format!("Model {} loaded and ready for inference", model_info.name),
1148 model_id: payload.model_id.clone(),
1149 model_path: model_path_str,
1150 }))
1151 }
1152 Err(e) => {
1153 error!("❌ Model activation health check failed: {}", e);
1154 error!(" The model may be too large or incompatible with your hardware");
1155 Err(StatusCode::INTERNAL_SERVER_ERROR)
1156 }
1157 }
1158 }
1159 Err(e) => {
1160 let error_msg = e.to_string();
1161
1162 if error_msg.contains("binary not found")
1164 || error_msg.contains("not found at")
1165 || error_msg.contains("No such file")
1166 {
1167 error!("Engine binary not found - attempting automatic download and retry");
1168
1169 if let Some(ref engine_manager) = state.shared_state.engine_manager {
1171 match engine_manager.ensure_engine_available().await {
1172 Ok(true) => {
1173 info!("Engine downloaded successfully, retrying model switch...");
1174
1175 return Ok(Json(SwitchModelResponse {
1180 message: "Engine was downloaded. Please retry switching models.".to_string(),
1181 model_id: payload.model_id.clone(),
1182 model_path: "retry_required".to_string(),
1183 }));
1184 }
1185 Ok(false) => {
1186 error!("Failed to auto-download engine - download returned false");
1187 return Ok(Json(SwitchModelResponse {
1189 message: "Engine download failed. Please check your internet connection and try again.".to_string(),
1190 model_id: payload.model_id.clone(),
1191 model_path: "engine_download_failed".to_string(),
1192 }));
1193 }
1194 Err(e) => {
1195 error!("Failed to auto-download engine: {}", e);
1196 return Ok(Json(SwitchModelResponse {
1198 message: format!("Engine download error: {}", e),
1199 model_id: payload.model_id.clone(),
1200 model_path: "engine_download_error".to_string(),
1201 }));
1202 }
1203 }
1204 } else {
1205 error!("No engine manager available");
1206 return Ok(Json(SwitchModelResponse {
1207 message: "Engine manager not initialized. Please restart the application.".to_string(),
1208 model_id: payload.model_id.clone(),
1209 model_path: "no_engine_manager".to_string(),
1210 }));
1211 }
1212 }
1213
1214 error!("Failed to switch model: {}", error_msg);
1215 Err(StatusCode::INTERNAL_SERVER_ERROR)
1216 }
1217 }
1218}
1219
1220pub async fn get_system_metrics(
1222 State(state): State<UnifiedAppState>,
1223) -> Result<impl IntoResponse, StatusCode> {
1224 use sysinfo::System;
1225
1226 let mut system = System::new_all();
1228 system.refresh_cpu();
1229 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
1231 system.refresh_cpu();
1232 system.refresh_memory();
1233
1234 let cpu_usage = system.global_cpu_info().cpu_usage();
1235 let per_core: Vec<f32> = system.cpus().iter().map(|cpu| cpu.cpu_usage()).collect();
1236 let cpu_model = system.cpus().first().map(|c| c.brand().to_string()).unwrap_or_else(|| "Unknown CPU".into());
1237 let cpu_freq = system.cpus().first().map(|c| c.frequency()).unwrap_or(0);
1238
1239 let total_mem = system.total_memory() as f32 / (1024.0 * 1024.0 * 1024.0);
1240 let used_mem = (system.total_memory() - system.available_memory()) as f32 / (1024.0 * 1024.0 * 1024.0);
1241 let available_mem = system.available_memory() as f32 / (1024.0 * 1024.0 * 1024.0);
1242
1243 let (gpu_available, gpu_name, gpu_usage, gpu_vram_total, gpu_vram_used, gpu_temp) = {
1245 #[cfg(feature = "nvidia")]
1246 {
1247 match nvml_wrapper::Nvml::init() {
1248 Ok(nvml) => {
1249 match nvml.device_by_index(0) {
1250 Ok(device) => {
1251 let name = device.name().unwrap_or_else(|_| "GPU".into());
1252 let utilization = device.utilization_rates().map(|u| u.gpu as f32).unwrap_or(0.0);
1253 let mem_info = device.memory_info();
1254 let vram_total = mem_info.as_ref().map(|m| m.total as f32 / (1024.0 * 1024.0 * 1024.0)).unwrap_or(0.0);
1255 let vram_used = mem_info.as_ref().map(|m| m.used as f32 / (1024.0 * 1024.0 * 1024.0)).unwrap_or(0.0);
1256 let temp = device.temperature(nvml_wrapper::enum_wrappers::device::TemperatureSensor::Gpu).unwrap_or(0) as f32;
1257 tracing::debug!("GPU detected: {}, usage: {}%, VRAM: {}/{} GB", name, utilization, vram_used, vram_total);
1258 (true, name, utilization, vram_total, vram_used, temp)
1259 }
1260 Err(e) => {
1261 tracing::warn!("NVML initialized but failed to get device: {}", e);
1262 (false, String::from("Not detected"), 0.0_f32, 0.0_f32, 0.0_f32, 0.0_f32)
1263 }
1264 }
1265 }
1266 Err(e) => {
1267 tracing::debug!("NVML not available: {}", e);
1268 (false, String::from("Not detected"), 0.0_f32, 0.0_f32, 0.0_f32, 0.0_f32)
1269 }
1270 }
1271 }
1272 #[cfg(not(feature = "nvidia"))]
1273 {
1274 (false, String::from("Not detected"), 0.0_f32, 0.0_f32, 0.0_f32, 0.0_f32)
1276 }
1277 };
1278
1279 let gpu_layers = state.shared_state.config.gpu_layers;
1281 let inference_device = if !gpu_available {
1282 "CPU".to_string()
1283 } else if gpu_layers == 0 {
1284 "CPU".to_string()
1285 } else if gpu_layers >= 50 {
1286 "GPU".to_string()
1287 } else {
1288 "CPU+GPU".to_string()
1289 };
1290
1291 Ok(Json(SystemMetricsResponse {
1292 cpu_usage_percent: cpu_usage,
1293 per_core_usage: per_core,
1294 cpu_model_name: cpu_model,
1295 cpu_frequency_mhz: cpu_freq,
1296 gpu_available,
1297 gpu_name,
1298 gpu_usage_percent: gpu_usage,
1299 gpu_vram_total_gb: gpu_vram_total,
1300 gpu_vram_used_gb: gpu_vram_used,
1301 gpu_temperature_c: gpu_temp,
1302 memory_total_gb: total_mem,
1303 memory_used_gb: used_mem,
1304 memory_available_gb: available_mem,
1305 gpu_layers_offloaded: gpu_layers,
1306 inference_device,
1307 }))
1308}
1309
1310#[derive(Debug, Serialize)]
1312pub struct StorageMetadataResponse {
1313 pub paths: StoragePaths,
1315 pub models: Vec<DownloadedModelInfo>,
1317 pub storage_stats: StorageStats,
1319 pub database_info: DatabaseInfo,
1321 pub engines: Vec<InstalledEngineInfo>,
1323}
1324
1325#[derive(Debug, Serialize)]
1327pub struct StoragePaths {
1328 pub app_data_dir: String,
1329 pub models_dir: String,
1330 pub registry_dir: String,
1331 pub database_path: String,
1332 pub engines_dir: String,
1333}
1334
1335#[derive(Debug, Serialize)]
1337pub struct InstalledEngineInfo {
1338 pub id: String,
1339 pub name: String,
1340 pub version: String,
1341 pub platform: String,
1342 pub acceleration: String,
1343 pub file_size: u64,
1344 pub size_human: String,
1345 pub install_path: String,
1346 pub binary_name: String,
1347 pub is_default: bool,
1348}
1349
1350#[derive(Debug, Serialize)]
1352pub struct DownloadedModelInfo {
1353 pub id: String,
1354 pub name: String,
1355 pub format: String,
1356 pub size_bytes: u64,
1357 pub size_human: String,
1358 pub download_date: String,
1359 pub download_source: String,
1360 pub file_path: String,
1361 pub metadata_path: Option<String>,
1362}
1363
1364#[derive(Debug, Serialize)]
1366pub struct StorageStats {
1367 pub models_total_bytes: u64,
1368 pub models_total_human: String,
1369 pub available_space_bytes: u64,
1370 pub available_space_human: String,
1371 pub model_count: usize,
1372}
1373
1374#[derive(Debug, Serialize)]
1376pub struct DatabaseInfo {
1377 pub path: String,
1378 pub size_bytes: u64,
1379 pub size_human: String,
1380}
1381
1382pub async fn get_storage_metadata(
1384 State(state): State<UnifiedAppState>,
1385) -> Result<impl IntoResponse, StatusCode> {
1386 use crate::model_management::storage::ModelMetadata;
1387
1388 let model_manager = state.shared_state.model_manager.as_ref()
1389 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
1390
1391 let app_data_dir = dirs::data_dir()
1393 .map(|d| {
1394 if cfg!(target_os = "windows") || cfg!(target_os = "macos") {
1395 d.join("Aud.io")
1396 } else {
1397 d.join("aud.io")
1398 }
1399 })
1400 .unwrap_or_else(|| std::path::PathBuf::from("./aud.io-data"));
1401
1402 let engines_dir = app_data_dir.join("engines");
1404
1405 let paths = StoragePaths {
1407 app_data_dir: app_data_dir.to_string_lossy().to_string(),
1408 models_dir: model_manager.storage.location.models_dir.to_string_lossy().to_string(),
1409 registry_dir: model_manager.storage.location.registry_dir.to_string_lossy().to_string(),
1410 database_path: app_data_dir.join("memory.db").to_string_lossy().to_string(),
1411 engines_dir: engines_dir.to_string_lossy().to_string(),
1412 };
1413
1414 let mut models = Vec::new();
1416 let installed_models: Vec<crate::model_management::registry::ModelInfo> = {
1417 let registry = model_manager.registry.read().await;
1418 registry.list_models().into_iter()
1419 .filter(|m| matches!(m.status, crate::model_management::registry::ModelStatus::Installed))
1420 .cloned()
1421 .collect()
1422 };
1423
1424 for model in installed_models {
1425 let metadata_path = model_manager.storage.metadata_path(&model.id);
1427 let download_date = if metadata_path.exists() {
1428 std::fs::metadata(&metadata_path)
1429 .and_then(|m| m.modified())
1430 .ok()
1431 .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
1432 .map(|d| chrono::DateTime::from_timestamp(d.as_secs() as i64, 0))
1433 .flatten()
1434 .map(|d| d.format("%Y-%m-%d %H:%M:%S UTC").to_string())
1435 .unwrap_or_else(|| "Unknown".to_string())
1436 } else {
1437 "Unknown".to_string()
1438 };
1439
1440 let download_source = if metadata_path.exists() {
1441 std::fs::read_to_string(&metadata_path)
1442 .ok()
1443 .and_then(|content| serde_json::from_str::<ModelMetadata>(&content).ok())
1444 .map(|m| m.download_source)
1445 .unwrap_or_else(|| "unknown".to_string())
1446 } else {
1447 "unknown".to_string()
1448 };
1449
1450 let model_dir = model_manager.storage.location.models_dir.join(
1452 model.id.replace(':', "_").replace('/', "_").replace('\\', "_")
1453 );
1454 let mut actual_size = model.size_bytes;
1455 if model_dir.exists() {
1456 actual_size = walkdir::WalkDir::new(&model_dir)
1457 .into_iter()
1458 .filter_map(|e| e.ok())
1459 .filter(|e| e.file_type().is_file())
1460 .filter_map(|e| e.metadata().ok())
1461 .map(|m| m.len())
1462 .sum();
1463 }
1464
1465 models.push(DownloadedModelInfo {
1466 id: model.id.clone(),
1467 name: model.name.clone(),
1468 format: model.format.clone(),
1469 size_bytes: actual_size,
1470 size_human: format_bytes(actual_size),
1471 download_date,
1472 download_source,
1473 file_path: model_dir.to_string_lossy().to_string(),
1474 metadata_path: if metadata_path.exists() {
1475 Some(metadata_path.to_string_lossy().to_string())
1476 } else {
1477 None
1478 },
1479 });
1480 }
1481
1482 let models_total_bytes = model_manager.storage.get_storage_usage().unwrap_or(0);
1484 let available_space_bytes = model_manager.storage.get_available_space().unwrap_or(0);
1485
1486 let storage_stats = StorageStats {
1487 models_total_bytes,
1488 models_total_human: format_bytes(models_total_bytes),
1489 available_space_bytes,
1490 available_space_human: format_bytes(available_space_bytes),
1491 model_count: models.len(),
1492 };
1493
1494 let db_path = app_data_dir.join("memory.db");
1496 let db_size = std::fs::metadata(&db_path).map(|m| m.len()).unwrap_or(0);
1497
1498 let database_info = DatabaseInfo {
1499 path: db_path.to_string_lossy().to_string(),
1500 size_bytes: db_size,
1501 size_human: format_bytes(db_size),
1502 };
1503
1504 let mut engines = Vec::new();
1506 if let Some(ref engine_manager) = state.shared_state.engine_manager {
1507 let registry = engine_manager.registry.read().await;
1508 let default_engine_id = registry.default_engine.clone();
1509
1510 for (engine_id, engine_info) in ®istry.installed_engines {
1511 if let Some(install_path) = &engine_info.install_path {
1512 engines.push(InstalledEngineInfo {
1513 id: engine_info.id.clone(),
1514 name: engine_info.name.clone(),
1515 version: engine_info.version.clone(),
1516 platform: format!("{:?}", engine_info.platform),
1517 acceleration: format!("{:?}", engine_info.acceleration),
1518 file_size: engine_info.file_size,
1519 size_human: format_bytes(engine_info.file_size),
1520 install_path: install_path.to_string_lossy().to_string(),
1521 binary_name: engine_info.binary_name.clone(),
1522 is_default: default_engine_id.as_ref() == Some(engine_id),
1523 });
1524 }
1525 }
1526 }
1527
1528 Ok(Json(StorageMetadataResponse {
1529 paths,
1530 models,
1531 storage_stats,
1532 database_info,
1533 engines,
1534 }))
1535}
1536
1537fn format_bytes(bytes: u64) -> String {
1539 const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
1540 let mut size = bytes as f64;
1541 let mut unit_index = 0;
1542
1543 while size >= 1024.0 && unit_index < UNITS.len() - 1 {
1544 size /= 1024.0;
1545 unit_index += 1;
1546 }
1547
1548 format!("{:.2} {}", size, UNITS[unit_index])
1549}
1550
1551#[derive(Debug, Deserialize)]
1557pub struct HfAccessParams {
1558 pub repo_id: String,
1559 pub filename: String,
1560 pub hf_token: Option<String>,
1561}
1562
1563#[derive(Debug, Serialize)]
1565pub struct HfAccessResponse {
1566 pub status: String,
1568 pub can_download: bool,
1570 pub message: String,
1572}
1573
1574pub async fn check_hf_access(
1579 Query(params): Query<HfAccessParams>,
1580) -> Result<impl IntoResponse, StatusCode> {
1581 use crate::model_management::{check_hf_gated_access, HfAccessStatus};
1582
1583 let status = check_hf_gated_access(
1584 ¶ms.repo_id,
1585 ¶ms.filename,
1586 params.hf_token.as_deref(),
1587 )
1588 .await;
1589
1590 let (status_str, can_download, message) = match &status {
1591 HfAccessStatus::Accessible => (
1592 "accessible",
1593 true,
1594 "Access granted — download can proceed.".to_string(),
1595 ),
1596 HfAccessStatus::NotApproved => (
1597 "not_approved",
1598 false,
1599 "Your token is valid but you have not been approved to access this \
1600 model yet. Visit the model page on HuggingFace to request access."
1601 .to_string(),
1602 ),
1603 HfAccessStatus::Unauthorized => (
1604 "unauthorized",
1605 false,
1606 "No HuggingFace token provided or the token is invalid. \
1607 Please add your HF token in Settings."
1608 .to_string(),
1609 ),
1610 HfAccessStatus::NotFound => (
1611 "not_found",
1612 false,
1613 "The model or file was not found on HuggingFace.".to_string(),
1614 ),
1615 HfAccessStatus::Error(e) => (
1616 "error",
1617 false,
1618 format!("Network or server error: {}", e),
1619 ),
1620 };
1621
1622 Ok(Json(HfAccessResponse {
1623 status: status_str.to_string(),
1624 can_download,
1625 message,
1626 }))
1627}
1628
1629#[derive(Debug, Deserialize)]
1635pub struct OpenRouterCatalogParams {
1636 pub page: Option<usize>,
1638 pub per_page: Option<usize>,
1640 pub search: Option<String>,
1642 pub provider: Option<String>,
1644 pub free_only: Option<bool>,
1646 pub min_context: Option<u64>,
1648}
1649
1650#[derive(Debug, Serialize)]
1652pub struct OpenRouterCatalogResponse {
1653 pub models: Vec<ModelInfo>,
1654 pub total: usize,
1655 pub page: usize,
1656 pub per_page: usize,
1657 pub total_pages: usize,
1658}
1659
1660pub async fn openrouter_catalog(
1665 State(state): State<UnifiedAppState>,
1666 Query(params): Query<OpenRouterCatalogParams>,
1667) -> Result<impl IntoResponse, StatusCode> {
1668 let model_manager = state
1669 .shared_state
1670 .model_manager
1671 .as_ref()
1672 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
1673
1674 let page = params.page.unwrap_or(1).max(1);
1675 let per_page = params.per_page.unwrap_or(50).clamp(1, 200);
1676
1677 let mut models: Vec<ModelInfo> = {
1679 let registry = model_manager.registry.read().await;
1680 registry
1681 .list_models()
1682 .into_iter()
1683 .filter(|m| m.download_source.as_deref() == Some("openrouter"))
1684 .cloned()
1685 .collect()
1686 };
1687
1688 if let Some(ref search) = params.search {
1690 let q = search.to_lowercase();
1691 models.retain(|m| {
1692 m.name.to_lowercase().contains(&q)
1693 || m.id.to_lowercase().contains(&q)
1694 || m.description
1695 .as_ref()
1696 .map_or(false, |d| d.to_lowercase().contains(&q))
1697 || m.provider
1698 .as_ref()
1699 .map_or(false, |p| p.to_lowercase().contains(&q))
1700 });
1701 }
1702
1703 if let Some(ref provider) = params.provider {
1704 let prov = provider.to_lowercase();
1705 models.retain(|m| {
1706 m.provider
1707 .as_ref()
1708 .map_or(false, |p| p.to_lowercase().contains(&prov))
1709 || m.id.to_lowercase().starts_with(&prov)
1710 });
1711 }
1712
1713 if params.free_only.unwrap_or(false) {
1714 models.retain(|m| {
1715 m.pricing.as_ref().map_or(false, |p| p.is_free())
1716 || m.tags.iter().any(|t| t == "free")
1717 });
1718 }
1719
1720 if let Some(min_ctx) = params.min_context {
1721 models.retain(|m| m.context_length.map_or(false, |ctx| ctx >= min_ctx));
1722 }
1723
1724 models.sort_by(|a, b| a.name.cmp(&b.name));
1726
1727 let total = models.len();
1728 let total_pages = total.div_ceil(per_page);
1729 let start = (page - 1) * per_page;
1730 let page_models: Vec<ModelInfo> = models.into_iter().skip(start).take(per_page).collect();
1731
1732 Ok(Json(OpenRouterCatalogResponse {
1733 models: page_models,
1734 total,
1735 page,
1736 per_page,
1737 total_pages,
1738 }))
1739}
1740
1741#[derive(Debug, Serialize)]
1747pub struct OpenRouterQuotaResponse {
1748 pub usage_usd: f64,
1750 pub limit_usd: Option<f64>,
1752 pub is_free_tier: bool,
1754 pub remaining_usd: Option<f64>,
1756}
1757
1758pub async fn openrouter_quota(
1763 State(state): State<UnifiedAppState>,
1764) -> Result<impl IntoResponse, StatusCode> {
1765 let api_key = state
1767 .shared_state
1768 .database_pool
1769 .api_keys
1770 .get_key_plaintext(&ApiKeyType::OpenRouter)
1771 .ok()
1772 .flatten()
1773 .ok_or(StatusCode::UNAUTHORIZED)?;
1774
1775 let client = reqwest::Client::builder()
1776 .timeout(std::time::Duration::from_secs(10))
1777 .build()
1778 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
1779
1780 let resp = client
1781 .get("https://openrouter.ai/api/v1/auth/key")
1782 .header("Authorization", format!("Bearer {}", api_key))
1783 .send()
1784 .await
1785 .map_err(|e| {
1786 warn!("OpenRouter quota request failed: {}", e);
1787 StatusCode::BAD_GATEWAY
1788 })?;
1789
1790 if !resp.status().is_success() {
1791 warn!("OpenRouter /auth/key returned {}", resp.status());
1792 return Err(StatusCode::BAD_GATEWAY);
1793 }
1794
1795 #[derive(serde::Deserialize)]
1796 struct OrKeyData {
1797 usage: Option<f64>,
1798 limit: Option<f64>,
1799 is_free_tier: Option<bool>,
1800 }
1801 #[derive(serde::Deserialize)]
1802 struct OrKeyResp {
1803 data: OrKeyData,
1804 }
1805
1806 let body: OrKeyResp = resp.json().await.map_err(|e| {
1807 warn!("Failed to parse OpenRouter quota response: {}", e);
1808 StatusCode::BAD_GATEWAY
1809 })?;
1810
1811 let usage = body.data.usage.unwrap_or(0.0);
1812 let limit = body.data.limit;
1813 let is_free_tier = body.data.is_free_tier.unwrap_or(false);
1814 let remaining = limit.map(|l| (l - usage).max(0.0));
1815
1816 Ok(Json(OpenRouterQuotaResponse {
1817 usage_usd: usage,
1818 limit_usd: limit,
1819 is_free_tier,
1820 remaining_usd: remaining,
1821 }))
1822}