1use super::commands::{model_entry_matches, resolve_model_key_from_default_auth};
2use super::*;
3use crate::models::model_requires_configured_credential;
4
5impl PiApp {
6 fn normalize_model_key(entry: &ModelEntry) -> (String, String) {
7 let canonical_provider =
8 crate::provider_metadata::canonical_provider_id(entry.model.provider.as_str())
9 .unwrap_or(entry.model.provider.as_str());
10 (
11 canonical_provider.to_ascii_lowercase(),
12 entry.model.id.to_ascii_lowercase(),
13 )
14 }
15
16 fn unique_model_count(models: &[ModelEntry]) -> usize {
17 models
18 .iter()
19 .map(Self::normalize_model_key)
20 .collect::<std::collections::HashSet<_>>()
21 .len()
22 }
23
24 fn available_models_with_credentials(&self) -> Vec<ModelEntry> {
25 let auth = crate::auth::AuthStorage::load(crate::config::Config::auth_path()).ok();
26 let mut provider_has_credential: std::collections::HashMap<String, bool> =
27 std::collections::HashMap::new();
28 let mut filtered = Vec::new();
29 for entry in &self.available_models {
30 let provider = entry.model.provider.as_str();
31 let canonical = crate::provider_metadata::canonical_provider_id(provider)
32 .unwrap_or(provider)
33 .to_ascii_lowercase();
34 let requires_configured_credential = model_requires_configured_credential(entry);
35 let has_inline_key = entry
36 .api_key
37 .as_ref()
38 .is_some_and(|key| !key.trim().is_empty());
39 let has_auth_key = auth.as_ref().is_some_and(|storage| {
40 *provider_has_credential
41 .entry(canonical.clone())
42 .or_insert_with(|| storage.resolve_api_key(&canonical, None).is_some())
43 });
44 if !requires_configured_credential || has_inline_key || has_auth_key {
45 filtered.push(entry.clone());
46 }
47 }
48
49 filtered.sort_by_key(Self::normalize_model_key);
50 filtered.dedup_by(|left, right| model_entry_matches(left, right));
51 filtered
52 }
53
54 pub fn open_model_selector(&mut self) {
56 if self.agent_state != AgentState::Idle {
57 self.status_message = Some("Cannot switch models while processing".to_string());
58 return;
59 }
60
61 if self.available_models.is_empty() {
62 self.status_message = Some("No models available".to_string());
63 return;
64 }
65
66 let mut overlay = crate::model_selector::ModelSelectorOverlay::new(&self.available_models);
67 overlay.set_max_visible(super::overlay_max_visible(self.term_height));
68 self.model_selector = Some(overlay);
69 }
70
71 pub(super) fn open_model_selector_configured_only(&mut self) {
72 if self.agent_state != AgentState::Idle {
73 self.status_message = Some("Cannot switch models while processing".to_string());
74 return;
75 }
76
77 if self.available_models.is_empty() {
78 self.status_message = Some("No models available".to_string());
79 return;
80 }
81
82 let filtered = self.available_models_with_credentials();
83 if filtered.is_empty() {
84 self.status_message = Some(
85 "No models are ready to use. Configure credentials with /login <provider>."
86 .to_string(),
87 );
88 return;
89 }
90
91 let mut overlay = crate::model_selector::ModelSelectorOverlay::new(&filtered);
92 overlay.set_configured_only_scope(Self::unique_model_count(&self.available_models));
93 overlay.set_max_visible(super::overlay_max_visible(self.term_height));
94 self.model_selector = Some(overlay);
95 }
96
97 pub fn handle_model_selector_key(&mut self, key: &KeyMsg) -> Option<Cmd> {
99 let selector = self.model_selector.as_mut()?;
100
101 match key.key_type {
102 KeyType::Up => selector.select_prev(),
103 KeyType::Down => selector.select_next(),
104 KeyType::Runes if key.runes == ['k'] => selector.select_prev(),
105 KeyType::Runes if key.runes == ['j'] => selector.select_next(),
106 KeyType::PgDown => selector.select_page_down(),
107 KeyType::PgUp => selector.select_page_up(),
108 KeyType::Backspace => selector.pop_char(),
109 KeyType::Runes => selector.push_chars(key.runes.iter().copied()),
110 KeyType::Enter => {
111 let selected = selector.selected_item().cloned();
112 self.model_selector = None;
113 if let Some(selected) = selected {
114 self.apply_model_selection(&selected);
115 } else {
116 self.status_message = Some("No model selected".to_string());
117 }
118 return None;
119 }
120 KeyType::Esc | KeyType::CtrlC => {
121 self.model_selector = None;
122 self.status_message = Some("Model selector cancelled".to_string());
123 }
124 _ => {} }
126
127 None
128 }
129
130 fn apply_model_selection(&mut self, selected: &crate::model_selector::ModelKey) {
132 let entry = self
134 .available_models
135 .iter()
136 .find(|e| {
137 e.model.provider.eq_ignore_ascii_case(&selected.provider)
138 && e.model.id.eq_ignore_ascii_case(&selected.id)
139 })
140 .cloned();
141
142 let Some(next) = entry else {
143 self.status_message = Some(format!("Model {} not found", selected.full_id()));
144 return;
145 };
146
147 if model_entry_matches(&next, &self.model_entry) {
148 self.status_message = Some(format!("Already using {}", selected.full_id()));
149 return;
150 }
151
152 let resolved_key_opt = resolve_model_key_from_default_auth(&next);
153 if model_requires_configured_credential(&next) && resolved_key_opt.is_none() {
154 self.status_message = Some(format!(
155 "Missing credentials for provider {}. Run /login {}.",
156 next.model.provider, next.model.provider
157 ));
158 return;
159 }
160
161 let provider_impl = match providers::create_provider(&next, self.extensions.as_ref()) {
162 Ok(p) => p,
163 Err(err) => {
164 self.status_message = Some(err.to_string());
165 return;
166 }
167 };
168
169 if let Err(message) = self.switch_active_model(
170 &next,
171 provider_impl,
172 resolved_key_opt.as_deref(),
173 "selector",
174 ) {
175 self.status_message = Some(message);
176 return;
177 }
178 self.status_message = Some(format!("Switched model: {}", self.model));
179 }
180
181 #[allow(clippy::too_many_lines)]
183 pub(super) fn render_model_selector(
184 &self,
185 selector: &crate::model_selector::ModelSelectorOverlay,
186 ) -> String {
187 use std::fmt::Write;
188 let mut output = String::new();
189
190 let _ = writeln!(output, "\n {}", self.styles.title.render("Select a model"));
191 if selector.configured_only() {
192 let _ = writeln!(
193 output,
194 " {}",
195 self.styles
196 .muted
197 .render("Only showing models that are ready to use (see README for details)")
198 );
199 }
200
201 let query = selector.query();
203 let search_line = if query.is_empty() {
204 if selector.configured_only() {
205 " >".to_string()
206 } else {
207 " > (type to filter)".to_string()
208 }
209 } else {
210 format!(" > {query}")
211 };
212 let _ = writeln!(output, "{}", self.styles.muted.render(&search_line));
213
214 let _ = writeln!(
215 output,
216 " {}",
217 self.styles.muted.render("─".repeat(50).as_str())
218 );
219
220 if selector.filtered_len() == 0 {
221 let _ = writeln!(
222 output,
223 " {}",
224 self.styles.muted_italic.render("No matching models.")
225 );
226 } else {
227 let offset = selector.scroll_offset();
228 let visible_count = selector.max_visible().min(selector.filtered_len());
229 let end = (offset + visible_count).min(selector.filtered_len());
230
231 let current_full = format!(
232 "{}/{}",
233 self.model_entry.model.provider, self.model_entry.model.id
234 );
235
236 for idx in offset..end {
237 let is_selected = idx == selector.selected_index();
238 let prefix = if is_selected { ">" } else { " " };
239
240 if let Some(key) = selector.item_at(idx) {
241 let full = key.full_id();
242 let is_current = full.eq_ignore_ascii_case(¤t_full);
243 let marker = if is_current { " *" } else { "" };
244 let row = format!("{prefix} {full}{marker}");
245 let rendered = if is_selected {
246 self.styles.accent_bold.render(&row)
247 } else if is_current {
248 self.styles.accent.render(&row)
249 } else {
250 self.styles.muted.render(&row)
251 };
252 let _ = writeln!(output, " {rendered}");
253 }
254 }
255
256 if selector.filtered_len() > visible_count {
257 let _ = writeln!(
258 output,
259 " {}",
260 self.styles.muted.render(&format!(
261 "({}-{} of {})",
262 offset + 1,
263 end,
264 selector.filtered_len()
265 ))
266 );
267 }
268
269 if selector.configured_only() {
270 let _ = writeln!(
271 output,
272 " {}",
273 self.styles.muted.render(&format!(
274 "({}/{})",
275 selector.filtered_len(),
276 selector.source_total()
277 ))
278 );
279 }
280
281 if let Some(selected) = selector.selected_item()
282 && let Some(entry) = self.available_models.iter().find(|entry| {
283 entry
284 .model
285 .provider
286 .eq_ignore_ascii_case(&selected.provider)
287 && entry.model.id.eq_ignore_ascii_case(&selected.id)
288 })
289 {
290 let _ = writeln!(
291 output,
292 "\n {}",
293 self.styles
294 .muted
295 .render(&format!("Model Name: {}", entry.model.name))
296 );
297 }
298 }
299
300 let _ = writeln!(
301 output,
302 "\n {}",
303 self.styles
304 .muted_italic
305 .render("↑/↓/j/k/PgUp/PgDn: navigate Enter: select Esc: cancel * = current")
306 );
307 output
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::agent::{Agent, AgentConfig};
315 use crate::model::{StreamEvent, Usage};
316 use crate::provider::{Context, InputType, Model, ModelCost, Provider, StreamOptions};
317 use crate::resources::{ResourceCliOptions, ResourceLoader};
318 use crate::session::Session;
319 use crate::tools::ToolRegistry;
320 use asupersync::channel::mpsc;
321 use asupersync::runtime::RuntimeBuilder;
322 use futures::stream;
323 use std::collections::HashMap;
324 use std::path::Path;
325 use std::pin::Pin;
326 use std::sync::Arc;
327 use std::sync::OnceLock;
328
329 struct DummyProvider;
330
331 #[async_trait::async_trait]
332 impl Provider for DummyProvider {
333 fn name(&self) -> &'static str {
334 "dummy"
335 }
336
337 fn api(&self) -> &'static str {
338 "dummy"
339 }
340
341 fn model_id(&self) -> &'static str {
342 "dummy-model"
343 }
344
345 async fn stream(
346 &self,
347 _context: &Context<'_>,
348 _options: &StreamOptions,
349 ) -> crate::error::Result<
350 Pin<Box<dyn futures::Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
351 > {
352 Ok(Box::pin(stream::empty()))
353 }
354 }
355
356 fn runtime_handle() -> asupersync::runtime::RuntimeHandle {
357 static RT: OnceLock<asupersync::runtime::Runtime> = OnceLock::new();
358 RT.get_or_init(|| {
359 RuntimeBuilder::multi_thread()
360 .blocking_threads(1, 8)
361 .build()
362 .expect("build runtime")
363 })
364 .handle()
365 }
366
367 fn model_entry(provider: &str, id: &str, api_key: Option<&str>) -> ModelEntry {
368 ModelEntry {
369 model: Model {
370 id: id.to_string(),
371 name: id.to_string(),
372 api: "openai-completions".to_string(),
373 provider: provider.to_string(),
374 base_url: "https://example.invalid".to_string(),
375 reasoning: true,
376 input: vec![InputType::Text],
377 cost: ModelCost {
378 input: 0.0,
379 output: 0.0,
380 cache_read: 0.0,
381 cache_write: 0.0,
382 },
383 context_window: 128_000,
384 max_tokens: 8_192,
385 headers: HashMap::new(),
386 },
387 api_key: api_key.map(str::to_string),
388 headers: HashMap::new(),
389 auth_header: true,
390 compat: None,
391 oauth_config: None,
392 }
393 }
394
395 fn build_test_app(current: ModelEntry, available: Vec<ModelEntry>) -> PiApp {
396 let provider: Arc<dyn Provider> = Arc::new(DummyProvider);
397 let agent = Agent::new(
398 provider,
399 ToolRegistry::new(&[], Path::new("."), None),
400 AgentConfig::default(),
401 );
402 let session = Arc::new(asupersync::sync::Mutex::new(Session::in_memory()));
403 let resources = ResourceLoader::empty(false);
404 let resource_cli = ResourceCliOptions {
405 no_skills: false,
406 no_prompt_templates: false,
407 no_extensions: false,
408 no_themes: false,
409 skill_paths: Vec::new(),
410 prompt_paths: Vec::new(),
411 extension_paths: Vec::new(),
412 theme_paths: Vec::new(),
413 };
414 let (event_tx, _event_rx) = mpsc::channel(64);
415 let config = Config {
416 last_changelog_version: Some(crate::platform::VERSION.to_string()),
417 ..Config::default()
418 };
419 PiApp::new(
420 agent,
421 session,
422 config,
423 resources,
424 resource_cli,
425 Path::new(".").to_path_buf(),
426 current,
427 Vec::new(),
428 available,
429 Vec::new(),
430 event_tx,
431 runtime_handle(),
432 true,
433 false,
434 None,
435 Some(KeyBindings::new()),
436 Vec::new(),
437 Usage::default(),
438 )
439 }
440
441 #[test]
442 fn apply_model_selection_replaces_stream_options_api_key_and_headers() {
443 let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
444 let mut next = model_entry("openrouter", "openai/gpt-4o-mini", Some("next-key"));
445 next.headers
446 .insert("x-provider-header".to_string(), "next".to_string());
447
448 let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
449
450 {
451 let mut guard = app.agent.try_lock().expect("agent lock");
452 guard.stream_options_mut().api_key = Some("stale-token".to_string());
453 guard
454 .stream_options_mut()
455 .headers
456 .insert("x-stale".to_string(), "stale".to_string());
457 }
458
459 app.apply_model_selection(&crate::model_selector::ModelKey {
460 provider: next.model.provider.clone(),
461 id: next.model.id,
462 });
463
464 let mut guard = app.agent.try_lock().expect("agent lock");
465 assert_eq!(
466 guard.stream_options_mut().api_key.as_deref(),
467 Some("next-key")
468 );
469 assert_eq!(
470 guard
471 .stream_options_mut()
472 .headers
473 .get("x-provider-header")
474 .map(String::as_str),
475 Some("next")
476 );
477 assert!(
478 !guard.stream_options_mut().headers.contains_key("x-stale"),
479 "switching models must replace stale provider headers"
480 );
481 }
482
483 #[test]
484 fn apply_model_selection_clears_stale_api_key_when_next_model_has_no_key() {
485 let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
486 let mut next = model_entry("ollama", "llama3.2", None);
487 next.auth_header = false;
488 let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
489
490 {
491 let mut guard = app.agent.try_lock().expect("agent lock");
492 guard.stream_options_mut().api_key = Some("stale-token".to_string());
493 }
494
495 app.apply_model_selection(&crate::model_selector::ModelKey {
496 provider: next.model.provider.clone(),
497 id: next.model.id,
498 });
499
500 let mut guard = app.agent.try_lock().expect("agent lock");
501 assert!(
502 guard.stream_options_mut().api_key.is_none(),
503 "switching to a keyless model must clear stale API key"
504 );
505 }
506
507 #[test]
508 fn apply_model_selection_clamps_thinking_level_for_non_reasoning_targets() {
509 let current = model_entry("openai", "gpt-5.2", Some("old-key"));
510 let mut next = model_entry("ollama", "llama3.2", None);
511 next.auth_header = false;
512 next.model.reasoning = false;
513 let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
514
515 {
516 let mut guard = app.agent.try_lock().expect("agent lock");
517 guard.stream_options_mut().thinking_level = Some(crate::model::ThinkingLevel::High);
518 }
519 {
520 let mut guard = app.session.try_lock().expect("session lock");
521 guard.header.thinking_level = Some(crate::model::ThinkingLevel::High.to_string());
522 }
523
524 app.apply_model_selection(&crate::model_selector::ModelKey {
525 provider: next.model.provider.clone(),
526 id: next.model.id,
527 });
528
529 let mut agent_guard = app.agent.try_lock().expect("agent lock");
530 assert_eq!(
531 agent_guard.stream_options_mut().thinking_level,
532 Some(crate::model::ThinkingLevel::Off)
533 );
534 drop(agent_guard);
535
536 let session_guard = app.session.try_lock().expect("session lock");
537 assert_eq!(session_guard.header.thinking_level.as_deref(), Some("off"));
538 }
539
540 #[test]
541 fn configured_only_selector_includes_keyless_ready_models() {
542 let mut keyless = model_entry("ollama", "llama3.2", None);
543 keyless.auth_header = false;
544
545 let mut requires_creds = model_entry("acme-remote", "cloud-model", None);
546 requires_creds.auth_header = true;
547
548 let mut app = build_test_app(keyless.clone(), vec![keyless, requires_creds]);
549 app.open_model_selector_configured_only();
550
551 let selector = app
552 .model_selector
553 .as_ref()
554 .expect("configured-only selector should open when keyless models are ready");
555 let mut ids = Vec::new();
556 for idx in 0..selector.filtered_len() {
557 if let Some(item) = selector.item_at(idx) {
558 ids.push(item.full_id());
559 }
560 }
561
562 assert!(
563 ids.iter().any(|id| id == "ollama/llama3.2"),
564 "keyless local model must be considered ready"
565 );
566 assert!(
567 !ids.iter().any(|id| id == "acme-remote/cloud-model"),
568 "credentialed providers without configured auth should not appear"
569 );
570 }
571
572 #[test]
573 fn configured_only_selector_keeps_unknown_keyless_provider_models() {
574 let mut unknown_keyless = model_entry("acme-local", "dev-model", None);
575 unknown_keyless.auth_header = false;
576 let mut unknown_requires = model_entry("acme-remote", "cloud-model", None);
577 unknown_requires.auth_header = true;
578
579 let mut app = build_test_app(
580 unknown_keyless.clone(),
581 vec![unknown_keyless, unknown_requires],
582 );
583 app.open_model_selector_configured_only();
584
585 let selector = app
586 .model_selector
587 .as_ref()
588 .expect("unknown keyless model should keep selector available");
589 let mut ids = Vec::new();
590 for idx in 0..selector.filtered_len() {
591 if let Some(item) = selector.item_at(idx) {
592 ids.push(item.full_id());
593 }
594 }
595
596 assert!(ids.iter().any(|id| id == "acme-local/dev-model"));
597 assert!(!ids.iter().any(|id| id == "acme-remote/cloud-model"));
598 }
599}