1use std::collections::HashMap;
32use std::path::{Path, PathBuf};
33use std::sync::OnceLock;
34
35use anyhow::{anyhow, Result};
36use serde::{Deserialize, Serialize};
37
38pub(crate) const MODELS_YAML: &str = include_str!("../templates/models.yaml");
40
41pub const MODELS_SCHEMA_VERSION: &str = "1";
47
48pub const OMNI_DEV_MODELS_YAML_ENV: &str = "OMNI_DEV_MODELS_YAML";
51
52const FALLBACK_MAX_OUTPUT_TOKENS: usize = 4096;
54
55const FALLBACK_INPUT_CONTEXT: usize = 100_000;
57
58#[derive(
60 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default,
61)]
62#[serde(rename_all = "lowercase")]
63pub enum ModelSource {
64 #[default]
66 Embedded,
67 User,
69 Project,
71 Override,
73}
74
75impl std::fmt::Display for ModelSource {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.write_str(match self {
78 Self::Embedded => "embedded",
79 Self::User => "user",
80 Self::Project => "project",
81 Self::Override => "override",
82 })
83 }
84}
85
86#[derive(Debug, Deserialize, Serialize, Clone)]
100pub struct BetaHeader {
101 pub key: String,
103 pub value: String,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
107 pub max_output_tokens: Option<usize>,
108 #[serde(default, skip_serializing_if = "Option::is_none")]
110 pub input_context: Option<usize>,
111}
112
113#[derive(Debug, Deserialize, Serialize, Clone)]
140pub struct ModelSpec {
141 pub provider: String,
143 pub model: String,
145 pub api_identifier: String,
147 pub max_output_tokens: usize,
149 pub input_context: usize,
151 pub generation: f32,
153 pub tier: String,
155 #[serde(default)]
157 pub legacy: bool,
158 #[serde(default, skip_serializing_if = "Vec::is_empty")]
160 pub beta_headers: Vec<BetaHeader>,
161 #[serde(default, skip_deserializing)]
164 pub source: ModelSource,
165}
166
167#[derive(Debug, Deserialize, Serialize, Clone)]
177pub struct TierInfo {
178 pub description: String,
180 pub use_cases: Vec<String>,
182}
183
184#[derive(Debug, Deserialize, Serialize, Clone)]
195pub struct DefaultConfig {
196 pub max_output_tokens: usize,
198 pub input_context: usize,
200}
201
202#[derive(Debug, Deserialize, Serialize, Clone)]
214pub struct ProviderConfig {
215 pub name: String,
217 pub api_base: String,
219 pub default_model: String,
221 pub tiers: HashMap<String, TierInfo>,
223 pub defaults: DefaultConfig,
225 #[serde(default, skip_deserializing)]
227 pub source: ModelSource,
228}
229
230#[derive(Debug, Deserialize, Serialize, Clone)]
243pub struct ModelConfiguration {
244 #[serde(default, skip_serializing_if = "Option::is_none")]
246 pub version: Option<String>,
247 pub models: Vec<ModelSpec>,
249 pub providers: HashMap<String, ProviderConfig>,
251}
252
253pub struct ModelRegistry {
263 config: ModelConfiguration,
264 by_identifier: HashMap<String, ModelSpec>,
265 by_provider: HashMap<String, Vec<ModelSpec>>,
266}
267
268impl ModelRegistry {
269 pub fn load() -> Result<Self> {
282 let override_path = std::env::var(OMNI_DEV_MODELS_YAML_ENV)
283 .ok()
284 .filter(|s| !s.is_empty())
285 .map(PathBuf::from);
286 let project_path = default_project_path();
287 let user_path = default_user_path();
288 Self::load_layered_from_paths(
289 project_path.as_deref(),
290 user_path.as_deref(),
291 override_path.as_deref(),
292 )
293 }
294
295 pub fn load_layered_from_paths(
299 project_path: Option<&Path>,
300 user_path: Option<&Path>,
301 override_path: Option<&Path>,
302 ) -> Result<Self> {
303 let mut layers: Vec<(ModelSource, String)> = Vec::new();
304 layers.push((ModelSource::Embedded, MODELS_YAML.to_string()));
305
306 if let Some(path) = override_path {
307 match read_optional_yaml(path) {
308 Some(yaml) => layers.push((ModelSource::Override, yaml)),
309 None => {
310 tracing::warn!(
311 "{OMNI_DEV_MODELS_YAML_ENV} points at {} but the file is missing or unreadable; falling back to embedded catalog",
312 path.display()
313 );
314 }
315 }
316 } else {
317 if let Some(path) = user_path {
318 if let Some(yaml) = read_optional_yaml(path) {
319 layers.push((ModelSource::User, yaml));
320 }
321 }
322 if let Some(path) = project_path {
323 if let Some(yaml) = read_optional_yaml(path) {
324 layers.push((ModelSource::Project, yaml));
325 }
326 }
327 }
328
329 Self::from_layers(&layers)
330 }
331
332 pub(crate) fn from_layers(layers: &[(ModelSource, String)]) -> Result<Self> {
338 let mut merged: serde_yaml::Value =
339 serde_yaml::Value::Mapping(serde_yaml::Mapping::default());
340 let mut model_sources: HashMap<String, ModelSource> = HashMap::new();
341 let mut provider_sources: HashMap<String, ModelSource> = HashMap::new();
342 let mut declared_versions: Vec<(ModelSource, Option<String>)> = Vec::new();
343
344 for (source, yaml) in layers {
345 let value: serde_yaml::Value = match serde_yaml::from_str(yaml) {
346 Ok(v) => v,
347 Err(e) => {
348 if matches!(source, ModelSource::Embedded) {
349 return Err(anyhow!(
350 "Embedded models.yaml is malformed at compile time: {e}"
351 ));
352 }
353 tracing::error!(
354 "Malformed {source} models.yaml: {e}. Falling through to lower-precedence layers."
355 );
356 continue;
357 }
358 };
359
360 let version = value
362 .get("version")
363 .and_then(|v| v.as_str())
364 .map(String::from);
365 declared_versions.push((*source, version));
366
367 merge_layer_into(
368 &mut merged,
369 value,
370 *source,
371 &mut model_sources,
372 &mut provider_sources,
373 );
374 }
375
376 warn_on_version_mismatch(&declared_versions);
377
378 let mut config: ModelConfiguration = serde_yaml::from_value(merged)
379 .map_err(|e| anyhow!("Failed to deserialize merged model configuration: {e}"))?;
380
381 for spec in &mut config.models {
382 spec.source = model_sources
383 .get(&spec.api_identifier)
384 .copied()
385 .unwrap_or_default();
386 }
387 for (name, prov) in &mut config.providers {
388 prov.source = provider_sources.get(name).copied().unwrap_or_default();
389 }
390
391 let mut by_identifier = HashMap::new();
392 let mut by_provider: HashMap<String, Vec<ModelSpec>> = HashMap::new();
393 for model in &config.models {
394 by_identifier.insert(model.api_identifier.clone(), model.clone());
395 by_provider
396 .entry(model.provider.clone())
397 .or_default()
398 .push(model.clone());
399 }
400
401 Ok(Self {
402 config,
403 by_identifier,
404 by_provider,
405 })
406 }
407
408 #[must_use]
410 pub fn config(&self) -> &ModelConfiguration {
411 &self.config
412 }
413
414 #[must_use]
416 pub fn get_model_spec(&self, api_identifier: &str) -> Option<&ModelSpec> {
417 if let Some(spec) = self.by_identifier.get(api_identifier) {
419 return Some(spec);
420 }
421
422 self.find_model_by_normalized_id(api_identifier)
424 }
425
426 #[must_use]
428 pub fn get_max_output_tokens(&self, api_identifier: &str) -> usize {
429 if let Some(spec) = self.get_model_spec(api_identifier) {
430 return spec.max_output_tokens;
431 }
432
433 if let Some(provider) = self.infer_provider(api_identifier) {
435 if let Some(provider_config) = self.config.providers.get(&provider) {
436 return provider_config.defaults.max_output_tokens;
437 }
438 }
439
440 FALLBACK_MAX_OUTPUT_TOKENS
442 }
443
444 #[must_use]
446 pub fn get_input_context(&self, api_identifier: &str) -> usize {
447 if let Some(spec) = self.get_model_spec(api_identifier) {
448 return spec.input_context;
449 }
450
451 if let Some(provider) = self.infer_provider(api_identifier) {
453 if let Some(provider_config) = self.config.providers.get(&provider) {
454 return provider_config.defaults.input_context;
455 }
456 }
457
458 FALLBACK_INPUT_CONTEXT
460 }
461
462 fn infer_provider(&self, api_identifier: &str) -> Option<String> {
464 if api_identifier.starts_with("claude") || api_identifier.contains("anthropic") {
465 Some("claude".to_string())
466 } else {
467 None
468 }
469 }
470
471 fn find_model_by_normalized_id(&self, api_identifier: &str) -> Option<&ModelSpec> {
476 let core_identifier = self.extract_core_model_identifier(api_identifier);
477 self.by_identifier.get(&core_identifier)
478 }
479
480 fn extract_core_model_identifier(&self, api_identifier: &str) -> String {
482 let mut identifier = api_identifier.to_string();
483
484 if let Some(dot_pos) = identifier.find('.') {
486 if identifier[..dot_pos].len() <= 3 {
487 identifier = identifier[dot_pos + 1..].to_string();
489 }
490 }
491
492 if identifier.starts_with("anthropic.") {
494 identifier = identifier["anthropic.".len()..].to_string();
495 }
496
497 if let Some(version_pos) = identifier.rfind("-v") {
499 if identifier[version_pos..].contains(':') {
500 identifier = identifier[..version_pos].to_string();
501 }
502 }
503
504 identifier
505 }
506
507 #[must_use]
509 pub fn is_legacy_model(&self, api_identifier: &str) -> bool {
510 self.get_model_spec(api_identifier)
511 .is_some_and(|spec| spec.legacy)
512 }
513
514 #[must_use]
516 pub fn get_all_models(&self) -> &[ModelSpec] {
517 &self.config.models
518 }
519
520 #[must_use]
522 pub fn get_models_by_provider(&self, provider: &str) -> Vec<&ModelSpec> {
523 self.by_provider
524 .get(provider)
525 .map(|models| models.iter().collect())
526 .unwrap_or_default()
527 }
528
529 #[must_use]
531 pub fn get_models_by_provider_and_tier(&self, provider: &str, tier: &str) -> Vec<&ModelSpec> {
532 self.get_models_by_provider(provider)
533 .into_iter()
534 .filter(|model| model.tier == tier)
535 .collect()
536 }
537
538 #[must_use]
540 pub fn get_default_model(&self, provider: &str) -> Option<&str> {
541 self.config
542 .providers
543 .get(provider)
544 .map(|p| p.default_model.as_str())
545 }
546
547 #[must_use]
549 pub fn get_provider_config(&self, provider: &str) -> Option<&ProviderConfig> {
550 self.config.providers.get(provider)
551 }
552
553 #[must_use]
555 pub fn get_tier_info(&self, provider: &str, tier: &str) -> Option<&TierInfo> {
556 self.config.providers.get(provider)?.tiers.get(tier)
557 }
558
559 #[must_use]
561 pub fn get_beta_headers(&self, api_identifier: &str) -> &[BetaHeader] {
562 self.get_model_spec(api_identifier)
563 .map(|spec| spec.beta_headers.as_slice())
564 .unwrap_or_default()
565 }
566
567 #[must_use]
569 pub fn get_max_output_tokens_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
570 if let Some(spec) = self.get_model_spec(api_identifier) {
571 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
572 if let Some(max) = bh.max_output_tokens {
573 return max;
574 }
575 }
576 return spec.max_output_tokens;
577 }
578 self.get_max_output_tokens(api_identifier)
579 }
580
581 #[must_use]
583 pub fn get_input_context_with_beta(&self, api_identifier: &str, beta_value: &str) -> usize {
584 if let Some(spec) = self.get_model_spec(api_identifier) {
585 if let Some(bh) = spec.beta_headers.iter().find(|b| b.value == beta_value) {
586 if let Some(ctx) = bh.input_context {
587 return ctx;
588 }
589 }
590 return spec.input_context;
591 }
592 self.get_input_context(api_identifier)
593 }
594}
595
596fn default_project_path() -> Option<PathBuf> {
598 std::env::current_dir()
599 .ok()
600 .map(|cwd| cwd.join(".omni-dev").join("models.yaml"))
601}
602
603fn default_user_path() -> Option<PathBuf> {
605 dirs::home_dir().map(|h| h.join(".omni-dev").join("models.yaml"))
606}
607
608fn read_optional_yaml(path: &Path) -> Option<String> {
611 if !path.exists() {
612 return None;
613 }
614 match std::fs::read_to_string(path) {
615 Ok(s) => Some(s),
616 Err(e) => {
617 tracing::error!(
618 "Failed to read {}: {e}. Falling through to lower-precedence layers.",
619 path.display()
620 );
621 None
622 }
623 }
624}
625
626fn merge_layer_into(
637 dest: &mut serde_yaml::Value,
638 src: serde_yaml::Value,
639 source: ModelSource,
640 model_sources: &mut HashMap<String, ModelSource>,
641 provider_sources: &mut HashMap<String, ModelSource>,
642) {
643 use serde_yaml::Value;
644
645 let Value::Mapping(src_map) = src else {
646 *dest = src;
650 return;
651 };
652
653 if !matches!(dest, Value::Mapping(_)) {
654 *dest = Value::Mapping(serde_yaml::Mapping::new());
655 }
656 let Value::Mapping(dest_map) = dest else {
657 unreachable!("dest is a mapping after the check above");
658 };
659
660 for (k, v) in src_map {
661 match k.as_str() {
662 Some("models") => merge_models_into(dest_map, k, v, source, model_sources),
663 Some("providers") => merge_providers_into(dest_map, k, v, source, provider_sources),
664 _ => {
665 dest_map.insert(k, v);
666 }
667 }
668 }
669}
670
671fn merge_models_into(
672 dest_map: &mut serde_yaml::Mapping,
673 key: serde_yaml::Value,
674 incoming: serde_yaml::Value,
675 source: ModelSource,
676 model_sources: &mut HashMap<String, ModelSource>,
677) {
678 use serde_yaml::Value;
679
680 let Value::Sequence(incoming_seq) = incoming else {
681 dest_map.insert(key, incoming);
683 return;
684 };
685
686 let dest_value = dest_map
687 .entry(key)
688 .or_insert_with(|| Value::Sequence(Vec::new()));
689 if !matches!(dest_value, Value::Sequence(_)) {
690 *dest_value = Value::Sequence(Vec::new());
691 }
692 let Value::Sequence(dest_seq) = dest_value else {
693 unreachable!("dest is a sequence after the check above");
694 };
695
696 for entry in incoming_seq {
697 let api_id = entry
698 .get("api_identifier")
699 .and_then(|v| v.as_str())
700 .map(String::from);
701
702 let Some(api_id) = api_id else {
703 tracing::warn!(
704 "Skipping model entry without `api_identifier` from {source} models.yaml"
705 );
706 continue;
707 };
708
709 if let Some(existing) = dest_seq
710 .iter_mut()
711 .find(|e| e.get("api_identifier").and_then(serde_yaml::Value::as_str) == Some(&api_id))
712 {
713 deep_merge(existing, entry);
714 } else {
715 dest_seq.push(entry);
716 }
717
718 model_sources.insert(api_id, source);
719 }
720}
721
722fn merge_providers_into(
723 dest_map: &mut serde_yaml::Mapping,
724 key: serde_yaml::Value,
725 incoming: serde_yaml::Value,
726 source: ModelSource,
727 provider_sources: &mut HashMap<String, ModelSource>,
728) {
729 use serde_yaml::Value;
730
731 let Value::Mapping(incoming_providers) = incoming else {
732 dest_map.insert(key, incoming);
733 return;
734 };
735
736 let dest_value = dest_map
737 .entry(key)
738 .or_insert_with(|| Value::Mapping(serde_yaml::Mapping::new()));
739 if !matches!(dest_value, Value::Mapping(_)) {
740 *dest_value = Value::Mapping(serde_yaml::Mapping::new());
741 }
742 let Value::Mapping(dest_providers) = dest_value else {
743 unreachable!("dest is a mapping after the check above");
744 };
745
746 for (pname, pvalue) in incoming_providers {
747 let pname_str = pname.as_str().map(String::from);
748
749 if let Some(existing) = dest_providers.get_mut(&pname) {
750 deep_merge(existing, pvalue);
751 } else {
752 dest_providers.insert(pname.clone(), pvalue);
753 }
754
755 if let Some(name) = pname_str {
756 provider_sources.insert(name, source);
757 }
758 }
759}
760
761fn deep_merge(dest: &mut serde_yaml::Value, src: serde_yaml::Value) {
764 use serde_yaml::Value;
765 match (dest, src) {
766 (Value::Mapping(d), Value::Mapping(s)) => {
767 for (k, v) in s {
768 if let Some(existing) = d.get_mut(&k) {
769 deep_merge(existing, v);
770 } else {
771 d.insert(k, v);
772 }
773 }
774 }
775 (d, s) => *d = s,
776 }
777}
778
779fn warn_on_version_mismatch(declared: &[(ModelSource, Option<String>)]) {
782 for (source, version) in declared {
783 if matches!(source, ModelSource::Embedded) {
784 continue;
785 }
786 match version {
787 None => {
788 tracing::warn!(
789 "{source} models.yaml has no `version:` field; assuming compatibility with schema version {MODELS_SCHEMA_VERSION}. Add `version: \"{MODELS_SCHEMA_VERSION}\"` to silence this warning."
790 );
791 }
792 Some(v) if v == MODELS_SCHEMA_VERSION => {}
793 Some(v) => {
794 tracing::warn!(
795 "{source} models.yaml declares schema version {v}; this build understands {MODELS_SCHEMA_VERSION}. Continuing — unrecognised fields may be ignored."
796 );
797 }
798 }
799 }
800}
801
802static MODEL_REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
804
805#[must_use]
807pub fn get_model_registry() -> &'static ModelRegistry {
808 #[allow(clippy::expect_used)] MODEL_REGISTRY.get_or_init(|| ModelRegistry::load().expect("Failed to load model registry"))
810}
811
812#[cfg(test)]
813#[allow(clippy::unwrap_used, clippy::expect_used)]
814mod tests {
815 use super::*;
816 use std::io::Write;
817
818 fn embedded_only() -> ModelRegistry {
819 ModelRegistry::load_layered_from_paths(None, None, None).unwrap()
820 }
821
822 fn write_yaml(dir: &Path, name: &str, contents: &str) -> PathBuf {
823 let path = dir.join(name);
824 let mut f = std::fs::File::create(&path).unwrap();
825 f.write_all(contents.as_bytes()).unwrap();
826 path
827 }
828
829 #[test]
830 fn load_model_registry() {
831 let registry = embedded_only();
832 assert!(!registry.config.models.is_empty());
833 assert!(registry.config.providers.contains_key("claude"));
834 assert_eq!(
835 registry.config.version.as_deref(),
836 Some(MODELS_SCHEMA_VERSION)
837 );
838 }
839
840 #[test]
841 fn claude_model_lookup() {
842 let registry = embedded_only();
843
844 let opus_spec = registry.get_model_spec("claude-3-opus-20240229");
846 assert!(opus_spec.is_some());
847 assert_eq!(opus_spec.unwrap().max_output_tokens, 4096);
848 assert_eq!(opus_spec.unwrap().provider, "claude");
849 assert!(registry.is_legacy_model("claude-3-opus-20240229"));
850
851 let sonnet45_tokens = registry.get_max_output_tokens("claude-sonnet-4-5-20250929");
853 assert_eq!(sonnet45_tokens, 64000);
854
855 let sonnet4_tokens = registry.get_max_output_tokens("claude-sonnet-4-20250514");
857 assert_eq!(sonnet4_tokens, 64000);
858 assert!(registry.is_legacy_model("claude-sonnet-4-20250514"));
859
860 let unknown_tokens = registry.get_max_output_tokens("claude-unknown-model");
862 assert_eq!(unknown_tokens, 4096); }
864
865 #[test]
866 fn unknown_provider_uses_ultimate_fallback() {
867 let registry = embedded_only();
868
869 assert_eq!(
871 registry.get_max_output_tokens("totally-unknown-vendor-x"),
872 FALLBACK_MAX_OUTPUT_TOKENS
873 );
874 assert_eq!(
875 registry.get_input_context("totally-unknown-vendor-x"),
876 FALLBACK_INPUT_CONTEXT
877 );
878 }
879
880 #[test]
881 fn provider_filtering() {
882 let registry = embedded_only();
883
884 let claude_models = registry.get_models_by_provider("claude");
885 assert!(!claude_models.is_empty());
886
887 let fast_claude_models = registry.get_models_by_provider_and_tier("claude", "fast");
888 assert!(!fast_claude_models.is_empty());
889
890 let tier_info = registry.get_tier_info("claude", "fast");
891 assert!(tier_info.is_some());
892 }
893
894 #[test]
895 fn provider_config() {
896 let registry = embedded_only();
897
898 let claude_config = registry.get_provider_config("claude");
899 assert!(claude_config.is_some());
900 assert_eq!(claude_config.unwrap().name, "Anthropic Claude");
901 }
902
903 #[test]
904 fn default_model_per_provider() {
905 let registry = embedded_only();
906
907 assert_eq!(
908 registry.get_default_model("claude"),
909 Some("claude-sonnet-4-6")
910 );
911 assert_eq!(registry.get_default_model("openai"), Some("gpt-5-mini"));
912 assert_eq!(
913 registry.get_default_model("gemini"),
914 Some("gemini-2.5-flash")
915 );
916 assert_eq!(registry.get_default_model("nonexistent"), None);
917 }
918
919 #[test]
920 fn normalized_id_matching() {
921 let registry = embedded_only();
922
923 let bedrock_3_7_sonnet = "us.anthropic.claude-3-7-sonnet-20250219-v1:0";
925 let spec = registry.get_model_spec(bedrock_3_7_sonnet);
926 assert!(spec.is_some());
927 assert_eq!(spec.unwrap().api_identifier, "claude-3-7-sonnet-20250219");
928 assert_eq!(spec.unwrap().max_output_tokens, 64000);
929
930 let aws_haiku = "anthropic.claude-3-haiku-20240307-v1:0";
932 let spec = registry.get_model_spec(aws_haiku);
933 assert!(spec.is_some());
934 assert_eq!(spec.unwrap().api_identifier, "claude-3-haiku-20240307");
935 assert_eq!(spec.unwrap().max_output_tokens, 4096);
936
937 let eu_opus = "eu.anthropic.claude-3-opus-20240229-v2:1";
939 let spec = registry.get_model_spec(eu_opus);
940 assert!(spec.is_some());
941 assert_eq!(spec.unwrap().api_identifier, "claude-3-opus-20240229");
942 assert_eq!(spec.unwrap().max_output_tokens, 4096);
943
944 let exact_sonnet45 = "claude-sonnet-4-5-20250929";
946 let spec = registry.get_model_spec(exact_sonnet45);
947 assert!(spec.is_some());
948 assert_eq!(spec.unwrap().max_output_tokens, 64000);
949
950 let exact_sonnet4 = "claude-sonnet-4-20250514";
952 let spec = registry.get_model_spec(exact_sonnet4);
953 assert!(spec.is_some());
954 assert_eq!(spec.unwrap().max_output_tokens, 64000);
955 }
956
957 #[test]
958 fn extract_core_model_identifier() {
959 let registry = embedded_only();
960
961 assert_eq!(
963 registry.extract_core_model_identifier("us.anthropic.claude-3-7-sonnet-20250219-v1:0"),
964 "claude-3-7-sonnet-20250219"
965 );
966
967 assert_eq!(
968 registry.extract_core_model_identifier("anthropic.claude-3-haiku-20240307-v1:0"),
969 "claude-3-haiku-20240307"
970 );
971
972 assert_eq!(
973 registry.extract_core_model_identifier("claude-3-opus-20240229"),
974 "claude-3-opus-20240229"
975 );
976
977 assert_eq!(
978 registry.extract_core_model_identifier("eu.anthropic.claude-sonnet-4-20250514-v2:1"),
979 "claude-sonnet-4-20250514"
980 );
981 }
982
983 #[test]
984 fn beta_header_lookups() {
985 let registry = embedded_only();
986
987 assert_eq!(registry.get_max_output_tokens("claude-opus-4-6"), 128_000);
989 assert_eq!(registry.get_input_context("claude-opus-4-6"), 200_000);
990
991 assert_eq!(
993 registry.get_input_context_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
994 1_000_000
995 );
996 assert_eq!(
998 registry.get_max_output_tokens_with_beta("claude-opus-4-6", "context-1m-2025-08-07"),
999 128_000
1000 );
1001
1002 assert_eq!(
1004 registry.get_max_output_tokens_with_beta(
1005 "claude-3-7-sonnet-20250219",
1006 "output-128k-2025-02-19"
1007 ),
1008 128_000
1009 );
1010
1011 assert_eq!(
1013 registry.get_max_output_tokens("claude-3-7-sonnet-20250219"),
1014 64000
1015 );
1016
1017 let headers = registry.get_beta_headers("claude-opus-4-6");
1019 assert_eq!(headers.len(), 1);
1020 assert_eq!(headers[0].key, "anthropic-beta");
1021 assert_eq!(headers[0].value, "context-1m-2025-08-07");
1022
1023 let headers = registry.get_beta_headers("claude-3-7-sonnet-20250219");
1025 assert_eq!(headers.len(), 2);
1026
1027 let headers = registry.get_beta_headers("claude-3-haiku-20240307");
1029 assert!(headers.is_empty());
1030
1031 let headers = registry.get_beta_headers("unknown-model");
1033 assert!(headers.is_empty());
1034 }
1035
1036 #[test]
1037 fn beta_lookups_for_unknown_model_fall_through_to_provider_defaults() {
1038 let registry = embedded_only();
1039
1040 assert_eq!(
1044 registry
1045 .get_max_output_tokens_with_beta("claude-unknown-model", "context-1m-2025-08-07"),
1046 4096
1047 );
1048 assert_eq!(
1049 registry.get_input_context_with_beta("claude-unknown-model", "context-1m-2025-08-07"),
1050 200_000
1051 );
1052 }
1053
1054 #[test]
1055 fn embedded_models_default_to_embedded_source() {
1056 let registry = embedded_only();
1057 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1058 assert_eq!(spec.source, ModelSource::Embedded);
1059
1060 let provider = registry.get_provider_config("claude").unwrap();
1061 assert_eq!(provider.source, ModelSource::Embedded);
1062 }
1063
1064 #[test]
1065 fn missing_user_and_project_files_fall_through_silently() {
1066 let dir = tempfile::tempdir().unwrap();
1067 let project_path = dir.path().join("missing-project.yaml");
1068 let user_path = dir.path().join("missing-user.yaml");
1069 let registry =
1070 ModelRegistry::load_layered_from_paths(Some(&project_path), Some(&user_path), None)
1071 .unwrap();
1072
1073 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1075 assert_eq!(spec.source, ModelSource::Embedded);
1076 assert_eq!(spec.max_output_tokens, 128_000);
1077 }
1078
1079 #[test]
1080 fn user_layer_overrides_embedded_entry() {
1081 let dir = tempfile::tempdir().unwrap();
1082 let user = write_yaml(
1083 dir.path(),
1084 "user.yaml",
1085 r#"
1086version: "1"
1087models:
1088 - provider: "claude"
1089 model: "Claude Opus 4.6 (custom)"
1090 api_identifier: "claude-opus-4-6"
1091 max_output_tokens: 999999
1092 input_context: 200000
1093 generation: 4.6
1094 tier: "flagship"
1095"#,
1096 );
1097
1098 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1099 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1100 assert_eq!(spec.max_output_tokens, 999_999);
1101 assert_eq!(spec.model, "Claude Opus 4.6 (custom)");
1102 assert_eq!(spec.source, ModelSource::User);
1103 }
1104
1105 #[test]
1106 fn project_layer_takes_precedence_over_user_layer() {
1107 let dir = tempfile::tempdir().unwrap();
1108 let user = write_yaml(
1109 dir.path(),
1110 "user.yaml",
1111 r#"
1112version: "1"
1113models:
1114 - provider: "claude"
1115 model: "From User"
1116 api_identifier: "claude-opus-4-6"
1117 max_output_tokens: 1
1118 input_context: 1
1119 generation: 4.6
1120 tier: "flagship"
1121"#,
1122 );
1123 let project = write_yaml(
1124 dir.path(),
1125 "project.yaml",
1126 r#"
1127version: "1"
1128models:
1129 - provider: "claude"
1130 model: "From Project"
1131 api_identifier: "claude-opus-4-6"
1132 max_output_tokens: 2
1133 input_context: 2
1134 generation: 4.6
1135 tier: "flagship"
1136"#,
1137 );
1138
1139 let registry =
1140 ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1141 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1142 assert_eq!(spec.model, "From Project");
1143 assert_eq!(spec.max_output_tokens, 2);
1144 assert_eq!(spec.source, ModelSource::Project);
1145 }
1146
1147 #[test]
1148 fn additive_user_entry_is_appended() {
1149 let dir = tempfile::tempdir().unwrap();
1150 let user = write_yaml(
1151 dir.path(),
1152 "user.yaml",
1153 r#"
1154version: "1"
1155models:
1156 - provider: "claude"
1157 model: "Claude Custom Future"
1158 api_identifier: "claude-future-9000"
1159 max_output_tokens: 250000
1160 input_context: 5000000
1161 generation: 9.0
1162 tier: "flagship"
1163"#,
1164 );
1165
1166 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1167 let spec = registry.get_model_spec("claude-future-9000").unwrap();
1168 assert_eq!(spec.max_output_tokens, 250_000);
1169 assert_eq!(spec.input_context, 5_000_000);
1170 assert_eq!(spec.source, ModelSource::User);
1171
1172 let opus = registry.get_model_spec("claude-opus-4-6").unwrap();
1174 assert_eq!(opus.source, ModelSource::Embedded);
1175 }
1176
1177 #[test]
1178 fn provider_fields_can_be_partially_overridden() {
1179 let dir = tempfile::tempdir().unwrap();
1180 let user = write_yaml(
1183 dir.path(),
1184 "user.yaml",
1185 r#"
1186version: "1"
1187providers:
1188 claude:
1189 default_model: "claude-opus-4-6"
1190"#,
1191 );
1192
1193 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1194 let claude = registry.get_provider_config("claude").unwrap();
1195 assert_eq!(claude.default_model, "claude-opus-4-6");
1196 assert_eq!(claude.name, "Anthropic Claude");
1198 assert_eq!(claude.api_base, "https://api.anthropic.com/v1");
1199 assert!(claude.tiers.contains_key("flagship"));
1200 assert_eq!(claude.source, ModelSource::User);
1202 }
1203
1204 #[test]
1205 fn malformed_user_yaml_logs_and_falls_through() {
1206 let dir = tempfile::tempdir().unwrap();
1207 let user = write_yaml(
1208 dir.path(),
1209 "user.yaml",
1210 "this: is: definitely: not: valid: yaml: [unbalanced",
1211 );
1212
1213 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1214 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1216 assert_eq!(spec.source, ModelSource::Embedded);
1217 assert_eq!(spec.max_output_tokens, 128_000);
1218 }
1219
1220 #[test]
1221 fn override_path_short_circuits_user_and_project() {
1222 let dir = tempfile::tempdir().unwrap();
1223 let user = write_yaml(
1224 dir.path(),
1225 "user.yaml",
1226 r#"
1227version: "1"
1228models:
1229 - provider: "claude"
1230 model: "From User"
1231 api_identifier: "claude-opus-4-6"
1232 max_output_tokens: 1
1233 input_context: 1
1234 generation: 4.6
1235 tier: "flagship"
1236"#,
1237 );
1238 let project = write_yaml(
1239 dir.path(),
1240 "project.yaml",
1241 r#"
1242version: "1"
1243models:
1244 - provider: "claude"
1245 model: "From Project"
1246 api_identifier: "claude-opus-4-6"
1247 max_output_tokens: 2
1248 input_context: 2
1249 generation: 4.6
1250 tier: "flagship"
1251"#,
1252 );
1253 let override_file = write_yaml(
1254 dir.path(),
1255 "override.yaml",
1256 r#"
1257version: "1"
1258models:
1259 - provider: "claude"
1260 model: "From Override"
1261 api_identifier: "claude-opus-4-6"
1262 max_output_tokens: 3
1263 input_context: 3
1264 generation: 4.6
1265 tier: "flagship"
1266"#,
1267 );
1268
1269 let registry = ModelRegistry::load_layered_from_paths(
1270 Some(&project),
1271 Some(&user),
1272 Some(&override_file),
1273 )
1274 .unwrap();
1275 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1276 assert_eq!(spec.model, "From Override");
1277 assert_eq!(spec.max_output_tokens, 3);
1278 assert_eq!(spec.source, ModelSource::Override);
1279 }
1280
1281 #[test]
1282 fn missing_override_path_falls_back_to_embedded() {
1283 let dir = tempfile::tempdir().unwrap();
1284 let missing = dir.path().join("does-not-exist.yaml");
1285 let registry = ModelRegistry::load_layered_from_paths(None, None, Some(&missing)).unwrap();
1286 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1287 assert_eq!(spec.source, ModelSource::Embedded);
1288 }
1289
1290 #[test]
1291 fn version_mismatch_is_warned_not_fatal() {
1292 let dir = tempfile::tempdir().unwrap();
1293 let user = write_yaml(
1294 dir.path(),
1295 "user.yaml",
1296 r#"
1297version: "9999"
1298models:
1299 - provider: "claude"
1300 model: "From Future"
1301 api_identifier: "claude-future-9000"
1302 max_output_tokens: 1
1303 input_context: 1
1304 generation: 9.0
1305 tier: "flagship"
1306"#,
1307 );
1308 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1309 assert!(registry.get_model_spec("claude-future-9000").is_some());
1311 }
1312
1313 #[test]
1314 fn missing_version_is_accepted() {
1315 let dir = tempfile::tempdir().unwrap();
1316 let user = write_yaml(
1317 dir.path(),
1318 "user.yaml",
1319 r#"
1320models:
1321 - provider: "claude"
1322 model: "Versionless"
1323 api_identifier: "claude-versionless"
1324 max_output_tokens: 1
1325 input_context: 1
1326 generation: 1.0
1327 tier: "flagship"
1328"#,
1329 );
1330 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1331 assert!(registry.get_model_spec("claude-versionless").is_some());
1332 }
1333
1334 #[test]
1335 fn model_entry_without_api_identifier_is_skipped() {
1336 let dir = tempfile::tempdir().unwrap();
1337 let user = write_yaml(
1338 dir.path(),
1339 "user.yaml",
1340 r#"
1341version: "1"
1342models:
1343 - provider: "claude"
1344 model: "No Id"
1345 max_output_tokens: 1
1346 input_context: 1
1347 generation: 1.0
1348 tier: "flagship"
1349"#,
1350 );
1351 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1352 let opus = registry.get_model_spec("claude-opus-4-6").unwrap();
1354 assert_eq!(opus.source, ModelSource::Embedded);
1355 }
1356
1357 #[test]
1358 fn model_source_display() {
1359 assert_eq!(ModelSource::Embedded.to_string(), "embedded");
1360 assert_eq!(ModelSource::User.to_string(), "user");
1361 assert_eq!(ModelSource::Project.to_string(), "project");
1362 assert_eq!(ModelSource::Override.to_string(), "override");
1363 }
1364
1365 #[test]
1366 fn embedded_yaml_must_not_be_malformed() {
1367 let layers = [(ModelSource::Embedded, "::: not yaml :::".to_string())];
1369 let result = ModelRegistry::from_layers(&layers);
1370 assert!(result.is_err());
1371 }
1372
1373 #[test]
1374 fn user_layer_with_scalar_top_level_returns_error() {
1375 let dir = tempfile::tempdir().unwrap();
1379 let user = write_yaml(dir.path(), "user.yaml", "\"just a string\"\n");
1380 let result = ModelRegistry::load_layered_from_paths(None, Some(&user), None);
1381 assert!(result.is_err());
1382 }
1383
1384 #[test]
1385 fn user_layer_with_non_sequence_models_returns_error() {
1386 let dir = tempfile::tempdir().unwrap();
1390 let user = write_yaml(
1391 dir.path(),
1392 "user.yaml",
1393 r#"
1394version: "1"
1395models: 42
1396"#,
1397 );
1398 let result = ModelRegistry::load_layered_from_paths(None, Some(&user), None);
1399 assert!(result.is_err());
1400 }
1401
1402 #[test]
1403 fn user_layer_with_non_mapping_providers_returns_error() {
1404 let dir = tempfile::tempdir().unwrap();
1407 let user = write_yaml(
1408 dir.path(),
1409 "user.yaml",
1410 r#"
1411version: "1"
1412providers: 42
1413"#,
1414 );
1415 let result = ModelRegistry::load_layered_from_paths(None, Some(&user), None);
1416 assert!(result.is_err());
1417 }
1418
1419 #[test]
1420 fn deep_merge_inserts_new_keys_into_existing_mapping() {
1421 let dir = tempfile::tempdir().unwrap();
1425 let user = write_yaml(
1426 dir.path(),
1427 "user.yaml",
1428 r#"
1429version: "1"
1430providers:
1431 claude:
1432 tiers:
1433 experimental:
1434 description: "Experimental tier"
1435 use_cases: ["bleeding edge"]
1436"#,
1437 );
1438 let registry = ModelRegistry::load_layered_from_paths(None, Some(&user), None).unwrap();
1439 let claude = registry.get_provider_config("claude").unwrap();
1440 assert!(claude.tiers.contains_key("flagship"));
1442 assert!(claude.tiers.contains_key("balanced"));
1443 assert!(claude.tiers.contains_key("fast"));
1444 let experimental = claude.tiers.get("experimental").unwrap();
1446 assert_eq!(experimental.description, "Experimental tier");
1447 assert_eq!(experimental.use_cases, vec!["bleeding edge".to_string()]);
1448 }
1449
1450 #[test]
1451 #[cfg(unix)]
1452 fn user_path_pointing_at_a_directory_logs_and_falls_through() {
1453 let dir = tempfile::tempdir().unwrap();
1456 let bogus = dir.path().join("models.yaml");
1457 std::fs::create_dir(&bogus).unwrap();
1458 let registry = ModelRegistry::load_layered_from_paths(None, Some(&bogus), None).unwrap();
1459 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1460 assert_eq!(spec.source, ModelSource::Embedded);
1461 }
1462
1463 #[test]
1464 #[cfg(unix)]
1465 fn override_path_pointing_at_a_directory_warns_and_falls_through() {
1466 let dir = tempfile::tempdir().unwrap();
1467 let bogus = dir.path().join("override.yaml");
1468 std::fs::create_dir(&bogus).unwrap();
1469 let registry = ModelRegistry::load_layered_from_paths(None, None, Some(&bogus)).unwrap();
1470 let spec = registry.get_model_spec("claude-opus-4-6").unwrap();
1471 assert_eq!(spec.source, ModelSource::Embedded);
1472 }
1473
1474 #[test]
1475 fn project_layer_recovers_after_user_replaces_top_level_with_scalar() {
1476 let dir = tempfile::tempdir().unwrap();
1482 let user = write_yaml(dir.path(), "user.yaml", "\"junk\"\n");
1483 let project = write_yaml(
1484 dir.path(),
1485 "project.yaml",
1486 r#"
1487version: "1"
1488models:
1489 - provider: "claude"
1490 model: "Project Rescue"
1491 api_identifier: "claude-rescue"
1492 max_output_tokens: 1
1493 input_context: 1
1494 generation: 1.0
1495 tier: "flagship"
1496providers:
1497 custom-provider:
1498 name: "Custom"
1499 api_base: "https://example.invalid"
1500 default_model: "custom-default"
1501 tiers: {}
1502 defaults:
1503 max_output_tokens: 100
1504 input_context: 1000
1505"#,
1506 );
1507 let registry =
1508 ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1509 let spec = registry.get_model_spec("claude-rescue").unwrap();
1511 assert_eq!(spec.source, ModelSource::Project);
1512 }
1513
1514 #[test]
1515 fn project_layer_recovers_after_user_replaces_models_with_scalar() {
1516 let dir = tempfile::tempdir().unwrap();
1520 let user = write_yaml(
1521 dir.path(),
1522 "user.yaml",
1523 r#"
1524version: "1"
1525models: 42
1526"#,
1527 );
1528 let project = write_yaml(
1529 dir.path(),
1530 "project.yaml",
1531 r#"
1532version: "1"
1533models:
1534 - provider: "claude"
1535 model: "Project Rescue"
1536 api_identifier: "claude-rescue"
1537 max_output_tokens: 1
1538 input_context: 1
1539 generation: 1.0
1540 tier: "flagship"
1541"#,
1542 );
1543 let registry =
1544 ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1545 let spec = registry.get_model_spec("claude-rescue").unwrap();
1546 assert_eq!(spec.source, ModelSource::Project);
1547 }
1548
1549 #[test]
1550 fn project_layer_recovers_after_user_replaces_providers_with_scalar() {
1551 let dir = tempfile::tempdir().unwrap();
1554 let user = write_yaml(
1555 dir.path(),
1556 "user.yaml",
1557 r#"
1558version: "1"
1559providers: 42
1560"#,
1561 );
1562 let project = write_yaml(
1563 dir.path(),
1564 "project.yaml",
1565 r#"
1566version: "1"
1567providers:
1568 custom-provider:
1569 name: "Custom"
1570 api_base: "https://example.invalid"
1571 default_model: "custom-default"
1572 tiers: {}
1573 defaults:
1574 max_output_tokens: 100
1575 input_context: 1000
1576"#,
1577 );
1578 let registry =
1579 ModelRegistry::load_layered_from_paths(Some(&project), Some(&user), None).unwrap();
1580 let provider = registry.get_provider_config("custom-provider").unwrap();
1581 assert_eq!(provider.name, "Custom");
1582 assert_eq!(provider.source, ModelSource::Project);
1583 }
1584
1585 #[test]
1586 fn empty_omni_dev_models_yaml_env_var_is_ignored() {
1587 let resolved: Option<PathBuf> = Some(String::new())
1591 .filter(|s| !s.is_empty())
1592 .map(PathBuf::from);
1593 assert!(resolved.is_none());
1594 let resolved: Option<PathBuf> = Some("/some/path".to_string())
1595 .filter(|s| !s.is_empty())
1596 .map(PathBuf::from);
1597 assert_eq!(resolved.as_deref(), Some(Path::new("/some/path")));
1598 }
1599}