1use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11use serde::{Deserialize, Serialize};
12use tracing::info;
13
14use crate::schema::*;
15use crate::InferenceError;
16
17#[derive(Debug, Clone, Default)]
19pub struct ModelFilter {
20 pub capabilities: Vec<ModelCapability>,
22 pub max_size_mb: Option<u64>,
24 pub max_latency_ms: Option<u64>,
26 pub max_cost_per_mtok: Option<f64>,
28 pub tags: Vec<String>,
30 pub provider: Option<String>,
32 pub local_only: bool,
34 pub available_only: bool,
36}
37
38pub struct UnifiedRegistry {
40 models_dir: PathBuf,
41 models: HashMap<String, ModelSchema>,
43 user_config_path: PathBuf,
45}
46
47impl UnifiedRegistry {
48 pub fn new(models_dir: PathBuf) -> Self {
49 let user_config_path = models_dir
50 .parent()
51 .unwrap_or(&models_dir)
52 .join("models.json");
53
54 let mut registry = Self {
55 models_dir,
56 models: HashMap::new(),
57 user_config_path,
58 };
59 registry.load_builtin_catalog();
60 registry.refresh_availability();
61 let _ = registry.load_user_config();
63 registry
64 }
65
66 pub fn register(&mut self, mut schema: ModelSchema) {
68 if schema.is_local() {
70 let local_path = self.models_dir.join(&schema.name).join("model.gguf");
71 schema.available = local_path.exists();
72 } else if schema.is_remote() {
73 if let ModelSource::RemoteApi { ref api_key_env, .. } = schema.source {
75 schema.available = std::env::var(api_key_env).is_ok();
76 }
77 }
78 info!(id = %schema.id, name = %schema.name, available = schema.available, "registered model");
79 self.models.insert(schema.id.clone(), schema);
80 }
81
82 pub fn unregister(&mut self, id: &str) -> Option<ModelSchema> {
84 let removed = self.models.remove(id);
85 if let Some(ref m) = removed {
86 info!(id = %m.id, "unregistered model");
87 }
88 removed
89 }
90
91 pub fn list(&self) -> Vec<&ModelSchema> {
93 let mut models: Vec<&ModelSchema> = self.models.values().collect();
94 models.sort_by(|a, b| a.id.cmp(&b.id));
95 models
96 }
97
98 pub fn query(&self, filter: &ModelFilter) -> Vec<&ModelSchema> {
100 self.models
101 .values()
102 .filter(|m| {
103 if !filter.capabilities.iter().all(|c| m.has_capability(*c)) {
105 return false;
106 }
107 if let Some(max) = filter.max_size_mb {
109 if m.size_mb() > max && m.is_local() {
110 return false;
111 }
112 }
113 if let Some(max) = filter.max_latency_ms {
115 if let Some(p50) = m.performance.latency_p50_ms {
116 if p50 > max {
117 return false;
118 }
119 }
120 }
121 if let Some(max) = filter.max_cost_per_mtok {
123 if let Some(cost) = m.cost.output_per_mtok {
124 if cost > max {
125 return false;
126 }
127 }
128 }
129 if !filter.tags.iter().all(|t| m.tags.contains(t)) {
131 return false;
132 }
133 if let Some(ref p) = filter.provider {
135 if &m.provider != p {
136 return false;
137 }
138 }
139 if filter.local_only && !m.is_local() {
141 return false;
142 }
143 if filter.available_only && !m.available {
145 return false;
146 }
147 true
148 })
149 .collect()
150 }
151
152 pub fn query_by_capability(&self, cap: ModelCapability) -> Vec<&ModelSchema> {
154 self.query(&ModelFilter {
155 capabilities: vec![cap],
156 ..Default::default()
157 })
158 }
159
160 pub fn get(&self, id: &str) -> Option<&ModelSchema> {
162 self.models.get(id)
163 }
164
165 pub fn find_by_name(&self, name: &str) -> Option<&ModelSchema> {
168 self.models.values().find(|m| m.name.eq_ignore_ascii_case(name))
169 }
170
171 pub async fn ensure_local(&self, id: &str) -> Result<PathBuf, InferenceError> {
173 let schema = self.get(id)
174 .or_else(|| self.find_by_name(id))
175 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
176
177 match &schema.source {
178 ModelSource::Local { hf_repo, hf_filename, tokenizer_repo } => {
179 let model_dir = self.models_dir.join(&schema.name);
180 let model_path = model_dir.join("model.gguf");
181 let tokenizer_path = model_dir.join("tokenizer.json");
182
183 if model_path.exists() && tokenizer_path.exists() {
184 return Ok(model_dir);
185 }
186
187 std::fs::create_dir_all(&model_dir)?;
188
189 if !model_path.exists() {
190 info!(model = %schema.name, repo = %hf_repo, "downloading model weights");
191 download_file(hf_repo, hf_filename, &model_path).await?;
192 }
193 if !tokenizer_path.exists() {
194 info!(model = %schema.name, repo = %tokenizer_repo, "downloading tokenizer");
195 download_file(tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
196 }
197
198 Ok(model_dir)
199 }
200 _ => Err(InferenceError::InferenceFailed(format!(
201 "model {} is not local", id
202 ))),
203 }
204 }
205
206 pub fn remove_local(&mut self, id: &str) -> Result<(), InferenceError> {
208 let schema = self.get(id)
209 .or_else(|| self.find_by_name(id))
210 .ok_or_else(|| InferenceError::ModelNotFound(id.to_string()))?;
211
212 let model_dir = self.models_dir.join(&schema.name);
213 if model_dir.exists() {
214 std::fs::remove_dir_all(&model_dir)?;
215 info!(model = %schema.name, "removed model");
216 }
217
218 let id = schema.id.clone();
220 if let Some(m) = self.models.get_mut(&id) {
221 m.available = false;
222 }
223 Ok(())
224 }
225
226 pub fn refresh_availability(&mut self) {
228 let models_dir = self.models_dir.clone();
229 for m in self.models.values_mut() {
230 match &m.source {
231 ModelSource::Local { .. } => {
232 let local_path = models_dir.join(&m.name).join("model.gguf");
233 m.available = local_path.exists();
234 }
235 ModelSource::RemoteApi { api_key_env, .. } => {
236 m.available = std::env::var(api_key_env).is_ok();
237 }
238 ModelSource::Ollama { .. } => {
239 m.available = true;
241 }
242 ModelSource::Proprietary { auth, .. } => {
243 m.available = match auth {
245 crate::schema::ProprietaryAuth::ApiKeyEnv { env_var } => {
246 std::env::var(env_var).is_ok()
247 }
248 crate::schema::ProprietaryAuth::BearerTokenEnv { env_var } => {
249 std::env::var(env_var).is_ok()
250 }
251 crate::schema::ProprietaryAuth::OAuth2Pkce { .. } => {
252 true
254 }
255 };
256 }
257 }
258 }
259 }
260
261 pub fn save_user_config(&self) -> Result<(), InferenceError> {
263 let user_models: Vec<&ModelSchema> = self.models.values()
264 .filter(|m| !m.tags.contains(&"builtin".to_string()))
265 .collect();
266
267 if user_models.is_empty() {
268 return Ok(());
269 }
270
271 let json = serde_json::to_string_pretty(&user_models)
272 .map_err(|e| InferenceError::InferenceFailed(format!("serialize: {e}")))?;
273 std::fs::write(&self.user_config_path, json)?;
274 Ok(())
275 }
276
277 pub fn load_user_config(&mut self) -> Result<(), InferenceError> {
279 if !self.user_config_path.exists() {
280 return Ok(());
281 }
282
283 let json = std::fs::read_to_string(&self.user_config_path)?;
284 let models: Vec<ModelSchema> = serde_json::from_str(&json)
285 .map_err(|e| InferenceError::InferenceFailed(format!("parse models.json: {e}")))?;
286
287 for m in models {
288 self.register(m);
289 }
290 Ok(())
291 }
292
293 pub fn models_dir(&self) -> &Path {
295 &self.models_dir
296 }
297
298 fn load_builtin_catalog(&mut self) {
300 for schema in builtin_catalog() {
301 self.models.insert(schema.id.clone(), schema);
302 }
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ModelInfo {
309 pub id: String,
310 pub name: String,
311 pub provider: String,
312 pub capabilities: Vec<ModelCapability>,
313 pub param_count: String,
314 pub size_mb: u64,
315 pub context_length: usize,
316 pub available: bool,
317 pub is_local: bool,
318}
319
320impl From<&ModelSchema> for ModelInfo {
321 fn from(s: &ModelSchema) -> Self {
322 ModelInfo {
323 id: s.id.clone(),
324 name: s.name.clone(),
325 provider: s.provider.clone(),
326 capabilities: s.capabilities.clone(),
327 param_count: s.param_count.clone(),
328 size_mb: s.size_mb(),
329 context_length: s.context_length,
330 available: s.available,
331 is_local: s.is_local(),
332 }
333 }
334}
335
336async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
338 let api = hf_hub::api::tokio::Api::new()
339 .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
340
341 let repo = api.model(repo.to_string());
342 let path = repo
343 .get(filename)
344 .await
345 .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
346
347 if dest.exists() {
348 return Ok(());
349 }
350
351 #[cfg(unix)]
353 {
354 if std::os::unix::fs::symlink(&path, dest).is_ok() {
355 return Ok(());
356 }
357 }
358
359 std::fs::copy(&path, dest)
360 .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
361 Ok(())
362}
363
364fn builtin_catalog() -> Vec<ModelSchema> {
366 vec![
367 ModelSchema {
368 id: "qwen/qwen3-embedding-0.6b:q8_0".into(),
369 name: "Qwen3-Embedding-0.6B".into(),
370 provider: "qwen".into(),
371 family: "qwen3".into(),
372 version: "1.0".into(),
373 capabilities: vec![ModelCapability::Embed],
374 context_length: 8192,
375 param_count: "0.6B".into(),
376 quantization: Some("Q8_0".into()),
377 performance: PerformanceEnvelope::default(),
378 cost: CostModel {
379 size_mb: Some(639),
380 ram_mb: Some(639),
381 ..Default::default()
382 },
383 source: ModelSource::Local {
384 hf_repo: "Qwen/Qwen3-Embedding-0.6B-GGUF".into(),
385 hf_filename: "Qwen3-Embedding-0.6B-Q8_0.gguf".into(),
386 tokenizer_repo: "Qwen/Qwen3-Embedding-0.6B".into(),
387 },
388 tags: vec!["builtin".into(), "embedding".into()],
389 available: false,
390 },
391 ModelSchema {
392 id: "qwen/qwen3-0.6b:q8_0".into(),
393 name: "Qwen3-0.6B".into(),
394 provider: "qwen".into(),
395 family: "qwen3".into(),
396 version: "1.0".into(),
397 capabilities: vec![ModelCapability::Generate, ModelCapability::Classify],
398 context_length: 32768,
399 param_count: "0.6B".into(),
400 quantization: Some("Q8_0".into()),
401 performance: PerformanceEnvelope {
402 tokens_per_second: Some(100.0),
403 ..Default::default()
404 },
405 cost: CostModel {
406 size_mb: Some(650),
407 ram_mb: Some(650),
408 ..Default::default()
409 },
410 source: ModelSource::Local {
411 hf_repo: "Qwen/Qwen3-0.6B-GGUF".into(),
412 hf_filename: "Qwen3-0.6B-Q8_0.gguf".into(),
413 tokenizer_repo: "Qwen/Qwen3-0.6B".into(),
414 },
415 tags: vec!["builtin".into(), "fast".into()],
416 available: false,
417 },
418 ModelSchema {
419 id: "qwen/qwen3-1.7b:q8_0".into(),
420 name: "Qwen3-1.7B".into(),
421 provider: "qwen".into(),
422 family: "qwen3".into(),
423 version: "1.0".into(),
424 capabilities: vec![ModelCapability::Generate, ModelCapability::Code],
425 context_length: 32768,
426 param_count: "1.7B".into(),
427 quantization: Some("Q8_0".into()),
428 performance: PerformanceEnvelope {
429 tokens_per_second: Some(70.0),
430 ..Default::default()
431 },
432 cost: CostModel {
433 size_mb: Some(1800),
434 ram_mb: Some(1800),
435 ..Default::default()
436 },
437 source: ModelSource::Local {
438 hf_repo: "Qwen/Qwen3-1.7B-GGUF".into(),
439 hf_filename: "Qwen3-1.7B-Q8_0.gguf".into(),
440 tokenizer_repo: "Qwen/Qwen3-1.7B".into(),
441 },
442 tags: vec!["builtin".into()],
443 available: false,
444 },
445 ModelSchema {
446 id: "qwen/qwen3-4b:q4_k_m".into(),
447 name: "Qwen3-4B".into(),
448 provider: "qwen".into(),
449 family: "qwen3".into(),
450 version: "1.0".into(),
451 capabilities: vec![ModelCapability::Generate, ModelCapability::Code, ModelCapability::Reasoning],
452 context_length: 32768,
453 param_count: "4B".into(),
454 quantization: Some("Q4_K_M".into()),
455 performance: PerformanceEnvelope {
456 tokens_per_second: Some(45.0),
457 ..Default::default()
458 },
459 cost: CostModel {
460 size_mb: Some(2500),
461 ram_mb: Some(2500),
462 ..Default::default()
463 },
464 source: ModelSource::Local {
465 hf_repo: "Qwen/Qwen3-4B-GGUF".into(),
466 hf_filename: "Qwen3-4B-Q4_K_M.gguf".into(),
467 tokenizer_repo: "Qwen/Qwen3-4B".into(),
468 },
469 tags: vec!["builtin".into(), "code".into()],
470 available: false,
471 },
472 ModelSchema {
473 id: "qwen/qwen3-8b:q4_k_m".into(),
474 name: "Qwen3-8B".into(),
475 provider: "qwen".into(),
476 family: "qwen3".into(),
477 version: "1.0".into(),
478 capabilities: vec![
479 ModelCapability::Generate,
480 ModelCapability::Code,
481 ModelCapability::Reasoning,
482 ModelCapability::Summarize,
483 ],
484 context_length: 131072,
485 param_count: "8B".into(),
486 quantization: Some("Q4_K_M".into()),
487 performance: PerformanceEnvelope {
488 tokens_per_second: Some(25.0),
489 ..Default::default()
490 },
491 cost: CostModel {
492 size_mb: Some(4900),
493 ram_mb: Some(4900),
494 ..Default::default()
495 },
496 source: ModelSource::Local {
497 hf_repo: "Qwen/Qwen3-8B-GGUF".into(),
498 hf_filename: "Qwen3-8B-Q4_K_M.gguf".into(),
499 tokenizer_repo: "Qwen/Qwen3-8B".into(),
500 },
501 tags: vec!["builtin".into(), "reasoning".into()],
502 available: false,
503 },
504 ModelSchema {
505 id: "qwen/qwen3-30b-a3b:q4_k_m".into(),
506 name: "Qwen3-30B-A3B".into(),
507 provider: "qwen".into(),
508 family: "qwen3".into(),
509 version: "1.0".into(),
510 capabilities: vec![
511 ModelCapability::Generate,
512 ModelCapability::Code,
513 ModelCapability::Reasoning,
514 ModelCapability::Summarize,
515 ModelCapability::ToolUse,
516 ],
517 context_length: 131072,
518 param_count: "30B (3B active)".into(),
519 quantization: Some("Q4_K_M".into()),
520 performance: PerformanceEnvelope {
521 tokens_per_second: Some(35.0), ..Default::default()
523 },
524 cost: CostModel {
525 size_mb: Some(17000),
526 ram_mb: Some(17000),
527 ..Default::default()
528 },
529 source: ModelSource::Local {
530 hf_repo: "Qwen/Qwen3-30B-A3B-GGUF".into(),
531 hf_filename: "Qwen3-30B-A3B-Q4_K_M.gguf".into(),
532 tokenizer_repo: "Qwen/Qwen3-30B-A3B".into(),
533 },
534 tags: vec!["builtin".into(), "moe".into(), "reasoning".into()],
535 available: false,
536 },
537 ModelSchema {
539 id: "anthropic/claude-opus-4-6:latest".into(),
540 name: "claude-opus-4-6".into(),
541 provider: "anthropic".into(),
542 family: "claude-4".into(),
543 version: "latest".into(),
544 capabilities: vec![ModelCapability::Generate, ModelCapability::Code, ModelCapability::Reasoning],
545 context_length: 200000,
546 param_count: String::new(),
547 quantization: None,
548 performance: PerformanceEnvelope {
549 latency_p50_ms: Some(3000),
550 ..Default::default()
551 },
552 cost: CostModel {
553 input_per_mtok: Some(15.0),
554 output_per_mtok: Some(75.0),
555 ..Default::default()
556 },
557 source: ModelSource::RemoteApi {
558 endpoint: "https://api.anthropic.com/v1/messages".into(),
559 api_key_env: "ANTHROPIC_API_KEY".into(),
560 api_version: Some("2023-06-01".into()),
561 protocol: ApiProtocol::Anthropic,
562 },
563 tags: vec!["builtin".into(), "frontier".into()],
564 available: false,
565 },
566 ModelSchema {
567 id: "anthropic/claude-sonnet-4-6:latest".into(),
568 name: "claude-sonnet-4-6".into(),
569 provider: "anthropic".into(),
570 family: "claude-4".into(),
571 version: "latest".into(),
572 capabilities: vec![ModelCapability::Generate, ModelCapability::Code, ModelCapability::Reasoning],
573 context_length: 200000,
574 param_count: String::new(),
575 quantization: None,
576 performance: PerformanceEnvelope {
577 latency_p50_ms: Some(1500),
578 ..Default::default()
579 },
580 cost: CostModel {
581 input_per_mtok: Some(3.0),
582 output_per_mtok: Some(15.0),
583 ..Default::default()
584 },
585 source: ModelSource::RemoteApi {
586 endpoint: "https://api.anthropic.com/v1/messages".into(),
587 api_key_env: "ANTHROPIC_API_KEY".into(),
588 api_version: Some("2023-06-01".into()),
589 protocol: ApiProtocol::Anthropic,
590 },
591 tags: vec!["builtin".into(), "fast".into()],
592 available: false,
593 },
594 ModelSchema {
595 id: "openai/gpt-5.3-codex:latest".into(),
596 name: "gpt-5.3-codex".into(),
597 provider: "openai".into(),
598 family: "gpt-5".into(),
599 version: "latest".into(),
600 capabilities: vec![ModelCapability::Generate, ModelCapability::Code, ModelCapability::Reasoning],
601 context_length: 128000,
602 param_count: String::new(),
603 quantization: None,
604 performance: PerformanceEnvelope {
605 latency_p50_ms: Some(2000),
606 ..Default::default()
607 },
608 cost: CostModel {
609 input_per_mtok: Some(2.0),
610 output_per_mtok: Some(10.0),
611 ..Default::default()
612 },
613 source: ModelSource::RemoteApi {
614 endpoint: "https://api.openai.com/v1/chat/completions".into(),
615 api_key_env: "OPENAI_API_KEY".into(),
616 api_version: None,
617 protocol: ApiProtocol::OpenAiCompat,
618 },
619 tags: vec!["builtin".into(), "frontier".into(), "code".into()],
620 available: false,
621 },
622 ModelSchema {
623 id: "openai/gpt-4o-mini:latest".into(),
624 name: "gpt-4o-mini".into(),
625 provider: "openai".into(),
626 family: "gpt-4".into(),
627 version: "latest".into(),
628 capabilities: vec![ModelCapability::Generate, ModelCapability::Code],
629 context_length: 128000,
630 param_count: String::new(),
631 quantization: None,
632 performance: PerformanceEnvelope {
633 latency_p50_ms: Some(800),
634 ..Default::default()
635 },
636 cost: CostModel {
637 input_per_mtok: Some(0.15),
638 output_per_mtok: Some(0.6),
639 ..Default::default()
640 },
641 source: ModelSource::RemoteApi {
642 endpoint: "https://api.openai.com/v1/chat/completions".into(),
643 api_key_env: "OPENAI_API_KEY".into(),
644 api_version: None,
645 protocol: ApiProtocol::OpenAiCompat,
646 },
647 tags: vec!["builtin".into(), "fast".into()],
648 available: false,
649 },
650 ]
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use tempfile::TempDir;
657
658 fn test_registry() -> (UnifiedRegistry, TempDir) {
659 let tmp = TempDir::new().unwrap();
660 let reg = UnifiedRegistry::new(tmp.path().join("models"));
661 (reg, tmp)
662 }
663
664 #[test]
665 fn builtin_catalog_loads() {
666 let (reg, _tmp) = test_registry();
667 let all = reg.list();
668 assert_eq!(all.len(), 10);
669 }
670
671 #[test]
672 fn find_by_name() {
673 let (reg, _tmp) = test_registry();
674 let m = reg.find_by_name("Qwen3-4B").unwrap();
675 assert_eq!(m.id, "qwen/qwen3-4b:q4_k_m");
676 assert!(m.has_capability(ModelCapability::Code));
677 }
678
679 #[test]
680 fn query_by_capability() {
681 let (reg, _tmp) = test_registry();
682 let embed_models = reg.query_by_capability(ModelCapability::Embed);
683 assert_eq!(embed_models.len(), 1);
684 assert_eq!(embed_models[0].name, "Qwen3-Embedding-0.6B");
685 }
686
687 #[test]
688 fn query_with_filter() {
689 let (reg, _tmp) = test_registry();
690 let code_small = reg.query(&ModelFilter {
691 capabilities: vec![ModelCapability::Code],
692 max_size_mb: Some(3000),
693 local_only: true,
694 ..Default::default()
695 });
696 assert_eq!(code_small.len(), 2);
698 }
699
700 #[test]
701 fn register_remote() {
702 let (mut reg, _tmp) = test_registry();
703 let remote = ModelSchema {
704 id: "anthropic/claude-sonnet-4-6:latest".into(),
705 name: "Claude Sonnet 4.6".into(),
706 provider: "anthropic".into(),
707 family: "claude-4".into(),
708 version: "latest".into(),
709 capabilities: vec![ModelCapability::Generate, ModelCapability::Code, ModelCapability::Reasoning],
710 context_length: 200000,
711 param_count: String::new(),
712 quantization: None,
713 performance: PerformanceEnvelope {
714 latency_p50_ms: Some(2000),
715 ..Default::default()
716 },
717 cost: CostModel {
718 input_per_mtok: Some(3.0),
719 output_per_mtok: Some(15.0),
720 ..Default::default()
721 },
722 source: ModelSource::RemoteApi {
723 endpoint: "https://api.anthropic.com/v1/messages".into(),
724 api_key_env: "ANTHROPIC_API_KEY".into(),
725 api_version: Some("2023-06-01".into()),
726 protocol: ApiProtocol::Anthropic,
727 },
728 tags: vec![],
729 available: false,
730 };
731
732 reg.register(remote);
733 assert_eq!(reg.list().len(), 10);
735
736 let reasoning = reg.query(&ModelFilter {
737 capabilities: vec![ModelCapability::Reasoning],
738 ..Default::default()
739 });
740 assert_eq!(reasoning.len(), 6);
742 }
743
744 #[test]
745 fn unregister() {
746 let (mut reg, _tmp) = test_registry();
747 let removed = reg.unregister("qwen/qwen3-0.6b:q8_0");
748 assert!(removed.is_some());
749 assert_eq!(reg.list().len(), 9);
750 }
751}