Skip to main content

codetether_agent/tui/app/state/
model_picker.rs

1//! Async model list refresh from the provider registry.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use futures::future::join_all;
7use tokio::sync::mpsc;
8
9use crate::provider::{ModelInfo, ProviderRegistry};
10use crate::tui::models::ViewMode;
11
12const MODEL_LIST_TIMEOUT: Duration = Duration::from_secs(4);
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct ModelRefreshSummary {
16    pub provider_count: usize,
17    pub loaded_providers: usize,
18    pub loaded_models: usize,
19    pub failed_providers: usize,
20    pub timed_out_providers: usize,
21}
22
23impl ModelRefreshSummary {
24    fn empty() -> Self {
25        Self {
26            provider_count: 0,
27            loaded_providers: 0,
28            loaded_models: 0,
29            failed_providers: 0,
30            timed_out_providers: 0,
31        }
32    }
33
34    fn skipped_providers(&self) -> usize {
35        self.failed_providers + self.timed_out_providers
36    }
37
38    fn status_message(&self) -> String {
39        if self.loaded_models == 0 {
40            let skipped = self.skipped_providers();
41            if self.provider_count == 0 || skipped == 0 {
42                "No models available".to_string()
43            } else {
44                format!(
45                    "No models available ({} provider{} skipped)",
46                    skipped,
47                    plural(skipped)
48                )
49            }
50        } else {
51            let mut status = format!(
52                "Model picker: {} models from {} provider{}",
53                self.loaded_models,
54                self.loaded_providers,
55                plural(self.loaded_providers)
56            );
57            let skipped = self.skipped_providers();
58            if skipped > 0 {
59                status.push_str(&format!(" ({} skipped)", skipped));
60            }
61            status
62        }
63    }
64}
65
66#[derive(Debug)]
67pub enum ModelRefreshEvent {
68    Loaded {
69        models: Vec<String>,
70        summary: ModelRefreshSummary,
71    },
72}
73
74impl super::AppState {
75    pub fn start_model_refresh(&mut self, registry: Arc<ProviderRegistry>) {
76        let (tx, rx) = mpsc::unbounded_channel();
77        self.model_refresh_rx = Some(rx);
78        self.model_refresh_in_flight = true;
79
80        tokio::spawn(async move {
81            let (models, summary) = load_available_models(&registry).await;
82            let _ = tx.send(ModelRefreshEvent::Loaded { models, summary });
83        });
84    }
85
86    pub fn drain_model_refresh(&mut self) {
87        let mut latest = None;
88        if let Some(rx) = self.model_refresh_rx.as_mut() {
89            while let Ok(event) = rx.try_recv() {
90                latest = Some(event);
91            }
92        }
93
94        let Some(ModelRefreshEvent::Loaded { models, summary }) = latest else {
95            return;
96        };
97
98        self.set_available_models(models);
99        self.model_refresh_in_flight = false;
100        self.model_refresh_rx = None;
101
102        if let Some(target) = self.model_picker_target_model.as_deref()
103            && let Some(index) = self
104                .filtered_models()
105                .iter()
106                .position(|model| *model == target)
107        {
108            self.selected_model_index = index;
109        }
110
111        if self.model_picker_active || self.view_mode == ViewMode::Model {
112            self.status = summary.status_message();
113        }
114    }
115
116    /// Refresh the available models list from the provider registry.
117    ///
118    /// # Errors
119    ///
120    /// Returns an error if any provider call fails critically.
121    pub async fn refresh_available_models(
122        &mut self,
123        registry: Option<&Arc<ProviderRegistry>>,
124    ) -> anyhow::Result<ModelRefreshSummary> {
125        let Some(registry) = registry else {
126            self.available_models.clear();
127            return Ok(ModelRefreshSummary::empty());
128        };
129
130        let (models, summary) = load_available_models(registry).await;
131        self.set_available_models(models);
132        Ok(summary)
133    }
134}
135
136async fn load_available_models(registry: &ProviderRegistry) -> (Vec<String>, ModelRefreshSummary) {
137    let provider_names: Vec<String> = registry.list().into_iter().map(str::to_string).collect();
138    let provider_count = provider_names.len();
139
140    let fetches = provider_names.into_iter().filter_map(|provider_name| {
141        registry.get(&provider_name).map(|provider| async move {
142            let result = tokio::time::timeout(MODEL_LIST_TIMEOUT, provider.list_models()).await;
143            (provider_name, result)
144        })
145    });
146
147    let mut models = Vec::new();
148    let mut summary = ModelRefreshSummary {
149        provider_count,
150        ..ModelRefreshSummary::empty()
151    };
152
153    for (provider_name, result) in join_all(fetches).await {
154        match result {
155            Ok(Ok(provider_models)) => {
156                let before = models.len();
157                models.extend(
158                    provider_models
159                        .iter()
160                        .filter_map(|model| model_ref_for_provider(&provider_name, model)),
161                );
162                if models.len() > before {
163                    summary.loaded_providers += 1;
164                }
165            }
166            Ok(Err(err)) => {
167                summary.failed_providers += 1;
168                tracing::warn!(
169                    provider = %provider_name,
170                    error = %err,
171                    "failed to load models for TUI picker"
172                );
173            }
174            Err(_) => {
175                summary.timed_out_providers += 1;
176                tracing::warn!(
177                    provider = %provider_name,
178                    timeout_ms = MODEL_LIST_TIMEOUT.as_millis(),
179                    "timed out loading models for TUI picker"
180                );
181            }
182        }
183    }
184
185    models.sort();
186    models.dedup();
187    summary.loaded_models = models.len();
188    (models, summary)
189}
190
191fn model_ref_for_provider(provider_name: &str, model: &ModelInfo) -> Option<String> {
192    let provider_name = provider_name.trim();
193    let model_id = model.id.trim();
194    if provider_name.is_empty() || model_id.is_empty() {
195        return None;
196    }
197
198    let provider_prefix = format!("{provider_name}/");
199    if model_id.starts_with(&provider_prefix) {
200        Some(model_id.to_string())
201    } else {
202        Some(format!("{provider_name}/{model_id}"))
203    }
204}
205
206fn plural(count: usize) -> &'static str {
207    if count == 1 { "" } else { "s" }
208}
209
210#[cfg(test)]
211mod tests {
212    use std::sync::Arc;
213
214    use anyhow::Result;
215    use async_trait::async_trait;
216    use futures::stream::BoxStream;
217
218    use super::*;
219    use crate::provider::{
220        CompletionRequest, CompletionResponse, EmbeddingRequest, EmbeddingResponse, Provider,
221        StreamChunk,
222    };
223
224    struct StaticProvider {
225        name: &'static str,
226        models: Vec<ModelInfo>,
227    }
228
229    #[async_trait]
230    impl Provider for StaticProvider {
231        fn name(&self) -> &str {
232            self.name
233        }
234
235        async fn list_models(&self) -> Result<Vec<ModelInfo>> {
236            Ok(self.models.clone())
237        }
238
239        async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse> {
240            unimplemented!("not used by model picker tests")
241        }
242
243        async fn complete_stream(
244            &self,
245            _request: CompletionRequest,
246        ) -> Result<BoxStream<'static, StreamChunk>> {
247            unimplemented!("not used by model picker tests")
248        }
249
250        async fn embed(&self, _request: EmbeddingRequest) -> Result<EmbeddingResponse> {
251            unimplemented!("not used by model picker tests")
252        }
253    }
254
255    fn model(id: &str, provider: &str) -> ModelInfo {
256        ModelInfo {
257            id: id.to_string(),
258            name: id.to_string(),
259            provider: provider.to_string(),
260            context_window: 128_000,
261            max_output_tokens: Some(16_384),
262            supports_vision: false,
263            supports_tools: true,
264            supports_streaming: true,
265            input_cost_per_million: None,
266            output_cost_per_million: None,
267        }
268    }
269
270    #[tokio::test]
271    async fn refresh_prefixes_nested_model_ids_with_registry_provider() {
272        let mut registry = ProviderRegistry::new();
273        registry.register(Arc::new(StaticProvider {
274            name: "openrouter",
275            models: vec![model("openai/gpt-5.5", "openrouter")],
276        }));
277        let registry = Arc::new(registry);
278        let mut state = super::super::AppState::default();
279
280        let summary = state
281            .refresh_available_models(Some(&registry))
282            .await
283            .expect("refresh should succeed");
284
285        assert_eq!(state.available_models, vec!["openrouter/openai/gpt-5.5"]);
286        assert_eq!(summary.loaded_models, 1);
287        assert_eq!(summary.loaded_providers, 1);
288    }
289
290    #[tokio::test]
291    async fn refresh_does_not_double_prefix_provider_qualified_ids() {
292        let mut registry = ProviderRegistry::new();
293        registry.register(Arc::new(StaticProvider {
294            name: "openrouter",
295            models: vec![model("openrouter/z-ai/glm-5", "openrouter")],
296        }));
297        let registry = Arc::new(registry);
298        let mut state = super::super::AppState::default();
299
300        state
301            .refresh_available_models(Some(&registry))
302            .await
303            .expect("refresh should succeed");
304
305        assert_eq!(state.available_models, vec!["openrouter/z-ai/glm-5"]);
306    }
307}