use anyhow::Result;
use inquire::{Select, Text, validator::Validation};
use std::error::Error;
use std::fmt;
use std::path::PathBuf;
use crate::git::GitRepo;
pub type ValidatorFn = fn(&str) -> Result<Validation, Box<dyn Error + Send + Sync>>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum GitRefOption {
Reference { name: String, display: String },
Separator(String),
}
impl fmt::Display for GitRefOption {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
GitRefOption::Reference { display, .. } => write!(f, "{}", display),
GitRefOption::Separator(label) => {
if label.is_empty() {
write!(f, "") } else {
write!(f, "─── {} ───", label)
}
}
}
}
}
pub trait SelectionProvider {
fn select(&self, prompt: &str, options: Vec<String>) -> Result<String>;
fn select_grouped(&self, prompt: &str, options: Vec<GitRefOption>) -> Result<String>;
fn get_text_input(&self, prompt: &str, validator: Option<ValidatorFn>) -> Result<String>;
}
pub struct RealSelectionProvider;
impl SelectionProvider for RealSelectionProvider {
fn select(&self, prompt: &str, options: Vec<String>) -> Result<String> {
let selection = Select::new(prompt, options)
.with_page_size(10)
.with_vim_mode(true)
.prompt()?;
Ok(selection)
}
fn select_grouped(&self, prompt: &str, options: Vec<GitRefOption>) -> Result<String> {
let mut groups: Vec<(String, Vec<String>)> = Vec::new();
let mut current_group_name = String::new();
let mut current_group_refs = Vec::new();
for option in options {
match option {
GitRefOption::Separator(label) => {
if !current_group_refs.is_empty() {
groups.push((current_group_name.clone(), current_group_refs.clone()));
current_group_refs.clear();
}
if !label.is_empty() {
current_group_name = label;
}
}
GitRefOption::Reference { name, .. } => {
current_group_refs.push(name);
}
}
}
if !current_group_refs.is_empty() {
groups.push((current_group_name, current_group_refs));
}
if groups.len() == 1 {
let (_, refs) = &groups[0];
return self.select(prompt, refs.clone());
}
let group_names: Vec<String> = groups
.iter()
.map(|(name, refs)| format!("{} ({} items)", name, refs.len()))
.collect();
let selected_group = self.select("Choose a category:", group_names)?;
for (group_name, refs) in groups.iter() {
let group_display = format!("{} ({} items)", group_name, refs.len());
if group_display == selected_group {
return self.select(&format!("Choose from {}:", group_name), refs.clone());
}
}
anyhow::bail!("Selected group not found")
}
fn get_text_input(&self, prompt: &str, validator: Option<ValidatorFn>) -> Result<String> {
let mut text_prompt = Text::new(prompt);
if let Some(validation_fn) = validator {
text_prompt = text_prompt.with_validator(validation_fn);
}
let result = text_prompt.prompt()?;
Ok(result)
}
}
pub struct MockSelectionProvider {
pub response: String,
}
impl MockSelectionProvider {
pub fn new(response: impl Into<String>) -> Self {
Self {
response: response.into(),
}
}
}
impl SelectionProvider for MockSelectionProvider {
fn select(&self, _prompt: &str, options: Vec<String>) -> Result<String> {
if options.contains(&self.response) {
Ok(self.response.clone())
} else {
anyhow::bail!("Mock response '{}' not found in options", self.response)
}
}
fn select_grouped(&self, _prompt: &str, options: Vec<GitRefOption>) -> Result<String> {
let selectable_values: Vec<String> = options
.into_iter()
.filter_map(|opt| match opt {
GitRefOption::Reference { name, .. } => Some(name),
GitRefOption::Separator(_) => None,
})
.collect();
if selectable_values.contains(&self.response) {
Ok(self.response.clone())
} else {
anyhow::bail!(
"Mock response '{}' not found in grouped options",
self.response
)
}
}
fn get_text_input(&self, _prompt: &str, _validator: Option<ValidatorFn>) -> Result<String> {
Ok(self.response.clone())
}
}
pub fn extract_path_from_selection(selection: &str) -> Result<PathBuf> {
if let Some(path_start) = selection.rfind(" (") {
let path_str = &selection[path_start + 2..selection.len() - 1];
Ok(PathBuf::from(path_str))
} else {
anyhow::bail!("Invalid selection format: {}", selection)
}
}
pub fn extract_branch_from_selection(selection: &str) -> Result<String> {
if let Some(path_start) = selection.rfind(" (") {
let branch_part = &selection[..path_start];
if let Some(slash_pos) = branch_part.rfind('/') {
Ok(branch_part[slash_pos + 1..].to_string())
} else {
anyhow::bail!("Invalid selection format: {}", selection)
}
} else {
anyhow::bail!("Invalid selection format: {}", selection)
}
}
pub fn select_git_reference_interactive(
git_repo: &GitRepo,
provider: &dyn SelectionProvider,
) -> Result<String> {
let local_branches = git_repo.list_local_branches()?;
let remote_branches = git_repo.list_remote_branches()?;
let tags = git_repo.list_tags()?;
if local_branches.is_empty() && remote_branches.is_empty() && tags.is_empty() {
anyhow::bail!("No git references found");
}
let mut options = Vec::new();
if !local_branches.is_empty() {
options.push(GitRefOption::Separator("Local Branches".to_string()));
for branch in &local_branches {
options.push(GitRefOption::Reference {
name: branch.clone(),
display: format!(" {}", branch), });
}
}
if !remote_branches.is_empty() {
if !options.is_empty() {
options.push(GitRefOption::Separator(String::new())); }
options.push(GitRefOption::Separator("Remote Branches".to_string()));
for branch in &remote_branches {
options.push(GitRefOption::Reference {
name: branch.clone(),
display: format!(" {}", branch), });
}
}
if !tags.is_empty() {
if !options.is_empty() {
options.push(GitRefOption::Separator(String::new())); }
options.push(GitRefOption::Separator("Tags".to_string()));
for tag in &tags {
options.push(GitRefOption::Reference {
name: tag.clone(),
display: format!(" {}", tag), });
}
}
if options.is_empty() {
anyhow::bail!("No git references found");
}
provider.select_grouped("Select git reference to create worktree from:", options)
}
pub fn extract_reference_from_selection(selection: &str) -> Result<String> {
if let Some(space_pos) = selection.find(" (") {
Ok(selection[..space_pos].to_string())
} else {
anyhow::bail!("Invalid reference selection format: {}", selection)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mock_selection_provider_valid_response() {
let options = vec!["option1".to_string(), "option2".to_string()];
let provider = MockSelectionProvider::new("option1");
let result = provider.select("Test prompt", options);
assert!(matches!(result, Ok(ref s) if s == "option1"));
}
#[test]
fn test_mock_selection_provider_invalid_response() {
let options = vec!["option1".to_string(), "option2".to_string()];
let provider = MockSelectionProvider::new("invalid");
let result = provider.select("Test prompt", options);
assert!(result.is_err());
}
#[test]
fn test_extract_path_from_selection() {
let selection = "repo/branch (/some/path)";
let result = extract_path_from_selection(selection);
assert!(matches!(result, Ok(ref p) if p == &PathBuf::from("/some/path")));
}
#[test]
fn test_extract_branch_from_selection() {
let selection = "repo/feature-branch (/some/path)";
let result = extract_branch_from_selection(selection);
assert!(matches!(result, Ok(ref b) if b == "feature-branch"));
}
#[test]
fn test_extract_from_invalid_selection() {
let invalid_selection = "invalid format";
assert!(extract_path_from_selection(invalid_selection).is_err());
assert!(extract_branch_from_selection(invalid_selection).is_err());
}
#[test]
fn test_extract_reference_from_selection() {
let selection = "main (local branch)";
let result = extract_reference_from_selection(selection);
assert!(matches!(result, Ok(ref s) if s == "main"));
let selection = "origin/feature (remote branch)";
let result = extract_reference_from_selection(selection);
assert!(matches!(result, Ok(ref s) if s == "origin/feature"));
let selection = "v1.0.0 (tag)";
let result = extract_reference_from_selection(selection);
assert!(matches!(result, Ok(ref s) if s == "v1.0.0"));
}
#[test]
fn test_extract_reference_from_invalid_selection() {
let invalid_selection = "invalid format";
assert!(extract_reference_from_selection(invalid_selection).is_err());
}
#[test]
fn test_git_ref_option_formatting() {
let local_ref = GitRefOption::Reference {
name: "main".to_string(),
display: " main".to_string(),
};
let separator = GitRefOption::Separator("Local Branches".to_string());
let empty_separator = GitRefOption::Separator(String::new());
assert_eq!(format!("{}", local_ref), " main");
assert_eq!(format!("{}", separator), "─── Local Branches ───");
assert_eq!(format!("{}", empty_separator), "");
if let GitRefOption::Reference { name, .. } = local_ref {
assert_eq!(name, "main");
} else {
unreachable!("Expected Reference variant");
}
}
#[test]
fn test_select_grouped_functionality() {
let provider = MockSelectionProvider::new("main");
let options = vec![
GitRefOption::Separator("Local Branches".to_string()),
GitRefOption::Reference {
name: "main".to_string(),
display: " main".to_string(),
},
GitRefOption::Reference {
name: "feature".to_string(),
display: " feature".to_string(),
},
GitRefOption::Separator(String::new()), GitRefOption::Separator("Remote Branches".to_string()),
GitRefOption::Reference {
name: "origin/develop".to_string(),
display: " origin/develop".to_string(),
},
];
let result = provider.select_grouped("Choose base for new branch:", options);
if let Ok(selected) = result {
assert_eq!(selected, "main");
} else {
unreachable!("Selection should succeed in test: {:?}", result);
}
}
#[test]
fn test_git_ref_option_extraction() {
let ref_option = GitRefOption::Reference {
name: "origin/feature".to_string(),
display: " origin/feature".to_string(),
};
match ref_option {
GitRefOption::Reference { name, .. } => {
assert_eq!(name, "origin/feature");
}
GitRefOption::Separator(_) => {
unreachable!("Expected Reference, got Separator");
}
}
let separator = GitRefOption::Separator("Remote Branches".to_string());
match separator {
GitRefOption::Separator(label) => {
assert_eq!(label, "Remote Branches");
}
GitRefOption::Reference { .. } => {
unreachable!("Expected Separator, got Reference");
}
}
}
}