1use anyhow::{Context, Result};
4use crossterm::{
5 event::{self, Event, KeyCode, KeyEvent, KeyModifiers},
6 execute,
7 terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
8};
9use ratatui::{
10 backend::CrosstermBackend,
11 layout::{Alignment, Constraint, Direction, Layout},
12 style::{Color, Modifier, Style},
13 widgets::{Block, Borders, List, ListItem, ListState, Paragraph, Wrap},
14 Frame, Terminal,
15};
16use std::io::{self, Stdout};
17
18const PROVIDERS: &[&str] = &["openai", "anthropic", "openrouter"];
20
21const OPENAI_MODELS: &[&str] = &[
23 "gpt-5.1",
24 "gpt-5.1-mini",
25 "gpt-5.1-nano",
26 "gpt-5",
27 "gpt-5-mini",
28 "gpt-5-nano",
29];
30const ANTHROPIC_MODELS: &[&str] = &[
31 "claude-sonnet-4-5",
32 "claude-haiku-4-5",
33 "claude-sonnet-4",
34];
35use crate::semantic::providers::openrouter::OpenRouterModel;
36
37const OPENROUTER_SORT_STRATEGIES: &[(&str, &str)] = &[
39 ("price", "Cheapest provider for the model"),
40 ("latency", "Fastest response time (lowest latency)"),
41 ("throughput", "Highest tokens per second"),
42];
43
44#[derive(Debug, Clone, PartialEq)]
46enum WizardScreen {
47 ProviderSelection,
48 ApiKeyInput,
49 FetchingModels,
50 ModelSelection,
51 SortStrategySelection,
52 ConnectivityTest,
53 Result { success: bool, message: String },
54}
55
56fn load_existing_api_key(provider: &str) -> Option<String> {
58 match crate::semantic::config::get_api_key(provider) {
59 Ok(key) => {
60 log::debug!("Found existing API key for {}", provider);
61 Some(key)
62 }
63 Err(_) => {
64 log::debug!("No existing API key found for {}", provider);
65 None
66 }
67 }
68}
69
70fn mask_api_key(key: &str) -> String {
72 if key.len() <= 11 {
73 return "*".repeat(key.len());
75 }
76
77 let start = &key[..7];
78 let end = &key[key.len() - 4..];
79 format!("{}...{}", start, end)
80}
81
82pub struct ConfigWizard {
84 screen: WizardScreen,
85 selected_provider_idx: usize,
86 api_key: String,
87 api_key_cursor: usize,
88 selected_model_idx: usize,
89 selected_sort_idx: usize,
90 error_message: Option<String>,
91 existing_api_key: Option<String>,
92 fetched_models: Vec<OpenRouterModel>,
94 model_filter: String,
96}
97
98impl ConfigWizard {
99 pub fn new() -> Self {
100 Self {
101 screen: WizardScreen::ProviderSelection,
102 selected_provider_idx: 0,
103 api_key: String::new(),
104 api_key_cursor: 0,
105 selected_model_idx: 0,
106 selected_sort_idx: 0,
107 error_message: None,
108 existing_api_key: None,
109 fetched_models: Vec::new(),
110 model_filter: String::new(),
111 }
112 }
113
114 fn selected_provider(&self) -> &str {
116 PROVIDERS[self.selected_provider_idx]
117 }
118
119 fn static_models(&self) -> &'static [&'static str] {
121 match self.selected_provider() {
122 "openai" => OPENAI_MODELS,
123 "anthropic" => ANTHROPIC_MODELS,
124 _ => &[],
125 }
126 }
127
128 fn filtered_model_ids(&self) -> Vec<String> {
130 if self.selected_provider() == "openrouter" {
131 let filter = self.model_filter.to_lowercase();
132 self.fetched_models
133 .iter()
134 .filter(|m| {
135 if filter.is_empty() {
136 return true;
137 }
138 m.id.to_lowercase().contains(&filter)
139 || m.name.to_lowercase().contains(&filter)
140 })
141 .map(|m| m.id.clone())
142 .collect()
143 } else {
144 self.static_models().iter().map(|s| s.to_string()).collect()
145 }
146 }
147
148 fn selected_sort(&self) -> &str {
150 OPENROUTER_SORT_STRATEGIES[self.selected_sort_idx].0
151 }
152
153 fn selected_model(&self) -> String {
155 let models = self.filtered_model_ids();
156 if self.selected_model_idx < models.len() {
157 models[self.selected_model_idx].clone()
158 } else if !models.is_empty() {
159 models[0].clone()
160 } else {
161 String::new()
162 }
163 }
164
165 fn filtered_openrouter_model(&self, idx: usize) -> Option<&OpenRouterModel> {
167 let filter = self.model_filter.to_lowercase();
168 self.fetched_models
169 .iter()
170 .filter(|m| {
171 if filter.is_empty() {
172 return true;
173 }
174 m.id.to_lowercase().contains(&filter)
175 || m.name.to_lowercase().contains(&filter)
176 })
177 .nth(idx)
178 }
179
180 fn handle_key(&mut self, key: KeyEvent) -> Result<bool> {
182 if key.code == KeyCode::Char('c') && key.modifiers.contains(KeyModifiers::CONTROL) {
184 return Ok(true);
185 }
186
187 match &self.screen {
188 WizardScreen::ProviderSelection => self.handle_provider_selection_key(key),
189 WizardScreen::ApiKeyInput => self.handle_api_key_input_key(key),
190 WizardScreen::FetchingModels => Ok(false), WizardScreen::ModelSelection => self.handle_model_selection_key(key),
192 WizardScreen::SortStrategySelection => self.handle_sort_strategy_key(key),
193 WizardScreen::ConnectivityTest => Ok(false), WizardScreen::Result { .. } => {
195 if key.code == KeyCode::Enter || key.code == KeyCode::Char('q') {
197 return Ok(true);
198 }
199 Ok(false)
200 }
201 }
202 }
203
204 fn handle_provider_selection_key(&mut self, key: KeyEvent) -> Result<bool> {
206 match key.code {
207 KeyCode::Up | KeyCode::Char('k') => {
208 if self.selected_provider_idx > 0 {
209 self.selected_provider_idx -= 1;
210 }
211 }
212 KeyCode::Down | KeyCode::Char('j') => {
213 if self.selected_provider_idx < PROVIDERS.len() - 1 {
214 self.selected_provider_idx += 1;
215 }
216 }
217 KeyCode::Enter => {
218 self.existing_api_key = load_existing_api_key(self.selected_provider());
220
221 self.screen = WizardScreen::ApiKeyInput;
223 self.api_key.clear();
224 self.api_key_cursor = 0;
225 }
226 KeyCode::Esc | KeyCode::Char('q') => {
227 return Ok(true); }
229 _ => {}
230 }
231 Ok(false)
232 }
233
234 fn handle_api_key_input_key(&mut self, key: KeyEvent) -> Result<bool> {
236 match key.code {
237 KeyCode::Char(c) if !key.modifiers.contains(KeyModifiers::CONTROL) => {
238 self.api_key.insert(self.api_key_cursor, c);
239 self.api_key_cursor += 1;
240 }
241 KeyCode::Backspace => {
242 if self.api_key_cursor > 0 {
243 self.api_key_cursor -= 1;
244 self.api_key.remove(self.api_key_cursor);
245 }
246 }
247 KeyCode::Delete => {
248 if self.api_key_cursor < self.api_key.len() {
249 self.api_key.remove(self.api_key_cursor);
250 }
251 }
252 KeyCode::Left => {
253 if self.api_key_cursor > 0 {
254 self.api_key_cursor -= 1;
255 }
256 }
257 KeyCode::Right => {
258 if self.api_key_cursor < self.api_key.len() {
259 self.api_key_cursor += 1;
260 }
261 }
262 KeyCode::Home => {
263 self.api_key_cursor = 0;
264 }
265 KeyCode::End => {
266 self.api_key_cursor = self.api_key.len();
267 }
268 KeyCode::Enter => {
269 if self.api_key.is_empty() {
271 if let Some(ref existing_key) = self.existing_api_key {
272 log::debug!("Keeping existing API key for {}", self.selected_provider());
273 self.api_key = existing_key.clone();
274 self.error_message = None;
275 self.selected_model_idx = 0;
276 self.model_filter.clear();
277 if self.selected_provider() == "openrouter" {
278 self.screen = WizardScreen::FetchingModels;
279 } else {
280 self.screen = WizardScreen::ModelSelection;
281 }
282 } else {
283 self.error_message = Some("API key cannot be empty".to_string());
284 }
285 } else {
286 self.error_message = None;
288 self.selected_model_idx = 0;
289 self.model_filter.clear();
290 if self.selected_provider() == "openrouter" {
291 self.screen = WizardScreen::FetchingModels;
292 } else {
293 self.screen = WizardScreen::ModelSelection;
294 }
295 }
296 }
297 KeyCode::Esc => {
298 self.screen = WizardScreen::ProviderSelection;
300 }
301 _ => {}
302 }
303 Ok(false)
304 }
305
306 fn handle_model_selection_key(&mut self, key: KeyEvent) -> Result<bool> {
308 let is_openrouter = self.selected_provider() == "openrouter";
309 let model_count = self.filtered_model_ids().len();
310
311 match key.code {
312 KeyCode::Up => {
313 if self.selected_model_idx > 0 {
314 self.selected_model_idx -= 1;
315 }
316 }
317 KeyCode::Down => {
318 if model_count > 0 && self.selected_model_idx < model_count - 1 {
319 self.selected_model_idx += 1;
320 }
321 }
322 KeyCode::Char('k') if !is_openrouter => {
323 if self.selected_model_idx > 0 {
324 self.selected_model_idx -= 1;
325 }
326 }
327 KeyCode::Char('j') if !is_openrouter => {
328 if model_count > 0 && self.selected_model_idx < model_count - 1 {
329 self.selected_model_idx += 1;
330 }
331 }
332 KeyCode::Char(c) if is_openrouter && !key.modifiers.contains(KeyModifiers::CONTROL) => {
333 self.model_filter.push(c);
334 self.selected_model_idx = 0;
335 }
336 KeyCode::Backspace if is_openrouter => {
337 self.model_filter.pop();
338 self.selected_model_idx = 0;
339 }
340 KeyCode::Enter => {
341 if model_count == 0 {
342 return Ok(false);
344 }
345 if is_openrouter {
346 self.selected_sort_idx = 0;
347 self.screen = WizardScreen::SortStrategySelection;
348 } else {
349 self.screen = WizardScreen::ConnectivityTest;
350 }
351 }
352 KeyCode::Esc => {
353 self.model_filter.clear();
354 self.screen = WizardScreen::ApiKeyInput;
355 }
356 _ => {}
357 }
358 Ok(false)
359 }
360
361 fn handle_sort_strategy_key(&mut self, key: KeyEvent) -> Result<bool> {
363 match key.code {
364 KeyCode::Up | KeyCode::Char('k') => {
365 if self.selected_sort_idx > 0 {
366 self.selected_sort_idx -= 1;
367 }
368 }
369 KeyCode::Down | KeyCode::Char('j') => {
370 if self.selected_sort_idx < OPENROUTER_SORT_STRATEGIES.len() - 1 {
371 self.selected_sort_idx += 1;
372 }
373 }
374 KeyCode::Enter => {
375 self.screen = WizardScreen::ConnectivityTest;
376 }
377 KeyCode::Esc => {
378 self.screen = WizardScreen::ModelSelection;
380 }
381 _ => {}
382 }
383 Ok(false)
384 }
385
386 fn render(&mut self, frame: &mut Frame) {
388 let screen = self.screen.clone();
390 match &screen {
391 WizardScreen::ProviderSelection => self.render_provider_selection(frame),
392 WizardScreen::ApiKeyInput => self.render_api_key_input(frame),
393 WizardScreen::FetchingModels => self.render_fetching_models(frame),
394 WizardScreen::ModelSelection => self.render_model_selection(frame),
395 WizardScreen::SortStrategySelection => self.render_sort_strategy_selection(frame),
396 WizardScreen::ConnectivityTest => self.render_connectivity_test(frame),
397 WizardScreen::Result { success, message } => {
398 self.render_result(frame, *success, message)
399 }
400 }
401 }
402
403 fn render_provider_selection(&mut self, frame: &mut Frame) {
405 let chunks = Layout::default()
406 .direction(Direction::Vertical)
407 .margin(2)
408 .constraints([
409 Constraint::Length(3),
410 Constraint::Min(0),
411 Constraint::Length(3),
412 ])
413 .split(frame.area());
414
415 let title = Paragraph::new("Reflex AI Configuration Wizard")
417 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
418 .alignment(Alignment::Center)
419 .block(Block::default().borders(Borders::ALL));
420 frame.render_widget(title, chunks[0]);
421
422 let providers: Vec<ListItem> = PROVIDERS
424 .iter()
425 .map(|provider| {
426 let provider_display = match *provider {
427 "openrouter" => format!("{} (200+ models)", provider),
428 _ => provider.to_string(),
429 };
430
431 ListItem::new(provider_display)
432 })
433 .collect();
434
435 let list = List::new(providers)
436 .block(
437 Block::default()
438 .borders(Borders::ALL)
439 .title("Select AI Provider (↑/↓ to navigate, Enter to select, Esc/q/Ctrl+C to quit)"),
440 )
441 .highlight_style(Style::default().fg(Color::Yellow).add_modifier(Modifier::BOLD))
442 .highlight_symbol("> ");
443
444 let mut list_state = ListState::default().with_selected(Some(self.selected_provider_idx));
445 frame.render_stateful_widget(list, chunks[1], &mut list_state);
446
447 let help = Paragraph::new("Use arrow keys or j/k to navigate, Enter to select, Esc/q/Ctrl+C to quit")
449 .style(Style::default().fg(Color::DarkGray))
450 .alignment(Alignment::Center);
451 frame.render_widget(help, chunks[2]);
452 }
453
454 fn render_api_key_input(&mut self, frame: &mut Frame) {
456 let chunks = Layout::default()
457 .direction(Direction::Vertical)
458 .margin(2)
459 .constraints([
460 Constraint::Length(3),
461 Constraint::Length(5),
462 Constraint::Min(0),
463 Constraint::Length(3),
464 ])
465 .split(frame.area());
466
467 let title = Paragraph::new(format!(
469 "Configure {} API Key",
470 self.selected_provider()
471 ))
472 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
473 .alignment(Alignment::Center)
474 .block(Block::default().borders(Borders::ALL));
475 frame.render_widget(title, chunks[0]);
476
477 let masked_key = "*".repeat(self.api_key.len());
479 let input_text = if self.api_key_cursor < masked_key.len() {
480 format!("{}█{}", &masked_key[..self.api_key_cursor], &masked_key[self.api_key_cursor..])
481 } else {
482 format!("{}█", masked_key)
483 };
484
485 let input = Paragraph::new(input_text)
486 .style(Style::default().fg(Color::Yellow))
487 .block(
488 Block::default()
489 .borders(Borders::ALL)
490 .title(format!("Enter API Key for {}", self.selected_provider())),
491 );
492 frame.render_widget(input, chunks[1]);
493
494 let message_widget = if let Some(ref error) = self.error_message {
496 Paragraph::new(error.as_str())
497 .style(Style::default().fg(Color::Red))
498 .alignment(Alignment::Center)
499 } else if let Some(ref existing_key) = self.existing_api_key {
500 let masked = mask_api_key(existing_key);
502 Paragraph::new(format!(
503 "Current API key: {}\n\
504 Press Enter to keep existing key, or type a new key to replace it\n\
505 Your API key will be securely stored in ~/.reflex/config.toml",
506 masked
507 ))
508 .style(Style::default().fg(Color::Yellow))
509 .alignment(Alignment::Center)
510 } else {
511 Paragraph::new("Your API key will be securely stored in ~/.reflex/config.toml")
512 .style(Style::default().fg(Color::Green))
513 .alignment(Alignment::Center)
514 };
515 frame.render_widget(message_widget, chunks[2]);
516
517 let help = Paragraph::new("Enter to continue, Esc to go back, Ctrl+C to quit")
519 .style(Style::default().fg(Color::DarkGray))
520 .alignment(Alignment::Center);
521 frame.render_widget(help, chunks[3]);
522 }
523
524 fn render_model_selection(&mut self, frame: &mut Frame) {
526 let is_openrouter = self.selected_provider() == "openrouter";
527 let filtered = self.filtered_model_ids();
528 let model_count = filtered.len();
529
530 let constraints = if is_openrouter {
531 vec![
532 Constraint::Length(3), Constraint::Length(3), Constraint::Min(0), Constraint::Length(3), ]
537 } else {
538 vec![
539 Constraint::Length(3), Constraint::Min(0), Constraint::Length(3), ]
543 };
544
545 let chunks = Layout::default()
546 .direction(Direction::Vertical)
547 .margin(2)
548 .constraints(constraints)
549 .split(frame.area());
550
551 let title_text = if is_openrouter {
553 format!("Select Model for {} ({} models)", self.selected_provider(), model_count)
554 } else {
555 format!("Select Model for {}", self.selected_provider())
556 };
557 let title = Paragraph::new(title_text)
558 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
559 .alignment(Alignment::Center)
560 .block(Block::default().borders(Borders::ALL));
561 frame.render_widget(title, chunks[0]);
562
563 let (list_chunk, help_chunk) = if is_openrouter {
565 let filter_text = format!("{}█", self.model_filter);
566 let filter_input = Paragraph::new(filter_text)
567 .style(Style::default().fg(Color::Yellow))
568 .block(
569 Block::default()
570 .borders(Borders::ALL)
571 .title("Filter (type to search)"),
572 );
573 frame.render_widget(filter_input, chunks[1]);
574 (chunks[2], chunks[3])
575 } else {
576 (chunks[1], chunks[2])
577 };
578
579 if model_count == 0 && is_openrouter {
581 let empty_msg = Paragraph::new("No models match filter")
582 .style(Style::default().fg(Color::DarkGray))
583 .alignment(Alignment::Center)
584 .block(Block::default().borders(Borders::ALL).title("Models"));
585 frame.render_widget(empty_msg, list_chunk);
586 } else {
587 let model_items: Vec<ListItem> = filtered
588 .iter()
589 .enumerate()
590 .map(|(idx, model_id)| {
591 let model_display = if is_openrouter {
592 if let Some(m) = self.filtered_openrouter_model(idx) {
593 format!("{} ${:.2} / ${:.2} per 1M tokens",
594 model_id, m.prompt_price, m.completion_price)
595 } else {
596 model_id.to_string()
597 }
598 } else if idx == 0 {
599 format!("{} (recommended)", model_id)
600 } else {
601 model_id.to_string()
602 };
603
604 ListItem::new(model_display)
605 })
606 .collect();
607
608 let list_title = if is_openrouter {
609 "Models (↑/↓ to navigate, type to filter, Enter to select, Esc to go back)"
610 } else {
611 "Select Model (↑/↓ to navigate, Enter to select, Esc to go back, Ctrl+C to quit)"
612 };
613 let list = List::new(model_items)
614 .block(
615 Block::default()
616 .borders(Borders::ALL)
617 .title(list_title),
618 )
619 .highlight_style(Style::default().fg(Color::Yellow).add_modifier(Modifier::BOLD))
620 .highlight_symbol("> ");
621
622 let mut list_state = ListState::default().with_selected(Some(self.selected_model_idx));
623 frame.render_stateful_widget(list, list_chunk, &mut list_state);
624 }
625
626 let help_text = if is_openrouter {
628 "Type to filter, ↑/↓ to navigate, Enter to select, Esc to go back, Ctrl+C to quit"
629 } else {
630 "Use arrow keys or j/k to navigate, Enter to select, Esc to go back, Ctrl+C to quit"
631 };
632 let help = Paragraph::new(help_text)
633 .style(Style::default().fg(Color::DarkGray))
634 .alignment(Alignment::Center);
635 frame.render_widget(help, help_chunk);
636 }
637
638 fn render_fetching_models(&mut self, frame: &mut Frame) {
640 let chunks = Layout::default()
641 .direction(Direction::Vertical)
642 .margin(2)
643 .constraints([
644 Constraint::Length(3),
645 Constraint::Min(0),
646 ])
647 .split(frame.area());
648
649 let title = Paragraph::new("Fetching Available Models...")
651 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
652 .alignment(Alignment::Center)
653 .block(Block::default().borders(Borders::ALL));
654 frame.render_widget(title, chunks[0]);
655
656 let message = Paragraph::new("Loading models from OpenRouter...\n\nPlease wait...")
658 .style(Style::default().fg(Color::Yellow))
659 .alignment(Alignment::Center)
660 .wrap(Wrap { trim: true });
661 frame.render_widget(message, chunks[1]);
662 }
663
664 fn render_sort_strategy_selection(&mut self, frame: &mut Frame) {
666 let chunks = Layout::default()
667 .direction(Direction::Vertical)
668 .margin(2)
669 .constraints([
670 Constraint::Length(3),
671 Constraint::Min(0),
672 Constraint::Length(3),
673 ])
674 .split(frame.area());
675
676 let title = Paragraph::new("Select Provider Sort Strategy (OpenRouter)")
678 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
679 .alignment(Alignment::Center)
680 .block(Block::default().borders(Borders::ALL));
681 frame.render_widget(title, chunks[0]);
682
683 let strategy_items: Vec<ListItem> = OPENROUTER_SORT_STRATEGIES
685 .iter()
686 .enumerate()
687 .map(|(idx, (name, description))| {
688 let display = if idx == 0 {
689 format!("{} - {} (recommended)", name, description)
690 } else {
691 format!("{} - {}", name, description)
692 };
693
694 ListItem::new(display)
695 })
696 .collect();
697
698 let list = List::new(strategy_items)
699 .block(
700 Block::default()
701 .borders(Borders::ALL)
702 .title("Select Sort Strategy (↑/↓ to navigate, Enter to select, Esc to go back)"),
703 )
704 .highlight_style(Style::default().fg(Color::Yellow).add_modifier(Modifier::BOLD))
705 .highlight_symbol("> ");
706
707 let mut list_state = ListState::default().with_selected(Some(self.selected_sort_idx));
708 frame.render_stateful_widget(list, chunks[1], &mut list_state);
709
710 let help = Paragraph::new("Controls how OpenRouter selects the upstream provider for your chosen model")
712 .style(Style::default().fg(Color::DarkGray))
713 .alignment(Alignment::Center);
714 frame.render_widget(help, chunks[2]);
715 }
716
717 fn render_connectivity_test(&mut self, frame: &mut Frame) {
719 let chunks = Layout::default()
720 .direction(Direction::Vertical)
721 .margin(2)
722 .constraints([
723 Constraint::Length(3),
724 Constraint::Min(0),
725 ])
726 .split(frame.area());
727
728 let title = Paragraph::new("Testing Connection...")
730 .style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD))
731 .alignment(Alignment::Center)
732 .block(Block::default().borders(Borders::ALL));
733 frame.render_widget(title, chunks[0]);
734
735 let message = Paragraph::new(format!(
737 "Testing connection to {}...\n\nPlease wait...",
738 self.selected_provider()
739 ))
740 .style(Style::default().fg(Color::Yellow))
741 .alignment(Alignment::Center)
742 .wrap(Wrap { trim: true });
743 frame.render_widget(message, chunks[1]);
744 }
745
746 fn render_result(&mut self, frame: &mut Frame, success: bool, message: &str) {
748 let chunks = Layout::default()
749 .direction(Direction::Vertical)
750 .margin(2)
751 .constraints([
752 Constraint::Length(3),
753 Constraint::Min(0),
754 Constraint::Length(3),
755 ])
756 .split(frame.area());
757
758 let title = if success {
760 Paragraph::new("Configuration Successful!")
761 .style(Style::default().fg(Color::Green).add_modifier(Modifier::BOLD))
762 } else {
763 Paragraph::new("Configuration Failed")
764 .style(Style::default().fg(Color::Red).add_modifier(Modifier::BOLD))
765 };
766 let title = title.alignment(Alignment::Center).block(Block::default().borders(Borders::ALL));
767 frame.render_widget(title, chunks[0]);
768
769 let message_widget = Paragraph::new(message)
771 .style(if success {
772 Style::default().fg(Color::Green)
773 } else {
774 Style::default().fg(Color::Red)
775 })
776 .alignment(Alignment::Center)
777 .wrap(Wrap { trim: true });
778 frame.render_widget(message_widget, chunks[1]);
779
780 let help = Paragraph::new(if success {
782 "Press Enter, q, or Ctrl+C to exit"
783 } else {
784 "Press Enter, q, or Ctrl+C to exit (configuration not saved)"
785 })
786 .style(Style::default().fg(Color::DarkGray))
787 .alignment(Alignment::Center);
788 frame.render_widget(help, chunks[2]);
789 }
790}
791
792fn setup_terminal() -> Result<Terminal<CrosstermBackend<Stdout>>> {
794 enable_raw_mode().context("Failed to enable raw mode")?;
795 let mut stdout = io::stdout();
796 execute!(stdout, EnterAlternateScreen).context("Failed to enter alternate screen")?;
797 let backend = CrosstermBackend::new(stdout);
798 Terminal::new(backend).context("Failed to create terminal")
799}
800
801fn restore_terminal(terminal: &mut Terminal<CrosstermBackend<Stdout>>) -> Result<()> {
803 disable_raw_mode().context("Failed to disable raw mode")?;
804 execute!(terminal.backend_mut(), LeaveAlternateScreen)
805 .context("Failed to leave alternate screen")?;
806 terminal.show_cursor().context("Failed to show cursor")?;
807 Ok(())
808}
809
810pub fn run_configure_wizard() -> Result<()> {
812 let mut terminal = setup_terminal()?;
813 let mut wizard = ConfigWizard::new();
814
815 let result = run_wizard_loop(&mut terminal, &mut wizard);
816
817 restore_terminal(&mut terminal)?;
819
820 result
821}
822
823fn run_wizard_loop(
825 terminal: &mut Terminal<CrosstermBackend<Stdout>>,
826 wizard: &mut ConfigWizard,
827) -> Result<()> {
828 loop {
829 terminal.draw(|frame| wizard.render(frame))?;
831
832 if wizard.screen == WizardScreen::FetchingModels {
834 let result = fetch_openrouter_models(&wizard.api_key);
835 match result {
836 Ok(models) => {
837 wizard.fetched_models = models;
838 wizard.selected_model_idx = 0;
839 wizard.screen = WizardScreen::ModelSelection;
840 }
841 Err(e) => {
842 wizard.screen = WizardScreen::Result {
843 success: false,
844 message: format!(
845 "Failed to fetch models from OpenRouter: {}\n\n\
846 Please check your API key and try again.",
847 e
848 ),
849 };
850 }
851 }
852 continue;
853 }
854
855 if wizard.screen == WizardScreen::ConnectivityTest {
857 let selected_model = wizard.selected_model();
858 let result = test_connectivity(wizard.selected_provider(), &wizard.api_key);
859 match result {
860 Ok(_) => {
861 let sort = if wizard.selected_provider() == "openrouter" {
863 Some(wizard.selected_sort())
864 } else {
865 None
866 };
867 if let Err(e) = save_user_config(
868 wizard.selected_provider(),
869 &wizard.api_key,
870 &selected_model,
871 sort,
872 ) {
873 wizard.screen = WizardScreen::Result {
874 success: false,
875 message: format!("Failed to save configuration: {}", e),
876 };
877 } else {
878 wizard.screen = WizardScreen::Result {
879 success: true,
880 message: format!(
881 "Configuration saved successfully!\n\n\
882 Provider: {}\n\
883 Config file: ~/.reflex/config.toml\n\n\
884 You can now use 'rfx ask' to query your codebase.",
885 wizard.selected_provider()
886 ),
887 };
888 }
889 }
890 Err(e) => {
891 wizard.screen = WizardScreen::Result {
892 success: false,
893 message: format!(
894 "Connectivity test failed: {}\n\n\
895 Please check your API key and try again.",
896 e
897 ),
898 };
899 }
900 }
901 continue;
902 }
903
904 if event::poll(std::time::Duration::from_millis(100))? {
906 if let Event::Key(key) = event::read()? {
907 let should_exit = wizard.handle_key(key)?;
908 if should_exit {
909 break;
910 }
911 }
912 }
913 }
914
915 Ok(())
916}
917
918fn test_connectivity(provider_name: &str, api_key: &str) -> Result<()> {
920 let runtime = tokio::runtime::Runtime::new()
922 .context("Failed to create async runtime")?;
923
924 runtime.block_on(async {
925 let provider = crate::semantic::providers::create_provider(
927 provider_name,
928 api_key.to_string(),
929 None,
930 None,
931 )?;
932
933 let test_prompt = "Please respond with valid JSON: {\"status\": \"ok\"}";
936
937 provider.complete(test_prompt, true).await?; Ok::<(), anyhow::Error>(())
941 })?;
942
943 Ok(())
944}
945
946fn fetch_openrouter_models(api_key: &str) -> Result<Vec<OpenRouterModel>> {
948 let runtime = tokio::runtime::Runtime::new()
949 .context("Failed to create async runtime")?;
950 runtime.block_on(async {
951 crate::semantic::providers::openrouter::fetch_models(api_key).await
952 })
953}
954
955fn save_user_config(provider: &str, api_key: &str, model: &str, sort: Option<&str>) -> Result<()> {
957 use serde::{Deserialize, Serialize};
958 use std::collections::HashMap;
959 use std::fs;
960
961 #[derive(Debug, Serialize, Deserialize)]
962 struct UserConfig {
963 #[serde(default)]
964 semantic: SemanticSection,
965 #[serde(default)]
966 credentials: HashMap<String, String>,
967 }
968
969 #[derive(Debug, Serialize, Deserialize)]
970 struct SemanticSection {
971 provider: String,
972 }
973
974 impl Default for SemanticSection {
975 fn default() -> Self {
976 Self {
977 provider: "openai".to_string(),
978 }
979 }
980 }
981
982 let home = dirs::home_dir()
983 .ok_or_else(|| anyhow::anyhow!("Could not determine home directory"))?;
984
985 let config_dir = home.join(".reflex");
986 fs::create_dir_all(&config_dir)
987 .context("Failed to create ~/.reflex directory")?;
988
989 let config_path = config_dir.join("config.toml");
990
991 let mut config = if config_path.exists() {
993 let config_str = fs::read_to_string(&config_path)
994 .context("Failed to read existing config file")?;
995 toml::from_str::<UserConfig>(&config_str)
996 .unwrap_or_else(|_| UserConfig {
997 semantic: SemanticSection::default(),
998 credentials: HashMap::new(),
999 })
1000 } else {
1001 UserConfig {
1002 semantic: SemanticSection::default(),
1003 credentials: HashMap::new(),
1004 }
1005 };
1006
1007 config.semantic.provider = provider.to_string();
1009
1010 let key_name = format!("{}_api_key", provider);
1012 let model_name = format!("{}_model", provider);
1013 config.credentials.insert(key_name, api_key.to_string());
1014 config.credentials.insert(model_name, model.to_string());
1015
1016 if let Some(sort_value) = sort {
1018 config.credentials.insert("openrouter_sort".to_string(), sort_value.to_string());
1019 }
1020
1021 let toml_content = toml::to_string_pretty(&config)
1023 .context("Failed to serialize config to TOML")?;
1024
1025 let final_content = format!(
1027 "# Reflex User Configuration\n\
1028 # This file stores your AI provider API keys\n\
1029 # Location: ~/.reflex/config.toml\n\
1030 \n\
1031 {}",
1032 toml_content
1033 );
1034
1035 fs::write(&config_path, final_content)
1036 .context("Failed to write configuration file")?;
1037
1038 log::info!("Configuration saved to {:?}", config_path);
1039
1040 Ok(())
1041}