1use anyhow::{Context, Result};
4use crossterm::{
5 event::{self, Event, KeyCode, KeyEvent, KeyModifiers},
6 execute,
7 terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
8};
9use ratatui::{
10 backend::CrosstermBackend,
11 layout::{Alignment, Constraint, Direction, Layout},
12 style::{Color, Modifier, Style},
13 widgets::{Block, Borders, List, ListItem, Paragraph, Wrap},
14 Frame, Terminal,
15};
16use std::io::{self, Stdout};
17
18const PROVIDERS: &[&str] = &["groq", "openai", "anthropic"];
20
21const OPENAI_MODELS: &[&str] = &[
23 "gpt-5.1",
24 "gpt-5.1-mini",
25 "gpt-5.1-nano",
26 "gpt-5",
27 "gpt-5-mini",
28 "gpt-5-nano",
29];
30const ANTHROPIC_MODELS: &[&str] = &[
31 "claude-sonnet-4-5",
32 "claude-haiku-4-5",
33 "claude-sonnet-4",
34];
35const GROQ_MODELS: &[&str] = &[
36 "openai/gpt-oss-120b",
37 "openai/gpt-oss-20b",
38 "meta-llama/llama-4-maverick-17b-128e-instruct",
39 "meta-llama/llama-4-scout-17b-16e-instruct",
40 "qwen/qwen3-32b",
41 "moonshotai/kimi-k2-instruct-0905",
42];
43
44#[derive(Debug, Clone, PartialEq)]
46enum WizardScreen {
47 ProviderSelection,
48 ApiKeyInput,
49 ModelSelection,
50 ConnectivityTest,
51 Result { success: bool, message: String },
52}
53
54fn load_existing_api_key(provider: &str) -> Option<String> {
56 match crate::semantic::config::get_api_key(provider) {
57 Ok(key) => {
58 log::debug!("Found existing API key for {}", provider);
59 Some(key)
60 }
61 Err(_) => {
62 log::debug!("No existing API key found for {}", provider);
63 None
64 }
65 }
66}
67
68fn mask_api_key(key: &str) -> String {
70 if key.len() <= 11 {
71 return "*".repeat(key.len());
73 }
74
75 let start = &key[..7];
76 let end = &key[key.len() - 4..];
77 format!("{}...{}", start, end)
78}
79
80fn is_gpt_oss_model(model: &str) -> bool {
82 model.starts_with("openai/gpt-oss-")
83}
84
85pub struct ConfigWizard {
87 screen: WizardScreen,
88 selected_provider_idx: usize,
89 api_key: String,
90 api_key_cursor: usize,
91 selected_model_idx: usize,
92 error_message: Option<String>,
93 existing_api_key: Option<String>,
94}
95
96impl ConfigWizard {
97 pub fn new() -> Self {
98 Self {
99 screen: WizardScreen::ProviderSelection,
100 selected_provider_idx: 0,
101 api_key: String::new(),
102 api_key_cursor: 0,
103 selected_model_idx: 0,
104 error_message: None,
105 existing_api_key: None,
106 }
107 }
108
109 fn selected_provider(&self) -> &str {
111 PROVIDERS[self.selected_provider_idx]
112 }
113
114 fn available_models(&self) -> &'static [&'static str] {
116 match self.selected_provider() {
117 "openai" => OPENAI_MODELS,
118 "anthropic" => ANTHROPIC_MODELS,
119 "groq" => GROQ_MODELS,
120 _ => &[],
121 }
122 }
123
124 fn selected_model(&self) -> &str {
126 let models = self.available_models();
127 models[self.selected_model_idx]
128 }
129
130 fn handle_key(&mut self, key: KeyEvent) -> Result<bool> {
132 if key.code == KeyCode::Char('c') && key.modifiers.contains(KeyModifiers::CONTROL) {
134 return Ok(true);
135 }
136
137 match &self.screen {
138 WizardScreen::ProviderSelection => self.handle_provider_selection_key(key),
139 WizardScreen::ApiKeyInput => self.handle_api_key_input_key(key),
140 WizardScreen::ModelSelection => self.handle_model_selection_key(key),
141 WizardScreen::ConnectivityTest => Ok(false), WizardScreen::Result { .. } => {
143 if key.code == KeyCode::Enter || key.code == KeyCode::Char('q') {
145 return Ok(true);
146 }
147 Ok(false)
148 }
149 }
150 }
151
152 fn handle_provider_selection_key(&mut self, key: KeyEvent) -> Result<bool> {
154 match key.code {
155 KeyCode::Up | KeyCode::Char('k') => {
156 if self.selected_provider_idx > 0 {
157 self.selected_provider_idx -= 1;
158 }
159 }
160 KeyCode::Down | KeyCode::Char('j') => {
161 if self.selected_provider_idx < PROVIDERS.len() - 1 {
162 self.selected_provider_idx += 1;
163 }
164 }
165 KeyCode::Enter => {
166 self.existing_api_key = load_existing_api_key(self.selected_provider());
168
169 self.screen = WizardScreen::ApiKeyInput;
171 self.api_key.clear();
172 self.api_key_cursor = 0;
173 }
174 KeyCode::Esc | KeyCode::Char('q') => {
175 return Ok(true); }
177 _ => {}
178 }
179 Ok(false)
180 }
181
182 fn handle_api_key_input_key(&mut self, key: KeyEvent) -> Result<bool> {
184 match key.code {
185 KeyCode::Char(c) if !key.modifiers.contains(KeyModifiers::CONTROL) => {
186 self.api_key.insert(self.api_key_cursor, c);
187 self.api_key_cursor += 1;
188 }
189 KeyCode::Backspace => {
190 if self.api_key_cursor > 0 {
191 self.api_key_cursor -= 1;
192 self.api_key.remove(self.api_key_cursor);
193 }
194 }
195 KeyCode::Delete => {
196 if self.api_key_cursor < self.api_key.len() {
197 self.api_key.remove(self.api_key_cursor);
198 }
199 }
200 KeyCode::Left => {
201 if self.api_key_cursor > 0 {
202 self.api_key_cursor -= 1;
203 }
204 }
205 KeyCode::Right => {
206 if self.api_key_cursor < self.api_key.len() {
207 self.api_key_cursor += 1;
208 }
209 }
210 KeyCode::Home => {
211 self.api_key_cursor = 0;
212 }
213 KeyCode::End => {
214 self.api_key_cursor = self.api_key.len();
215 }
216 KeyCode::Enter => {
217 if self.api_key.is_empty() {
219 if let Some(ref existing_key) = self.existing_api_key {
220 log::debug!("Keeping existing API key for {}", self.selected_provider());
221 self.api_key = existing_key.clone();
222 self.error_message = None;
223 self.selected_model_idx = 0;
224 self.screen = WizardScreen::ModelSelection;
225 } else {
226 self.error_message = Some("API key cannot be empty".to_string());
227 }
228 } else {
229 self.error_message = None;
231 self.selected_model_idx = 0;
232 self.screen = WizardScreen::ModelSelection;
233 }
234 }
235 KeyCode::Esc => {
236 self.screen = WizardScreen::ProviderSelection;
238 }
239 _ => {}
240 }
241 Ok(false)
242 }
243
244 fn handle_model_selection_key(&mut self, key: KeyEvent) -> Result<bool> {
246 let models = self.available_models();
247
248 match key.code {
249 KeyCode::Up | KeyCode::Char('k') => {
250 if self.selected_model_idx > 0 {
251 self.selected_model_idx -= 1;
252 }
253 }
254 KeyCode::Down | KeyCode::Char('j') => {
255 if self.selected_model_idx < models.len() - 1 {
256 self.selected_model_idx += 1;
257 }
258 }
259 KeyCode::Enter => {
260 self.screen = WizardScreen::ConnectivityTest;
262 }
263 KeyCode::Esc => {
264 self.screen = WizardScreen::ApiKeyInput;
266 }
267 _ => {}
268 }
269 Ok(false)
270 }
271
272 fn render(&self, frame: &mut Frame) {
274 match &self.screen {
275 WizardScreen::ProviderSelection => self.render_provider_selection(frame),
276 WizardScreen::ApiKeyInput => self.render_api_key_input(frame),
277 WizardScreen::ModelSelection => self.render_model_selection(frame),
278 WizardScreen::ConnectivityTest => self.render_connectivity_test(frame),
279 WizardScreen::Result { success, message } => {
280 self.render_result(frame, *success, message)
281 }
282 }
283 }
284
285 fn render_provider_selection(&self, frame: &mut Frame) {
287 let chunks = Layout::default()
288 .direction(Direction::Vertical)
289 .margin(2)
290 .constraints([
291 Constraint::Length(3),
292 Constraint::Min(0),
293 Constraint::Length(3),
294 ])
295 .split(frame.area());
296
297 let title = Paragraph::new("Reflex AI Configuration Wizard")
299 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
300 .alignment(Alignment::Center)
301 .block(Block::default().borders(Borders::ALL));
302 frame.render_widget(title, chunks[0]);
303
304 let providers: Vec<ListItem> = PROVIDERS
306 .iter()
307 .enumerate()
308 .map(|(idx, provider)| {
309 let style = if idx == self.selected_provider_idx {
310 Style::default()
311 .fg(Color::Yellow)
312 .add_modifier(Modifier::BOLD)
313 } else {
314 Style::default()
315 };
316
317 let prefix = if idx == self.selected_provider_idx {
318 "> "
319 } else {
320 " "
321 };
322
323 let provider_display = if *provider == "groq" {
324 format!("{} (recommended)", provider)
325 } else {
326 provider.to_string()
327 };
328
329 ListItem::new(format!("{}{}", prefix, provider_display)).style(style)
330 })
331 .collect();
332
333 let list = List::new(providers).block(
334 Block::default()
335 .borders(Borders::ALL)
336 .title("Select AI Provider (↑/↓ to navigate, Enter to select, Esc/q/Ctrl+C to quit)"),
337 );
338 frame.render_widget(list, chunks[1]);
339
340 let help = Paragraph::new("Use arrow keys or j/k to navigate, Enter to select, Esc/q/Ctrl+C to quit")
342 .style(Style::default().fg(Color::DarkGray))
343 .alignment(Alignment::Center);
344 frame.render_widget(help, chunks[2]);
345 }
346
347 fn render_api_key_input(&self, frame: &mut Frame) {
349 let chunks = Layout::default()
350 .direction(Direction::Vertical)
351 .margin(2)
352 .constraints([
353 Constraint::Length(3),
354 Constraint::Length(5),
355 Constraint::Min(0),
356 Constraint::Length(3),
357 ])
358 .split(frame.area());
359
360 let title = Paragraph::new(format!(
362 "Configure {} API Key",
363 self.selected_provider()
364 ))
365 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
366 .alignment(Alignment::Center)
367 .block(Block::default().borders(Borders::ALL));
368 frame.render_widget(title, chunks[0]);
369
370 let masked_key = "*".repeat(self.api_key.len());
372 let input_text = if self.api_key_cursor < masked_key.len() {
373 format!("{}█{}", &masked_key[..self.api_key_cursor], &masked_key[self.api_key_cursor..])
374 } else {
375 format!("{}█", masked_key)
376 };
377
378 let input = Paragraph::new(input_text)
379 .style(Style::default().fg(Color::Yellow))
380 .block(
381 Block::default()
382 .borders(Borders::ALL)
383 .title(format!("Enter API Key for {}", self.selected_provider())),
384 );
385 frame.render_widget(input, chunks[1]);
386
387 let message_widget = if let Some(ref error) = self.error_message {
389 Paragraph::new(error.as_str())
390 .style(Style::default().fg(Color::Red))
391 .alignment(Alignment::Center)
392 } else if let Some(ref existing_key) = self.existing_api_key {
393 let masked = mask_api_key(existing_key);
395 Paragraph::new(format!(
396 "Current API key: {}\n\
397 Press Enter to keep existing key, or type a new key to replace it\n\
398 Your API key will be securely stored in ~/.reflex/config.toml",
399 masked
400 ))
401 .style(Style::default().fg(Color::Yellow))
402 .alignment(Alignment::Center)
403 } else {
404 Paragraph::new("Your API key will be securely stored in ~/.reflex/config.toml")
405 .style(Style::default().fg(Color::Green))
406 .alignment(Alignment::Center)
407 };
408 frame.render_widget(message_widget, chunks[2]);
409
410 let help = Paragraph::new("Enter to continue, Esc to go back, Ctrl+C to quit")
412 .style(Style::default().fg(Color::DarkGray))
413 .alignment(Alignment::Center);
414 frame.render_widget(help, chunks[3]);
415 }
416
417 fn render_model_selection(&self, frame: &mut Frame) {
419 let chunks = Layout::default()
420 .direction(Direction::Vertical)
421 .margin(2)
422 .constraints([
423 Constraint::Length(3),
424 Constraint::Min(0),
425 Constraint::Length(3),
426 ])
427 .split(frame.area());
428
429 let title = Paragraph::new(format!(
431 "Select Model for {}",
432 self.selected_provider()
433 ))
434 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
435 .alignment(Alignment::Center)
436 .block(Block::default().borders(Borders::ALL));
437 frame.render_widget(title, chunks[0]);
438
439 let models = self.available_models();
441 let model_items: Vec<ListItem> = models
442 .iter()
443 .enumerate()
444 .map(|(idx, model)| {
445 let style = if idx == self.selected_model_idx {
446 Style::default()
447 .fg(Color::Yellow)
448 .add_modifier(Modifier::BOLD)
449 } else {
450 Style::default()
451 };
452
453 let prefix = if idx == self.selected_model_idx {
454 "> "
455 } else {
456 " "
457 };
458
459 let model_display = if idx == 0 {
461 format!("{} (recommended)", model)
462 } else {
463 model.to_string()
464 };
465
466 ListItem::new(format!("{}{}", prefix, model_display)).style(style)
467 })
468 .collect();
469
470 let list = List::new(model_items).block(
471 Block::default()
472 .borders(Borders::ALL)
473 .title("Select Model (↑/↓ to navigate, Enter to select, Esc to go back, Ctrl+C to quit)"),
474 );
475 frame.render_widget(list, chunks[1]);
476
477 let help = Paragraph::new("Use arrow keys or j/k to navigate, Enter to select, Esc to go back, Ctrl+C to quit")
479 .style(Style::default().fg(Color::DarkGray))
480 .alignment(Alignment::Center);
481 frame.render_widget(help, chunks[2]);
482 }
483
484 fn render_connectivity_test(&self, frame: &mut Frame) {
486 let chunks = Layout::default()
487 .direction(Direction::Vertical)
488 .margin(2)
489 .constraints([
490 Constraint::Length(3),
491 Constraint::Min(0),
492 ])
493 .split(frame.area());
494
495 let title = Paragraph::new("Testing Connection...")
497 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
498 .alignment(Alignment::Center)
499 .block(Block::default().borders(Borders::ALL));
500 frame.render_widget(title, chunks[0]);
501
502 let message = Paragraph::new(format!(
504 "Testing connection to {}...\n\nPlease wait...",
505 self.selected_provider()
506 ))
507 .style(Style::default().fg(Color::Yellow))
508 .alignment(Alignment::Center)
509 .wrap(Wrap { trim: true });
510 frame.render_widget(message, chunks[1]);
511 }
512
513 fn render_result(&self, frame: &mut Frame, success: bool, message: &str) {
515 let chunks = Layout::default()
516 .direction(Direction::Vertical)
517 .margin(2)
518 .constraints([
519 Constraint::Length(3),
520 Constraint::Min(0),
521 Constraint::Length(3),
522 ])
523 .split(frame.area());
524
525 let title = if success {
527 Paragraph::new("Configuration Successful!")
528 .style(Style::default().fg(Color::Green).add_modifier(Modifier::BOLD))
529 } else {
530 Paragraph::new("Configuration Failed")
531 .style(Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
532 };
533 let title = title.alignment(Alignment::Center).block(Block::default().borders(Borders::ALL));
534 frame.render_widget(title, chunks[0]);
535
536 let message_widget = Paragraph::new(message)
538 .style(if success {
539 Style::default().fg(Color::Green)
540 } else {
541 Style::default().fg(Color::Red)
542 })
543 .alignment(Alignment::Center)
544 .wrap(Wrap { trim: true });
545 frame.render_widget(message_widget, chunks[1]);
546
547 let help = Paragraph::new(if success {
549 "Press Enter, q, or Ctrl+C to exit"
550 } else {
551 "Press Enter, q, or Ctrl+C to exit (configuration not saved)"
552 })
553 .style(Style::default().fg(Color::DarkGray))
554 .alignment(Alignment::Center);
555 frame.render_widget(help, chunks[2]);
556 }
557}
558
559fn setup_terminal() -> Result<Terminal<CrosstermBackend<Stdout>>> {
561 enable_raw_mode().context("Failed to enable raw mode")?;
562 let mut stdout = io::stdout();
563 execute!(stdout, EnterAlternateScreen).context("Failed to enter alternate screen")?;
564 let backend = CrosstermBackend::new(stdout);
565 Terminal::new(backend).context("Failed to create terminal")
566}
567
568fn restore_terminal(terminal: &mut Terminal<CrosstermBackend<Stdout>>) -> Result<()> {
570 disable_raw_mode().context("Failed to disable raw mode")?;
571 execute!(terminal.backend_mut(), LeaveAlternateScreen)
572 .context("Failed to leave alternate screen")?;
573 terminal.show_cursor().context("Failed to show cursor")?;
574 Ok(())
575}
576
577pub fn run_configure_wizard() -> Result<()> {
579 let mut terminal = setup_terminal()?;
580 let mut wizard = ConfigWizard::new();
581
582 let result = run_wizard_loop(&mut terminal, &mut wizard);
583
584 restore_terminal(&mut terminal)?;
586
587 result
588}
589
590fn run_wizard_loop(
592 terminal: &mut Terminal<CrosstermBackend<Stdout>>,
593 wizard: &mut ConfigWizard,
594) -> Result<()> {
595 loop {
596 terminal.draw(|frame| wizard.render(frame))?;
598
599 if wizard.screen == WizardScreen::ConnectivityTest {
601 let result = test_connectivity(wizard.selected_provider(), &wizard.api_key);
602 match result {
603 Ok(_) => {
604 if let Err(e) = save_user_config(
606 wizard.selected_provider(),
607 &wizard.api_key,
608 wizard.selected_model(),
609 ) {
610 wizard.screen = WizardScreen::Result {
611 success: false,
612 message: format!("Failed to save configuration: {}", e),
613 };
614 } else {
615 wizard.screen = WizardScreen::Result {
616 success: true,
617 message: format!(
618 "Configuration saved successfully!\n\n\
619 Provider: {}\n\
620 Config file: ~/.reflex/config.toml\n\n\
621 You can now use 'rfx ask' to query your codebase.",
622 wizard.selected_provider()
623 ),
624 };
625 }
626 }
627 Err(e) => {
628 wizard.screen = WizardScreen::Result {
629 success: false,
630 message: format!(
631 "Connectivity test failed: {}\n\n\
632 Please check your API key and try again.",
633 e
634 ),
635 };
636 }
637 }
638 continue;
639 }
640
641 if event::poll(std::time::Duration::from_millis(100))? {
643 if let Event::Key(key) = event::read()? {
644 let should_exit = wizard.handle_key(key)?;
645 if should_exit {
646 break;
647 }
648 }
649 }
650 }
651
652 Ok(())
653}
654
655fn test_connectivity(provider_name: &str, api_key: &str) -> Result<()> {
657 let runtime = tokio::runtime::Runtime::new()
659 .context("Failed to create async runtime")?;
660
661 runtime.block_on(async {
662 let provider = crate::semantic::providers::create_provider(
664 provider_name,
665 api_key.to_string(),
666 None,
667 )?;
668
669 let test_prompt = "Please respond with valid JSON: {\"status\": \"ok\"}";
672
673 provider.complete(test_prompt, true).await?; Ok::<(), anyhow::Error>(())
677 })?;
678
679 Ok(())
680}
681
682fn save_user_config(provider: &str, api_key: &str, model: &str) -> Result<()> {
684 use serde::{Deserialize, Serialize};
685 use std::collections::HashMap;
686 use std::fs;
687
688 #[derive(Debug, Serialize, Deserialize)]
689 struct UserConfig {
690 #[serde(default)]
691 semantic: SemanticSection,
692 #[serde(default)]
693 credentials: HashMap<String, String>,
694 }
695
696 #[derive(Debug, Serialize, Deserialize)]
697 struct SemanticSection {
698 provider: String,
699 }
700
701 impl Default for SemanticSection {
702 fn default() -> Self {
703 Self {
704 provider: "openai".to_string(),
705 }
706 }
707 }
708
709 let home = dirs::home_dir()
710 .ok_or_else(|| anyhow::anyhow!("Could not determine home directory"))?;
711
712 let config_dir = home.join(".reflex");
713 fs::create_dir_all(&config_dir)
714 .context("Failed to create ~/.reflex directory")?;
715
716 let config_path = config_dir.join("config.toml");
717
718 let mut config = if config_path.exists() {
720 let config_str = fs::read_to_string(&config_path)
721 .context("Failed to read existing config file")?;
722 toml::from_str::<UserConfig>(&config_str)
723 .unwrap_or_else(|_| UserConfig {
724 semantic: SemanticSection::default(),
725 credentials: HashMap::new(),
726 })
727 } else {
728 UserConfig {
729 semantic: SemanticSection::default(),
730 credentials: HashMap::new(),
731 }
732 };
733
734 config.semantic.provider = provider.to_string();
736
737 let key_name = format!("{}_api_key", provider);
739 let model_name = format!("{}_model", provider);
740 config.credentials.insert(key_name, api_key.to_string());
741 config.credentials.insert(model_name, model.to_string());
742
743 let toml_content = toml::to_string_pretty(&config)
745 .context("Failed to serialize config to TOML")?;
746
747 let final_content = format!(
749 "# Reflex User Configuration\n\
750 # This file stores your AI provider API keys\n\
751 # Location: ~/.reflex/config.toml\n\
752 \n\
753 {}",
754 toml_content
755 );
756
757 fs::write(&config_path, final_content)
758 .context("Failed to write configuration file")?;
759
760 log::info!("Configuration saved to {:?}", config_path);
761
762 Ok(())
763}