1use crate::auth::{AuthStorage, SapResolvedCredentials, resolve_sap_credentials};
4use crate::error::Error;
5use crate::provider::{Api, InputType, Model, ModelCost};
6use crate::provider_metadata::{
7 ProviderRoutingDefaults, canonical_provider_id, provider_routing_defaults,
8};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use std::fs;
13use std::io::Write;
14use std::path::{Path, PathBuf};
15use std::sync::OnceLock;
16
17#[derive(Debug, Clone)]
18pub struct ModelEntry {
19 pub model: Model,
20 pub api_key: Option<String>,
21 pub headers: HashMap<String, String>,
22 pub auth_header: bool,
23 pub compat: Option<CompatConfig>,
24 pub oauth_config: Option<OAuthConfig>,
26}
27
28impl ModelEntry {
29 pub fn supports_xhigh(&self) -> bool {
31 matches!(
32 self.model.id.as_str(),
33 "gpt-5.1-codex-max"
34 | "gpt-5.2"
35 | "gpt-5.2-codex"
36 | "gpt-5.3-codex"
37 | "gpt-5.3-codex-spark"
38 )
39 }
40
41 pub fn clamp_thinking_level(
46 &self,
47 thinking: crate::model::ThinkingLevel,
48 ) -> crate::model::ThinkingLevel {
49 if !self.model.reasoning {
50 return crate::model::ThinkingLevel::Off;
51 }
52 if thinking == crate::model::ThinkingLevel::XHigh && !self.supports_xhigh() {
53 return crate::model::ThinkingLevel::High;
54 }
55 thinking
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct OAuthConfig {
62 pub auth_url: String,
63 pub token_url: String,
64 pub client_id: String,
65 pub scopes: Vec<String>,
66 pub redirect_uri: Option<String>,
67}
68
69#[derive(Debug, Clone, Default, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub struct ModelsConfig {
72 pub providers: HashMap<String, ProviderConfig>,
73}
74
75#[derive(Debug, Clone, Default, Deserialize)]
76#[serde(rename_all = "camelCase")]
77pub struct ProviderConfig {
78 pub base_url: Option<String>,
79 pub api: Option<String>,
80 pub api_key: Option<String>,
81 pub headers: Option<HashMap<String, String>>,
82 pub auth_header: Option<bool>,
83 pub compat: Option<CompatConfig>,
84 pub models: Option<Vec<ModelConfig>>,
85}
86
87#[derive(Debug, Clone, Default, Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct ModelConfig {
90 pub id: String,
91 pub name: Option<String>,
92 pub api: Option<String>,
93 pub reasoning: Option<bool>,
94 pub input: Option<Vec<String>>,
95 pub cost: Option<ModelCost>,
96 pub context_window: Option<u32>,
97 pub max_tokens: Option<u32>,
98 pub headers: Option<HashMap<String, String>>,
99 pub compat: Option<CompatConfig>,
100}
101
102#[derive(Debug, Clone, Default, Deserialize, Serialize)]
103#[serde(rename_all = "camelCase")]
104pub struct CompatConfig {
105 pub supports_store: Option<bool>,
107 pub supports_developer_role: Option<bool>,
108 pub supports_reasoning_effort: Option<bool>,
109 pub supports_usage_in_streaming: Option<bool>,
110 pub supports_tools: Option<bool>,
111 pub supports_streaming: Option<bool>,
112 pub supports_parallel_tool_calls: Option<bool>,
113
114 pub max_tokens_field: Option<String>,
117 pub system_role_name: Option<String>,
119 pub stop_reason_field: Option<String>,
121
122 pub custom_headers: Option<HashMap<String, String>>,
126
127 pub open_router_routing: Option<serde_json::Value>,
129 pub vercel_gateway_routing: Option<serde_json::Value>,
130}
131
132#[derive(Debug, Clone)]
133pub struct ModelRegistry {
134 models: Vec<ModelEntry>,
135 error: Option<String>,
136}
137
138#[derive(Debug, Clone)]
139pub struct ModelAutocompleteCandidate {
140 pub slug: String,
141 pub description: Option<String>,
142}
143
144#[derive(Debug, Clone, Deserialize, Serialize)]
145#[serde(rename_all = "camelCase")]
146struct LegacyGeneratedModel {
147 id: String,
148 name: String,
149 api: String,
150 provider: String,
151 #[serde(default)]
152 base_url: String,
153 #[serde(default)]
154 reasoning: bool,
155 #[serde(default)]
156 input: Vec<String>,
157 #[serde(default)]
158 cost: Option<ModelCost>,
159 #[serde(default)]
160 context_window: Option<u32>,
161 #[serde(default)]
162 max_tokens: Option<u32>,
163 #[serde(default)]
164 headers: HashMap<String, String>,
165 #[serde(default)]
166 compat: Option<CompatConfig>,
167}
168
169const LEGACY_MODELS_GENERATED_TS: &str =
170 include_str!("../legacy_pi_mono_code/pi-mono/packages/ai/src/models.generated.ts");
171const UPSTREAM_PROVIDER_MODEL_IDS_JSON: &str =
172 include_str!("../docs/provider-upstream-model-ids-snapshot.json");
173const CODEX_RESPONSES_API_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
174const GOOGLE_GEMINI_CLI_API_URL: &str = "https://cloudcode-pa.googleapis.com";
175const GOOGLE_ANTIGRAVITY_API_URL: &str = "https://daily-cloudcode-pa.sandbox.googleapis.com";
176
177static LEGACY_GENERATED_MODELS_CACHE: OnceLock<Vec<LegacyGeneratedModel>> = OnceLock::new();
178static UPSTREAM_PROVIDER_MODEL_IDS_CACHE: OnceLock<HashMap<String, Vec<String>>> = OnceLock::new();
179static MODEL_AUTOCOMPLETE_CACHE: OnceLock<Vec<ModelAutocompleteCandidate>> = OnceLock::new();
180static MODEL_CATALOG_CACHE_FINGERPRINT: OnceLock<u64> = OnceLock::new();
181static SATISFIES_RE: OnceLock<Regex> = OnceLock::new();
182const INPUT_TEXT_ONLY: [InputType; 1] = [InputType::Text];
183const INPUT_TEXT_AND_IMAGE: [InputType; 2] = [InputType::Text, InputType::Image];
184
185fn canonicalize_openrouter_model_id(model_id: &str) -> String {
186 let trimmed = model_id.trim();
187 match trimmed.to_ascii_lowercase().as_str() {
188 "auto" => "openrouter/auto".to_string(),
189 "gpt-4o-mini" => "openai/gpt-4o-mini".to_string(),
190 "gpt-4o" => "openai/gpt-4o".to_string(),
191 "claude-3.5-sonnet" => "anthropic/claude-3.5-sonnet".to_string(),
192 "gemini-2.5-pro" => "google/gemini-2.5-pro".to_string(),
193 _ => trimmed.to_string(),
194 }
195}
196
197fn canonicalize_model_id_for_provider(provider: &str, model_id: &str) -> String {
198 if canonical_provider_id(provider).is_some_and(|canonical| canonical == "openrouter") {
199 return canonicalize_openrouter_model_id(model_id);
200 }
201 model_id.trim().to_string()
202}
203
204fn normalized_registry_key(provider: &str, model_id: &str) -> (String, String) {
205 let provider = provider.trim();
206 let canonical_provider = canonical_provider_id(provider).unwrap_or(provider);
207 let canonical_model_id = canonicalize_model_id_for_provider(canonical_provider, model_id);
208 (
209 canonical_provider.to_ascii_lowercase(),
210 canonical_model_id.to_ascii_lowercase(),
211 )
212}
213
214fn openrouter_model_lookup_ids(model_id: &str) -> Vec<String> {
215 let raw = model_id.trim().to_string();
216 let canonical = canonicalize_openrouter_model_id(model_id);
217 if canonical.eq_ignore_ascii_case(&raw) {
218 vec![canonical]
219 } else {
220 vec![raw, canonical]
221 }
222}
223
224fn api_fallback_base_url(api: &str) -> Option<&'static str> {
225 match api {
226 "openai-codex-responses" => Some(CODEX_RESPONSES_API_URL),
227 "google-gemini-cli" => Some(GOOGLE_GEMINI_CLI_API_URL),
228 "google-antigravity" => Some(GOOGLE_ANTIGRAVITY_API_URL),
229 _ => None,
230 }
231}
232
233fn parse_input_types(input: &[String]) -> Vec<InputType> {
234 input
235 .iter()
236 .filter_map(|value| match value.as_str() {
237 "text" => Some(InputType::Text),
238 "image" => Some(InputType::Image),
239 _ => None,
240 })
241 .collect()
242}
243
244fn legacy_generated_models_cache_path() -> Option<PathBuf> {
245 let checksum = crc32c::crc32c(LEGACY_MODELS_GENERATED_TS.as_bytes());
246 dirs::cache_dir().map(|dir| {
247 dir.join("pi")
248 .join("models-cache")
249 .join(format!("legacy-generated-models-{checksum:08x}.json"))
250 })
251}
252
253fn load_legacy_generated_models_cache() -> Option<Vec<LegacyGeneratedModel>> {
254 let path = legacy_generated_models_cache_path()?;
255 let cache = fs::read_to_string(path).ok()?;
256 serde_json::from_str::<Vec<LegacyGeneratedModel>>(&cache).ok()
257}
258
259fn persist_legacy_generated_models_cache(models: &[LegacyGeneratedModel]) {
260 let Some(path) = legacy_generated_models_cache_path() else {
261 return;
262 };
263 if path.exists() {
264 return;
265 }
266 let Some(parent) = path.parent() else {
267 return;
268 };
269 if fs::create_dir_all(parent).is_err() {
270 return;
271 }
272
273 let temp_path = path.with_extension(format!("tmp-{}", std::process::id()));
274 let Ok(file) = fs::OpenOptions::new()
275 .write(true)
276 .create_new(true)
277 .open(&temp_path)
278 else {
279 return;
280 };
281 let mut writer = std::io::BufWriter::new(file);
282 if serde_json::to_writer(&mut writer, models).is_ok() && writer.flush().is_ok() {
283 let _ = fs::rename(temp_path, path);
284 }
285}
286
287fn parse_legacy_generated_models() -> Vec<LegacyGeneratedModel> {
288 if let Some(cached) = load_legacy_generated_models_cache() {
289 return cached;
290 }
291
292 let Some(models_decl_start) = LEGACY_MODELS_GENERATED_TS.find("export const MODELS =") else {
293 tracing::warn!("Legacy model catalog missing MODELS declaration");
294 return Vec::new();
295 };
296 let Some(object_start_rel) = LEGACY_MODELS_GENERATED_TS[models_decl_start..].find('{') else {
297 tracing::warn!("Legacy model catalog missing object start after MODELS declaration");
298 return Vec::new();
299 };
300 let object_start = models_decl_start + object_start_rel;
301 let Some(end_marker_rel) = LEGACY_MODELS_GENERATED_TS[object_start..].rfind("} as const;")
302 else {
303 tracing::warn!("Legacy model catalog missing end marker");
304 return Vec::new();
305 };
306 let end_marker = object_start + end_marker_rel;
307
308 let mut object_source = LEGACY_MODELS_GENERATED_TS[object_start..=end_marker]
309 .trim_end_matches(" as const;")
310 .to_string();
311 let satisfies_re = SATISFIES_RE.get_or_init(|| {
312 Regex::new(r#"\s+satisfies\s+Model<"[^"]+">"#).expect("valid satisfies regex")
313 });
314 object_source = satisfies_re.replace_all(&object_source, "").into_owned();
315
316 let parsed: HashMap<String, HashMap<String, LegacyGeneratedModel>> =
317 match json5::from_str(&object_source) {
318 Ok(value) => value,
319 Err(err) => {
320 tracing::warn!(error = %err, "Failed to parse legacy model catalog");
321 return Vec::new();
322 }
323 };
324
325 let mut models = parsed
326 .into_values()
327 .flat_map(HashMap::into_values)
328 .collect::<Vec<_>>();
329 models.sort_by(|a, b| {
330 a.provider
331 .cmp(&b.provider)
332 .then_with(|| a.id.cmp(&b.id))
333 .then_with(|| a.api.cmp(&b.api))
334 });
335 persist_legacy_generated_models_cache(&models);
336 models
337}
338
339fn legacy_generated_models() -> &'static [LegacyGeneratedModel] {
340 LEGACY_GENERATED_MODELS_CACHE
341 .get_or_init(parse_legacy_generated_models)
342 .as_slice()
343}
344
345fn parse_upstream_provider_model_ids() -> HashMap<String, Vec<String>> {
346 let parsed: HashMap<String, Vec<String>> =
347 match serde_json::from_str(UPSTREAM_PROVIDER_MODEL_IDS_JSON) {
348 Ok(value) => value,
349 Err(err) => {
350 tracing::warn!(error = %err, "Failed to parse upstream provider model snapshot");
351 return HashMap::new();
352 }
353 };
354
355 let mut by_provider: HashMap<String, Vec<String>> = HashMap::new();
356 for (provider, ids) in parsed {
357 let provider = provider.trim();
358 if provider.is_empty() {
359 continue;
360 }
361 let canonical_provider = canonical_provider_id(provider)
362 .unwrap_or(provider)
363 .to_string();
364 let entry = by_provider.entry(canonical_provider.clone()).or_default();
365 for model_id in ids {
366 let normalized = canonicalize_model_id_for_provider(&canonical_provider, &model_id);
367 if !normalized.is_empty() {
368 entry.push(normalized);
369 }
370 }
371 }
372
373 for ids in by_provider.values_mut() {
374 ids.sort_unstable();
375 ids.dedup();
376 }
377 by_provider
378}
379
380fn upstream_provider_model_ids() -> &'static HashMap<String, Vec<String>> {
381 UPSTREAM_PROVIDER_MODEL_IDS_CACHE.get_or_init(parse_upstream_provider_model_ids)
382}
383
384pub fn model_autocomplete_candidates() -> &'static [ModelAutocompleteCandidate] {
385 MODEL_AUTOCOMPLETE_CACHE
386 .get_or_init(|| {
387 let mut candidates = legacy_generated_models()
388 .iter()
389 .map(|entry| ModelAutocompleteCandidate {
390 slug: format!("{}/{}", entry.provider, entry.id),
391 description: Some(entry.name.clone()).filter(|name| !name.trim().is_empty()),
392 })
393 .collect::<Vec<_>>();
394 for (provider, ids) in upstream_provider_model_ids() {
395 let provider = provider.trim();
396 if provider.is_empty() {
397 continue;
398 }
399 for id in ids {
400 if id.trim().is_empty() {
401 continue;
402 }
403 candidates.push(ModelAutocompleteCandidate {
404 slug: format!("{provider}/{id}"),
405 description: None,
406 });
407 }
408 }
409 candidates.push(ModelAutocompleteCandidate {
410 slug: "anthropic/claude-sonnet-4-6".to_string(),
411 description: Some("Claude Sonnet 4.6".to_string()),
412 });
413 candidates.sort_by_key(|candidate| candidate.slug.to_ascii_lowercase());
414 candidates.dedup_by(|a, b| a.slug.eq_ignore_ascii_case(&b.slug));
415 candidates
416 })
417 .as_slice()
418}
419
420pub fn model_catalog_cache_fingerprint() -> u64 {
421 *MODEL_CATALOG_CACHE_FINGERPRINT.get_or_init(|| {
422 let legacy = u64::from(crc32c::crc32c(LEGACY_MODELS_GENERATED_TS.as_bytes()));
423 let upstream = u64::from(crc32c::crc32c(UPSTREAM_PROVIDER_MODEL_IDS_JSON.as_bytes()));
424 (legacy << 32) | upstream
425 })
426}
427
428fn model_requires_configured_credential(entry: &ModelEntry) -> bool {
429 let provider = entry.model.provider.as_str();
430 entry.auth_header
431 || crate::provider_metadata::provider_metadata(provider)
432 .is_some_and(|meta| !meta.auth_env_keys.is_empty())
433 || entry.oauth_config.is_some()
434}
435
436fn model_entry_is_ready(entry: &ModelEntry) -> bool {
437 !model_requires_configured_credential(entry)
438 || entry
439 .api_key
440 .as_ref()
441 .is_some_and(|value| !value.trim().is_empty())
442}
443
444#[derive(Clone, Copy, Debug, PartialEq, Eq)]
445enum ModelRegistryLoadMode {
446 Full,
447 ListingLite,
448}
449
450impl ModelRegistry {
451 pub fn load(auth: &AuthStorage, models_path: Option<PathBuf>) -> Self {
452 Self::load_with_mode(auth, models_path, ModelRegistryLoadMode::Full)
453 }
454
455 pub fn load_for_listing(auth: &AuthStorage, models_path: Option<PathBuf>) -> Self {
456 Self::load_with_mode(auth, models_path, ModelRegistryLoadMode::ListingLite)
457 }
458
459 fn load_with_mode(
460 auth: &AuthStorage,
461 models_path: Option<PathBuf>,
462 mode: ModelRegistryLoadMode,
463 ) -> Self {
464 let mut models = built_in_models(auth, mode);
465 let mut error = None;
466
467 if let Some(path) = models_path {
468 if path.exists() {
469 match std::fs::read_to_string(&path)
470 .map_err(|e| Error::config(format!("Failed to read models.json: {e}")))
471 .and_then(|s| serde_json::from_str::<ModelsConfig>(&s).map_err(Error::from))
472 {
473 Ok(config) => {
474 apply_custom_models(auth, &mut models, &config);
475 }
476 Err(e) => {
477 error = Some(format!("{e}\n\nFile: {}", path.display()));
478 }
479 }
480 }
481 }
482
483 Self { models, error }
484 }
485
486 pub fn models(&self) -> &[ModelEntry] {
487 &self.models
488 }
489
490 pub fn error(&self) -> Option<&str> {
491 self.error.as_deref()
492 }
493
494 pub fn available_models(&self) -> Vec<&ModelEntry> {
495 self.models
496 .iter()
497 .filter(|m| model_entry_is_ready(m))
498 .collect()
499 }
500
501 pub fn get_available(&self) -> Vec<ModelEntry> {
502 self.available_models().into_iter().cloned().collect()
503 }
504
505 pub fn find(&self, provider: &str, id: &str) -> Option<ModelEntry> {
506 let provider = provider.trim();
507 let canonical_provider = canonical_provider_id(provider).unwrap_or(provider);
508 let is_openrouter = canonical_provider.eq_ignore_ascii_case("openrouter");
509 let openrouter_ids = if is_openrouter {
511 openrouter_model_lookup_ids(id)
512 } else {
513 Vec::new()
514 };
515 let trimmed_id = id.trim();
516
517 self.models
518 .iter()
519 .find(|m| {
520 let model_provider = m.model.provider.as_str();
521 let model_provider_canonical =
522 canonical_provider_id(model_provider).unwrap_or(model_provider);
523 let provider_matches = model_provider.eq_ignore_ascii_case(provider)
524 || model_provider.eq_ignore_ascii_case(canonical_provider)
525 || model_provider_canonical.eq_ignore_ascii_case(provider)
526 || model_provider_canonical.eq_ignore_ascii_case(canonical_provider);
527 provider_matches
528 && if is_openrouter {
529 openrouter_ids
530 .iter()
531 .any(|lookup_id| m.model.id.eq_ignore_ascii_case(lookup_id))
532 } else {
533 m.model.id.eq_ignore_ascii_case(trimmed_id)
534 }
535 })
536 .cloned()
537 }
538
539 pub fn find_by_id(&self, id: &str) -> Option<ModelEntry> {
542 let id = id.trim();
543 self.models
544 .iter()
545 .find(|m| m.model.id.eq_ignore_ascii_case(id))
546 .cloned()
547 }
548
549 pub fn merge_entries(&mut self, entries: Vec<ModelEntry>) {
551 for entry in entries {
552 let entry_key = normalized_registry_key(&entry.model.provider, &entry.model.id);
554 let exists = self
555 .models
556 .iter()
557 .any(|m| normalized_registry_key(&m.model.provider, &m.model.id) == entry_key);
558 if !exists {
559 self.models.push(entry);
560 }
561 }
562 }
563}
564
565fn native_adapter_seed_defaults(provider: &str) -> Option<AdHocProviderDefaults> {
566 match provider {
567 "azure-openai" => Some(AdHocProviderDefaults {
568 api: "openai-completions",
569 base_url: "",
570 auth_header: false,
571 reasoning: true,
572 input: &INPUT_TEXT_AND_IMAGE,
573 context_window: 128_000,
574 max_tokens: 16_384,
575 }),
576 "github-copilot" | "gitlab" | "sap-ai-core" => Some(AdHocProviderDefaults {
577 api: "openai-completions",
578 base_url: "",
579 auth_header: true,
580 reasoning: true,
581 input: &INPUT_TEXT_ONLY,
582 context_window: 128_000,
583 max_tokens: 16_384,
584 }),
585 _ => None,
586 }
587}
588
589fn legacy_provider_ids() -> HashSet<String> {
590 legacy_generated_models()
591 .iter()
592 .map(|model| {
593 let provider = model.provider.trim();
594 canonical_provider_id(provider)
595 .unwrap_or(provider)
596 .to_ascii_lowercase()
597 })
598 .collect()
599}
600
601fn resolve_provider_api_key_cached(
602 auth: &AuthStorage,
603 canonical_provider: &str,
604 provider: &str,
605 canonical_cache: &mut HashMap<String, Option<String>>,
606 provider_cache: &mut HashMap<String, Option<String>>,
607) -> Option<String> {
608 let canonical_key = canonical_provider.to_ascii_lowercase();
609 let canonical_result = canonical_cache
610 .entry(canonical_key)
611 .or_insert_with(|| auth.resolve_api_key(canonical_provider, None))
612 .clone();
613
614 if canonical_result.is_some() || canonical_provider.eq_ignore_ascii_case(provider) {
615 return canonical_result;
616 }
617
618 provider_cache
619 .entry(provider.to_ascii_lowercase())
620 .or_insert_with(|| auth.resolve_api_key(provider, None))
621 .clone()
622}
623
624fn append_upstream_nonlegacy_models(
625 auth: &AuthStorage,
626 models: &mut Vec<ModelEntry>,
627 seen: &mut HashSet<String>,
628 canonical_api_key_cache: &mut HashMap<String, Option<String>>,
629 provider_api_key_cache: &mut HashMap<String, Option<String>>,
630) {
631 let legacy_providers = legacy_provider_ids();
632 for (provider, ids) in upstream_provider_model_ids() {
633 let provider = provider.trim();
634 if provider.is_empty() {
635 continue;
636 }
637 let canonical_provider = canonical_provider_id(provider).unwrap_or(provider);
638 if legacy_providers.contains(&canonical_provider.to_ascii_lowercase()) {
639 continue;
640 }
641
642 let Some(defaults) = ad_hoc_provider_defaults(canonical_provider)
643 .or_else(|| native_adapter_seed_defaults(canonical_provider))
644 else {
645 continue;
646 };
647
648 let api_key = resolve_provider_api_key_cached(
649 auth,
650 canonical_provider,
651 provider,
652 canonical_api_key_cache,
653 provider_api_key_cache,
654 );
655
656 for model_id in ids {
657 let normalized_model_id =
658 canonicalize_model_id_for_provider(canonical_provider, model_id);
659 if normalized_model_id.is_empty() {
660 continue;
661 }
662 let dedupe_key = format!(
663 "{}::{}",
664 canonical_provider.to_ascii_lowercase(),
665 normalized_model_id.to_ascii_lowercase()
666 );
667 if !seen.insert(dedupe_key) {
668 continue;
669 }
670
671 models.push(ModelEntry {
672 model: Model {
673 id: normalized_model_id.clone(),
674 name: normalized_model_id,
675 api: defaults.api.to_string(),
676 provider: canonical_provider.to_string(),
677 base_url: defaults.base_url.to_string(),
678 reasoning: defaults.reasoning,
679 input: defaults.input.to_vec(),
680 cost: ModelCost {
681 input: 0.0,
682 output: 0.0,
683 cache_read: 0.0,
684 cache_write: 0.0,
685 },
686 context_window: defaults.context_window,
687 max_tokens: defaults.max_tokens,
688 headers: HashMap::new(),
689 },
690 api_key: api_key.clone(),
691 headers: HashMap::new(),
692 auth_header: defaults.auth_header,
693 compat: None,
694 oauth_config: None,
695 });
696 }
697 }
698}
699
700#[allow(clippy::too_many_lines)]
701fn built_in_models(auth: &AuthStorage, mode: ModelRegistryLoadMode) -> Vec<ModelEntry> {
702 let mut models = Vec::with_capacity(legacy_generated_models().len() + 8);
703 let mut seen = HashSet::new();
704 let mut canonical_api_key_cache: HashMap<String, Option<String>> = HashMap::new();
705 let mut provider_api_key_cache: HashMap<String, Option<String>> = HashMap::new();
706
707 for legacy in legacy_generated_models() {
708 let provider = legacy.provider.trim();
709 if provider.is_empty() {
710 continue;
711 }
712
713 let normalized_model_id = canonicalize_model_id_for_provider(provider, &legacy.id);
714 if normalized_model_id.is_empty() {
715 continue;
716 }
717
718 let dedupe_key = format!(
719 "{}::{}",
720 provider.to_ascii_lowercase(),
721 normalized_model_id.to_ascii_lowercase()
722 );
723 if !seen.insert(dedupe_key) {
724 continue;
725 }
726
727 let routing_defaults = provider_routing_defaults(provider);
728 let api_string = if mode == ModelRegistryLoadMode::Full {
729 legacy
730 .api
731 .parse::<Api>()
732 .unwrap_or_else(|_| Api::Custom(legacy.api.clone()))
733 .to_string()
734 } else {
735 legacy.api.clone()
736 };
737
738 let base_url = if mode == ModelRegistryLoadMode::Full {
739 if !legacy.base_url.trim().is_empty() {
740 legacy.base_url.trim().to_string()
741 } else if let Some(default_base) = routing_defaults
742 .map(|defaults| defaults.base_url)
743 .or_else(|| api_fallback_base_url(api_string.as_str()))
744 {
745 default_base.to_string()
746 } else {
747 String::new()
748 }
749 } else {
750 String::new()
751 };
752
753 let input = {
754 let parsed = parse_input_types(&legacy.input);
755 if parsed.is_empty() {
756 routing_defaults
757 .map_or_else(|| vec![InputType::Text], |defaults| defaults.input.to_vec())
758 } else {
759 parsed
760 }
761 };
762
763 let auth_header = match api_string.as_str() {
764 "openai-codex-responses" | "google-gemini-cli" => true,
765 _ => routing_defaults.is_some_and(|defaults| defaults.auth_header),
766 };
767
768 let canonical_provider = canonical_provider_id(provider).unwrap_or(provider);
769 let api_key = resolve_provider_api_key_cached(
770 auth,
771 canonical_provider,
772 provider,
773 &mut canonical_api_key_cache,
774 &mut provider_api_key_cache,
775 );
776
777 let default_cost = ModelCost {
778 input: 0.0,
779 output: 0.0,
780 cache_read: 0.0,
781 cache_write: 0.0,
782 };
783 let model_name = if mode == ModelRegistryLoadMode::Full && !legacy.name.trim().is_empty() {
784 legacy.name.clone()
785 } else {
786 normalized_model_id.clone()
787 };
788 let model_headers = if mode == ModelRegistryLoadMode::Full {
789 legacy.headers.clone()
790 } else {
791 HashMap::new()
792 };
793 let entry_headers = if mode == ModelRegistryLoadMode::Full {
794 legacy.headers.clone()
795 } else {
796 HashMap::new()
797 };
798
799 models.push(ModelEntry {
800 model: Model {
801 id: normalized_model_id.clone(),
802 name: model_name,
803 api: api_string,
804 provider: provider.to_string(),
805 base_url,
806 reasoning: legacy.reasoning,
807 input,
808 cost: if mode == ModelRegistryLoadMode::Full {
809 legacy.cost.clone().unwrap_or_else(|| default_cost.clone())
810 } else {
811 default_cost
812 },
813 context_window: legacy.context_window.unwrap_or_else(|| {
814 routing_defaults.map_or(128_000, |defaults| defaults.context_window)
815 }),
816 max_tokens: legacy.max_tokens.unwrap_or_else(|| {
817 routing_defaults.map_or(16_384, |defaults| defaults.max_tokens)
818 }),
819 headers: model_headers,
820 },
821 api_key,
822 headers: entry_headers,
823 auth_header,
824 compat: if mode == ModelRegistryLoadMode::Full {
825 legacy.compat.clone()
826 } else {
827 None
828 },
829 oauth_config: None,
830 });
831 }
832
833 append_upstream_nonlegacy_models(
834 auth,
835 &mut models,
836 &mut seen,
837 &mut canonical_api_key_cache,
838 &mut provider_api_key_cache,
839 );
840
841 if !models.iter().any(|entry| {
843 entry.model.provider == "anthropic"
844 && (entry.model.id == "claude-sonnet-4-6"
845 || entry.model.id == "claude-sonnet-4-6-20260217")
846 }) {
847 models.push(ModelEntry {
848 model: Model {
849 id: "claude-sonnet-4-6".to_string(),
850 name: "Claude Sonnet 4.6".to_string(),
851 api: if mode == ModelRegistryLoadMode::Full {
852 Api::AnthropicMessages.to_string()
853 } else {
854 "anthropic-messages".to_string()
855 },
856 provider: "anthropic".to_string(),
857 base_url: if mode == ModelRegistryLoadMode::Full {
858 "https://api.anthropic.com/v1/messages".to_string()
859 } else {
860 String::new()
861 },
862 reasoning: true,
863 input: vec![InputType::Text, InputType::Image],
864 cost: ModelCost {
865 input: 0.0,
866 output: 0.0,
867 cache_read: 0.0,
868 cache_write: 0.0,
869 },
870 context_window: 1_000_000,
871 max_tokens: 128_000,
872 headers: HashMap::new(),
873 },
874 api_key: resolve_provider_api_key_cached(
875 auth,
876 "anthropic",
877 "anthropic",
878 &mut canonical_api_key_cache,
879 &mut provider_api_key_cache,
880 ),
881 headers: HashMap::new(),
882 auth_header: false,
883 compat: None,
884 oauth_config: None,
885 });
886 }
887
888 if !models
893 .iter()
894 .any(|entry| entry.model.provider == "openai-codex" && entry.model.id == "gpt-5.3-codex")
895 {
896 models.push(ModelEntry {
897 model: Model {
898 id: "gpt-5.3-codex".to_string(),
899 name: "GPT-5.3 Codex".to_string(),
900 api: if mode == ModelRegistryLoadMode::Full {
901 Api::OpenAICodexResponses.to_string()
902 } else {
903 "openai-codex-responses".to_string()
904 },
905 provider: "openai-codex".to_string(),
906 base_url: if mode == ModelRegistryLoadMode::Full {
907 "https://chatgpt.com/backend-api".to_string()
908 } else {
909 String::new()
910 },
911 reasoning: true,
912 input: vec![InputType::Text, InputType::Image],
913 cost: ModelCost {
914 input: 0.0,
915 output: 0.0,
916 cache_read: 0.0,
917 cache_write: 0.0,
918 },
919 context_window: 272_000,
920 max_tokens: 128_000,
921 headers: HashMap::new(),
922 },
923 api_key: resolve_provider_api_key_cached(
924 auth,
925 "openai-codex",
926 "openai-codex",
927 &mut canonical_api_key_cache,
928 &mut provider_api_key_cache,
929 ),
930 headers: HashMap::new(),
931 auth_header: true,
932 compat: None,
933 oauth_config: None,
934 });
935 }
936
937 if !models.iter().any(|entry| {
939 entry.model.provider == "openai-codex" && entry.model.id == "gpt-5.3-codex-spark"
940 }) {
941 models.push(ModelEntry {
942 model: Model {
943 id: "gpt-5.3-codex-spark".to_string(),
944 name: "GPT-5.3 Codex Spark".to_string(),
945 api: if mode == ModelRegistryLoadMode::Full {
946 Api::OpenAICodexResponses.to_string()
947 } else {
948 "openai-codex-responses".to_string()
949 },
950 provider: "openai-codex".to_string(),
951 base_url: if mode == ModelRegistryLoadMode::Full {
952 "https://chatgpt.com/backend-api".to_string()
953 } else {
954 String::new()
955 },
956 reasoning: true,
957 input: vec![InputType::Text, InputType::Image],
958 cost: ModelCost {
959 input: 0.0,
960 output: 0.0,
961 cache_read: 0.0,
962 cache_write: 0.0,
963 },
964 context_window: 272_000,
965 max_tokens: 128_000,
966 headers: HashMap::new(),
967 },
968 api_key: resolve_provider_api_key_cached(
969 auth,
970 "openai-codex",
971 "openai-codex",
972 &mut canonical_api_key_cache,
973 &mut provider_api_key_cache,
974 ),
975 headers: HashMap::new(),
976 auth_header: true,
977 compat: None,
978 oauth_config: None,
979 });
980 }
981
982 models
983}
984
985#[allow(clippy::too_many_lines)]
986fn apply_custom_models(auth: &AuthStorage, models: &mut Vec<ModelEntry>, config: &ModelsConfig) {
987 for (provider_id, provider_cfg) in &config.providers {
988 let provider_id_str = provider_id.as_str();
989 let routing_defaults = provider_routing_defaults(provider_id);
990 let default_api = routing_defaults.map_or("openai-completions", |defaults| defaults.api);
991 let provider_api = provider_cfg.api.as_deref().unwrap_or(default_api);
992 let provider_api_parsed: Api = provider_api
993 .parse()
994 .unwrap_or_else(|_| Api::Custom(provider_api.to_string()));
995 let provider_api_string = provider_api_parsed.to_string();
996 let provider_base = provider_cfg.base_url.clone().unwrap_or_else(|| {
997 routing_defaults.map_or_else(
998 || "https://api.openai.com/v1".to_string(),
999 |defaults| defaults.base_url.to_string(),
1000 )
1001 });
1002
1003 let provider_headers = resolve_headers(provider_cfg.headers.as_ref());
1004 let canonical_provider = canonical_provider_id(provider_id).unwrap_or(provider_id_str);
1005 let provider_matches = |candidate_provider: &str| {
1006 let candidate_canonical =
1007 canonical_provider_id(candidate_provider).unwrap_or(candidate_provider);
1008 candidate_provider.eq_ignore_ascii_case(provider_id_str)
1009 || candidate_provider.eq_ignore_ascii_case(canonical_provider)
1010 || candidate_canonical.eq_ignore_ascii_case(provider_id_str)
1011 || candidate_canonical.eq_ignore_ascii_case(canonical_provider)
1012 };
1013 let provider_key = provider_cfg
1014 .api_key
1015 .as_deref()
1016 .and_then(resolve_value)
1017 .or_else(|| auth.resolve_api_key(canonical_provider, None));
1018
1019 let auth_header = provider_cfg
1020 .auth_header
1021 .unwrap_or_else(|| routing_defaults.is_some_and(|defaults| defaults.auth_header));
1022
1023 if routing_defaults.is_some() {
1024 tracing::debug!(
1025 event = "pi.provider.schema_defaults",
1026 provider = %provider_id,
1027 canonical_provider = %canonical_provider,
1028 api = %provider_api_string,
1029 base_url = %provider_base,
1030 auth_header,
1031 "Applied provider metadata defaults"
1032 );
1033 }
1034
1035 let has_models = provider_cfg.models.as_ref().is_some();
1036 let is_override = !has_models;
1037
1038 if is_override {
1039 for entry in models
1040 .iter_mut()
1041 .filter(|m| provider_matches(&m.model.provider))
1042 {
1043 if provider_cfg.base_url.is_some() {
1046 entry.model.base_url.clone_from(&provider_base);
1047 }
1048 if provider_cfg.api.is_some() {
1049 entry.model.api.clone_from(&provider_api_string);
1050 }
1051 if provider_cfg.headers.is_some() {
1052 entry.headers.clone_from(&provider_headers);
1053 }
1054 if provider_key.is_some() {
1055 entry.api_key.clone_from(&provider_key);
1056 }
1057 if provider_cfg.compat.is_some() {
1058 entry.compat.clone_from(&provider_cfg.compat);
1059 }
1060 if provider_cfg.auth_header.is_some() {
1061 entry.auth_header = auth_header;
1062 }
1063 }
1064 continue;
1065 }
1066
1067 models.retain(|m| !provider_matches(&m.model.provider));
1069
1070 let mut normalized_provider_ids = HashSet::new();
1071 for model_cfg in provider_cfg.models.clone().unwrap_or_default() {
1072 let normalized_model_id =
1073 canonicalize_model_id_for_provider(provider_id, &model_cfg.id);
1074 if normalized_model_id.is_empty() {
1075 tracing::warn!(
1076 provider = %provider_id,
1077 model_id = %model_cfg.id,
1078 "Skipping model with empty normalized id"
1079 );
1080 continue;
1081 }
1082
1083 if canonical_provider == "openrouter"
1084 && !normalized_provider_ids.insert(normalized_model_id.to_ascii_lowercase())
1085 {
1086 tracing::warn!(
1087 provider = %provider_id,
1088 model_id = %normalized_model_id,
1089 "Skipping duplicate OpenRouter model id after alias normalization"
1090 );
1091 continue;
1092 }
1093
1094 let model_api = model_cfg.api.as_deref().unwrap_or(provider_api);
1095 let model_api_parsed: Api = model_api
1096 .parse()
1097 .unwrap_or_else(|_| Api::Custom(model_api.to_string()));
1098 let model_headers = merge_headers(
1099 &provider_headers,
1100 resolve_headers(model_cfg.headers.as_ref()),
1101 );
1102 let default_input_types = routing_defaults
1103 .map_or_else(|| vec![InputType::Text], |defaults| defaults.input.to_vec());
1104 let input_types = model_cfg.input.as_ref().map_or_else(
1105 || default_input_types.clone(),
1106 |input| {
1107 input
1108 .iter()
1109 .filter_map(|i| match i.as_str() {
1110 "text" => Some(InputType::Text),
1111 "image" => Some(InputType::Image),
1112 _ => None,
1113 })
1114 .collect::<Vec<_>>()
1115 },
1116 );
1117 let input_types = if input_types.is_empty() {
1118 default_input_types
1119 } else {
1120 input_types
1121 };
1122 let default_reasoning = routing_defaults.is_some_and(|defaults| defaults.reasoning);
1123 let default_context_window =
1124 routing_defaults.map_or(128_000, |defaults| defaults.context_window);
1125 let default_max_tokens =
1126 routing_defaults.map_or(16_384, |defaults| defaults.max_tokens);
1127
1128 let model = Model {
1129 id: normalized_model_id.clone(),
1130 name: model_cfg
1131 .name
1132 .clone()
1133 .unwrap_or_else(|| normalized_model_id.clone()),
1134 api: model_api_parsed.to_string(),
1135 provider: provider_id.clone(),
1136 base_url: provider_base.clone(),
1137 reasoning: model_cfg.reasoning.unwrap_or(default_reasoning),
1138 input: input_types,
1139 cost: model_cfg.cost.clone().unwrap_or(ModelCost {
1140 input: 0.0,
1141 output: 0.0,
1142 cache_read: 0.0,
1143 cache_write: 0.0,
1144 }),
1145 context_window: model_cfg.context_window.unwrap_or(default_context_window),
1146 max_tokens: model_cfg.max_tokens.unwrap_or(default_max_tokens),
1147 headers: HashMap::new(),
1148 };
1149
1150 models.push(ModelEntry {
1151 model,
1152 api_key: provider_key.clone(),
1153 headers: model_headers,
1154 auth_header,
1155 compat: merge_compat(provider_cfg.compat.as_ref(), model_cfg.compat.as_ref()),
1156 oauth_config: None,
1157 });
1158 }
1159 }
1160}
1161
1162fn merge_compat(
1163 provider_compat: Option<&CompatConfig>,
1164 model_compat: Option<&CompatConfig>,
1165) -> Option<CompatConfig> {
1166 match (provider_compat, model_compat) {
1167 (None, None) => None,
1168 (Some(provider), None) => Some(provider.clone()),
1169 (None, Some(model)) => Some(model.clone()),
1170 (Some(provider), Some(model)) => {
1171 let custom_headers = match (&provider.custom_headers, &model.custom_headers) {
1172 (None, None) => None,
1173 (Some(headers), None) | (None, Some(headers)) => Some(headers.clone()),
1174 (Some(provider_headers), Some(model_headers)) => {
1175 let mut merged = provider_headers.clone();
1176 for (key, value) in model_headers {
1177 merged.insert(key.clone(), value.clone());
1178 }
1179 Some(merged)
1180 }
1181 };
1182
1183 Some(CompatConfig {
1184 supports_store: model.supports_store.or(provider.supports_store),
1185 supports_developer_role: model
1186 .supports_developer_role
1187 .or(provider.supports_developer_role),
1188 supports_reasoning_effort: model
1189 .supports_reasoning_effort
1190 .or(provider.supports_reasoning_effort),
1191 supports_usage_in_streaming: model
1192 .supports_usage_in_streaming
1193 .or(provider.supports_usage_in_streaming),
1194 supports_tools: model.supports_tools.or(provider.supports_tools),
1195 supports_streaming: model.supports_streaming.or(provider.supports_streaming),
1196 supports_parallel_tool_calls: model
1197 .supports_parallel_tool_calls
1198 .or(provider.supports_parallel_tool_calls),
1199 max_tokens_field: model
1200 .max_tokens_field
1201 .clone()
1202 .or_else(|| provider.max_tokens_field.clone()),
1203 system_role_name: model
1204 .system_role_name
1205 .clone()
1206 .or_else(|| provider.system_role_name.clone()),
1207 stop_reason_field: model
1208 .stop_reason_field
1209 .clone()
1210 .or_else(|| provider.stop_reason_field.clone()),
1211 custom_headers,
1212 open_router_routing: model
1213 .open_router_routing
1214 .clone()
1215 .or_else(|| provider.open_router_routing.clone()),
1216 vercel_gateway_routing: model
1217 .vercel_gateway_routing
1218 .clone()
1219 .or_else(|| provider.vercel_gateway_routing.clone()),
1220 })
1221 }
1222 }
1223}
1224
1225fn merge_headers(
1226 base: &HashMap<String, String>,
1227 override_headers: HashMap<String, String>,
1228) -> HashMap<String, String> {
1229 let mut merged = base.clone();
1230 for (k, v) in override_headers {
1231 merged.insert(k, v);
1232 }
1233 merged
1234}
1235
1236fn resolve_headers(headers: Option<&HashMap<String, String>>) -> HashMap<String, String> {
1237 let mut resolved = HashMap::new();
1238 if let Some(headers) = headers {
1239 for (k, v) in headers {
1240 if let Some(val) = resolve_value(v) {
1241 resolved.insert(k.clone(), val);
1242 }
1243 }
1244 }
1245 resolved
1246}
1247
1248fn resolve_value(value: &str) -> Option<String> {
1249 if let Some(rest) = value.strip_prefix('!') {
1250 return resolve_shell(rest);
1251 }
1252
1253 if let Some(var_name) = value.strip_prefix("env:") {
1254 if var_name.is_empty() {
1255 return None;
1256 }
1257 return std::env::var(var_name).ok().filter(|v| !v.is_empty());
1258 }
1259
1260 if let Some(file_path) = value.strip_prefix("file:") {
1261 if file_path.is_empty() {
1262 return None;
1263 }
1264 return std::fs::read_to_string(file_path)
1265 .ok()
1266 .map(|contents| contents.trim().to_string())
1267 .filter(|v| !v.is_empty());
1268 }
1269
1270 if value.is_empty() {
1271 None
1272 } else {
1273 Some(value.to_string())
1274 }
1275}
1276
1277fn resolve_shell(cmd: &str) -> Option<String> {
1278 let output = if cfg!(windows) {
1279 std::process::Command::new("cmd")
1280 .args(["/C", cmd])
1281 .output()
1282 .ok()?
1283 } else {
1284 std::process::Command::new("sh")
1285 .arg("-c")
1286 .arg(cmd)
1287 .output()
1288 .ok()?
1289 };
1290
1291 if !output.status.success() {
1292 return None;
1293 }
1294 let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
1295 if stdout.is_empty() {
1296 None
1297 } else {
1298 Some(stdout)
1299 }
1300}
1301
1302pub fn default_models_path(agent_dir: &Path) -> PathBuf {
1304 agent_dir.join("models.json")
1305}
1306
1307#[derive(Debug, Clone, Copy)]
1310struct AdHocProviderDefaults {
1311 api: &'static str,
1312 base_url: &'static str,
1313 auth_header: bool,
1314 reasoning: bool,
1315 input: &'static [InputType],
1316 context_window: u32,
1317 max_tokens: u32,
1318}
1319
1320impl From<ProviderRoutingDefaults> for AdHocProviderDefaults {
1321 fn from(value: ProviderRoutingDefaults) -> Self {
1322 Self {
1323 api: value.api,
1324 base_url: value.base_url,
1325 auth_header: value.auth_header,
1326 reasoning: value.reasoning,
1327 input: value.input,
1328 context_window: value.context_window,
1329 max_tokens: value.max_tokens,
1330 }
1331 }
1332}
1333
1334fn ad_hoc_provider_defaults(provider: &str) -> Option<AdHocProviderDefaults> {
1335 provider_routing_defaults(provider).map(AdHocProviderDefaults::from)
1336}
1337
1338fn sap_chat_completions_endpoint(service_url: &str, model_id: &str) -> Option<String> {
1339 let base = service_url.trim().trim_end_matches('/');
1340 let deployment = model_id.trim();
1341 if base.is_empty() || deployment.is_empty() {
1342 return None;
1343 }
1344 Some(format!(
1345 "{base}/v2/inference/deployments/{deployment}/chat/completions"
1346 ))
1347}
1348
1349fn ad_hoc_model_entry_with_sap_resolver<F>(
1350 provider: &str,
1351 model_id: &str,
1352 mut resolve_sap: F,
1353) -> Option<ModelEntry>
1354where
1355 F: FnMut() -> Option<SapResolvedCredentials>,
1356{
1357 if canonical_provider_id(provider).is_some_and(|canonical| canonical == "sap-ai-core") {
1358 let sap_creds = resolve_sap()?;
1359 let base_url = sap_chat_completions_endpoint(&sap_creds.service_url, model_id)?;
1360 return Some(ModelEntry {
1361 model: Model {
1362 id: model_id.to_string(),
1363 name: model_id.to_string(),
1364 api: "openai-completions".to_string(),
1365 provider: provider.to_string(),
1366 base_url,
1367 reasoning: true,
1368 input: vec![InputType::Text],
1369 cost: ModelCost {
1370 input: 0.0,
1371 output: 0.0,
1372 cache_read: 0.0,
1373 cache_write: 0.0,
1374 },
1375 context_window: 128_000,
1376 max_tokens: 16_384,
1377 headers: HashMap::new(),
1378 },
1379 api_key: None,
1380 headers: HashMap::new(),
1381 auth_header: true,
1382 compat: None,
1383 oauth_config: None,
1384 });
1385 }
1386
1387 let defaults = ad_hoc_provider_defaults(provider)?;
1388 let normalized_model_id = canonicalize_model_id_for_provider(provider, model_id);
1389 if normalized_model_id.is_empty() {
1390 return None;
1391 }
1392 Some(ModelEntry {
1393 model: Model {
1394 id: normalized_model_id.clone(),
1395 name: normalized_model_id,
1396 api: defaults.api.to_string(),
1397 provider: provider.to_string(),
1398 base_url: defaults.base_url.to_string(),
1399 reasoning: defaults.reasoning,
1400 input: defaults.input.to_vec(),
1401 cost: ModelCost {
1402 input: 0.0,
1403 output: 0.0,
1404 cache_read: 0.0,
1405 cache_write: 0.0,
1406 },
1407 context_window: defaults.context_window,
1408 max_tokens: defaults.max_tokens,
1409 headers: HashMap::new(),
1410 },
1411 api_key: None,
1412 headers: HashMap::new(),
1413 auth_header: defaults.auth_header,
1414 compat: None,
1415 oauth_config: None,
1416 })
1417}
1418
1419pub(crate) fn ad_hoc_model_entry(provider: &str, model_id: &str) -> Option<ModelEntry> {
1420 ad_hoc_model_entry_with_sap_resolver(provider, model_id, || {
1421 let auth = AuthStorage::load(crate::config::Config::auth_path()).ok()?;
1422 resolve_sap_credentials(&auth)
1423 })
1424}
1425
1426#[cfg(test)]
1427mod tests {
1428 use super::*;
1429 use crate::auth::{AuthCredential, AuthStorage};
1430 use tempfile::tempdir;
1431
1432 fn test_auth_storage() -> (tempfile::TempDir, AuthStorage) {
1433 let dir = tempdir().expect("tempdir");
1434 let auth_path = dir.path().join("auth.json");
1435 let mut auth = AuthStorage::load(auth_path).expect("load auth");
1436 auth.set(
1437 "anthropic",
1438 AuthCredential::ApiKey {
1439 key: "anthropic-auth-key".to_string(),
1440 },
1441 );
1442 auth.set(
1443 "openai",
1444 AuthCredential::ApiKey {
1445 key: "openai-auth-key".to_string(),
1446 },
1447 );
1448 auth.set(
1449 "google",
1450 AuthCredential::ApiKey {
1451 key: "google-auth-key".to_string(),
1452 },
1453 );
1454 auth.set(
1455 "openrouter",
1456 AuthCredential::ApiKey {
1457 key: "openrouter-auth-key".to_string(),
1458 },
1459 );
1460 auth.set(
1461 "acme",
1462 AuthCredential::ApiKey {
1463 key: "acme-auth-key".to_string(),
1464 },
1465 );
1466 (dir, auth)
1467 }
1468
1469 fn expected_env_pair() -> (String, String) {
1470 let key = ["PATH", "HOME", "PWD"]
1471 .iter()
1472 .find_map(|k| {
1473 std::env::var(k)
1474 .ok()
1475 .filter(|v| !v.is_empty())
1476 .map(|v| ((*k).to_string(), v))
1477 })
1478 .expect("expected at least one non-empty environment variable");
1479 (key.0, key.1)
1480 }
1481
1482 #[test]
1483 fn parse_legacy_generated_models_extracts_known_legacy_only_providers() {
1484 let parsed = parse_legacy_generated_models();
1485 assert!(
1486 !parsed.is_empty(),
1487 "legacy generated model catalog should parse into entries"
1488 );
1489
1490 assert!(
1491 parsed
1492 .iter()
1493 .any(|m| m.provider == "azure-openai-responses")
1494 );
1495 assert!(parsed.iter().any(|m| m.provider == "vercel-ai-gateway"));
1496 assert!(parsed.iter().any(|m| m.provider == "kimi-coding"));
1497 }
1498
1499 #[test]
1500 fn built_in_models_include_all_legacy_provider_model_pairs() {
1501 let (_dir, auth) = test_auth_storage();
1502 let built = built_in_models(&auth, ModelRegistryLoadMode::Full);
1503
1504 let built_keys: HashSet<(String, String)> = built
1505 .iter()
1506 .map(|entry| {
1507 (
1508 entry.model.provider.to_ascii_lowercase(),
1509 entry.model.id.to_ascii_lowercase(),
1510 )
1511 })
1512 .collect();
1513
1514 let mut missing = Vec::new();
1515 for legacy in legacy_generated_models() {
1516 let normalized_id = canonicalize_model_id_for_provider(&legacy.provider, &legacy.id);
1517 if normalized_id.is_empty() {
1518 continue;
1519 }
1520 let key = (
1521 legacy.provider.to_ascii_lowercase(),
1522 normalized_id.to_ascii_lowercase(),
1523 );
1524 if !built_keys.contains(&key) {
1525 missing.push(format!("{}/{}", legacy.provider, legacy.id));
1526 }
1527 }
1528
1529 assert!(
1530 missing.is_empty(),
1531 "missing legacy provider/model entries in built-in registry: {}",
1532 missing.join(", ")
1533 );
1534 }
1535
1536 #[test]
1537 fn built_in_models_preserve_legacy_model_display_names() {
1538 let (_dir, auth) = test_auth_storage();
1539 let built = built_in_models(&auth, ModelRegistryLoadMode::Full);
1540
1541 let name_by_key: HashMap<(String, String), String> = built
1542 .iter()
1543 .map(|entry| {
1544 (
1545 (
1546 entry.model.provider.to_ascii_lowercase(),
1547 entry.model.id.to_ascii_lowercase(),
1548 ),
1549 entry.model.name.clone(),
1550 )
1551 })
1552 .collect();
1553
1554 let mut mismatches = Vec::new();
1555 for legacy in legacy_generated_models() {
1556 let normalized_id = canonicalize_model_id_for_provider(&legacy.provider, &legacy.id);
1557 if normalized_id.is_empty() {
1558 continue;
1559 }
1560 let key = (
1561 legacy.provider.to_ascii_lowercase(),
1562 normalized_id.to_ascii_lowercase(),
1563 );
1564 let Some(built_name) = name_by_key.get(&key) else {
1565 continue;
1566 };
1567 if !legacy.name.trim().is_empty() && built_name != &legacy.name {
1568 mismatches.push(format!(
1569 "{}/{} => expected {:?}, got {:?}",
1570 legacy.provider, legacy.id, legacy.name, built_name
1571 ));
1572 }
1573 }
1574
1575 assert!(
1576 mismatches.is_empty(),
1577 "legacy model display name mismatches: {}",
1578 mismatches.join("; ")
1579 );
1580 }
1581
1582 #[test]
1583 fn built_in_models_include_core_provider_entries() {
1584 let (_dir, auth) = test_auth_storage();
1585 let models = built_in_models(&auth, ModelRegistryLoadMode::Full);
1586
1587 assert!(
1588 models.iter().any(
1589 |m| m.model.provider == "anthropic" && m.model.id == "claude-sonnet-4-20250514"
1590 )
1591 );
1592 assert!(
1593 models
1594 .iter()
1595 .any(|m| m.model.provider == "openai" && m.model.id == "gpt-4o")
1596 );
1597 assert!(
1598 models
1599 .iter()
1600 .any(|m| m.model.provider == "google" && m.model.id == "gemini-2.5-pro")
1601 );
1602 assert!(
1603 models
1604 .iter()
1605 .any(|m| m.model.provider == "openrouter" && m.model.id == "openrouter/auto")
1606 );
1607
1608 let anthropic = models
1609 .iter()
1610 .find(|m| m.model.provider == "anthropic")
1611 .expect("anthropic model");
1612 let openai = models
1613 .iter()
1614 .find(|m| m.model.provider == "openai")
1615 .expect("openai model");
1616 let google = models
1617 .iter()
1618 .find(|m| m.model.provider == "google")
1619 .expect("google model");
1620 let openrouter = models
1621 .iter()
1622 .find(|m| m.model.provider == "openrouter")
1623 .expect("openrouter model");
1624 assert_eq!(anthropic.api_key.as_deref(), Some("anthropic-auth-key"));
1625 assert_eq!(openai.api_key.as_deref(), Some("openai-auth-key"));
1626 assert_eq!(google.api_key.as_deref(), Some("google-auth-key"));
1627 assert_eq!(openrouter.api_key.as_deref(), Some("openrouter-auth-key"));
1628 }
1629
1630 #[test]
1631 fn built_in_models_include_legacy_oauth_provider_entries() {
1632 let (_dir, auth) = test_auth_storage();
1633 let models = built_in_models(&auth, ModelRegistryLoadMode::Full);
1634
1635 assert!(models.iter().any(|m| {
1636 m.model.provider == "openai-codex"
1637 && m.model.api == "openai-codex-responses"
1638 && m.model.id == "gpt-5.2-codex"
1639 }));
1640 assert!(models.iter().any(|m| {
1641 m.model.provider == "google-gemini-cli"
1642 && m.model.api == "google-gemini-cli"
1643 && m.model.id == "gemini-2.5-pro"
1644 }));
1645 assert!(models.iter().any(|m| {
1646 m.model.provider == "google-antigravity"
1647 && m.model.api == "google-gemini-cli"
1648 && m.model.id == "gemini-3-flash"
1649 }));
1650 }
1651
1652 #[test]
1653 fn built_in_models_include_non_legacy_provider_model_strings_from_snapshot() {
1654 let (_dir, auth) = test_auth_storage();
1655 let models = built_in_models(&auth, ModelRegistryLoadMode::Full);
1656
1657 assert!(
1658 models
1659 .iter()
1660 .any(|m| { m.model.provider == "groq" && m.model.id == "llama-3.3-70b-versatile" })
1661 );
1662 assert!(
1663 models
1664 .iter()
1665 .any(|m| { m.model.provider == "zhipuai" && m.model.id == "glm-4.6" })
1666 );
1667 assert!(models.iter().any(|m| {
1668 m.model.provider == "openrouter" && m.model.id == "anthropic/claude-sonnet-4"
1669 }));
1670 }
1671
1672 #[test]
1673 fn autocomplete_candidates_include_legacy_and_latest_entries() {
1674 let candidates = model_autocomplete_candidates();
1675 assert!(
1676 candidates
1677 .iter()
1678 .any(|candidate| candidate.slug == "openai-codex/gpt-5.2-codex")
1679 );
1680 assert!(
1681 candidates
1682 .iter()
1683 .any(|candidate| candidate.slug == "google-gemini-cli/gemini-2.5-pro")
1684 );
1685 assert!(
1686 candidates
1687 .iter()
1688 .any(|candidate| candidate.slug == "anthropic/claude-opus-4-5")
1689 );
1690 assert!(
1691 candidates
1692 .iter()
1693 .any(|candidate| candidate.slug == "groq/llama-3.3-70b-versatile")
1694 );
1695 assert!(
1696 candidates
1697 .iter()
1698 .any(|candidate| candidate.slug == "openrouter/anthropic/claude-sonnet-4.6")
1699 );
1700 }
1701
1702 #[test]
1703 fn autocomplete_candidates_are_case_insensitively_unique() {
1704 let candidates = model_autocomplete_candidates();
1705 let mut seen = HashSet::new();
1706 for candidate in candidates {
1707 let key = candidate.slug.to_ascii_lowercase();
1708 assert!(
1709 seen.insert(key),
1710 "duplicate autocomplete slug (case-insensitive): {}",
1711 candidate.slug
1712 );
1713 }
1714 }
1715
1716 #[test]
1717 fn apply_custom_models_overrides_provider_fields() {
1718 let (_dir, auth) = test_auth_storage();
1719 let mut models = built_in_models(&auth, ModelRegistryLoadMode::Full);
1720 let (env_key, env_val) = expected_env_pair();
1721 let mut provider_headers = HashMap::new();
1722 provider_headers.insert("x-provider".to_string(), "provider-header".to_string());
1723
1724 let config = ModelsConfig {
1725 providers: HashMap::from([(
1726 "anthropic".to_string(),
1727 ProviderConfig {
1728 base_url: Some("https://proxy.example/v1/messages".to_string()),
1729 api: Some("anthropic-messages".to_string()),
1730 api_key: Some(format!("env:{env_key}")),
1731 headers: Some(provider_headers),
1732 auth_header: Some(true),
1733 compat: Some(CompatConfig {
1734 supports_store: Some(true),
1735 ..CompatConfig::default()
1736 }),
1737 models: None,
1738 },
1739 )]),
1740 };
1741
1742 apply_custom_models(&auth, &mut models, &config);
1743
1744 for entry in models.iter().filter(|m| m.model.provider == "anthropic") {
1745 assert_eq!(entry.model.base_url, "https://proxy.example/v1/messages");
1746 assert_eq!(entry.model.api, "anthropic-messages");
1747 assert_eq!(entry.api_key.as_deref(), Some(env_val.as_str()));
1748 assert_eq!(
1749 entry.headers.get("x-provider").map(String::as_str),
1750 Some("provider-header")
1751 );
1752 assert!(entry.auth_header);
1753 assert!(
1754 entry
1755 .compat
1756 .as_ref()
1757 .and_then(|c| c.supports_store)
1758 .unwrap_or(false)
1759 );
1760 }
1761 }
1762
1763 #[test]
1764 fn apply_custom_models_uses_schema_defaults_for_provider_models() {
1765 let (_dir, auth) = test_auth_storage();
1766 let mut models = Vec::new();
1767 let config = ModelsConfig {
1768 providers: HashMap::from([(
1769 "cohere".to_string(),
1770 ProviderConfig {
1771 models: Some(vec![ModelConfig {
1772 id: "command-r-plus".to_string(),
1773 ..ModelConfig::default()
1774 }]),
1775 ..ProviderConfig::default()
1776 },
1777 )]),
1778 };
1779
1780 apply_custom_models(&auth, &mut models, &config);
1781
1782 let cohere = models
1783 .iter()
1784 .find(|entry| entry.model.provider == "cohere")
1785 .expect("cohere model should be added");
1786 assert_eq!(cohere.model.api, "cohere-chat");
1787 assert_eq!(cohere.model.base_url, "https://api.cohere.com/v2");
1788 assert!(cohere.model.reasoning);
1789 assert_eq!(cohere.model.input, vec![InputType::Text]);
1790 assert_eq!(cohere.model.context_window, 128_000);
1791 assert_eq!(cohere.model.max_tokens, 8192);
1792 assert!(!cohere.auth_header);
1793 }
1794
1795 #[test]
1796 fn apply_custom_models_merges_provider_and_model_compat() {
1797 let (_dir, auth) = test_auth_storage();
1798 let mut models = Vec::new();
1799 let config = ModelsConfig {
1800 providers: HashMap::from([(
1801 "custom-openai".to_string(),
1802 ProviderConfig {
1803 api: Some("openai-completions".to_string()),
1804 base_url: Some("https://compat.example/v1".to_string()),
1805 compat: Some(CompatConfig {
1806 supports_tools: Some(false),
1807 supports_usage_in_streaming: Some(false),
1808 max_tokens_field: Some("max_completion_tokens".to_string()),
1809 custom_headers: Some(HashMap::from([
1810 ("x-provider-only".to_string(), "provider".to_string()),
1811 ("x-shared".to_string(), "provider".to_string()),
1812 ])),
1813 ..CompatConfig::default()
1814 }),
1815 models: Some(vec![ModelConfig {
1816 id: "custom-model".to_string(),
1817 compat: Some(CompatConfig {
1818 supports_tools: Some(true),
1819 system_role_name: Some("developer".to_string()),
1820 custom_headers: Some(HashMap::from([
1821 ("x-model-only".to_string(), "model".to_string()),
1822 ("x-shared".to_string(), "model".to_string()),
1823 ])),
1824 ..CompatConfig::default()
1825 }),
1826 ..ModelConfig::default()
1827 }]),
1828 ..ProviderConfig::default()
1829 },
1830 )]),
1831 };
1832
1833 apply_custom_models(&auth, &mut models, &config);
1834
1835 let entry = models
1836 .iter()
1837 .find(|m| m.model.provider == "custom-openai" && m.model.id == "custom-model")
1838 .expect("custom model should be added");
1839 let compat = entry.compat.as_ref().expect("compat should be merged");
1840 assert_eq!(
1841 compat.max_tokens_field.as_deref(),
1842 Some("max_completion_tokens")
1843 );
1844 assert_eq!(compat.system_role_name.as_deref(), Some("developer"));
1845 assert_eq!(compat.supports_usage_in_streaming, Some(false));
1846 assert_eq!(compat.supports_tools, Some(true));
1847 let custom_headers = compat
1848 .custom_headers
1849 .as_ref()
1850 .expect("custom headers should be merged");
1851 assert_eq!(
1852 custom_headers.get("x-provider-only").map(String::as_str),
1853 Some("provider")
1854 );
1855 assert_eq!(
1856 custom_headers.get("x-model-only").map(String::as_str),
1857 Some("model")
1858 );
1859 assert_eq!(
1860 custom_headers.get("x-shared").map(String::as_str),
1861 Some("model")
1862 );
1863 }
1864
1865 #[test]
1866 fn apply_custom_models_uses_schema_defaults_for_native_anthropic_models() {
1867 let (_dir, auth) = test_auth_storage();
1868 let mut models = Vec::new();
1869 let config = ModelsConfig {
1870 providers: HashMap::from([(
1871 "anthropic".to_string(),
1872 ProviderConfig {
1873 models: Some(vec![ModelConfig {
1874 id: "claude-schema-default".to_string(),
1875 ..ModelConfig::default()
1876 }]),
1877 ..ProviderConfig::default()
1878 },
1879 )]),
1880 };
1881
1882 apply_custom_models(&auth, &mut models, &config);
1883
1884 let anthropic = models
1885 .iter()
1886 .find(|entry| entry.model.provider == "anthropic")
1887 .expect("anthropic model should be added");
1888 assert_eq!(anthropic.model.api, "anthropic-messages");
1889 assert_eq!(
1890 anthropic.model.base_url,
1891 "https://api.anthropic.com/v1/messages"
1892 );
1893 assert!(anthropic.model.reasoning);
1894 assert_eq!(
1895 anthropic.model.input,
1896 vec![InputType::Text, InputType::Image]
1897 );
1898 assert_eq!(anthropic.model.context_window, 200_000);
1899 assert_eq!(anthropic.model.max_tokens, 8192);
1900 assert!(!anthropic.auth_header);
1901 }
1902
1903 #[test]
1904 fn apply_custom_models_alias_resolves_canonical_provider_api_key() {
1905 let (_dir, mut auth) = test_auth_storage();
1906 auth.set(
1907 "moonshotai",
1908 AuthCredential::ApiKey {
1909 key: "moonshot-auth-key".to_string(),
1910 },
1911 );
1912
1913 let mut models = Vec::new();
1914 let config = ModelsConfig {
1915 providers: HashMap::from([(
1916 "kimi".to_string(),
1917 ProviderConfig {
1918 models: Some(vec![ModelConfig {
1919 id: "kimi-k2-instruct".to_string(),
1920 ..ModelConfig::default()
1921 }]),
1922 ..ProviderConfig::default()
1923 },
1924 )]),
1925 };
1926
1927 apply_custom_models(&auth, &mut models, &config);
1928
1929 let kimi = models
1930 .iter()
1931 .find(|entry| entry.model.provider == "kimi")
1932 .expect("kimi model should be added");
1933 assert_eq!(kimi.model.api, "openai-completions");
1934 assert_eq!(kimi.model.base_url, "https://api.moonshot.ai/v1");
1935 assert_eq!(kimi.api_key.as_deref(), Some("moonshot-auth-key"));
1936 assert!(kimi.auth_header);
1937 }
1938
1939 #[test]
1940 fn model_registry_find_and_find_by_id_work() {
1941 let (_dir, auth) = test_auth_storage();
1942 let registry = ModelRegistry::load(&auth, None);
1943
1944 let by_provider_and_id = registry
1945 .find("openai", "gpt-4o")
1946 .expect("openai/gpt-4o should exist");
1947 assert_eq!(by_provider_and_id.model.provider, "openai");
1948 assert_eq!(by_provider_and_id.model.id, "gpt-4o");
1949
1950 let by_id = registry
1951 .find_by_id("claude-opus-4-5")
1952 .expect("claude-opus-4-5 should exist");
1953 assert_eq!(by_id.model.provider, "anthropic");
1954 assert_eq!(by_id.model.id, "claude-opus-4-5");
1955
1956 assert!(registry.find("openai", "does-not-exist").is_none());
1957 assert!(registry.find_by_id("does-not-exist").is_none());
1958 }
1959
1960 #[test]
1961 fn model_registry_find_by_id_is_case_insensitive() {
1962 let (_dir, auth) = test_auth_storage();
1963 let registry = ModelRegistry::load(&auth, None);
1964
1965 let by_id = registry
1966 .find_by_id("GPT-5.2-CODEX")
1967 .expect("gpt-5.2-codex should resolve case-insensitively");
1968 assert_eq!(by_id.model.id, "gpt-5.2-codex");
1969 }
1970
1971 #[test]
1972 fn model_registry_find_normalizes_openrouter_model_aliases() {
1973 let (_dir, auth) = test_auth_storage();
1974 let registry = ModelRegistry::load(&auth, None);
1975
1976 let gpt4o_mini = registry
1977 .find("openrouter", "gpt-4o-mini")
1978 .expect("openrouter alias should resolve");
1979 assert_eq!(gpt4o_mini.model.provider, "openrouter");
1980 assert_eq!(gpt4o_mini.model.id, "openai/gpt-4o-mini");
1981
1982 let auto = registry
1983 .find("openrouter", "auto")
1984 .expect("openrouter auto alias should resolve");
1985 assert_eq!(auto.model.id, "openrouter/auto");
1986
1987 let provider_alias = registry
1988 .find("open-router", "gpt-4o-mini")
1989 .expect("open-router provider alias should resolve");
1990 assert_eq!(provider_alias.model.provider, "openrouter");
1991 assert_eq!(provider_alias.model.id, "openai/gpt-4o-mini");
1992 }
1993
1994 #[test]
1995 fn ad_hoc_model_entry_normalizes_openrouter_aliases() {
1996 let auto = ad_hoc_model_entry("openrouter", "auto").expect("openrouter auto ad-hoc");
1997 assert_eq!(auto.model.id, "openrouter/auto");
1998
1999 let gpt4o_mini =
2000 ad_hoc_model_entry("openrouter", "gpt-4o-mini").expect("openrouter gpt-4o-mini ad-hoc");
2001 assert_eq!(gpt4o_mini.model.id, "openai/gpt-4o-mini");
2002 }
2003
2004 #[test]
2005 fn model_registry_merge_entries_deduplicates() {
2006 let (_dir, auth) = test_auth_storage();
2007 let mut registry = ModelRegistry::load(&auth, None);
2008 let before = registry.models().len();
2009 let duplicate = registry
2010 .find("openai", "gpt-4o")
2011 .expect("expected built-in openai model");
2012
2013 let new_entry = ModelEntry {
2014 model: Model {
2015 id: "acme-chat".to_string(),
2016 name: "Acme Chat".to_string(),
2017 api: "openai-completions".to_string(),
2018 provider: "acme".to_string(),
2019 base_url: "https://acme.example/v1".to_string(),
2020 reasoning: true,
2021 input: vec![InputType::Text],
2022 cost: ModelCost {
2023 input: 0.0,
2024 output: 0.0,
2025 cache_read: 0.0,
2026 cache_write: 0.0,
2027 },
2028 context_window: 64_000,
2029 max_tokens: 4096,
2030 headers: HashMap::new(),
2031 },
2032 api_key: Some("acme-auth-key".to_string()),
2033 headers: HashMap::new(),
2034 auth_header: true,
2035 compat: None,
2036 oauth_config: None,
2037 };
2038
2039 registry.merge_entries(vec![duplicate, new_entry]);
2040 assert_eq!(registry.models().len(), before + 1);
2041 assert!(registry.find("acme", "acme-chat").is_some());
2042 }
2043
2044 #[test]
2045 fn model_registry_merge_entries_deduplicates_alias_and_case_variants() {
2046 let (_dir, auth) = test_auth_storage();
2047 let mut registry = ModelRegistry::load(&auth, None);
2048 let before = registry.models().len();
2049
2050 let source = registry
2051 .find("openrouter", "gpt-4o-mini")
2052 .or_else(|| registry.find("openrouter", "openai/gpt-4o-mini"))
2053 .expect("expected built-in openrouter gpt-4o-mini model");
2054
2055 let mut alias_case_variant = source.clone();
2056 alias_case_variant.model.provider = "open-router".to_string();
2057 alias_case_variant.model.id = source.model.id.to_ascii_uppercase();
2058
2059 registry.merge_entries(vec![alias_case_variant]);
2060 assert_eq!(registry.models().len(), before);
2061 }
2062
2063 #[test]
2064 fn apply_custom_models_dedupes_openrouter_alias_conflicts() {
2065 let (_dir, auth) = test_auth_storage();
2066 let mut models = Vec::new();
2067 let config = ModelsConfig {
2068 providers: HashMap::from([(
2069 "openrouter".to_string(),
2070 ProviderConfig {
2071 models: Some(vec![
2072 ModelConfig {
2073 id: "gpt-4o-mini".to_string(),
2074 ..ModelConfig::default()
2075 },
2076 ModelConfig {
2077 id: "openai/gpt-4o-mini".to_string(),
2078 ..ModelConfig::default()
2079 },
2080 ModelConfig {
2081 id: "auto".to_string(),
2082 ..ModelConfig::default()
2083 },
2084 ]),
2085 ..ProviderConfig::default()
2086 },
2087 )]),
2088 };
2089
2090 apply_custom_models(&auth, &mut models, &config);
2091
2092 let openrouter_models: Vec<&ModelEntry> = models
2093 .iter()
2094 .filter(|entry| entry.model.provider == "openrouter")
2095 .collect();
2096 assert_eq!(openrouter_models.len(), 2);
2097 assert!(
2098 openrouter_models
2099 .iter()
2100 .any(|entry| entry.model.id == "openai/gpt-4o-mini")
2101 );
2102 assert!(
2103 openrouter_models
2104 .iter()
2105 .any(|entry| entry.model.id == "openrouter/auto")
2106 );
2107 }
2108
2109 #[test]
2110 fn resolve_value_supports_env_and_file_prefixes() {
2111 let (env_key, env_val) = expected_env_pair();
2112 assert_eq!(
2113 resolve_value(&format!("env:{env_key}")).as_deref(),
2114 Some(env_val.as_str())
2115 );
2116
2117 let dir = tempdir().expect("tempdir");
2118 let key_path = dir.path().join("api_key.txt");
2119 std::fs::write(&key_path, "file-key\n").expect("write key file");
2120 assert_eq!(
2121 resolve_value(&format!("file:{}", key_path.display())).as_deref(),
2122 Some("file-key")
2123 );
2124 assert!(resolve_value("file:/definitely/missing/path").is_none());
2125 }
2126
2127 #[test]
2128 fn model_registry_load_reads_models_json_and_applies_config() {
2129 let (dir, auth) = test_auth_storage();
2130 let models_path = dir.path().join("models.json");
2131 let key_path = dir.path().join("custom_key.txt");
2132 std::fs::write(&key_path, "acme-file-key\n").expect("write custom key");
2133
2134 let models_json = serde_json::json!({
2135 "providers": {
2136 "acme": {
2137 "baseUrl": "https://acme.example/v1",
2138 "api": "openai-completions",
2139 "apiKey": format!("file:{}", key_path.display()),
2140 "headers": {
2141 "x-provider": "provider-level"
2142 },
2143 "authHeader": true,
2144 "models": [
2145 {
2146 "id": "acme-chat",
2147 "name": "Acme Chat",
2148 "input": ["text", "image"],
2149 "reasoning": true,
2150 "contextWindow": 64000,
2151 "maxTokens": 4096,
2152 "headers": {
2153 "x-model": "model-level"
2154 }
2155 }
2156 ]
2157 }
2158 }
2159 });
2160
2161 std::fs::write(
2162 &models_path,
2163 serde_json::to_string_pretty(&models_json).expect("serialize models json"),
2164 )
2165 .expect("write models.json");
2166
2167 let registry = ModelRegistry::load(&auth, Some(models_path));
2168 let acme = registry
2169 .find("acme", "acme-chat")
2170 .expect("custom acme model should load from models.json");
2171
2172 assert_eq!(acme.model.name, "Acme Chat");
2173 assert_eq!(acme.model.api, "openai-completions");
2174 assert_eq!(acme.model.base_url, "https://acme.example/v1");
2175 assert_eq!(acme.model.context_window, 64_000);
2176 assert_eq!(acme.model.max_tokens, 4096);
2177 assert_eq!(acme.api_key.as_deref(), Some("acme-file-key"));
2178 assert!(acme.auth_header);
2179 assert_eq!(
2180 acme.headers.get("x-provider").map(String::as_str),
2181 Some("provider-level")
2182 );
2183 assert_eq!(
2184 acme.headers.get("x-model").map(String::as_str),
2185 Some("model-level")
2186 );
2187 assert_eq!(acme.model.input, vec![InputType::Text, InputType::Image]);
2188 }
2189
2190 fn make_model_entry(id: &str, reasoning: bool) -> ModelEntry {
2193 ModelEntry {
2194 model: Model {
2195 id: id.to_string(),
2196 name: id.to_string(),
2197 api: "openai-responses".to_string(),
2198 provider: "test".to_string(),
2199 base_url: "https://example.com".to_string(),
2200 reasoning,
2201 input: vec![InputType::Text],
2202 cost: ModelCost {
2203 input: 0.0,
2204 output: 0.0,
2205 cache_read: 0.0,
2206 cache_write: 0.0,
2207 },
2208 context_window: 128_000,
2209 max_tokens: 8192,
2210 headers: HashMap::new(),
2211 },
2212 api_key: None,
2213 headers: HashMap::new(),
2214 auth_header: false,
2215 compat: None,
2216 oauth_config: None,
2217 }
2218 }
2219
2220 #[test]
2221 fn supports_xhigh_for_known_models() {
2222 assert!(make_model_entry("gpt-5.1-codex-max", true).supports_xhigh());
2223 assert!(make_model_entry("gpt-5.2", true).supports_xhigh());
2224 assert!(make_model_entry("gpt-5.2-codex", true).supports_xhigh());
2225 assert!(make_model_entry("gpt-5.3-codex", true).supports_xhigh());
2226 assert!(make_model_entry("gpt-5.3-codex-spark", true).supports_xhigh());
2227 }
2228
2229 #[test]
2230 fn supports_xhigh_false_for_other_models() {
2231 assert!(!make_model_entry("gpt-4o", true).supports_xhigh());
2232 assert!(!make_model_entry("claude-sonnet-4-20250514", true).supports_xhigh());
2233 assert!(!make_model_entry("gemini-2.5-pro", true).supports_xhigh());
2234 }
2235
2236 #[test]
2239 fn clamp_non_reasoning_always_off() {
2240 use crate::model::ThinkingLevel;
2241 let entry = make_model_entry("gpt-4o-mini", false);
2242 assert_eq!(
2243 entry.clamp_thinking_level(ThinkingLevel::High),
2244 ThinkingLevel::Off
2245 );
2246 assert_eq!(
2247 entry.clamp_thinking_level(ThinkingLevel::Medium),
2248 ThinkingLevel::Off
2249 );
2250 assert_eq!(
2251 entry.clamp_thinking_level(ThinkingLevel::Off),
2252 ThinkingLevel::Off
2253 );
2254 }
2255
2256 #[test]
2257 fn clamp_xhigh_downgraded_without_support() {
2258 use crate::model::ThinkingLevel;
2259 let entry = make_model_entry("claude-sonnet-4-20250514", true);
2260 assert_eq!(
2261 entry.clamp_thinking_level(ThinkingLevel::XHigh),
2262 ThinkingLevel::High,
2263 );
2264 }
2265
2266 #[test]
2267 fn clamp_xhigh_preserved_with_support() {
2268 use crate::model::ThinkingLevel;
2269 let entry = make_model_entry("gpt-5.2", true);
2270 assert_eq!(
2271 entry.clamp_thinking_level(ThinkingLevel::XHigh),
2272 ThinkingLevel::XHigh,
2273 );
2274 }
2275
2276 #[test]
2277 fn clamp_passthrough_for_regular_levels() {
2278 use crate::model::ThinkingLevel;
2279 let entry = make_model_entry("claude-sonnet-4-20250514", true);
2280 assert_eq!(
2281 entry.clamp_thinking_level(ThinkingLevel::High),
2282 ThinkingLevel::High
2283 );
2284 assert_eq!(
2285 entry.clamp_thinking_level(ThinkingLevel::Medium),
2286 ThinkingLevel::Medium
2287 );
2288 assert_eq!(
2289 entry.clamp_thinking_level(ThinkingLevel::Low),
2290 ThinkingLevel::Low
2291 );
2292 assert_eq!(
2293 entry.clamp_thinking_level(ThinkingLevel::Minimal),
2294 ThinkingLevel::Minimal
2295 );
2296 assert_eq!(
2297 entry.clamp_thinking_level(ThinkingLevel::Off),
2298 ThinkingLevel::Off
2299 );
2300 }
2301
2302 #[test]
2305 fn ad_hoc_known_providers() {
2306 let providers = [
2307 "anthropic",
2308 "openai",
2309 "google",
2310 "cohere",
2311 "amazon-bedrock",
2312 "groq",
2313 "deepinfra",
2314 "cerebras",
2315 "openrouter",
2316 "mistral",
2317 "deepseek",
2318 "fireworks",
2319 "togetherai",
2320 "perplexity",
2321 "xai",
2322 "baseten",
2323 "llama",
2324 "lmstudio",
2325 "ollama-cloud",
2326 ];
2327 for provider in providers {
2328 assert!(
2329 ad_hoc_provider_defaults(provider).is_some(),
2330 "expected defaults for '{provider}'"
2331 );
2332 }
2333 }
2334
2335 #[test]
2336 fn ad_hoc_alibaba_aliases() {
2337 for alias in ["alibaba", "dashscope", "qwen"] {
2338 let defaults = ad_hoc_provider_defaults(alias)
2339 .unwrap_or_else(|| panic!("expected defaults for '{alias}'"));
2340 assert!(defaults.base_url.contains("dashscope"));
2341 }
2342 }
2343
2344 #[test]
2345 fn ad_hoc_moonshot_aliases() {
2346 for alias in ["moonshotai", "moonshot", "kimi"] {
2347 let defaults = ad_hoc_provider_defaults(alias)
2348 .unwrap_or_else(|| panic!("expected defaults for '{alias}'"));
2349 assert!(defaults.base_url.contains("moonshot"));
2350 }
2351 }
2352
2353 #[test]
2354 fn ad_hoc_batch_b1_defaults_resolve_expected_routes() {
2355 let alibaba_cn =
2356 ad_hoc_provider_defaults("alibaba-cn").expect("expected defaults for alibaba-cn");
2357 assert_eq!(alibaba_cn.api, "openai-completions");
2358 assert!(alibaba_cn.auth_header);
2359 assert!(alibaba_cn.base_url.contains("dashscope.aliyuncs.com"));
2360
2361 let kimi_for_coding = ad_hoc_provider_defaults("kimi-for-coding")
2362 .expect("expected defaults for kimi-for-coding");
2363 assert_eq!(kimi_for_coding.api, "anthropic-messages");
2364 assert!(!kimi_for_coding.auth_header);
2365 assert!(kimi_for_coding.base_url.contains("api.kimi.com/coding"));
2366
2367 for provider in [
2368 "minimax",
2369 "minimax-cn",
2370 "minimax-coding-plan",
2371 "minimax-cn-coding-plan",
2372 ] {
2373 let defaults =
2374 ad_hoc_provider_defaults(provider).unwrap_or_else(|| panic!("defaults {provider}"));
2375 assert_eq!(defaults.api, "anthropic-messages");
2376 assert!(!defaults.auth_header);
2377 assert!(defaults.base_url.contains("api.minimax"));
2378 }
2379 }
2380
2381 #[test]
2382 fn ad_hoc_batch_b2_defaults_resolve_expected_routes() {
2383 let cases = [
2384 ("modelscope", "https://api-inference.modelscope.cn/v1"),
2385 ("moonshotai-cn", "https://api.moonshot.cn/v1"),
2386 ("nebius", "https://api.tokenfactory.nebius.com/v1"),
2387 (
2388 "ovhcloud",
2389 "https://oai.endpoints.kepler.ai.cloud.ovh.net/v1",
2390 ),
2391 ("scaleway", "https://api.scaleway.ai/v1"),
2392 ];
2393 for (provider, expected_base_url) in &cases {
2394 let defaults =
2395 ad_hoc_provider_defaults(provider).unwrap_or_else(|| panic!("defaults {provider}"));
2396 assert_eq!(defaults.api, "openai-completions");
2397 assert!(defaults.auth_header);
2398 assert_eq!(defaults.base_url, *expected_base_url);
2399 }
2400 }
2401
2402 #[test]
2403 fn ad_hoc_batch_b3_defaults_resolve_expected_routes() {
2404 let cases = [
2405 ("siliconflow", "https://api.siliconflow.com/v1"),
2406 ("siliconflow-cn", "https://api.siliconflow.cn/v1"),
2407 ("upstage", "https://api.upstage.ai/v1/solar"),
2408 ("venice", "https://api.venice.ai/api/v1"),
2409 ("zai", "https://api.z.ai/api/paas/v4"),
2410 ("zai-coding-plan", "https://api.z.ai/api/coding/paas/v4"),
2411 ("zhipuai", "https://open.bigmodel.cn/api/paas/v4"),
2412 (
2413 "zhipuai-coding-plan",
2414 "https://open.bigmodel.cn/api/coding/paas/v4",
2415 ),
2416 ];
2417 for (provider, expected_base_url) in &cases {
2418 let defaults =
2419 ad_hoc_provider_defaults(provider).unwrap_or_else(|| panic!("defaults {provider}"));
2420 assert_eq!(defaults.api, "openai-completions");
2421 assert!(defaults.auth_header);
2422 assert_eq!(defaults.base_url, *expected_base_url);
2423 }
2424 }
2425
2426 #[test]
2427 fn ad_hoc_batch_b3_coding_plan_and_regional_variants_remain_distinct() {
2428 let siliconflow = ad_hoc_provider_defaults("siliconflow").expect("siliconflow defaults");
2429 let siliconflow_cn =
2430 ad_hoc_provider_defaults("siliconflow-cn").expect("siliconflow-cn defaults");
2431 assert_eq!(canonical_provider_id("siliconflow"), Some("siliconflow"));
2432 assert_eq!(
2433 canonical_provider_id("siliconflow-cn"),
2434 Some("siliconflow-cn")
2435 );
2436 assert_ne!(siliconflow.base_url, siliconflow_cn.base_url);
2437
2438 let zai = ad_hoc_provider_defaults("zai").expect("zai defaults");
2439 let zai_coding = ad_hoc_provider_defaults("zai-coding-plan").expect("zai-coding defaults");
2440 assert_eq!(canonical_provider_id("zai"), Some("zai"));
2441 assert_eq!(
2442 canonical_provider_id("zai-coding-plan"),
2443 Some("zai-coding-plan")
2444 );
2445 assert_eq!(zai.api, "openai-completions");
2446 assert_eq!(zai_coding.api, "openai-completions");
2447 assert_ne!(zai.base_url, zai_coding.base_url);
2448
2449 let zhipu = ad_hoc_provider_defaults("zhipuai").expect("zhipu defaults");
2450 let zhipu_coding =
2451 ad_hoc_provider_defaults("zhipuai-coding-plan").expect("zhipu-coding defaults");
2452 assert_eq!(canonical_provider_id("zhipuai"), Some("zhipuai"));
2453 assert_eq!(
2454 canonical_provider_id("zhipuai-coding-plan"),
2455 Some("zhipuai-coding-plan")
2456 );
2457 assert_eq!(zhipu.api, "openai-completions");
2458 assert_eq!(zhipu_coding.api, "openai-completions");
2459 assert_ne!(zhipu.base_url, zhipu_coding.base_url);
2460 }
2461
2462 #[test]
2463 fn ad_hoc_batch_c1_defaults_resolve_expected_routes() {
2464 let cases = [
2465 ("baseten", "https://inference.baseten.co/v1"),
2466 ("llama", "https://api.llama.com/compat/v1"),
2467 ("lmstudio", "http://127.0.0.1:1234/v1"),
2468 ("ollama-cloud", "https://ollama.com/v1"),
2469 ];
2470 for (provider, expected_base_url) in &cases {
2471 let defaults =
2472 ad_hoc_provider_defaults(provider).unwrap_or_else(|| panic!("defaults {provider}"));
2473 assert_eq!(defaults.api, "openai-completions");
2474 assert!(defaults.auth_header);
2475 assert_eq!(defaults.base_url, *expected_base_url);
2476 }
2477 }
2478
2479 #[test]
2480 fn ad_hoc_kimi_alias_and_kimi_for_coding_remain_distinct() {
2481 assert_eq!(canonical_provider_id("kimi"), Some("moonshotai"));
2482 assert_eq!(
2483 canonical_provider_id("kimi-for-coding"),
2484 Some("kimi-for-coding")
2485 );
2486
2487 let kimi_alias = ad_hoc_provider_defaults("kimi").expect("kimi alias defaults");
2488 let kimi_for_coding =
2489 ad_hoc_provider_defaults("kimi-for-coding").expect("kimi-for-coding defaults");
2490 assert!(kimi_alias.base_url.contains("moonshot.ai"));
2491 assert!(kimi_for_coding.base_url.contains("api.kimi.com"));
2492 assert_ne!(kimi_alias.base_url, kimi_for_coding.base_url);
2493 assert_ne!(kimi_alias.api, kimi_for_coding.api);
2494 }
2495
2496 #[test]
2497 fn ad_hoc_alibaba_cn_is_distinct_from_alibaba_family_aliases() {
2498 let alibaba = ad_hoc_provider_defaults("alibaba").expect("alibaba defaults");
2499 let alibaba_cn = ad_hoc_provider_defaults("alibaba-cn").expect("alibaba-cn defaults");
2500 assert_eq!(canonical_provider_id("dashscope"), Some("alibaba"));
2501 assert_eq!(canonical_provider_id("alibaba-cn"), Some("alibaba-cn"));
2502 assert_eq!(alibaba.api, "openai-completions");
2503 assert_eq!(alibaba_cn.api, "openai-completions");
2504 assert_ne!(alibaba.base_url, alibaba_cn.base_url);
2505 }
2506
2507 #[test]
2508 fn ad_hoc_moonshot_cn_is_distinct_from_global_moonshot_aliases() {
2509 let moonshot_global = ad_hoc_provider_defaults("moonshot").expect("moonshot defaults");
2510 let moonshot_cn =
2511 ad_hoc_provider_defaults("moonshotai-cn").expect("moonshotai-cn defaults");
2512 assert_eq!(canonical_provider_id("moonshot"), Some("moonshotai"));
2513 assert_eq!(
2514 canonical_provider_id("moonshotai-cn"),
2515 Some("moonshotai-cn")
2516 );
2517 assert_eq!(moonshot_global.api, "openai-completions");
2518 assert_eq!(moonshot_cn.api, "openai-completions");
2519 assert_ne!(moonshot_global.base_url, moonshot_cn.base_url);
2520 }
2521
2522 #[test]
2523 fn ad_hoc_unknown_returns_none() {
2524 assert!(ad_hoc_provider_defaults("unknown-provider").is_none());
2525 assert!(ad_hoc_provider_defaults("").is_none());
2526 }
2527
2528 #[test]
2529 fn ad_hoc_anthropic_uses_messages_api() {
2530 let defaults = ad_hoc_provider_defaults("anthropic").unwrap();
2531 assert_eq!(defaults.api, "anthropic-messages");
2532 assert_eq!(defaults.base_url, "https://api.anthropic.com/v1/messages");
2533 assert!(defaults.reasoning);
2534 }
2535
2536 #[test]
2537 fn ad_hoc_openai_uses_responses_api() {
2538 let defaults = ad_hoc_provider_defaults("openai").unwrap();
2539 assert_eq!(defaults.api, "openai-responses");
2540 }
2541
2542 #[test]
2543 fn ad_hoc_groq_uses_completions_api() {
2544 let defaults = ad_hoc_provider_defaults("groq").unwrap();
2545 assert_eq!(defaults.api, "openai-completions");
2546 assert!(defaults.base_url.contains("groq.com"));
2547 }
2548
2549 #[test]
2550 fn ad_hoc_bedrock_uses_converse_api() {
2551 let defaults = ad_hoc_provider_defaults("amazon-bedrock").unwrap();
2552 assert_eq!(defaults.api, "bedrock-converse-stream");
2553 assert_eq!(defaults.base_url, "");
2554 assert!(!defaults.auth_header);
2555 }
2556
2557 #[test]
2560 fn ad_hoc_model_entry_creates_valid_entry() {
2561 let entry = ad_hoc_model_entry("groq", "llama-3-70b").unwrap();
2562 assert_eq!(entry.model.id, "llama-3-70b");
2563 assert_eq!(entry.model.name, "llama-3-70b");
2564 assert_eq!(entry.model.provider, "groq");
2565 assert_eq!(entry.model.api, "openai-completions");
2566 assert!(entry.model.base_url.contains("groq.com"));
2567 assert!(entry.auth_header); assert!(entry.api_key.is_none()); }
2570
2571 #[test]
2572 fn ad_hoc_model_entry_anthropic_no_auth_header() {
2573 let entry = ad_hoc_model_entry("anthropic", "claude-custom").unwrap();
2574 assert!(!entry.auth_header); }
2576
2577 #[test]
2578 fn ad_hoc_model_entry_unknown_returns_none() {
2579 assert!(ad_hoc_model_entry("nonexistent", "model").is_none());
2580 }
2581
2582 #[test]
2583 fn sap_chat_completions_endpoint_formats_expected_path() {
2584 let endpoint =
2585 sap_chat_completions_endpoint("https://api.ai.sap.example.com/", "deployment-a")
2586 .expect("endpoint");
2587 assert_eq!(
2588 endpoint,
2589 "https://api.ai.sap.example.com/v2/inference/deployments/deployment-a/chat/completions"
2590 );
2591 }
2592
2593 #[test]
2594 fn ad_hoc_model_entry_supports_sap_with_resolved_service_key() {
2595 let entry = ad_hoc_model_entry_with_sap_resolver("sap-ai-core", "dep-123", || {
2596 Some(SapResolvedCredentials {
2597 client_id: "id".to_string(),
2598 client_secret: "secret".to_string(),
2599 token_url: "https://auth.sap.example.com/oauth/token".to_string(),
2600 service_url: "https://api.ai.sap.example.com".to_string(),
2601 })
2602 })
2603 .expect("sap ad-hoc entry");
2604
2605 assert_eq!(entry.model.provider, "sap-ai-core");
2606 assert_eq!(entry.model.api, "openai-completions");
2607 assert_eq!(
2608 entry.model.base_url,
2609 "https://api.ai.sap.example.com/v2/inference/deployments/dep-123/chat/completions"
2610 );
2611 assert!(entry.auth_header);
2612 }
2613
2614 #[test]
2615 fn ad_hoc_model_entry_supports_sap_alias() {
2616 let entry = ad_hoc_model_entry_with_sap_resolver("sap", "dep-123", || {
2617 Some(SapResolvedCredentials {
2618 client_id: "id".to_string(),
2619 client_secret: "secret".to_string(),
2620 token_url: "https://auth.sap.example.com/oauth/token".to_string(),
2621 service_url: "https://api.ai.sap.example.com".to_string(),
2622 })
2623 })
2624 .expect("sap alias ad-hoc entry");
2625
2626 assert_eq!(entry.model.provider, "sap");
2627 assert_eq!(entry.model.api, "openai-completions");
2628 assert!(entry.auth_header);
2629 }
2630
2631 #[test]
2632 fn ad_hoc_model_entry_sap_without_credentials_returns_none() {
2633 assert!(ad_hoc_model_entry_with_sap_resolver("sap-ai-core", "dep-123", || None).is_none());
2634 }
2635
2636 #[test]
2639 fn merge_headers_combines_both() {
2640 let base = HashMap::from([
2641 ("a".to_string(), "1".to_string()),
2642 ("b".to_string(), "2".to_string()),
2643 ]);
2644 let overrides = HashMap::from([
2645 ("b".to_string(), "override".to_string()),
2646 ("c".to_string(), "3".to_string()),
2647 ]);
2648 let merged = merge_headers(&base, overrides);
2649 assert_eq!(merged.get("a").unwrap(), "1");
2650 assert_eq!(merged.get("b").unwrap(), "override");
2651 assert_eq!(merged.get("c").unwrap(), "3");
2652 }
2653
2654 #[test]
2655 fn merge_headers_empty_base() {
2656 let merged = merge_headers(
2657 &HashMap::new(),
2658 HashMap::from([("x".to_string(), "y".to_string())]),
2659 );
2660 assert_eq!(merged.len(), 1);
2661 assert_eq!(merged.get("x").unwrap(), "y");
2662 }
2663
2664 #[test]
2665 fn merge_headers_empty_overrides() {
2666 let base = HashMap::from([("x".to_string(), "y".to_string())]);
2667 let merged = merge_headers(&base, HashMap::new());
2668 assert_eq!(merged, base);
2669 }
2670
2671 #[test]
2674 fn resolve_value_plain_literal() {
2675 assert_eq!(resolve_value("my-key").as_deref(), Some("my-key"));
2676 }
2677
2678 #[test]
2679 fn resolve_value_empty_returns_none() {
2680 assert!(resolve_value("").is_none());
2681 }
2682
2683 #[test]
2684 fn resolve_value_env_empty_var_name_returns_none() {
2685 assert!(resolve_value("env:").is_none());
2686 }
2687
2688 #[test]
2689 fn resolve_value_file_empty_path_returns_none() {
2690 assert!(resolve_value("file:").is_none());
2691 }
2692
2693 #[test]
2694 fn resolve_value_file_missing_returns_none() {
2695 assert!(resolve_value("file:/nonexistent/path/key.txt").is_none());
2696 }
2697
2698 #[test]
2699 fn resolve_value_shell_echo() {
2700 let result = resolve_value("!echo hello");
2701 assert_eq!(result.as_deref(), Some("hello"));
2702 }
2703
2704 #[test]
2705 fn resolve_value_shell_failing_command() {
2706 assert!(resolve_value("!false").is_none());
2707 }
2708
2709 #[test]
2712 fn resolve_headers_none_returns_empty() {
2713 assert!(resolve_headers(None).is_empty());
2714 }
2715
2716 #[test]
2717 fn resolve_headers_resolves_literal_values() {
2718 let mut headers = HashMap::new();
2719 headers.insert("x-key".to_string(), "literal-value".to_string());
2720 let resolved = resolve_headers(Some(&headers));
2721 assert_eq!(resolved.get("x-key").unwrap(), "literal-value");
2722 }
2723
2724 #[test]
2727 fn model_registry_get_available_returns_only_ready_models() {
2728 let (_dir, auth) = test_auth_storage();
2729 let registry = ModelRegistry::load(&auth, None);
2730 let available = registry.get_available();
2731 assert!(!available.is_empty());
2732 for entry in &available {
2733 assert!(
2734 model_entry_is_ready(entry),
2735 "all available models should be ready for use"
2736 );
2737 }
2738 }
2739
2740 #[test]
2741 fn model_registry_get_available_includes_keyless_models() {
2742 let dir = tempdir().expect("tempdir");
2743 let auth = AuthStorage::load(dir.path().join("auth.json")).expect("auth");
2744 let models_path = dir.path().join("models.json");
2745 let config = serde_json::json!({
2746 "providers": {
2747 "acme-local": {
2748 "baseUrl": "http://127.0.0.1:11434/v1",
2749 "api": "openai-completions",
2750 "authHeader": false,
2751 "models": [
2752 { "id": "dev-model", "name": "Dev Model", "reasoning": false }
2753 ]
2754 }
2755 }
2756 });
2757 std::fs::write(
2758 &models_path,
2759 serde_json::to_string(&config).expect("serialize models"),
2760 )
2761 .expect("write models.json");
2762
2763 let registry = ModelRegistry::load(&auth, Some(models_path));
2764 let available = registry.get_available();
2765 assert!(
2766 available
2767 .iter()
2768 .any(|entry| entry.model.provider == "acme-local" && entry.model.id == "dev-model"),
2769 "keyless models should be considered available"
2770 );
2771 }
2772
2773 #[test]
2774 fn model_registry_error_none_for_valid_load() {
2775 let (_dir, auth) = test_auth_storage();
2776 let registry = ModelRegistry::load(&auth, None);
2777 assert!(registry.error().is_none());
2778 }
2779
2780 #[test]
2781 fn model_registry_error_on_invalid_json() {
2782 let dir = tempdir().expect("tempdir");
2783 let auth = AuthStorage::load(dir.path().join("auth.json")).expect("auth");
2784 let models_path = dir.path().join("models.json");
2785 std::fs::write(&models_path, "not valid json").expect("write bad json");
2786 let registry = ModelRegistry::load(&auth, Some(models_path));
2787 assert!(registry.error().is_some());
2788 }
2789
2790 #[test]
2791 fn model_registry_load_missing_models_json_is_fine() {
2792 let dir = tempdir().expect("tempdir");
2793 let auth = AuthStorage::load(dir.path().join("auth.json")).expect("auth");
2794 let registry = ModelRegistry::load(&auth, Some(dir.path().join("nonexistent.json")));
2795 assert!(registry.error().is_none());
2796 }
2797
2798 #[test]
2801 fn default_models_path_joins_correctly() {
2802 let path = default_models_path(Path::new("/home/user/.pi"));
2803 assert_eq!(path, PathBuf::from("/home/user/.pi/models.json"));
2804 }
2805
2806 #[test]
2809 fn models_config_deserialize_camel_case() {
2810 let json = r#"{
2811 "providers": {
2812 "acme": {
2813 "baseUrl": "https://acme.com/v1",
2814 "apiKey": "env:ACME_KEY",
2815 "authHeader": true,
2816 "models": [{
2817 "id": "acme-1",
2818 "contextWindow": 32000,
2819 "maxTokens": 2048
2820 }]
2821 }
2822 }
2823 }"#;
2824 let config: ModelsConfig = serde_json::from_str(json).expect("parse");
2825 let acme = config.providers.get("acme").expect("acme provider");
2826 assert_eq!(acme.base_url.as_deref(), Some("https://acme.com/v1"));
2827 assert_eq!(acme.auth_header, Some(true));
2828 let model = &acme.models.as_ref().unwrap()[0];
2829 assert_eq!(model.context_window, Some(32000));
2830 assert_eq!(model.max_tokens, Some(2048));
2831 }
2832
2833 #[test]
2834 fn models_config_empty_providers_ok() {
2835 let json = r#"{"providers": {}}"#;
2836 let config: ModelsConfig = serde_json::from_str(json).expect("parse");
2837 assert!(config.providers.is_empty());
2838 }
2839
2840 #[test]
2841 fn compat_config_deserialize() {
2842 let json = r#"{
2843 "supportsStore": true,
2844 "supportsDeveloperRole": false,
2845 "supportsReasoningEffort": true,
2846 "supportsUsageInStreaming": false,
2847 "maxTokensField": "max_completion_tokens"
2848 }"#;
2849 let compat: CompatConfig = serde_json::from_str(json).expect("parse");
2850 assert_eq!(compat.supports_store, Some(true));
2851 assert_eq!(compat.supports_developer_role, Some(false));
2852 assert_eq!(compat.supports_reasoning_effort, Some(true));
2853 assert_eq!(compat.supports_usage_in_streaming, Some(false));
2854 assert_eq!(
2855 compat.max_tokens_field.as_deref(),
2856 Some("max_completion_tokens")
2857 );
2858 }
2859
2860 #[test]
2861 fn compat_config_deserialize_all_fields() {
2862 let json = r#"{
2863 "supportsStore": true,
2864 "supportsDeveloperRole": true,
2865 "supportsReasoningEffort": false,
2866 "supportsUsageInStreaming": false,
2867 "supportsTools": false,
2868 "supportsStreaming": true,
2869 "supportsParallelToolCalls": false,
2870 "maxTokensField": "max_completion_tokens",
2871 "systemRoleName": "developer",
2872 "stopReasonField": "finish_reason",
2873 "customHeaders": {"X-Region": "us-east-1", "X-Tag": "override"},
2874 "openRouterRouting": {"order": ["fallback"]},
2875 "vercelGatewayRouting": {"priority": 1}
2876 }"#;
2877 let compat: CompatConfig = serde_json::from_str(json).expect("parse");
2878 assert_eq!(compat.supports_tools, Some(false));
2879 assert_eq!(compat.supports_streaming, Some(true));
2880 assert_eq!(compat.supports_parallel_tool_calls, Some(false));
2881 assert_eq!(compat.system_role_name.as_deref(), Some("developer"));
2882 assert_eq!(compat.stop_reason_field.as_deref(), Some("finish_reason"));
2883 let custom = compat.custom_headers.as_ref().expect("custom_headers");
2884 assert_eq!(
2885 custom.get("X-Region").map(String::as_str),
2886 Some("us-east-1")
2887 );
2888 assert_eq!(custom.get("X-Tag").map(String::as_str), Some("override"));
2889 assert!(compat.open_router_routing.is_some());
2890 assert!(compat.vercel_gateway_routing.is_some());
2891 }
2892
2893 #[test]
2894 fn compat_config_default_all_none() {
2895 let compat = CompatConfig::default();
2896 assert!(compat.supports_store.is_none());
2897 assert!(compat.supports_tools.is_none());
2898 assert!(compat.supports_streaming.is_none());
2899 assert!(compat.max_tokens_field.is_none());
2900 assert!(compat.system_role_name.is_none());
2901 assert!(compat.stop_reason_field.is_none());
2902 assert!(compat.custom_headers.is_none());
2903 }
2904
2905 #[test]
2906 fn compat_config_deserialize_empty_object() {
2907 let compat: CompatConfig = serde_json::from_str("{}").expect("parse");
2908 assert!(compat.supports_store.is_none());
2909 assert!(compat.supports_tools.is_none());
2910 assert!(compat.custom_headers.is_none());
2911 }
2912
2913 #[test]
2916 fn apply_custom_models_replaces_built_in_when_models_specified() {
2917 let (_dir, auth) = test_auth_storage();
2918 let mut models = built_in_models(&auth, ModelRegistryLoadMode::Full);
2919 let anthropic_before = models
2920 .iter()
2921 .filter(|m| m.model.provider == "anthropic")
2922 .count();
2923 assert!(anthropic_before > 0);
2924
2925 let config = ModelsConfig {
2926 providers: HashMap::from([(
2927 "anthropic".to_string(),
2928 ProviderConfig {
2929 base_url: Some("https://proxy.example/v1".to_string()),
2930 api: Some("anthropic-messages".to_string()),
2931 models: Some(vec![ModelConfig {
2932 id: "custom-claude".to_string(),
2933 name: Some("Custom Claude".to_string()),
2934 ..ModelConfig::default()
2935 }]),
2936 ..ProviderConfig::default()
2937 },
2938 )]),
2939 };
2940
2941 apply_custom_models(&auth, &mut models, &config);
2942
2943 let anthropic_after: Vec<_> = models
2945 .iter()
2946 .filter(|m| m.model.provider == "anthropic")
2947 .collect();
2948 assert_eq!(anthropic_after.len(), 1);
2949 assert_eq!(anthropic_after[0].model.id, "custom-claude");
2950 }
2951
2952 #[test]
2953 fn apply_custom_models_alias_replaces_canonical_built_ins_when_models_specified() {
2954 let (_dir, auth) = test_auth_storage();
2955 let mut models = built_in_models(&auth, ModelRegistryLoadMode::Full);
2956 let google_before = models
2957 .iter()
2958 .filter(|m| m.model.provider == "google")
2959 .count();
2960 assert!(google_before > 0);
2961
2962 let config = ModelsConfig {
2963 providers: HashMap::from([(
2964 "gemini".to_string(),
2965 ProviderConfig {
2966 models: Some(vec![ModelConfig {
2967 id: "gemini-custom".to_string(),
2968 name: Some("Gemini Custom".to_string()),
2969 ..ModelConfig::default()
2970 }]),
2971 ..ProviderConfig::default()
2972 },
2973 )]),
2974 };
2975
2976 apply_custom_models(&auth, &mut models, &config);
2977
2978 assert!(
2979 !models.iter().any(|m| m.model.provider == "google"),
2980 "canonical google built-ins should be replaced when alias config provides explicit models"
2981 );
2982 let gemini_models: Vec<_> = models
2983 .iter()
2984 .filter(|m| m.model.provider == "gemini")
2985 .collect();
2986 assert_eq!(gemini_models.len(), 1);
2987 assert_eq!(gemini_models[0].model.id, "gemini-custom");
2988 }
2989
2990 #[test]
2991 fn apply_custom_models_alias_override_without_models_updates_canonical_provider_models() {
2992 let (_dir, auth) = test_auth_storage();
2993 let mut models = built_in_models(&auth, ModelRegistryLoadMode::Full);
2994 let google_before = models
2995 .iter()
2996 .filter(|m| m.model.provider == "google")
2997 .count();
2998 assert!(google_before > 0);
2999
3000 let config = ModelsConfig {
3001 providers: HashMap::from([(
3002 "gemini".to_string(),
3003 ProviderConfig {
3004 base_url: Some("https://proxy.example/v1".to_string()),
3005 api: Some("google-generative-ai".to_string()),
3006 auth_header: Some(true),
3007 ..ProviderConfig::default()
3008 },
3009 )]),
3010 };
3011
3012 apply_custom_models(&auth, &mut models, &config);
3013
3014 let google_after: Vec<_> = models
3015 .iter()
3016 .filter(|m| m.model.provider == "google")
3017 .collect();
3018 assert_eq!(google_after.len(), google_before);
3019 assert!(
3020 google_after
3021 .iter()
3022 .all(|m| m.model.base_url == "https://proxy.example/v1")
3023 );
3024 assert!(
3025 google_after
3026 .iter()
3027 .all(|m| m.model.api == "google-generative-ai")
3028 );
3029 assert!(google_after.iter().all(|m| m.auth_header));
3030 }
3031
3032 #[test]
3033 fn model_registry_find_canonical_provider_matches_alias_backed_custom_model() {
3034 let (_dir, auth) = test_auth_storage();
3035 let mut models = Vec::new();
3036 let config = ModelsConfig {
3037 providers: HashMap::from([(
3038 "gemini".to_string(),
3039 ProviderConfig {
3040 models: Some(vec![ModelConfig {
3041 id: "gemini-custom-find".to_string(),
3042 ..ModelConfig::default()
3043 }]),
3044 ..ProviderConfig::default()
3045 },
3046 )]),
3047 };
3048
3049 apply_custom_models(&auth, &mut models, &config);
3050 let registry = ModelRegistry {
3051 models,
3052 error: None,
3053 };
3054
3055 assert!(
3056 registry.find("gemini", "gemini-custom-find").is_some(),
3057 "alias lookup should resolve"
3058 );
3059 assert!(
3060 registry.find("google", "gemini-custom-find").is_some(),
3061 "canonical provider lookup should also match alias-backed model"
3062 );
3063 }
3064
3065 #[test]
3068 fn oauth_config_fields() {
3069 let config = OAuthConfig {
3070 auth_url: "https://auth.example.com/authorize".to_string(),
3071 token_url: "https://auth.example.com/token".to_string(),
3072 client_id: "client-123".to_string(),
3073 scopes: vec!["read".to_string(), "write".to_string()],
3074 redirect_uri: Some("http://localhost:8080/callback".to_string()),
3075 };
3076 assert_eq!(config.client_id, "client-123");
3077 assert_eq!(config.scopes.len(), 2);
3078 assert!(config.redirect_uri.is_some());
3079 }
3080
3081 #[test]
3084 fn built_in_anthropic_models_use_correct_api() {
3085 let (_dir, auth) = test_auth_storage();
3086 let models = built_in_models(&auth, ModelRegistryLoadMode::Full);
3087 for m in models.iter().filter(|m| m.model.provider == "anthropic") {
3088 assert_eq!(m.model.api, "anthropic-messages");
3089 assert!(!m.auth_header, "anthropic uses x-api-key, not auth header");
3090 assert!(
3091 m.model.context_window >= 200_000,
3092 "anthropic model {} should expose a modern context window",
3093 m.model.id
3094 );
3095 }
3096 }
3097
3098 #[test]
3099 fn built_in_openai_models_use_auth_header() {
3100 let (_dir, auth) = test_auth_storage();
3101 let models = built_in_models(&auth, ModelRegistryLoadMode::Full);
3102 for m in models.iter().filter(|m| m.model.provider == "openai") {
3103 assert!(m.auth_header, "openai uses Authorization header");
3104 assert_eq!(m.model.api, "openai-responses");
3105 }
3106 }
3107
3108 #[test]
3109 fn built_in_google_models_no_auth_header() {
3110 let (_dir, auth) = test_auth_storage();
3111 let models = built_in_models(&auth, ModelRegistryLoadMode::Full);
3112 for m in models.iter().filter(|m| m.model.provider == "google") {
3113 assert!(!m.auth_header, "google uses api key in URL, not header");
3114 assert_eq!(m.model.api, "google-generative-ai");
3115 }
3116 }
3117
3118 #[test]
3119 fn built_in_reasoning_models_marked_correctly() {
3120 let (_dir, auth) = test_auth_storage();
3121 let models = built_in_models(&auth, ModelRegistryLoadMode::Full);
3122 for m in models
3124 .iter()
3125 .filter(|m| m.model.id.contains("3-5-haiku-20241022"))
3126 {
3127 assert!(!m.model.reasoning, "{} should be non-reasoning", m.model.id);
3128 }
3129 let anthropic_opus_sonnet = models
3130 .iter()
3131 .filter(|m| {
3132 m.model.provider == "anthropic"
3133 && (m.model.id.contains("opus") || m.model.id.contains("sonnet"))
3134 })
3135 .collect::<Vec<_>>();
3136 assert!(
3137 !anthropic_opus_sonnet.is_empty(),
3138 "expected anthropic opus/sonnet models in built-ins"
3139 );
3140 assert!(
3141 anthropic_opus_sonnet.iter().any(|m| m.model.reasoning),
3142 "expected at least one reasoning anthropic opus/sonnet model"
3143 );
3144
3145 for m in anthropic_opus_sonnet
3147 .iter()
3148 .filter(|m| m.model.id.contains("opus-4") || m.model.id.contains("sonnet-4"))
3149 {
3150 assert!(m.model.reasoning, "{} should be reasoning", m.model.id);
3151 }
3152 }
3153
3154 mod proptest_models {
3155 use super::*;
3156 use proptest::prelude::*;
3157
3158 fn dummy_model(id: &str, reasoning: bool) -> ModelEntry {
3159 ModelEntry {
3160 model: Model {
3161 id: id.to_string(),
3162 name: id.to_string(),
3163 provider: "test".to_string(),
3164 api: "messages".to_string(),
3165 base_url: String::new(),
3166 reasoning,
3167 input: vec![InputType::Text],
3168 context_window: 128_000,
3169 max_tokens: 4096,
3170 cost: ModelCost {
3171 input: 0.0,
3172 output: 0.0,
3173 cache_read: 0.0,
3174 cache_write: 0.0,
3175 },
3176 headers: HashMap::new(),
3177 },
3178 api_key: None,
3179 headers: HashMap::new(),
3180 auth_header: false,
3181 compat: None,
3182 oauth_config: None,
3183 }
3184 }
3185
3186 proptest! {
3187 #[test]
3189 fn clamp_thinking_non_reasoning(level_idx in 0..6usize) {
3190 use crate::model::ThinkingLevel;
3191 let levels = [
3192 ThinkingLevel::Off,
3193 ThinkingLevel::Minimal,
3194 ThinkingLevel::Low,
3195 ThinkingLevel::Medium,
3196 ThinkingLevel::High,
3197 ThinkingLevel::XHigh,
3198 ];
3199 let entry = dummy_model("non-reasoning-model", false);
3200 assert_eq!(entry.clamp_thinking_level(levels[level_idx]), ThinkingLevel::Off);
3201 }
3202
3203 #[test]
3205 fn clamp_thinking_reasoning_no_xhigh(level_idx in 0..6usize) {
3206 use crate::model::ThinkingLevel;
3207 let levels = [
3208 ThinkingLevel::Off,
3209 ThinkingLevel::Minimal,
3210 ThinkingLevel::Low,
3211 ThinkingLevel::Medium,
3212 ThinkingLevel::High,
3213 ThinkingLevel::XHigh,
3214 ];
3215 let entry = dummy_model("claude-sonnet-4-5", true);
3216 let result = entry.clamp_thinking_level(levels[level_idx]);
3217 if levels[level_idx] == ThinkingLevel::XHigh {
3218 assert_eq!(result, ThinkingLevel::High);
3219 } else {
3220 assert_eq!(result, levels[level_idx]);
3221 }
3222 }
3223
3224 #[test]
3226 fn supports_xhigh_specific_ids(id in "[a-z\\-0-9]{5,20}") {
3227 let entry = dummy_model(&id, true);
3228 let expected = matches!(
3229 id.as_str(),
3230 "gpt-5.1-codex-max"
3231 | "gpt-5.2"
3232 | "gpt-5.2-codex"
3233 | "gpt-5.3-codex"
3234 | "gpt-5.3-codex-spark"
3235 );
3236 assert_eq!(entry.supports_xhigh(), expected);
3237 }
3238
3239 #[test]
3241 fn openrouter_known_aliases(idx in 0..5usize) {
3242 let pairs = [
3243 ("auto", "openrouter/auto"),
3244 ("gpt-4o-mini", "openai/gpt-4o-mini"),
3245 ("gpt-4o", "openai/gpt-4o"),
3246 ("claude-3.5-sonnet", "anthropic/claude-3.5-sonnet"),
3247 ("gemini-2.5-pro", "google/gemini-2.5-pro"),
3248 ];
3249 let (input, expected) = pairs[idx];
3250 assert_eq!(canonicalize_openrouter_model_id(input), expected);
3251 }
3252
3253 #[test]
3255 fn openrouter_case_insensitive(idx in 0..5usize) {
3256 let aliases = ["auto", "gpt-4o-mini", "gpt-4o", "claude-3.5-sonnet", "gemini-2.5-pro"];
3257 let lower = canonicalize_openrouter_model_id(aliases[idx]);
3258 let upper = canonicalize_openrouter_model_id(&aliases[idx].to_uppercase());
3259 assert_eq!(lower, upper);
3260 }
3261
3262 #[test]
3264 fn openrouter_passthrough(id in "[a-z]/[a-z]{5,15}") {
3265 let result = canonicalize_openrouter_model_id(&id);
3266 assert_eq!(result, id);
3267 }
3268
3269 #[test]
3271 fn openrouter_lookup_includes_canonical(id in "[a-z\\-0-9]{1,20}") {
3272 let ids = openrouter_model_lookup_ids(&id);
3273 let canonical = canonicalize_openrouter_model_id(&id);
3274 assert!(ids.contains(&canonical));
3275 }
3276
3277 #[test]
3279 fn merge_headers_override_wins(key in "[a-z]{1,5}", v1 in "[a-z]{1,5}", v2 in "[a-z]{1,5}") {
3280 let base = HashMap::from([(key.clone(), v1)]);
3281 let over = HashMap::from([(key.clone(), v2.clone())]);
3282 let merged = merge_headers(&base, over);
3283 assert_eq!(merged.get(&key).unwrap(), &v2);
3284 }
3285
3286 #[test]
3288 fn merge_headers_preserves_both(k1 in "[a-z]{1,5}", k2 in "[A-Z]{1,5}", v1 in "[a-z]{1,5}", v2 in "[a-z]{1,5}") {
3289 let base = HashMap::from([(k1.clone(), v1.clone())]);
3290 let over = HashMap::from([(k2.clone(), v2.clone())]);
3291 let merged = merge_headers(&base, over);
3292 assert_eq!(merged.get(&k1), Some(&v1));
3293 assert_eq!(merged.get(&k2), Some(&v2));
3294 }
3295
3296 #[test]
3298 fn sap_endpoint_rejects_empty(s in "[a-z]{0,10}") {
3299 assert_eq!(sap_chat_completions_endpoint("", &s), None);
3300 assert_eq!(sap_chat_completions_endpoint(&s, ""), None);
3301 assert_eq!(sap_chat_completions_endpoint(" ", &s), None);
3302 }
3303
3304 #[test]
3306 fn sap_endpoint_format(base in "[a-z]{3,10}", deployment in "[a-z]{3,10}") {
3307 let url = format!("https://{base}.example.com");
3308 let result = sap_chat_completions_endpoint(&url, &deployment);
3309 assert!(result.is_some());
3310 let endpoint = result.unwrap();
3311 assert!(endpoint.contains(&deployment));
3312 assert!(endpoint.contains("/v2/inference/deployments/"));
3313 assert!(endpoint.ends_with("/chat/completions"));
3314 }
3315
3316 #[test]
3318 fn sap_endpoint_strips_trailing_slash(base in "[a-z]{5,10}") {
3319 let url_no_slash = format!("https://{base}");
3320 let url_slash = format!("https://{base}/");
3321 let r1 = sap_chat_completions_endpoint(&url_no_slash, "model");
3322 let r2 = sap_chat_completions_endpoint(&url_slash, "model");
3323 assert_eq!(r1, r2);
3324 }
3325 }
3326 }
3327}