1use std::fmt;
8
9use dialoguer::{MultiSelect, Select};
10
11use crate::error::PawError;
12
13pub struct CliInfo {
21 pub display_name: String,
23 pub binary_name: String,
25}
26
27impl fmt::Display for CliInfo {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 if self.display_name == self.binary_name {
30 write!(f, "{}", self.binary_name)
31 } else {
32 write!(f, "{} ({})", self.display_name, self.binary_name)
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum CliMode {
40 Uniform,
42 PerBranch,
44}
45
46impl fmt::Display for CliMode {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 Self::Uniform => write!(f, "Same CLI for all branches"),
50 Self::PerBranch => write!(f, "Different CLI per branch"),
51 }
52 }
53}
54
55#[derive(Debug)]
57pub struct SelectionResult {
58 pub mappings: Vec<(String, String)>,
60}
61
62pub trait Prompter {
68 fn select_mode(&self) -> Result<CliMode, PawError>;
70
71 fn select_branches(&self, branches: &[String]) -> Result<Vec<String>, PawError>;
73
74 fn select_cli(&self, clis: &[CliInfo]) -> Result<String, PawError>;
76
77 fn select_cli_for_branch(&self, branch: &str, clis: &[CliInfo]) -> Result<String, PawError>;
79}
80
81pub struct TerminalPrompter;
87
88impl Prompter for TerminalPrompter {
89 fn select_mode(&self) -> Result<CliMode, PawError> {
90 let modes = [CliMode::Uniform, CliMode::PerBranch];
91 let labels: Vec<String> = modes.iter().map(ToString::to_string).collect();
92
93 let selection = Select::new()
94 .with_prompt("CLI assignment mode")
95 .items(&labels)
96 .default(0)
97 .interact_opt()
98 .map_err(|e| map_dialoguer_error(&e))?;
99
100 match selection {
101 Some(idx) => Ok(modes[idx]),
102 None => Err(PawError::UserCancelled),
103 }
104 }
105
106 fn select_branches(&self, branches: &[String]) -> Result<Vec<String>, PawError> {
107 let selection = MultiSelect::new()
108 .with_prompt("Select branches (space to toggle, enter to confirm)")
109 .items(branches)
110 .interact_opt()
111 .map_err(|e| map_dialoguer_error(&e))?;
112
113 match selection {
114 Some(indices) if indices.is_empty() => Err(PawError::UserCancelled),
115 Some(indices) => Ok(indices.into_iter().map(|i| branches[i].clone()).collect()),
116 None => Err(PawError::UserCancelled),
117 }
118 }
119
120 fn select_cli(&self, clis: &[CliInfo]) -> Result<String, PawError> {
121 let labels: Vec<String> = clis.iter().map(ToString::to_string).collect();
122
123 let selection = Select::new()
124 .with_prompt("Select AI CLI for all branches")
125 .items(&labels)
126 .default(0)
127 .interact_opt()
128 .map_err(|e| map_dialoguer_error(&e))?;
129
130 match selection {
131 Some(idx) => Ok(clis[idx].binary_name.clone()),
132 None => Err(PawError::UserCancelled),
133 }
134 }
135
136 fn select_cli_for_branch(&self, branch: &str, clis: &[CliInfo]) -> Result<String, PawError> {
137 let labels: Vec<String> = clis.iter().map(ToString::to_string).collect();
138
139 let selection = Select::new()
140 .with_prompt(format!("Select CLI for {branch}"))
141 .items(&labels)
142 .default(0)
143 .interact_opt()
144 .map_err(|e| map_dialoguer_error(&e))?;
145
146 match selection {
147 Some(idx) => Ok(clis[idx].binary_name.clone()),
148 None => Err(PawError::UserCancelled),
149 }
150 }
151}
152
153fn map_dialoguer_error(err: &dialoguer::Error) -> PawError {
156 match err {
157 dialoguer::Error::IO(io_err) if io_err.kind() == std::io::ErrorKind::Interrupted => {
158 PawError::UserCancelled
159 }
160 dialoguer::Error::IO(_) => {
161 PawError::SessionError(format!("Interactive prompt failed: {err}"))
162 }
163 }
164}
165
166pub fn run_selection(
179 prompter: &dyn Prompter,
180 branches: &[String],
181 clis: &[CliInfo],
182 cli_flag: Option<&str>,
183 branches_flag: Option<&[String]>,
184) -> Result<SelectionResult, PawError> {
185 if clis.is_empty() {
186 return Err(PawError::NoCLIsFound);
187 }
188 if branches.is_empty() {
189 return Err(PawError::BranchError("No branches available.".to_string()));
190 }
191
192 let selected_branches = if let Some(flagged) = branches_flag {
194 flagged.to_vec()
195 } else {
196 prompter.select_branches(branches)?
197 };
198
199 let mappings = if let Some(cli) = cli_flag {
201 selected_branches
202 .into_iter()
203 .map(|branch| (branch, cli.to_string()))
204 .collect()
205 } else {
206 let mode = prompter.select_mode()?;
207 match mode {
208 CliMode::Uniform => {
209 let cli = prompter.select_cli(clis)?;
210 selected_branches
211 .into_iter()
212 .map(|branch| (branch, cli.clone()))
213 .collect()
214 }
215 CliMode::PerBranch => {
216 let mut pairs = Vec::with_capacity(selected_branches.len());
217 for branch in selected_branches {
218 let cli = prompter.select_cli_for_branch(&branch, clis)?;
219 pairs.push((branch, cli));
220 }
221 pairs
222 }
223 }
224 };
225
226 Ok(SelectionResult { mappings })
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 use std::cell::Cell;
238
239 struct TrackingPrompter {
242 mode: CliMode,
243 branch_indices: Vec<usize>,
244 uniform_cli: String,
245 per_branch_clis: Vec<String>,
246 per_branch_call_count: Cell<usize>,
247 cancel_on_branch_select: bool,
248 cancel_on_cli_select: bool,
249 }
250
251 impl TrackingPrompter {
252 fn uniform(branch_indices: Vec<usize>, cli: &str) -> Self {
253 Self {
254 mode: CliMode::Uniform,
255 branch_indices,
256 uniform_cli: cli.to_string(),
257 per_branch_clis: vec![],
258 per_branch_call_count: Cell::new(0),
259 cancel_on_branch_select: false,
260 cancel_on_cli_select: false,
261 }
262 }
263
264 fn per_branch(branch_indices: Vec<usize>, clis: Vec<&str>) -> Self {
265 Self {
266 mode: CliMode::PerBranch,
267 branch_indices,
268 uniform_cli: String::new(),
269 per_branch_clis: clis.into_iter().map(String::from).collect(),
270 per_branch_call_count: Cell::new(0),
271 cancel_on_branch_select: false,
272 cancel_on_cli_select: false,
273 }
274 }
275
276 fn cancel_on_branches() -> Self {
277 Self {
278 mode: CliMode::Uniform,
279 branch_indices: vec![],
280 uniform_cli: String::new(),
281 per_branch_clis: vec![],
282 per_branch_call_count: Cell::new(0),
283 cancel_on_branch_select: true,
284 cancel_on_cli_select: false,
285 }
286 }
287
288 fn cancel_on_cli(branch_indices: Vec<usize>) -> Self {
289 Self {
290 mode: CliMode::Uniform,
291 branch_indices,
292 uniform_cli: String::new(),
293 per_branch_clis: vec![],
294 per_branch_call_count: Cell::new(0),
295 cancel_on_branch_select: false,
296 cancel_on_cli_select: true,
297 }
298 }
299 }
300
301 impl Prompter for TrackingPrompter {
302 fn select_mode(&self) -> Result<CliMode, PawError> {
303 Ok(self.mode)
304 }
305
306 fn select_branches(&self, branches: &[String]) -> Result<Vec<String>, PawError> {
307 if self.cancel_on_branch_select || self.branch_indices.is_empty() {
308 return Err(PawError::UserCancelled);
309 }
310 Ok(self
311 .branch_indices
312 .iter()
313 .map(|&i| branches[i].clone())
314 .collect())
315 }
316
317 fn select_cli(&self, _clis: &[CliInfo]) -> Result<String, PawError> {
318 if self.cancel_on_cli_select {
319 return Err(PawError::UserCancelled);
320 }
321 Ok(self.uniform_cli.clone())
322 }
323
324 fn select_cli_for_branch(
325 &self,
326 _branch: &str,
327 _clis: &[CliInfo],
328 ) -> Result<String, PawError> {
329 let idx = self.per_branch_call_count.get();
330 self.per_branch_call_count.set(idx + 1);
331 self.per_branch_clis
332 .get(idx)
333 .cloned()
334 .ok_or(PawError::UserCancelled)
335 }
336 }
337
338 fn test_clis() -> Vec<CliInfo> {
343 vec![
344 CliInfo {
345 display_name: "Alpha CLI".to_string(),
346 binary_name: "alpha".to_string(),
347 },
348 CliInfo {
349 display_name: "Beta CLI".to_string(),
350 binary_name: "beta".to_string(),
351 },
352 ]
353 }
354
355 fn test_branches() -> Vec<String> {
356 vec!["feature/auth".to_string(), "fix/api".to_string()]
357 }
358
359 #[test]
364 fn both_flags_skips_all_prompts_and_maps_cli_to_all_branches() {
365 let prompter = TrackingPrompter::cancel_on_branches(); let branches = test_branches();
367 let clis = test_clis();
368 let flag_branches = vec!["feature/auth".to_string(), "fix/api".to_string()];
369
370 let result = run_selection(
371 &prompter,
372 &branches,
373 &clis,
374 Some("alpha"),
375 Some(&flag_branches),
376 )
377 .unwrap();
378
379 assert_eq!(
380 result.mappings,
381 vec![
382 ("feature/auth".to_string(), "alpha".to_string()),
383 ("fix/api".to_string(), "alpha".to_string()),
384 ]
385 );
386 }
387
388 #[test]
389 fn cli_flag_skips_cli_prompt_but_prompts_for_branches() {
390 let prompter = TrackingPrompter::uniform(vec![0], "should-not-be-used");
391 let branches = test_branches();
392 let clis = test_clis();
393
394 let result = run_selection(&prompter, &branches, &clis, Some("alpha"), None).unwrap();
395
396 assert_eq!(
398 result.mappings,
399 vec![("feature/auth".to_string(), "alpha".to_string())]
400 );
401 }
402
403 #[test]
404 fn branches_flag_skips_branch_prompt_but_prompts_for_cli_uniform() {
405 let prompter = TrackingPrompter::uniform(vec![], "beta");
406 let branches = test_branches();
407 let clis = test_clis();
408 let flag_branches = vec!["feature/auth".to_string(), "fix/api".to_string()];
409
410 let result =
411 run_selection(&prompter, &branches, &clis, None, Some(&flag_branches)).unwrap();
412
413 assert_eq!(
414 result.mappings,
415 vec![
416 ("feature/auth".to_string(), "beta".to_string()),
417 ("fix/api".to_string(), "beta".to_string()),
418 ]
419 );
420 }
421
422 #[test]
427 fn uniform_mode_maps_same_cli_to_all_selected_branches() {
428 let prompter = TrackingPrompter::uniform(vec![0, 1], "alpha");
429 let branches = test_branches();
430 let clis = test_clis();
431
432 let result = run_selection(&prompter, &branches, &clis, None, None).unwrap();
433
434 assert_eq!(
435 result.mappings,
436 vec![
437 ("feature/auth".to_string(), "alpha".to_string()),
438 ("fix/api".to_string(), "alpha".to_string()),
439 ]
440 );
441 }
442
443 #[test]
444 fn per_branch_mode_maps_different_cli_to_each_branch() {
445 let prompter = TrackingPrompter::per_branch(vec![0, 1], vec!["alpha", "beta"]);
446 let branches = test_branches();
447 let clis = test_clis();
448
449 let result = run_selection(&prompter, &branches, &clis, None, None).unwrap();
450
451 assert_eq!(
452 result.mappings,
453 vec![
454 ("feature/auth".to_string(), "alpha".to_string()),
455 ("fix/api".to_string(), "beta".to_string()),
456 ]
457 );
458 }
459
460 #[test]
461 fn per_branch_mode_with_branches_flag() {
462 let prompter = TrackingPrompter::per_branch(vec![], vec!["beta", "alpha"]);
463 let branches = test_branches();
464 let clis = test_clis();
465 let flag_branches = vec!["feature/auth".to_string(), "fix/api".to_string()];
466
467 let result =
468 run_selection(&prompter, &branches, &clis, None, Some(&flag_branches)).unwrap();
469
470 assert_eq!(
471 result.mappings,
472 vec![
473 ("feature/auth".to_string(), "beta".to_string()),
474 ("fix/api".to_string(), "alpha".to_string()),
475 ]
476 );
477 }
478
479 #[test]
484 fn no_clis_available_returns_error() {
485 let prompter = TrackingPrompter::cancel_on_branches();
486 let branches = test_branches();
487 let clis: Vec<CliInfo> = vec![];
488
489 let result = run_selection(&prompter, &branches, &clis, None, None);
490
491 assert!(matches!(result, Err(PawError::NoCLIsFound)));
492 }
493
494 #[test]
495 fn no_branches_available_returns_error() {
496 let prompter = TrackingPrompter::cancel_on_branches();
497 let branches: Vec<String> = vec![];
498 let clis = test_clis();
499
500 let result = run_selection(&prompter, &branches, &clis, None, None);
501
502 assert!(matches!(result, Err(PawError::BranchError(_))));
503 }
504
505 #[test]
506 fn user_cancels_branch_selection_returns_cancelled() {
507 let prompter = TrackingPrompter::cancel_on_branches();
508 let branches = test_branches();
509 let clis = test_clis();
510
511 let result = run_selection(&prompter, &branches, &clis, None, None);
512
513 assert!(matches!(result, Err(PawError::UserCancelled)));
514 }
515
516 #[test]
517 fn user_selects_no_branches_returns_cancelled() {
518 let prompter = TrackingPrompter::uniform(vec![], "alpha");
520 let branches = test_branches();
521 let clis = test_clis();
522
523 let result = run_selection(&prompter, &branches, &clis, None, None);
524
525 assert!(matches!(result, Err(PawError::UserCancelled)));
526 }
527
528 #[test]
529 fn user_cancels_cli_selection_returns_cancelled() {
530 let prompter = TrackingPrompter::cancel_on_cli(vec![0]);
531 let branches = test_branches();
532 let clis = test_clis();
533
534 let result = run_selection(&prompter, &branches, &clis, None, None);
535
536 assert!(matches!(result, Err(PawError::UserCancelled)));
537 }
538
539 #[test]
544 fn selecting_subset_of_branches_works() {
545 let prompter = TrackingPrompter::uniform(vec![1], "alpha"); let branches = test_branches();
547 let clis = test_clis();
548
549 let result = run_selection(&prompter, &branches, &clis, None, None).unwrap();
550
551 assert_eq!(
552 result.mappings,
553 vec![("fix/api".to_string(), "alpha".to_string())]
554 );
555 }
556
557 #[test]
562 fn cli_mode_display() {
563 assert_eq!(CliMode::Uniform.to_string(), "Same CLI for all branches");
564 assert_eq!(CliMode::PerBranch.to_string(), "Different CLI per branch");
565 }
566
567 #[test]
568 fn cli_info_display_same_names() {
569 let info = CliInfo {
570 display_name: "claude".to_string(),
571 binary_name: "claude".to_string(),
572 };
573 assert_eq!(info.to_string(), "claude");
574 }
575
576 #[test]
577 fn cli_info_display_different_names() {
578 let info = CliInfo {
579 display_name: "My Agent".to_string(),
580 binary_name: "my-agent".to_string(),
581 };
582 assert_eq!(info.to_string(), "My Agent (my-agent)");
583 }
584}