Skip to main content

pi/interactive/
model_selector_ui.rs

1use super::commands::{model_entry_matches, resolve_model_key_from_default_auth};
2use super::*;
3use crate::models::model_requires_configured_credential;
4
5impl PiApp {
6    fn normalize_model_key(entry: &ModelEntry) -> (String, String) {
7        let canonical_provider =
8            crate::provider_metadata::canonical_provider_id(entry.model.provider.as_str())
9                .unwrap_or(entry.model.provider.as_str());
10        (
11            canonical_provider.to_ascii_lowercase(),
12            entry.model.id.to_ascii_lowercase(),
13        )
14    }
15
16    fn unique_model_count(models: &[ModelEntry]) -> usize {
17        models
18            .iter()
19            .map(Self::normalize_model_key)
20            .collect::<std::collections::HashSet<_>>()
21            .len()
22    }
23
24    fn available_models_with_credentials(&self) -> Vec<ModelEntry> {
25        let auth = crate::auth::AuthStorage::load(crate::config::Config::auth_path()).ok();
26        let mut provider_has_credential: std::collections::HashMap<String, bool> =
27            std::collections::HashMap::new();
28        let mut filtered = Vec::new();
29        for entry in &self.available_models {
30            let provider = entry.model.provider.as_str();
31            let canonical = crate::provider_metadata::canonical_provider_id(provider)
32                .unwrap_or(provider)
33                .to_ascii_lowercase();
34            let requires_configured_credential = model_requires_configured_credential(entry);
35            let has_inline_key = entry
36                .api_key
37                .as_ref()
38                .is_some_and(|key| !key.trim().is_empty());
39            let has_auth_key = auth.as_ref().is_some_and(|storage| {
40                *provider_has_credential
41                    .entry(canonical.clone())
42                    .or_insert_with(|| storage.resolve_api_key(&canonical, None).is_some())
43            });
44            if !requires_configured_credential || has_inline_key || has_auth_key {
45                filtered.push(entry.clone());
46            }
47        }
48
49        filtered.sort_by_key(Self::normalize_model_key);
50        filtered.dedup_by(|left, right| model_entry_matches(left, right));
51        filtered
52    }
53
54    /// Open the model selector overlay.
55    pub fn open_model_selector(&mut self) {
56        if self.agent_state != AgentState::Idle {
57            self.status_message = Some("Cannot switch models while processing".to_string());
58            return;
59        }
60
61        if self.available_models.is_empty() {
62            self.status_message = Some("No models available".to_string());
63            return;
64        }
65
66        let mut overlay = crate::model_selector::ModelSelectorOverlay::new(&self.available_models);
67        overlay.set_max_visible(super::overlay_max_visible(self.term_height));
68        self.model_selector = Some(overlay);
69    }
70
71    pub(super) fn open_model_selector_configured_only(&mut self) {
72        if self.agent_state != AgentState::Idle {
73            self.status_message = Some("Cannot switch models while processing".to_string());
74            return;
75        }
76
77        if self.available_models.is_empty() {
78            self.status_message = Some("No models available".to_string());
79            return;
80        }
81
82        let filtered = self.available_models_with_credentials();
83        if filtered.is_empty() {
84            self.status_message = Some(
85                "No models are ready to use. Configure credentials with /login <provider>."
86                    .to_string(),
87            );
88            return;
89        }
90
91        let mut overlay = crate::model_selector::ModelSelectorOverlay::new(&filtered);
92        overlay.set_configured_only_scope(Self::unique_model_count(&self.available_models));
93        overlay.set_max_visible(super::overlay_max_visible(self.term_height));
94        self.model_selector = Some(overlay);
95    }
96
97    /// Handle keyboard input while the model selector is open.
98    pub fn handle_model_selector_key(&mut self, key: &KeyMsg) -> Option<Cmd> {
99        let selector = self.model_selector.as_mut()?;
100
101        match key.key_type {
102            KeyType::Up => selector.select_prev(),
103            KeyType::Down => selector.select_next(),
104            KeyType::Runes if key.runes == ['k'] => selector.select_prev(),
105            KeyType::Runes if key.runes == ['j'] => selector.select_next(),
106            KeyType::PgDown => selector.select_page_down(),
107            KeyType::PgUp => selector.select_page_up(),
108            KeyType::Backspace => selector.pop_char(),
109            KeyType::Runes => selector.push_chars(key.runes.iter().copied()),
110            KeyType::Enter => {
111                let selected = selector.selected_item().cloned();
112                self.model_selector = None;
113                if let Some(selected) = selected {
114                    self.apply_model_selection(&selected);
115                } else {
116                    self.status_message = Some("No model selected".to_string());
117                }
118                return None;
119            }
120            KeyType::Esc | KeyType::CtrlC => {
121                self.model_selector = None;
122                self.status_message = Some("Model selector cancelled".to_string());
123            }
124            _ => {} // consume all other input while selector is open
125        }
126
127        None
128    }
129
130    /// Apply a model selection from the model selector overlay.
131    fn apply_model_selection(&mut self, selected: &crate::model_selector::ModelKey) {
132        // Find the matching ModelEntry from available_models
133        let entry = self
134            .available_models
135            .iter()
136            .find(|e| {
137                e.model.provider.eq_ignore_ascii_case(&selected.provider)
138                    && e.model.id.eq_ignore_ascii_case(&selected.id)
139            })
140            .cloned();
141
142        let Some(next) = entry else {
143            self.status_message = Some(format!("Model {} not found", selected.full_id()));
144            return;
145        };
146
147        if model_entry_matches(&next, &self.model_entry) {
148            self.status_message = Some(format!("Already using {}", selected.full_id()));
149            return;
150        }
151
152        let resolved_key_opt = resolve_model_key_from_default_auth(&next);
153        if model_requires_configured_credential(&next) && resolved_key_opt.is_none() {
154            self.status_message = Some(format!(
155                "Missing credentials for provider {}. Run /login {}.",
156                next.model.provider, next.model.provider
157            ));
158            return;
159        }
160
161        let provider_impl = match providers::create_provider(&next, self.extensions.as_ref()) {
162            Ok(p) => p,
163            Err(err) => {
164                self.status_message = Some(err.to_string());
165                return;
166            }
167        };
168
169        if let Err(message) = self.switch_active_model(
170            &next,
171            provider_impl,
172            resolved_key_opt.as_deref(),
173            "selector",
174        ) {
175            self.status_message = Some(message);
176            return;
177        }
178        self.status_message = Some(format!("Switched model: {}", self.model));
179    }
180
181    /// Render the model selector overlay.
182    #[allow(clippy::too_many_lines)]
183    pub(super) fn render_model_selector(
184        &self,
185        selector: &crate::model_selector::ModelSelectorOverlay,
186    ) -> String {
187        use std::fmt::Write;
188        let mut output = String::new();
189
190        let _ = writeln!(output, "\n  {}", self.styles.title.render("Select a model"));
191        if selector.configured_only() {
192            let _ = writeln!(
193                output,
194                "  {}",
195                self.styles
196                    .muted
197                    .render("Only showing models that are ready to use (see README for details)")
198            );
199        }
200
201        // Search field
202        let query = selector.query();
203        let search_line = if query.is_empty() {
204            if selector.configured_only() {
205                "  >".to_string()
206            } else {
207                "  > (type to filter)".to_string()
208            }
209        } else {
210            format!("  > {query}")
211        };
212        let _ = writeln!(output, "{}", self.styles.muted.render(&search_line));
213
214        let _ = writeln!(
215            output,
216            "  {}",
217            self.styles.muted.render("─".repeat(50).as_str())
218        );
219
220        if selector.filtered_len() == 0 {
221            let _ = writeln!(
222                output,
223                "  {}",
224                self.styles.muted_italic.render("No matching models.")
225            );
226        } else {
227            let offset = selector.scroll_offset();
228            let visible_count = selector.max_visible().min(selector.filtered_len());
229            let end = (offset + visible_count).min(selector.filtered_len());
230
231            let current_full = format!(
232                "{}/{}",
233                self.model_entry.model.provider, self.model_entry.model.id
234            );
235
236            for idx in offset..end {
237                let is_selected = idx == selector.selected_index();
238                let prefix = if is_selected { ">" } else { " " };
239
240                if let Some(key) = selector.item_at(idx) {
241                    let full = key.full_id();
242                    let is_current = full.eq_ignore_ascii_case(&current_full);
243                    let marker = if is_current { " *" } else { "" };
244                    let row = format!("{prefix} {full}{marker}");
245                    let rendered = if is_selected {
246                        self.styles.accent_bold.render(&row)
247                    } else if is_current {
248                        self.styles.accent.render(&row)
249                    } else {
250                        self.styles.muted.render(&row)
251                    };
252                    let _ = writeln!(output, "  {rendered}");
253                }
254            }
255
256            if selector.filtered_len() > visible_count {
257                let _ = writeln!(
258                    output,
259                    "  {}",
260                    self.styles.muted.render(&format!(
261                        "({}-{} of {})",
262                        offset + 1,
263                        end,
264                        selector.filtered_len()
265                    ))
266                );
267            }
268
269            if selector.configured_only() {
270                let _ = writeln!(
271                    output,
272                    "  {}",
273                    self.styles.muted.render(&format!(
274                        "({}/{})",
275                        selector.filtered_len(),
276                        selector.source_total()
277                    ))
278                );
279            }
280
281            if let Some(selected) = selector.selected_item()
282                && let Some(entry) = self.available_models.iter().find(|entry| {
283                    entry
284                        .model
285                        .provider
286                        .eq_ignore_ascii_case(&selected.provider)
287                        && entry.model.id.eq_ignore_ascii_case(&selected.id)
288                })
289            {
290                let _ = writeln!(
291                    output,
292                    "\n  {}",
293                    self.styles
294                        .muted
295                        .render(&format!("Model Name: {}", entry.model.name))
296                );
297            }
298        }
299
300        let _ = writeln!(
301            output,
302            "\n  {}",
303            self.styles
304                .muted_italic
305                .render("↑/↓/j/k/PgUp/PgDn: navigate  Enter: select  Esc: cancel  * = current")
306        );
307        output
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::agent::{Agent, AgentConfig};
315    use crate::model::{StreamEvent, Usage};
316    use crate::provider::{Context, InputType, Model, ModelCost, Provider, StreamOptions};
317    use crate::resources::{ResourceCliOptions, ResourceLoader};
318    use crate::session::Session;
319    use crate::tools::ToolRegistry;
320    use asupersync::channel::mpsc;
321    use asupersync::runtime::RuntimeBuilder;
322    use futures::stream;
323    use std::collections::HashMap;
324    use std::path::Path;
325    use std::pin::Pin;
326    use std::sync::Arc;
327    use std::sync::OnceLock;
328
329    struct DummyProvider;
330
331    #[async_trait::async_trait]
332    impl Provider for DummyProvider {
333        fn name(&self) -> &'static str {
334            "dummy"
335        }
336
337        fn api(&self) -> &'static str {
338            "dummy"
339        }
340
341        fn model_id(&self) -> &'static str {
342            "dummy-model"
343        }
344
345        async fn stream(
346            &self,
347            _context: &Context<'_>,
348            _options: &StreamOptions,
349        ) -> crate::error::Result<
350            Pin<Box<dyn futures::Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
351        > {
352            Ok(Box::pin(stream::empty()))
353        }
354    }
355
356    fn runtime_handle() -> asupersync::runtime::RuntimeHandle {
357        static RT: OnceLock<asupersync::runtime::Runtime> = OnceLock::new();
358        RT.get_or_init(|| {
359            RuntimeBuilder::multi_thread()
360                .blocking_threads(1, 8)
361                .build()
362                .expect("build runtime")
363        })
364        .handle()
365    }
366
367    fn model_entry(provider: &str, id: &str, api_key: Option<&str>) -> ModelEntry {
368        ModelEntry {
369            model: Model {
370                id: id.to_string(),
371                name: id.to_string(),
372                api: "openai-completions".to_string(),
373                provider: provider.to_string(),
374                base_url: "https://example.invalid".to_string(),
375                reasoning: true,
376                input: vec![InputType::Text],
377                cost: ModelCost {
378                    input: 0.0,
379                    output: 0.0,
380                    cache_read: 0.0,
381                    cache_write: 0.0,
382                },
383                context_window: 128_000,
384                max_tokens: 8_192,
385                headers: HashMap::new(),
386            },
387            api_key: api_key.map(str::to_string),
388            headers: HashMap::new(),
389            auth_header: true,
390            compat: None,
391            oauth_config: None,
392        }
393    }
394
395    fn build_test_app(current: ModelEntry, available: Vec<ModelEntry>) -> PiApp {
396        let provider: Arc<dyn Provider> = Arc::new(DummyProvider);
397        let agent = Agent::new(
398            provider,
399            ToolRegistry::new(&[], Path::new("."), None),
400            AgentConfig::default(),
401        );
402        let session = Arc::new(asupersync::sync::Mutex::new(Session::in_memory()));
403        let resources = ResourceLoader::empty(false);
404        let resource_cli = ResourceCliOptions {
405            no_skills: false,
406            no_prompt_templates: false,
407            no_extensions: false,
408            no_themes: false,
409            skill_paths: Vec::new(),
410            prompt_paths: Vec::new(),
411            extension_paths: Vec::new(),
412            theme_paths: Vec::new(),
413        };
414        let (event_tx, _event_rx) = mpsc::channel(64);
415        let config = Config {
416            last_changelog_version: Some(crate::platform::VERSION.to_string()),
417            ..Config::default()
418        };
419        PiApp::new(
420            agent,
421            session,
422            config,
423            resources,
424            resource_cli,
425            Path::new(".").to_path_buf(),
426            current,
427            Vec::new(),
428            available,
429            Vec::new(),
430            event_tx,
431            runtime_handle(),
432            true,
433            false,
434            None,
435            Some(KeyBindings::new()),
436            Vec::new(),
437            Usage::default(),
438        )
439    }
440
441    #[test]
442    fn apply_model_selection_replaces_stream_options_api_key_and_headers() {
443        let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
444        let mut next = model_entry("openrouter", "openai/gpt-4o-mini", Some("next-key"));
445        next.headers
446            .insert("x-provider-header".to_string(), "next".to_string());
447
448        let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
449
450        {
451            let mut guard = app.agent.try_lock().expect("agent lock");
452            guard.stream_options_mut().api_key = Some("stale-token".to_string());
453            guard
454                .stream_options_mut()
455                .headers
456                .insert("x-stale".to_string(), "stale".to_string());
457        }
458
459        app.apply_model_selection(&crate::model_selector::ModelKey {
460            provider: next.model.provider.clone(),
461            id: next.model.id,
462        });
463
464        let mut guard = app.agent.try_lock().expect("agent lock");
465        assert_eq!(
466            guard.stream_options_mut().api_key.as_deref(),
467            Some("next-key")
468        );
469        assert_eq!(
470            guard
471                .stream_options_mut()
472                .headers
473                .get("x-provider-header")
474                .map(String::as_str),
475            Some("next")
476        );
477        assert!(
478            !guard.stream_options_mut().headers.contains_key("x-stale"),
479            "switching models must replace stale provider headers"
480        );
481    }
482
483    #[test]
484    fn apply_model_selection_clears_stale_api_key_when_next_model_has_no_key() {
485        let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
486        let mut next = model_entry("ollama", "llama3.2", None);
487        next.auth_header = false;
488        let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
489
490        {
491            let mut guard = app.agent.try_lock().expect("agent lock");
492            guard.stream_options_mut().api_key = Some("stale-token".to_string());
493        }
494
495        app.apply_model_selection(&crate::model_selector::ModelKey {
496            provider: next.model.provider.clone(),
497            id: next.model.id,
498        });
499
500        let mut guard = app.agent.try_lock().expect("agent lock");
501        assert!(
502            guard.stream_options_mut().api_key.is_none(),
503            "switching to a keyless model must clear stale API key"
504        );
505    }
506
507    #[test]
508    fn apply_model_selection_clamps_thinking_level_for_non_reasoning_targets() {
509        let current = model_entry("openai", "gpt-5.2", Some("old-key"));
510        let mut next = model_entry("ollama", "llama3.2", None);
511        next.auth_header = false;
512        next.model.reasoning = false;
513        let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
514
515        {
516            let mut guard = app.agent.try_lock().expect("agent lock");
517            guard.stream_options_mut().thinking_level = Some(crate::model::ThinkingLevel::High);
518        }
519        {
520            let mut guard = app.session.try_lock().expect("session lock");
521            guard.header.thinking_level = Some(crate::model::ThinkingLevel::High.to_string());
522        }
523
524        app.apply_model_selection(&crate::model_selector::ModelKey {
525            provider: next.model.provider.clone(),
526            id: next.model.id,
527        });
528
529        let mut agent_guard = app.agent.try_lock().expect("agent lock");
530        assert_eq!(
531            agent_guard.stream_options_mut().thinking_level,
532            Some(crate::model::ThinkingLevel::Off)
533        );
534        drop(agent_guard);
535
536        let session_guard = app.session.try_lock().expect("session lock");
537        assert_eq!(session_guard.header.thinking_level.as_deref(), Some("off"));
538    }
539
540    #[test]
541    fn configured_only_selector_includes_keyless_ready_models() {
542        let mut keyless = model_entry("ollama", "llama3.2", None);
543        keyless.auth_header = false;
544
545        let mut requires_creds = model_entry("acme-remote", "cloud-model", None);
546        requires_creds.auth_header = true;
547
548        let mut app = build_test_app(keyless.clone(), vec![keyless, requires_creds]);
549        app.open_model_selector_configured_only();
550
551        let selector = app
552            .model_selector
553            .as_ref()
554            .expect("configured-only selector should open when keyless models are ready");
555        let mut ids = Vec::new();
556        for idx in 0..selector.filtered_len() {
557            if let Some(item) = selector.item_at(idx) {
558                ids.push(item.full_id());
559            }
560        }
561
562        assert!(
563            ids.iter().any(|id| id == "ollama/llama3.2"),
564            "keyless local model must be considered ready"
565        );
566        assert!(
567            !ids.iter().any(|id| id == "acme-remote/cloud-model"),
568            "credentialed providers without configured auth should not appear"
569        );
570    }
571
572    #[test]
573    fn configured_only_selector_keeps_unknown_keyless_provider_models() {
574        let mut unknown_keyless = model_entry("acme-local", "dev-model", None);
575        unknown_keyless.auth_header = false;
576        let mut unknown_requires = model_entry("acme-remote", "cloud-model", None);
577        unknown_requires.auth_header = true;
578
579        let mut app = build_test_app(
580            unknown_keyless.clone(),
581            vec![unknown_keyless, unknown_requires],
582        );
583        app.open_model_selector_configured_only();
584
585        let selector = app
586            .model_selector
587            .as_ref()
588            .expect("unknown keyless model should keep selector available");
589        let mut ids = Vec::new();
590        for idx in 0..selector.filtered_len() {
591            if let Some(item) = selector.item_at(idx) {
592                ids.push(item.full_id());
593            }
594        }
595
596        assert!(ids.iter().any(|id| id == "acme-local/dev-model"));
597        assert!(!ids.iter().any(|id| id == "acme-remote/cloud-model"));
598    }
599}