1use super::commands::{
2 model_entry_matches, model_requires_configured_credential, resolve_model_key_from_default_auth,
3};
4use super::*;
5
6impl PiApp {
7 fn normalize_model_key(entry: &ModelEntry) -> (String, String) {
8 let canonical_provider =
9 crate::provider_metadata::canonical_provider_id(entry.model.provider.as_str())
10 .unwrap_or(entry.model.provider.as_str());
11 (
12 canonical_provider.to_ascii_lowercase(),
13 entry.model.id.to_ascii_lowercase(),
14 )
15 }
16
17 fn unique_model_count(models: &[ModelEntry]) -> usize {
18 models
19 .iter()
20 .map(Self::normalize_model_key)
21 .collect::<std::collections::HashSet<_>>()
22 .len()
23 }
24
25 fn available_models_with_credentials(&self) -> Vec<ModelEntry> {
26 let auth = crate::auth::AuthStorage::load(crate::config::Config::auth_path()).ok();
27 let mut provider_has_credential: std::collections::HashMap<String, bool> =
28 std::collections::HashMap::new();
29 let mut filtered = Vec::new();
30 for entry in &self.available_models {
31 let provider = entry.model.provider.as_str();
32 let canonical = crate::provider_metadata::canonical_provider_id(provider)
33 .unwrap_or(provider)
34 .to_ascii_lowercase();
35 let requires_configured_credential = model_requires_configured_credential(entry);
36 let has_inline_key = entry
37 .api_key
38 .as_ref()
39 .is_some_and(|key| !key.trim().is_empty());
40 let has_auth_key = auth.as_ref().is_some_and(|storage| {
41 *provider_has_credential
42 .entry(canonical.clone())
43 .or_insert_with(|| storage.resolve_api_key(&canonical, None).is_some())
44 });
45 if !requires_configured_credential || has_inline_key || has_auth_key {
46 filtered.push(entry.clone());
47 }
48 }
49
50 filtered.sort_by_key(Self::normalize_model_key);
51 filtered.dedup_by(|left, right| model_entry_matches(left, right));
52 filtered
53 }
54
55 pub fn open_model_selector(&mut self) {
57 if self.agent_state != AgentState::Idle {
58 self.status_message = Some("Cannot switch models while processing".to_string());
59 return;
60 }
61
62 if self.available_models.is_empty() {
63 self.status_message = Some("No models available".to_string());
64 return;
65 }
66
67 self.model_selector = Some(crate::model_selector::ModelSelectorOverlay::new(
68 &self.available_models,
69 ));
70 }
71
72 pub(super) fn open_model_selector_configured_only(&mut self) {
73 if self.agent_state != AgentState::Idle {
74 self.status_message = Some("Cannot switch models while processing".to_string());
75 return;
76 }
77
78 if self.available_models.is_empty() {
79 self.status_message = Some("No models available".to_string());
80 return;
81 }
82
83 let filtered = self.available_models_with_credentials();
84 if filtered.is_empty() {
85 self.status_message = Some(
86 "No models are ready to use. Configure credentials with /login <provider>."
87 .to_string(),
88 );
89 return;
90 }
91
92 let mut overlay = crate::model_selector::ModelSelectorOverlay::new(&filtered);
93 overlay.set_configured_only_scope(Self::unique_model_count(&self.available_models));
94 self.model_selector = Some(overlay);
95 }
96
97 pub fn handle_model_selector_key(&mut self, key: &KeyMsg) -> Option<Cmd> {
99 let selector = self.model_selector.as_mut()?;
100
101 match key.key_type {
102 KeyType::Up => selector.select_prev(),
103 KeyType::Down => selector.select_next(),
104 KeyType::Runes if key.runes == ['k'] => selector.select_prev(),
105 KeyType::Runes if key.runes == ['j'] => selector.select_next(),
106 KeyType::PgDown => selector.select_page_down(),
107 KeyType::PgUp => selector.select_page_up(),
108 KeyType::Backspace => selector.pop_char(),
109 KeyType::Runes => selector.push_chars(key.runes.iter().copied()),
110 KeyType::Enter => {
111 let selected = selector.selected_item().cloned();
112 self.model_selector = None;
113 if let Some(selected) = selected {
114 self.apply_model_selection(&selected);
115 } else {
116 self.status_message = Some("No model selected".to_string());
117 }
118 return None;
119 }
120 KeyType::Esc | KeyType::CtrlC => {
121 self.model_selector = None;
122 self.status_message = Some("Model selector cancelled".to_string());
123 }
124 _ => {} }
126
127 None
128 }
129
130 fn apply_model_selection(&mut self, selected: &crate::model_selector::ModelKey) {
132 let entry = self
134 .available_models
135 .iter()
136 .find(|e| {
137 e.model.provider.eq_ignore_ascii_case(&selected.provider)
138 && e.model.id.eq_ignore_ascii_case(&selected.id)
139 })
140 .cloned();
141
142 let Some(next) = entry else {
143 self.status_message = Some(format!("Model {} not found", selected.full_id()));
144 return;
145 };
146
147 if model_entry_matches(&next, &self.model_entry) {
148 self.status_message = Some(format!("Already using {}", selected.full_id()));
149 return;
150 }
151
152 let resolved_key_opt = resolve_model_key_from_default_auth(&next);
153 if model_requires_configured_credential(&next) && resolved_key_opt.is_none() {
154 self.status_message = Some(format!(
155 "Missing credentials for provider {}. Run /login {}.",
156 next.model.provider, next.model.provider
157 ));
158 return;
159 }
160
161 let provider_impl = match providers::create_provider(&next, self.extensions.as_ref()) {
162 Ok(p) => p,
163 Err(err) => {
164 self.status_message = Some(err.to_string());
165 return;
166 }
167 };
168
169 let Ok(mut agent_guard) = self.agent.try_lock() else {
170 self.status_message = Some("Agent busy; try again".to_string());
171 return;
172 };
173 agent_guard.set_provider(provider_impl);
174 agent_guard
175 .stream_options_mut()
176 .api_key
177 .clone_from(&resolved_key_opt);
178 agent_guard
179 .stream_options_mut()
180 .headers
181 .clone_from(&next.headers);
182 drop(agent_guard);
183
184 let Ok(mut session_guard) = self.session.try_lock() else {
185 self.status_message = Some("Session busy; try again".to_string());
186 return;
187 };
188 session_guard.header.provider = Some(next.model.provider.clone());
189 session_guard.header.model_id = Some(next.model.id.clone());
190 session_guard.append_model_change(next.model.provider.clone(), next.model.id.clone());
191 drop(session_guard);
192 self.spawn_save_session();
193
194 self.model_entry = next.clone();
195 if let Ok(mut guard) = self.model_entry_shared.lock() {
196 *guard = next;
197 }
198 self.model = format!(
199 "{}/{}",
200 self.model_entry.model.provider, self.model_entry.model.id
201 );
202 self.status_message = Some(format!("Switched model: {}", self.model));
203 }
204
205 #[allow(clippy::too_many_lines)]
207 pub(super) fn render_model_selector(
208 &self,
209 selector: &crate::model_selector::ModelSelectorOverlay,
210 ) -> String {
211 use std::fmt::Write;
212 let mut output = String::new();
213
214 let _ = writeln!(output, "\n {}", self.styles.title.render("Select a model"));
215 if selector.configured_only() {
216 let _ = writeln!(
217 output,
218 " {}",
219 self.styles
220 .muted
221 .render("Only showing models that are ready to use (see README for details)")
222 );
223 }
224
225 let query = selector.query();
227 let search_line = if query.is_empty() {
228 if selector.configured_only() {
229 " >".to_string()
230 } else {
231 " > (type to filter)".to_string()
232 }
233 } else {
234 format!(" > {query}")
235 };
236 let _ = writeln!(output, "{}", self.styles.muted.render(&search_line));
237
238 let _ = writeln!(
239 output,
240 " {}",
241 self.styles.muted.render("─".repeat(50).as_str())
242 );
243
244 if selector.filtered_len() == 0 {
245 let _ = writeln!(
246 output,
247 " {}",
248 self.styles.muted_italic.render("No matching models.")
249 );
250 } else {
251 let offset = selector.scroll_offset();
252 let visible_count = selector.max_visible().min(selector.filtered_len());
253 let end = (offset + visible_count).min(selector.filtered_len());
254
255 let current_full = format!(
256 "{}/{}",
257 self.model_entry.model.provider, self.model_entry.model.id
258 );
259
260 for idx in offset..end {
261 let is_selected = idx == selector.selected_index();
262 let prefix = if is_selected { ">" } else { " " };
263
264 if let Some(key) = selector.item_at(idx) {
265 let full = key.full_id();
266 let is_current = full.eq_ignore_ascii_case(¤t_full);
267 let marker = if is_current { " *" } else { "" };
268 let row = format!("{prefix} {full}{marker}");
269 let rendered = if is_selected {
270 self.styles.accent_bold.render(&row)
271 } else if is_current {
272 self.styles.accent.render(&row)
273 } else {
274 self.styles.muted.render(&row)
275 };
276 let _ = writeln!(output, " {rendered}");
277 }
278 }
279
280 if selector.filtered_len() > visible_count {
281 let _ = writeln!(
282 output,
283 " {}",
284 self.styles.muted.render(&format!(
285 "({}-{} of {})",
286 offset + 1,
287 end,
288 selector.filtered_len()
289 ))
290 );
291 }
292
293 if selector.configured_only() {
294 let _ = writeln!(
295 output,
296 " {}",
297 self.styles.muted.render(&format!(
298 "({}/{})",
299 selector.filtered_len(),
300 selector.source_total()
301 ))
302 );
303 }
304
305 if let Some(selected) = selector.selected_item()
306 && let Some(entry) = self.available_models.iter().find(|entry| {
307 entry
308 .model
309 .provider
310 .eq_ignore_ascii_case(&selected.provider)
311 && entry.model.id.eq_ignore_ascii_case(&selected.id)
312 })
313 {
314 let _ = writeln!(
315 output,
316 "\n {}",
317 self.styles
318 .muted
319 .render(&format!("Model Name: {}", entry.model.name))
320 );
321 }
322 }
323
324 let _ = writeln!(
325 output,
326 "\n {}",
327 self.styles
328 .muted_italic
329 .render("↑/↓/j/k: navigate Enter: select Esc: cancel * = current")
330 );
331 output
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::agent::{Agent, AgentConfig};
339 use crate::model::{StreamEvent, Usage};
340 use crate::provider::{Context, InputType, Model, ModelCost, Provider, StreamOptions};
341 use crate::resources::{ResourceCliOptions, ResourceLoader};
342 use crate::session::Session;
343 use crate::tools::ToolRegistry;
344 use asupersync::channel::mpsc;
345 use asupersync::runtime::RuntimeBuilder;
346 use futures::stream;
347 use std::collections::HashMap;
348 use std::path::Path;
349 use std::pin::Pin;
350 use std::sync::Arc;
351 use std::sync::OnceLock;
352
353 struct DummyProvider;
354
355 #[async_trait::async_trait]
356 impl Provider for DummyProvider {
357 fn name(&self) -> &'static str {
358 "dummy"
359 }
360
361 fn api(&self) -> &'static str {
362 "dummy"
363 }
364
365 fn model_id(&self) -> &'static str {
366 "dummy-model"
367 }
368
369 async fn stream(
370 &self,
371 _context: &Context<'_>,
372 _options: &StreamOptions,
373 ) -> crate::error::Result<
374 Pin<Box<dyn futures::Stream<Item = crate::error::Result<StreamEvent>> + Send>>,
375 > {
376 Ok(Box::pin(stream::empty()))
377 }
378 }
379
380 fn runtime_handle() -> asupersync::runtime::RuntimeHandle {
381 static RT: OnceLock<asupersync::runtime::Runtime> = OnceLock::new();
382 RT.get_or_init(|| {
383 RuntimeBuilder::multi_thread()
384 .blocking_threads(1, 8)
385 .build()
386 .expect("build runtime")
387 })
388 .handle()
389 }
390
391 fn model_entry(provider: &str, id: &str, api_key: Option<&str>) -> ModelEntry {
392 ModelEntry {
393 model: Model {
394 id: id.to_string(),
395 name: id.to_string(),
396 api: "openai-completions".to_string(),
397 provider: provider.to_string(),
398 base_url: "https://example.invalid".to_string(),
399 reasoning: true,
400 input: vec![InputType::Text],
401 cost: ModelCost {
402 input: 0.0,
403 output: 0.0,
404 cache_read: 0.0,
405 cache_write: 0.0,
406 },
407 context_window: 128_000,
408 max_tokens: 8_192,
409 headers: HashMap::new(),
410 },
411 api_key: api_key.map(str::to_string),
412 headers: HashMap::new(),
413 auth_header: true,
414 compat: None,
415 oauth_config: None,
416 }
417 }
418
419 fn build_test_app(current: ModelEntry, available: Vec<ModelEntry>) -> PiApp {
420 let provider: Arc<dyn Provider> = Arc::new(DummyProvider);
421 let agent = Agent::new(
422 provider,
423 ToolRegistry::new(&[], Path::new("."), None),
424 AgentConfig::default(),
425 );
426 let session = Arc::new(asupersync::sync::Mutex::new(Session::in_memory()));
427 let resources = ResourceLoader::empty(false);
428 let resource_cli = ResourceCliOptions {
429 no_skills: false,
430 no_prompt_templates: false,
431 no_extensions: false,
432 no_themes: false,
433 skill_paths: Vec::new(),
434 prompt_paths: Vec::new(),
435 extension_paths: Vec::new(),
436 theme_paths: Vec::new(),
437 };
438 let (event_tx, _event_rx) = mpsc::channel(64);
439 PiApp::new(
440 agent,
441 session,
442 Config::default(),
443 resources,
444 resource_cli,
445 Path::new(".").to_path_buf(),
446 current,
447 Vec::new(),
448 available,
449 Vec::new(),
450 event_tx,
451 runtime_handle(),
452 true,
453 None,
454 Some(KeyBindings::new()),
455 Vec::new(),
456 Usage::default(),
457 )
458 }
459
460 #[test]
461 fn apply_model_selection_replaces_stream_options_api_key_and_headers() {
462 let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
463 let mut next = model_entry("openrouter", "openai/gpt-4o-mini", Some("next-key"));
464 next.headers
465 .insert("x-provider-header".to_string(), "next".to_string());
466
467 let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
468
469 {
470 let mut guard = app.agent.try_lock().expect("agent lock");
471 guard.stream_options_mut().api_key = Some("stale-key".to_string());
472 guard
473 .stream_options_mut()
474 .headers
475 .insert("x-stale".to_string(), "stale".to_string());
476 }
477
478 app.apply_model_selection(&crate::model_selector::ModelKey {
479 provider: next.model.provider.clone(),
480 id: next.model.id,
481 });
482
483 let mut guard = app.agent.try_lock().expect("agent lock");
484 assert_eq!(
485 guard.stream_options_mut().api_key.as_deref(),
486 Some("next-key")
487 );
488 assert_eq!(
489 guard
490 .stream_options_mut()
491 .headers
492 .get("x-provider-header")
493 .map(String::as_str),
494 Some("next")
495 );
496 assert!(
497 !guard.stream_options_mut().headers.contains_key("x-stale"),
498 "switching models must replace stale provider headers"
499 );
500 }
501
502 #[test]
503 fn apply_model_selection_clears_stale_api_key_when_next_model_has_no_key() {
504 let current = model_entry("openai", "gpt-4o-mini", Some("old-key"));
505 let mut next = model_entry("ollama", "llama3.2", None);
506 next.auth_header = false;
507 let mut app = build_test_app(current.clone(), vec![current, next.clone()]);
508
509 {
510 let mut guard = app.agent.try_lock().expect("agent lock");
511 guard.stream_options_mut().api_key = Some("stale-key".to_string());
512 }
513
514 app.apply_model_selection(&crate::model_selector::ModelKey {
515 provider: next.model.provider.clone(),
516 id: next.model.id,
517 });
518
519 let mut guard = app.agent.try_lock().expect("agent lock");
520 assert!(
521 guard.stream_options_mut().api_key.is_none(),
522 "switching to a keyless model must clear stale API key"
523 );
524 }
525
526 #[test]
527 fn configured_only_selector_includes_keyless_ready_models() {
528 let mut keyless = model_entry("ollama", "llama3.2", None);
529 keyless.auth_header = false;
530
531 let mut requires_creds = model_entry("acme-remote", "cloud-model", None);
532 requires_creds.auth_header = true;
533
534 let mut app = build_test_app(keyless.clone(), vec![keyless, requires_creds]);
535 app.open_model_selector_configured_only();
536
537 let selector = app
538 .model_selector
539 .as_ref()
540 .expect("configured-only selector should open when keyless models are ready");
541 let mut ids = Vec::new();
542 for idx in 0..selector.filtered_len() {
543 if let Some(item) = selector.item_at(idx) {
544 ids.push(item.full_id());
545 }
546 }
547
548 assert!(
549 ids.iter().any(|id| id == "ollama/llama3.2"),
550 "keyless local model must be considered ready"
551 );
552 assert!(
553 !ids.iter().any(|id| id == "acme-remote/cloud-model"),
554 "credentialed providers without configured auth should not appear"
555 );
556 }
557
558 #[test]
559 fn configured_only_selector_keeps_unknown_keyless_provider_models() {
560 let mut unknown_keyless = model_entry("acme-local", "dev-model", None);
561 unknown_keyless.auth_header = false;
562 let mut unknown_requires = model_entry("acme-remote", "cloud-model", None);
563 unknown_requires.auth_header = true;
564
565 let mut app = build_test_app(
566 unknown_keyless.clone(),
567 vec![unknown_keyless, unknown_requires],
568 );
569 app.open_model_selector_configured_only();
570
571 let selector = app
572 .model_selector
573 .as_ref()
574 .expect("unknown keyless model should keep selector available");
575 let mut ids = Vec::new();
576 for idx in 0..selector.filtered_len() {
577 if let Some(item) = selector.item_at(idx) {
578 ids.push(item.full_id());
579 }
580 }
581
582 assert!(ids.iter().any(|id| id == "acme-local/dev-model"));
583 assert!(!ids.iter().any(|id| id == "acme-remote/cloud-model"));
584 }
585}