Skip to main content

pi/interactive/
model_selector_ui.rs

1use super::commands::{
2    model_entry_matches, model_requires_configured_credential, resolve_model_key_from_default_auth,
3};
4use super::*;
5
6impl PiApp {
7    fn normalize_model_key(entry: &ModelEntry) -> (String, String) {
8        let canonical_provider =
9            crate::provider_metadata::canonical_provider_id(entry.model.provider.as_str())
10                .unwrap_or(entry.model.provider.as_str());
11        (
12            canonical_provider.to_ascii_lowercase(),
13            entry.model.id.to_ascii_lowercase(),
14        )
15    }
16
17    fn unique_model_count(models: &[ModelEntry]) -> usize {
18        models
19            .iter()
20            .map(Self::normalize_model_key)
21            .collect::<std::collections::HashSet<_>>()
22            .len()
23    }
24
25    fn available_models_with_credentials(&self) -> Vec<ModelEntry> {
26        let auth = crate::auth::AuthStorage::load(crate::config::Config::auth_path()).ok();
27        let mut provider_has_credential: std::collections::HashMap<String, bool> =
28            std::collections::HashMap::new();
29        let mut filtered = Vec::new();
30        for entry in &self.available_models {
31            let provider = entry.model.provider.as_str();
32            let canonical = crate::provider_metadata::canonical_provider_id(provider)
33                .unwrap_or(provider)
34                .to_ascii_lowercase();
35            let requires_configured_credential = model_requires_configured_credential(entry);
36            let has_inline_key = entry
37                .api_key
38                .as_ref()
39                .is_some_and(|key| !key.trim().is_empty());
40            let has_auth_key = auth.as_ref().is_some_and(|storage| {
41                *provider_has_credential
42                    .entry(canonical.clone())
43                    .or_insert_with(|| storage.resolve_api_key(&canonical, None).is_some())
44            });
45            if !requires_configured_credential || has_inline_key || has_auth_key {
46                filtered.push(entry.clone());
47            }
48        }
49
50        filtered.sort_by_key(Self::normalize_model_key);
51        filtered.dedup_by(|left, right| model_entry_matches(left, right));
52        filtered
53    }
54
55    /// Open the model selector overlay.
56    pub fn open_model_selector(&mut self) {
57        if self.agent_state != AgentState::Idle {
58            self.status_message = Some("Cannot switch models while processing".to_string());
59            return;
60        }
61
62        if self.available_models.is_empty() {
63            self.status_message = Some("No models available".to_string());
64            return;
65        }
66
67        self.model_selector = Some(crate::model_selector::ModelSelectorOverlay::new(
68            &self.available_models,
69        ));
70    }
71
72    pub(super) fn open_model_selector_configured_only(&mut self) {
73        if self.agent_state != AgentState::Idle {
74            self.status_message = Some("Cannot switch models while processing".to_string());
75            return;
76        }
77
78        if self.available_models.is_empty() {
79            self.status_message = Some("No models available".to_string());
80            return;
81        }
82
83        let filtered = self.available_models_with_credentials();
84        if filtered.is_empty() {
85            self.status_message = Some(
86                "No models are ready to use. Configure credentials with /login <provider>."
87                    .to_string(),
88            );
89            return;
90        }
91
92        let mut overlay = crate::model_selector::ModelSelectorOverlay::new(&filtered);
93        overlay.set_configured_only_scope(Self::unique_model_count(&self.available_models));
94        self.model_selector = Some(overlay);
95    }
96
97    /// Handle keyboard input while the model selector is open.
98    pub fn handle_model_selector_key(&mut self, key: &KeyMsg) -> Option<Cmd> {
99        let selector = self.model_selector.as_mut()?;
100
101        match key.key_type {
102            KeyType::Up => selector.select_prev(),
103            KeyType::Down => selector.select_next(),
104            KeyType::Runes if key.runes == ['k'] => selector.select_prev(),
105            KeyType::Runes if key.runes == ['j'] => selector.select_next(),
106            KeyType::PgDown => selector.select_page_down(),
107            KeyType::PgUp => selector.select_page_up(),
108            KeyType::Backspace => selector.pop_char(),
109            KeyType::Runes => selector.push_chars(key.runes.iter().copied()),
110            KeyType::Enter => {
111                let selected = selector.selected_item().cloned();
112                self.model_selector = None;
113                if let Some(selected) = selected {
114                    self.apply_model_selection(&selected);
115                } else {
116                    self.status_message = Some("No model selected".to_string());
117                }
118                return None;
119            }
120            KeyType::Esc | KeyType::CtrlC => {
121                self.model_selector = None;
122                self.status_message = Some("Model selector cancelled".to_string());
123            }
124            _ => {} // consume all other input while selector is open
125        }
126
127        None
128    }
129
130    /// Apply a model selection from the model selector overlay.
131    fn apply_model_selection(&mut self, selected: &crate::model_selector::ModelKey) {
132        // Find the matching ModelEntry from available_models
133        let entry = self
134            .available_models
135            .iter()
136            .find(|e| {
137                e.model.provider.eq_ignore_ascii_case(&selected.provider)
138                    && e.model.id.eq_ignore_ascii_case(&selected.id)
139            })
140            .cloned();
141
142        let Some(next) = entry else {
143            self.status_message = Some(format!("Model {} not found", selected.full_id()));
144            return;
145        };
146
147        if model_entry_matches(&next, &self.model_entry) {
148            self.status_message = Some(format!("Already using {}", selected.full_id()));
149            return;
150        }
151
152        let resolved_key_opt = resolve_model_key_from_default_auth(&next);
153        if model_requires_configured_credential(&next) && resolved_key_opt.is_none() {
154            self.status_message = Some(format!(
155                "Missing credentials for provider {}. Run /login {}.",
156                next.model.provider, next.model.provider
157            ));
158            return;
159        }
160
161        let provider_impl = match providers::create_provider(&next, self.extensions.as_ref()) {
162            Ok(p) => p,
163            Err(err) => {
164                self.status_message = Some(err.to_string());
165                return;
166            }
167        };
168
169        let Ok(mut agent_guard) = self.agent.try_lock() else {
170            self.status_message = Some("Agent busy; try again".to_string());
171            return;
172        };
173        agent_guard.set_provider(provider_impl);
174        agent_guard
175            .stream_options_mut()
176            .api_key
177            .clone_from(&resolved_key_opt);
178        agent_guard
179            .stream_options_mut()
180            .headers
181            .clone_from(&next.headers);
182        drop(agent_guard);
183
184        let Ok(mut session_guard) = self.session.try_lock() else {
185            self.status_message = Some("Session busy; try again".to_string());
186            return;
187        };
188        session_guard.header.provider = Some(next.model.provider.clone());
189        session_guard.header.model_id = Some(next.model.id.clone());
190        session_guard.append_model_change(next.model.provider.clone(), next.model.id.clone());
191        drop(session_guard);
192        self.spawn_save_session();
193
194        self.model_entry = next.clone();
195        if let Ok(mut guard) = self.model_entry_shared.lock() {
196            *guard = next;
197        }
198        self.model = format!(
199            "{}/{}",
200            self.model_entry.model.provider, self.model_entry.model.id
201        );
202        self.status_message = Some(format!("Switched model: {}", self.model));
203    }
204
205    /// Render the model selector overlay.
206    #[allow(clippy::too_many_lines)]
207    pub(super) fn render_model_selector(
208        &self,
209        selector: &crate::model_selector::ModelSelectorOverlay,
210    ) -> String {
211        use std::fmt::Write;
212        let mut output = String::new();
213
214        let _ = writeln!(output, "\n  {}", self.styles.title.render("Select a model"));
215        if selector.configured_only() {
216            let _ = writeln!(
217                output,
218                "  {}",
219                self.styles
220                    .muted
221                    .render("Only showing models that are ready to use (see README for details)")
222            );
223        }
224
225        // Search field
226        let query = selector.query();
227        let search_line = if query.is_empty() {
228            if selector.configured_only() {
229                "  >".to_string()
230            } else {
231                "  > (type to filter)".to_string()
232            }
233        } else {
234            format!("  > {query}")
235        };
236        let _ = writeln!(output, "{}", self.styles.muted.render(&search_line));
237
238        let _ = writeln!(
239            output,
240            "  {}",
241            self.styles.muted.render("─".repeat(50).as_str())
242        );
243
244        if selector.filtered_len() == 0 {
245            let _ = writeln!(
246                output,
247                "  {}",
248                self.styles.muted_italic.render("No matching models.")
249            );
250        } else {
251            let offset = selector.scroll_offset();
252            let visible_count = selector.max_visible().min(selector.filtered_len());
253            let end = (offset + visible_count).min(selector.filtered_len());
254
255            let current_full = format!(
256                "{}/{}",
257                self.model_entry.model.provider, self.model_entry.model.id
258            );
259
260            for idx in offset..end {
261                let is_selected = idx == selector.selected_index();
262                let prefix = if is_selected { ">" } else { " " };
263
264                if let Some(key) = selector.item_at(idx) {
265                    let full = key.full_id();
266                    let is_current = full.eq_ignore_ascii_case(&current_full);
267                    let marker = if is_current { " *" } else { "" };
268                    let row = format!("{prefix} {full}{marker}");
269                    let rendered = if is_selected {
270                        self.styles.accent_bold.render(&row)
271                    } else if is_current {
272                        self.styles.accent.render(&row)
273                    } else {
274                        self.styles.muted.render(&row)
275                    };
276                    let _ = writeln!(output, "  {rendered}");
277                }
278            }
279
280            if selector.filtered_len() > visible_count {
281                let _ = writeln!(
282                    output,
283                    "  {}",
284                    self.styles.muted.render(&format!(
285                        "({}-{} of {})",
286                        offset + 1,
287                        end,
288                        selector.filtered_len()
289                    ))
290                );
291            }
292
293            if selector.configured_only() {
294                let _ = writeln!(
295                    output,
296                    "  {}",
297                    self.styles.muted.render(&format!(
298                        "({}/{})",
299                        selector.filtered_len(),
300                        selector.source_total()
301                    ))
302                );
303            }
304
305            if let Some(selected) = selector.selected_item()
306                && let Some(entry) = self.available_models.iter().find(|entry| {
307                    entry
308                        .model
309                        .provider
310                        .eq_ignore_ascii_case(&selected.provider)
311                        && entry.model.id.eq_ignore_ascii_case(&selected.id)
312                })
313            {
314                let _ = writeln!(
315                    output,
316                    "\n  {}",
317                    self.styles
318                        .muted
319                        .render(&format!("Model Name: {}", entry.model.name))
320                );
321            }
322        }
323
324        let _ = writeln!(
325            output,
326            "\n  {}",
327            self.styles
328                .muted_italic
329                .render("↑/↓/j/k: navigate  Enter: select  Esc: cancel  * = current")
330        );
331        output
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::agent::{Agent, AgentConfig};
339    use crate::model::{StreamEvent, Usage};
340    use crate::provider::{Context, InputType, Model, ModelCost, Provider, StreamOptions};
341    use crate::resources::{ResourceCliOptions, ResourceLoader};
342    use crate::session::Session;
343    use crate::tools::ToolRegistry;
344    use asupersync::channel::mpsc;
345    use asupersync::runtime::RuntimeBuilder;
346    use futures::stream;
347    use std::collections::HashMap;
348    use std::path::Path;
349    use std::pin::Pin;
350    use std::sync::Arc;
351    use std::sync::OnceLock;
352
353    struct DummyProvider;
354
355    #[async_trait::async_trait]
356    impl Provider for DummyProvider {
357        fn name(&self) -> &'static str {
358            "dummy"
359        }
360
361        fn api(&self) -> &'static str {
362            "dummy"
363        }
364
365        fn model_id(&self) -> &'static str {
366            "dummy-model"
367        }
368
369        async fn stream(
370            &self,
371            _context: &Context<'_>,
372            _options: &StreamOptions,
373        ) -> crate::error::Result<
374            Pin<Box<dyn futures::Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
375        > {
376            Ok(Box::pin(stream::empty()))
377        }
378    }
379
380    fn runtime_handle() -> asupersync::runtime::RuntimeHandle {
381        static RT: OnceLock<asupersync::runtime::Runtime> = OnceLock::new();
382        RT.get_or_init(|| {
383            RuntimeBuilder::multi_thread()
384                .blocking_threads(1, 8)
385                .build()
386                .expect("build runtime")
387        })
388        .handle()
389    }
390
391    fn model_entry(provider: &str, id: &str, api_key: Option<&str>) -> ModelEntry {
392        ModelEntry {
393            model: Model {
394                id: id.to_string(),
395                name: id.to_string(),
396                api: "openai-completions".to_string(),
397                provider: provider.to_string(),
398                base_url: "https://example.invalid".to_string(),
399                reasoning: true,
400                input: vec![InputType::Text],
401                cost: ModelCost {
402                    input: 0.0,
403                    output: 0.0,
404                    cache_read: 0.0,
405                    cache_write: 0.0,
406                },
407                context_window: 128_000,
408                max_tokens: 8_192,
409                headers: HashMap::new(),
410            },
411            api_key: api_key.map(str::to_string),
412            headers: HashMap::new(),
413            auth_header: true,
414            compat: None,
415            oauth_config: None,
416        }
417    }
418
419    fn build_test_app(current: ModelEntry, available: Vec<ModelEntry>) -> PiApp {
420        let provider: Arc<dyn Provider> = Arc::new(DummyProvider);
421        let agent = Agent::new(
422            provider,
423            ToolRegistry::new(&[], Path::new("."), None),
424            AgentConfig::default(),
425        );
426        let session = Arc::new(asupersync::sync::Mutex::new(Session::in_memory()));
427        let resources = ResourceLoader::empty(false);
428        let resource_cli = ResourceCliOptions {
429            no_skills: false,
430            no_prompt_templates: false,
431            no_extensions: false,
432            no_themes: false,
433            skill_paths: Vec::new(),
434            prompt_paths: Vec::new(),
435            extension_paths: Vec::new(),
436            theme_paths: Vec::new(),
437        };
438        let (event_tx, _event_rx) = mpsc::channel(64);
439        PiApp::new(
440            agent,
441            session,
442            Config::default(),
443            resources,
444            resource_cli,
445            Path::new(".").to_path_buf(),
446            current,
447            Vec::new(),
448            available,
449            Vec::new(),
450            event_tx,
451            runtime_handle(),
452            true,
453            None,
454            Some(KeyBindings::new()),
455            Vec::new(),
456            Usage::default(),
457        )
458    }
459
460    #[test]
461    fn apply_model_selection_replaces_stream_options_api_key_and_headers() {
462        let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
463        let mut next = model_entry("openrouter", "openai/gpt-4o-mini", Some("next-key"));
464        next.headers
465            .insert("x-provider-header".to_string(), "next".to_string());
466
467        let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
468
469        {
470            let mut guard = app.agent.try_lock().expect("agent lock");
471            guard.stream_options_mut().api_key = Some("stale-key".to_string());
472            guard
473                .stream_options_mut()
474                .headers
475                .insert("x-stale".to_string(), "stale".to_string());
476        }
477
478        app.apply_model_selection(&crate::model_selector::ModelKey {
479            provider: next.model.provider.clone(),
480            id: next.model.id,
481        });
482
483        let mut guard = app.agent.try_lock().expect("agent lock");
484        assert_eq!(
485            guard.stream_options_mut().api_key.as_deref(),
486            Some("next-key")
487        );
488        assert_eq!(
489            guard
490                .stream_options_mut()
491                .headers
492                .get("x-provider-header")
493                .map(String::as_str),
494            Some("next")
495        );
496        assert!(
497            !guard.stream_options_mut().headers.contains_key("x-stale"),
498            "switching models must replace stale provider headers"
499        );
500    }
501
502    #[test]
503    fn apply_model_selection_clears_stale_api_key_when_next_model_has_no_key() {
504        let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
505        let mut next = model_entry("ollama", "llama3.2", None);
506        next.auth_header = false;
507        let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
508
509        {
510            let mut guard = app.agent.try_lock().expect("agent lock");
511            guard.stream_options_mut().api_key = Some("stale-key".to_string());
512        }
513
514        app.apply_model_selection(&crate::model_selector::ModelKey {
515            provider: next.model.provider.clone(),
516            id: next.model.id,
517        });
518
519        let mut guard = app.agent.try_lock().expect("agent lock");
520        assert!(
521            guard.stream_options_mut().api_key.is_none(),
522            "switching to a keyless model must clear stale API key"
523        );
524    }
525
526    #[test]
527    fn configured_only_selector_includes_keyless_ready_models() {
528        let mut keyless = model_entry("ollama", "llama3.2", None);
529        keyless.auth_header = false;
530
531        let mut requires_creds = model_entry("acme-remote", "cloud-model", None);
532        requires_creds.auth_header = true;
533
534        let mut app = build_test_app(keyless.clone(), vec![keyless, requires_creds]);
535        app.open_model_selector_configured_only();
536
537        let selector = app
538            .model_selector
539            .as_ref()
540            .expect("configured-only selector should open when keyless models are ready");
541        let mut ids = Vec::new();
542        for idx in 0..selector.filtered_len() {
543            if let Some(item) = selector.item_at(idx) {
544                ids.push(item.full_id());
545            }
546        }
547
548        assert!(
549            ids.iter().any(|id| id == "ollama/llama3.2"),
550            "keyless local model must be considered ready"
551        );
552        assert!(
553            !ids.iter().any(|id| id == "acme-remote/cloud-model"),
554            "credentialed providers without configured auth should not appear"
555        );
556    }
557
558    #[test]
559    fn configured_only_selector_keeps_unknown_keyless_provider_models() {
560        let mut unknown_keyless = model_entry("acme-local", "dev-model", None);
561        unknown_keyless.auth_header = false;
562        let mut unknown_requires = model_entry("acme-remote", "cloud-model", None);
563        unknown_requires.auth_header = true;
564
565        let mut app = build_test_app(
566            unknown_keyless.clone(),
567            vec![unknown_keyless, unknown_requires],
568        );
569        app.open_model_selector_configured_only();
570
571        let selector = app
572            .model_selector
573            .as_ref()
574            .expect("unknown keyless model should keep selector available");
575        let mut ids = Vec::new();
576        for idx in 0..selector.filtered_len() {
577            if let Some(item) = selector.item_at(idx) {
578                ids.push(item.full_id());
579            }
580        }
581
582        assert!(ids.iter().any(|id| id == "acme-local/dev-model"));
583        assert!(!ids.iter().any(|id| id == "acme-remote/cloud-model"));
584    }
585}