codetether_agent/autochat/
model_rotation.rs1use crate::provider::{ModelInfo, Provider, ProviderRegistry};
2use serde::{Deserialize, Serialize};
3use std::collections::HashSet;
4use std::sync::Arc;
5
6const PRIORITY_PROVIDERS: [&str; 8] = [
7 "minimax",
8 "zai",
9 "github-copilot",
10 "github-copilot-enterprise",
11 "openai-codex",
12 "openrouter",
13 "minimax-credits",
14 "openai",
15];
16
17const RESOLUTION_FALLBACKS: [&str; 12] = [
18 "minimax",
19 "zai",
20 "github-copilot",
21 "github-copilot-enterprise",
22 "openai-codex",
23 "openrouter",
24 "minimax-credits",
25 "openai",
26 "anthropic",
27 "novita",
28 "moonshotai",
29 "google",
30];
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct RelayModelRotation {
34 #[serde(default)]
35 pub model_refs: Vec<String>,
36 #[serde(default)]
37 pub cursor: usize,
38}
39
40impl RelayModelRotation {
41 pub fn fallback(model_ref: &str) -> Self {
42 Self {
43 model_refs: vec![model_ref.to_string()],
44 cursor: 0,
45 }
46 }
47
48 pub fn next_model_ref(&mut self, fallback: &str) -> String {
49 if self.model_refs.is_empty() {
50 return fallback.to_string();
51 }
52 let idx = self.cursor % self.model_refs.len();
53 let next = self.model_refs[idx].clone();
54 self.cursor = self.cursor.saturating_add(1);
55 next
56 }
57}
58
59pub async fn build_round_robin_model_rotation(
60 registry: &Arc<ProviderRegistry>,
61 requested_model_ref: &str,
62) -> RelayModelRotation {
63 let mut rotation = RelayModelRotation::default();
64 let mut used_providers = HashSet::<String>::new();
65
66 if let Some(model_ref) = normalize_requested_model_ref(registry, requested_model_ref) {
67 add_unique_model_ref(&mut rotation.model_refs, &mut used_providers, &model_ref);
68 }
69
70 for provider_name in PRIORITY_PROVIDERS {
71 if let Some(model_ref) = preferred_model_ref_for_provider(registry, provider_name).await {
72 add_unique_model_ref(&mut rotation.model_refs, &mut used_providers, &model_ref);
73 }
74 }
75
76 if rotation.model_refs.is_empty() {
77 rotation.model_refs.push(requested_model_ref.to_string());
78 }
79
80 rotation
81}
82
83pub fn resolve_provider_for_model_autochat(
84 registry: &Arc<ProviderRegistry>,
85 model_ref: &str,
86) -> Option<(Arc<dyn Provider>, String)> {
87 let (provider_name, model_name) = crate::provider::parse_model_string(model_ref);
88 if let Some(provider_name) = provider_name {
89 let normalized = normalize_provider_name(provider_name);
90 return registry
91 .get(normalized)
92 .map(|provider| (provider, model_name.to_string()));
93 }
94
95 for provider_name in RESOLUTION_FALLBACKS {
96 if let Some(provider) = registry.get(provider_name) {
97 return Some((provider, model_ref.to_string()));
98 }
99 }
100
101 registry
102 .list()
103 .first()
104 .copied()
105 .and_then(|name| registry.get(name))
106 .map(|provider| (provider, model_ref.to_string()))
107}
108
109fn normalize_requested_model_ref(
110 registry: &Arc<ProviderRegistry>,
111 requested_model_ref: &str,
112) -> Option<String> {
113 let (provider_name, model_name) = crate::provider::parse_model_string(requested_model_ref);
114 if let Some(provider_name) = provider_name {
115 let normalized = normalize_provider_name(provider_name);
116 if registry.get(normalized).is_some() {
117 return Some(format!("{normalized}/{model_name}"));
118 }
119 }
120
121 resolve_provider_for_model_autochat(registry, requested_model_ref)
122 .map(|(provider, model)| format!("{}/{}", provider.name(), model))
123}
124
125fn add_unique_model_ref(
126 model_refs: &mut Vec<String>,
127 used_providers: &mut HashSet<String>,
128 model_ref: &str,
129) {
130 let (provider_name, _) = crate::provider::parse_model_string(model_ref);
131 let Some(provider_name) = provider_name else {
132 return;
133 };
134 if used_providers.insert(provider_name.to_string()) {
135 model_refs.push(model_ref.to_string());
136 }
137}
138
139async fn preferred_model_ref_for_provider(
140 registry: &Arc<ProviderRegistry>,
141 provider_name: &str,
142) -> Option<String> {
143 let provider = registry.get(provider_name)?;
144 let listed = provider.list_models().await.unwrap_or_default();
145 if let Some(model_id) = choose_model_from_list(provider_name, &listed) {
146 return Some(format!("{provider_name}/{model_id}"));
147 }
148 default_model_for_provider(provider_name).map(|model_id| format!("{provider_name}/{model_id}"))
149}
150
151fn choose_model_from_list(provider_name: &str, models: &[ModelInfo]) -> Option<String> {
152 if models.is_empty() {
153 return None;
154 }
155
156 if provider_name == "openrouter" {
157 return choose_openrouter_model(models).or_else(|| models.first().map(|m| m.id.clone()));
158 }
159
160 if provider_name == "github-copilot" || provider_name == "github-copilot-enterprise" {
161 return choose_copilot_model(models).or_else(|| models.first().map(|m| m.id.clone()));
162 }
163
164 let preferred = preferred_models_for_provider(provider_name);
165 for model_id in preferred {
166 if let Some(found) = models
167 .iter()
168 .find(|model| model.id.eq_ignore_ascii_case(model_id))
169 {
170 return Some(found.id.clone());
171 }
172 }
173
174 models.first().map(|m| m.id.clone())
175}
176
177fn choose_openrouter_model(models: &[ModelInfo]) -> Option<String> {
178 let free_models: Vec<&ModelInfo> = models
179 .iter()
180 .filter(|model| is_openrouter_free_model(&model.id))
181 .collect();
182 if !free_models.is_empty() {
183 return free_models
184 .into_iter()
185 .max_by_key(|model| score_openrouter_model(&model.id))
186 .map(|model| model.id.clone());
187 }
188 models
189 .iter()
190 .max_by_key(|model| score_openrouter_model(&model.id))
191 .map(|model| model.id.clone())
192}
193
194fn choose_copilot_model(models: &[ModelInfo]) -> Option<String> {
195 let preferred = ["gpt-5-mini", "gpt-4.1", "gpt-4o"];
196 for model_id in preferred {
197 if let Some(found) = models
198 .iter()
199 .find(|model| model.id.eq_ignore_ascii_case(model_id))
200 {
201 return Some(found.id.clone());
202 }
203 }
204
205 models
206 .iter()
207 .find(|model| {
208 model.input_cost_per_million == Some(0.0) && model.output_cost_per_million == Some(0.0)
209 })
210 .map(|model| model.id.clone())
211}
212
213fn is_openrouter_free_model(model_id: &str) -> bool {
214 let id = model_id.to_ascii_lowercase();
215 id.contains(":free") || id.ends_with("-free")
216}
217
218fn score_openrouter_model(model_id: &str) -> i32 {
219 let id = model_id.to_ascii_lowercase();
220 let mut score = 0;
221 if is_openrouter_free_model(&id) {
222 score += 1000;
223 }
224 if id.contains("glm-5") {
225 score += 250;
226 }
227 if id.contains("minimax") {
228 score += 220;
229 }
230 if id.contains("gpt-5-mini") {
231 score += 180;
232 }
233 if id.contains("coder") {
234 score += 70;
235 }
236 score
237}
238
239fn preferred_models_for_provider(provider_name: &str) -> &'static [&'static str] {
240 match provider_name {
241 "minimax" => &["MiniMax-M2.5", "MiniMax-M2.1", "MiniMax-M2"],
242 "minimax-credits" => &["MiniMax-M2.5-highspeed", "MiniMax-M2.1-highspeed"],
243 "zai" => &["glm-5", "glm-4.7", "glm-4.7-flash"],
244 "openai-codex" => &["gpt-5-mini", "gpt-5", "gpt-5.1-codex"],
245 "github-copilot" | "github-copilot-enterprise" => &["gpt-5-mini", "gpt-4.1", "gpt-4o"],
246 "openrouter" => &["z-ai/glm-5:free", "z-ai/glm-5", "z-ai/glm-4.7:free"],
247 _ => &[],
248 }
249}
250
251fn default_model_for_provider(provider_name: &str) -> Option<&'static str> {
252 let defaults = preferred_models_for_provider(provider_name);
253 defaults.first().copied()
254}
255
256fn normalize_provider_name(provider_name: &str) -> &str {
257 if provider_name.eq_ignore_ascii_case("zhipuai") || provider_name.eq_ignore_ascii_case("z-ai") {
258 "zai"
259 } else {
260 provider_name
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::{
267 RelayModelRotation, choose_copilot_model, choose_openrouter_model, normalize_provider_name,
268 };
269 use crate::provider::ModelInfo;
270
271 #[test]
272 fn relay_rotation_cycles_models() {
273 let mut rotation = RelayModelRotation {
274 model_refs: vec!["a/x".to_string(), "b/y".to_string()],
275 cursor: 0,
276 };
277 assert_eq!(rotation.next_model_ref("fallback"), "a/x");
278 assert_eq!(rotation.next_model_ref("fallback"), "b/y");
279 assert_eq!(rotation.next_model_ref("fallback"), "a/x");
280 }
281
282 #[test]
283 fn provider_alias_normalizes_zhipu_and_z_ai() {
284 assert_eq!(normalize_provider_name("zhipuai"), "zai");
285 assert_eq!(normalize_provider_name("z-ai"), "zai");
286 assert_eq!(normalize_provider_name("openrouter"), "openrouter");
287 }
288
289 #[test]
290 fn openrouter_prefers_free_glm_models() {
291 let models = vec![
292 model("openai/gpt-4.1"),
293 model("z-ai/glm-5:free"),
294 model("moonshot/kimi-k2:free"),
295 ];
296 assert_eq!(
297 choose_openrouter_model(&models).as_deref(),
298 Some("z-ai/glm-5:free")
299 );
300 }
301
302 #[test]
303 fn copilot_prefers_gpt_5_mini_when_available() {
304 let models = vec![
305 model("gpt-4o"),
306 model("gpt-5-mini"),
307 model("claude-sonnet-4"),
308 ];
309 assert_eq!(choose_copilot_model(&models).as_deref(), Some("gpt-5-mini"));
310 }
311
312 fn model(id: &str) -> ModelInfo {
313 ModelInfo {
314 id: id.to_string(),
315 name: id.to_string(),
316 provider: "test".to_string(),
317 context_window: 128_000,
318 max_output_tokens: Some(16_384),
319 supports_vision: false,
320 supports_tools: true,
321 supports_streaming: true,
322 input_cost_per_million: Some(0.0),
323 output_cost_per_million: Some(0.0),
324 }
325 }
326}