1use std::borrow::Cow;
2
3use crate::{client::Provider, transport::TransportError, types::AiLibError};
4
5use super::catalog;
6
7#[derive(Debug, Clone)]
8pub struct ModelResolver;
9
10#[derive(Debug, Clone)]
11pub struct ModelResolution {
12 pub model: String,
13 pub source: ModelResolutionSource,
14 pub doc_url: &'static str,
15}
16
17impl ModelResolution {
18 pub fn new(
19 model: impl Into<String>,
20 source: ModelResolutionSource,
21 doc_url: &'static str,
22 ) -> Self {
23 Self {
24 model: model.into(),
25 source,
26 doc_url,
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
32pub enum ModelResolutionSource {
33 Explicit,
34 CustomDefault,
35 EnvOverride,
36 ProviderDefault,
37 ProfileFallback,
38}
39
40impl ModelResolver {
41 pub fn new() -> Self {
42 Self
43 }
44
45 pub fn resolve_chat_model(
46 &self,
47 provider: Provider,
48 requested: Option<&str>,
49 ) -> ModelResolution {
50 let profile = catalog::profile(provider);
51 let doc_url = profile.doc_url;
52 if let Some(model) = requested {
53 return ModelResolution::new(
54 model.to_string(),
55 ModelResolutionSource::Explicit,
56 doc_url,
57 );
58 }
59
60 if let Some(env_model) = self.env_override(provider) {
61 return ModelResolution::new(env_model, ModelResolutionSource::EnvOverride, doc_url);
62 }
63
64 ModelResolution::new(
65 profile.default_chat_model(),
66 ModelResolutionSource::ProviderDefault,
67 doc_url,
68 )
69 }
70
71 pub fn fallback_after_invalid(
72 &self,
73 provider: Provider,
74 failed_model: &str,
75 ) -> Option<ModelResolution> {
76 let profile = catalog::profile(provider);
77 let doc_url = profile.doc_url;
78
79 if let Some(env_model) = self.env_override(provider) {
80 if !equals_ignore_case(&env_model, failed_model) {
81 return Some(ModelResolution::new(
82 env_model,
83 ModelResolutionSource::EnvOverride,
84 doc_url,
85 ));
86 }
87 }
88
89 for candidate in profile.fallback_models {
90 if !equals_ignore_case(candidate, failed_model) {
91 return Some(ModelResolution::new(
92 *candidate,
93 ModelResolutionSource::ProfileFallback,
94 doc_url,
95 ));
96 }
97 }
98
99 if !equals_ignore_case(profile.default_chat_model(), failed_model) {
100 return Some(ModelResolution::new(
101 profile.default_chat_model(),
102 ModelResolutionSource::ProviderDefault,
103 doc_url,
104 ));
105 }
106
107 None
108 }
109
110 pub fn doc_url(&self, provider: Provider) -> &'static str {
111 catalog::profile(provider).doc_url
112 }
113
114 pub fn suggestions(&self, provider: Provider) -> Vec<String> {
115 let mut list = Vec::new();
116
117 if let Some(env_override) = self.env_override(provider) {
118 push_unique(&mut list, env_override);
119 }
120
121 let profile = catalog::profile(provider);
122 push_unique(&mut list, profile.default_chat_model().to_string());
123 for candidate in profile.fallback_models {
124 push_unique(&mut list, (*candidate).to_string());
125 }
126
127 list
128 }
129
130 pub fn looks_like_invalid_model(&self, err: &AiLibError) -> bool {
131 match err {
132 AiLibError::ModelNotFound(_) => true,
133 AiLibError::InvalidRequest(msg)
134 | AiLibError::ProviderError(msg)
135 | AiLibError::InvalidModelResponse(msg) => contains_invalid_keyword(msg),
136 AiLibError::TransportError(TransportError::ClientError { status, message })
137 | AiLibError::TransportError(TransportError::ServerError { status, message }) => {
138 (*status == 400 || *status == 404) && contains_invalid_keyword(message)
139 }
140 _ => false,
141 }
142 }
143
144 pub fn decorate_invalid_model_error(
145 &self,
146 provider: Provider,
147 requested_model: &str,
148 err: AiLibError,
149 ) -> AiLibError {
150 let doc_url = self.doc_url(provider);
151 let suggestions = self.suggestions(provider);
152 let provider_name = format!("{provider:?}");
153 let suggestion_text = if suggestions.is_empty() {
154 Cow::Borrowed("no known fallback models configured")
155 } else {
156 Cow::Owned(suggestions.join(", "))
157 };
158
159 AiLibError::ModelNotFound(format!(
160 "Model `{}` is not available for provider {}. Try: {}. Docs: {}. Original error: {}",
161 requested_model, provider_name, suggestion_text, doc_url, err
162 ))
163 }
164
165 fn env_override(&self, provider: Provider) -> Option<String> {
166 let var = format!("{}_MODEL", provider.env_prefix());
167 std::env::var(var)
168 .ok()
169 .filter(|value| !value.trim().is_empty())
170 }
171}
172
173fn push_unique(list: &mut Vec<String>, value: String) {
174 if !list
175 .iter()
176 .any(|existing| equals_ignore_case(existing, &value))
177 {
178 list.push(value);
179 }
180}
181
182fn equals_ignore_case(a: &str, b: &str) -> bool {
183 a.eq_ignore_ascii_case(b)
184}
185
186fn contains_invalid_keyword(message: &str) -> bool {
187 let lower = message.to_ascii_lowercase();
188 lower.contains("invalid model")
189 || lower.contains("model_not_found")
190 || lower.contains("model not found")
191 || lower.contains("unknown model")
192 || lower.contains("unsupported model")
193 || lower.contains("\"code\":\"1500\"")
194}