codetether_agent/tui/app/state/
model_picker.rs1use std::sync::Arc;
4use std::time::Duration;
5
6use futures::future::join_all;
7use tokio::sync::mpsc;
8
9use crate::provider::{ModelInfo, ProviderRegistry};
10use crate::tui::models::ViewMode;
11
12const MODEL_LIST_TIMEOUT: Duration = Duration::from_secs(4);
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct ModelRefreshSummary {
16 pub provider_count: usize,
17 pub loaded_providers: usize,
18 pub loaded_models: usize,
19 pub failed_providers: usize,
20 pub timed_out_providers: usize,
21}
22
23impl ModelRefreshSummary {
24 fn empty() -> Self {
25 Self {
26 provider_count: 0,
27 loaded_providers: 0,
28 loaded_models: 0,
29 failed_providers: 0,
30 timed_out_providers: 0,
31 }
32 }
33
34 fn skipped_providers(&self) -> usize {
35 self.failed_providers + self.timed_out_providers
36 }
37
38 fn status_message(&self) -> String {
39 if self.loaded_models == 0 {
40 let skipped = self.skipped_providers();
41 if self.provider_count == 0 || skipped == 0 {
42 "No models available".to_string()
43 } else {
44 format!(
45 "No models available ({} provider{} skipped)",
46 skipped,
47 plural(skipped)
48 )
49 }
50 } else {
51 let mut status = format!(
52 "Model picker: {} models from {} provider{}",
53 self.loaded_models,
54 self.loaded_providers,
55 plural(self.loaded_providers)
56 );
57 let skipped = self.skipped_providers();
58 if skipped > 0 {
59 status.push_str(&format!(" ({} skipped)", skipped));
60 }
61 status
62 }
63 }
64}
65
66#[derive(Debug)]
67pub enum ModelRefreshEvent {
68 Loaded {
69 models: Vec<String>,
70 summary: ModelRefreshSummary,
71 },
72}
73
74impl super::AppState {
75 pub fn start_model_refresh(&mut self, registry: Arc<ProviderRegistry>) {
76 let (tx, rx) = mpsc::unbounded_channel();
77 self.model_refresh_rx = Some(rx);
78 self.model_refresh_in_flight = true;
79
80 tokio::spawn(async move {
81 let (models, summary) = load_available_models(®istry).await;
82 let _ = tx.send(ModelRefreshEvent::Loaded { models, summary });
83 });
84 }
85
86 pub fn drain_model_refresh(&mut self) {
87 let mut latest = None;
88 if let Some(rx) = self.model_refresh_rx.as_mut() {
89 while let Ok(event) = rx.try_recv() {
90 latest = Some(event);
91 }
92 }
93
94 let Some(ModelRefreshEvent::Loaded { models, summary }) = latest else {
95 return;
96 };
97
98 self.set_available_models(models);
99 self.model_refresh_in_flight = false;
100 self.model_refresh_rx = None;
101
102 if let Some(target) = self.model_picker_target_model.as_deref()
103 && let Some(index) = self
104 .filtered_models()
105 .iter()
106 .position(|model| *model == target)
107 {
108 self.selected_model_index = index;
109 }
110
111 if self.model_picker_active || self.view_mode == ViewMode::Model {
112 self.status = summary.status_message();
113 }
114 }
115
116 pub async fn refresh_available_models(
122 &mut self,
123 registry: Option<&Arc<ProviderRegistry>>,
124 ) -> anyhow::Result<ModelRefreshSummary> {
125 let Some(registry) = registry else {
126 self.available_models.clear();
127 return Ok(ModelRefreshSummary::empty());
128 };
129
130 let (models, summary) = load_available_models(registry).await;
131 self.set_available_models(models);
132 Ok(summary)
133 }
134}
135
136async fn load_available_models(registry: &ProviderRegistry) -> (Vec<String>, ModelRefreshSummary) {
137 let provider_names: Vec<String> = registry.list().into_iter().map(str::to_string).collect();
138 let provider_count = provider_names.len();
139
140 let fetches = provider_names.into_iter().filter_map(|provider_name| {
141 registry.get(&provider_name).map(|provider| async move {
142 let result = tokio::time::timeout(MODEL_LIST_TIMEOUT, provider.list_models()).await;
143 (provider_name, result)
144 })
145 });
146
147 let mut models = Vec::new();
148 let mut summary = ModelRefreshSummary {
149 provider_count,
150 ..ModelRefreshSummary::empty()
151 };
152
153 for (provider_name, result) in join_all(fetches).await {
154 match result {
155 Ok(Ok(provider_models)) => {
156 let before = models.len();
157 models.extend(
158 provider_models
159 .iter()
160 .filter_map(|model| model_ref_for_provider(&provider_name, model)),
161 );
162 if models.len() > before {
163 summary.loaded_providers += 1;
164 }
165 }
166 Ok(Err(err)) => {
167 summary.failed_providers += 1;
168 tracing::warn!(
169 provider = %provider_name,
170 error = %err,
171 "failed to load models for TUI picker"
172 );
173 }
174 Err(_) => {
175 summary.timed_out_providers += 1;
176 tracing::warn!(
177 provider = %provider_name,
178 timeout_ms = MODEL_LIST_TIMEOUT.as_millis(),
179 "timed out loading models for TUI picker"
180 );
181 }
182 }
183 }
184
185 models.sort();
186 models.dedup();
187 summary.loaded_models = models.len();
188 (models, summary)
189}
190
191fn model_ref_for_provider(provider_name: &str, model: &ModelInfo) -> Option<String> {
192 let provider_name = provider_name.trim();
193 let model_id = model.id.trim();
194 if provider_name.is_empty() || model_id.is_empty() {
195 return None;
196 }
197
198 let provider_prefix = format!("{provider_name}/");
199 if model_id.starts_with(&provider_prefix) {
200 Some(model_id.to_string())
201 } else {
202 Some(format!("{provider_name}/{model_id}"))
203 }
204}
205
206fn plural(count: usize) -> &'static str {
207 if count == 1 { "" } else { "s" }
208}
209
210#[cfg(test)]
211mod tests {
212 use std::sync::Arc;
213
214 use anyhow::Result;
215 use async_trait::async_trait;
216 use futures::stream::BoxStream;
217
218 use super::*;
219 use crate::provider::{
220 CompletionRequest, CompletionResponse, EmbeddingRequest, EmbeddingResponse, Provider,
221 StreamChunk,
222 };
223
224 struct StaticProvider {
225 name: &'static str,
226 models: Vec<ModelInfo>,
227 }
228
229 #[async_trait]
230 impl Provider for StaticProvider {
231 fn name(&self) -> &str {
232 self.name
233 }
234
235 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
236 Ok(self.models.clone())
237 }
238
239 async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
240 unimplemented!("not used by model picker tests")
241 }
242
243 async fn complete_stream(
244 &self,
245 _request: CompletionRequest,
246 ) -> Result<BoxStream<'static, StreamChunk>> {
247 unimplemented!("not used by model picker tests")
248 }
249
250 async fn embed(&self, _request: EmbeddingRequest) -> Result<EmbeddingResponse> {
251 unimplemented!("not used by model picker tests")
252 }
253 }
254
255 fn model(id: &str, provider: &str) -> ModelInfo {
256 ModelInfo {
257 id: id.to_string(),
258 name: id.to_string(),
259 provider: provider.to_string(),
260 context_window: 128_000,
261 max_output_tokens: Some(16_384),
262 supports_vision: false,
263 supports_tools: true,
264 supports_streaming: true,
265 input_cost_per_million: None,
266 output_cost_per_million: None,
267 }
268 }
269
270 #[tokio::test]
271 async fn refresh_prefixes_nested_model_ids_with_registry_provider() {
272 let mut registry = ProviderRegistry::new();
273 registry.register(Arc::new(StaticProvider {
274 name: "openrouter",
275 models: vec![model("openai/gpt-5.5", "openrouter")],
276 }));
277 let registry = Arc::new(registry);
278 let mut state = super::super::AppState::default();
279
280 let summary = state
281 .refresh_available_models(Some(®istry))
282 .await
283 .expect("refresh should succeed");
284
285 assert_eq!(state.available_models, vec!["openrouter/openai/gpt-5.5"]);
286 assert_eq!(summary.loaded_models, 1);
287 assert_eq!(summary.loaded_providers, 1);
288 }
289
290 #[tokio::test]
291 async fn refresh_does_not_double_prefix_provider_qualified_ids() {
292 let mut registry = ProviderRegistry::new();
293 registry.register(Arc::new(StaticProvider {
294 name: "openrouter",
295 models: vec![model("openrouter/z-ai/glm-5", "openrouter")],
296 }));
297 let registry = Arc::new(registry);
298 let mut state = super::super::AppState::default();
299
300 state
301 .refresh_available_models(Some(®istry))
302 .await
303 .expect("refresh should succeed");
304
305 assert_eq!(state.available_models, vec!["openrouter/z-ai/glm-5"]);
306 }
307}