use std::cell::RefCell;
use std::collections::BTreeMap;
use std::ffi::{OsStr, OsString};
use std::io::Write;
use clap::Command;
use clap_complete::engine::{ArgValueCompleter, CompletionCandidate, ValueCompleter};
use clap_complete::env::CompleteEnv;
use crate::cli;
use crate::display::format_relative_time_short;
use worktrunk::config::{CommandConfig, ProjectConfig, UserConfig, append_aliases};
use worktrunk::git::{BranchCategory, HookType, Repository};
pub(crate) fn maybe_handle_env_completion() -> bool {
let Some(shell_name) = std::env::var_os("COMPLETE") else {
return false;
};
if shell_name.is_empty() || shell_name == "0" {
return false;
}
let mut args: Vec<OsString> = std::env::args_os().collect();
CONTEXT.with(|ctx| *ctx.borrow_mut() = Some(CompletionContext { args: args.clone() }));
args.remove(0);
let escape_index = args
.iter()
.position(|a| *a == "--")
.map(|i| i + 1)
.unwrap_or(args.len());
args.drain(0..escape_index);
let current_dir = std::env::current_dir().ok();
if args.is_empty() {
let all_args: Vec<OsString> = std::env::args_os().collect();
let _ = CompleteEnv::with_factory(completion_command)
.try_complete(all_args, current_dir.as_deref());
CONTEXT.with(|ctx| ctx.borrow_mut().take());
return true;
}
let mut cmd = completion_command();
cmd.build();
let index: usize = std::env::var("_CLAP_COMPLETE_INDEX")
.ok()
.and_then(|i| i.parse().ok())
.unwrap_or_else(|| args.len() - 1);
let current_word = args.get(index).map(|s| s.to_string_lossy().into_owned());
let include_long_flags = current_word.as_deref() == Some("-");
let completions = match clap_complete::engine::complete(
&mut cmd,
args.clone(),
index,
current_dir.as_deref(),
) {
Ok(c) => c,
Err(_) => {
CONTEXT.with(|ctx| ctx.borrow_mut().take());
return true;
}
};
let completions = if include_long_flags {
let mut merged = completions;
let mut args_with_double_dash = args;
if let Some(word) = args_with_double_dash.get_mut(index) {
*word = OsString::from("--");
}
let mut cmd2 = completion_command();
cmd2.build();
if let Ok(long_completions) = clap_complete::engine::complete(
&mut cmd2,
args_with_double_dash,
index,
current_dir.as_deref(),
) {
for candidate in long_completions {
let value = candidate.get_value();
if !merged.iter().any(|c| c.get_value() == value) {
merged.push(candidate);
}
}
}
merged
} else {
completions
};
let shell_name = shell_name.to_string_lossy();
let completions = if shell_name.as_ref() == "bash" {
let prefix = current_word.as_deref().unwrap_or("").to_owned();
if prefix.is_empty() {
completions
} else {
completions
.into_iter()
.filter(|c| c.get_value().to_string_lossy().starts_with(&*prefix))
.collect()
}
} else {
completions
};
let ifs = std::env::var("_CLAP_IFS").ok();
let separator = ifs.as_deref().unwrap_or("\n");
let help_sep = match shell_name.as_ref() {
"zsh" => Some(":"),
"fish" | "nu" => Some("\t"),
_ => None,
};
let mut stdout = std::io::stdout();
for (i, candidate) in completions.iter().enumerate() {
if i != 0 {
let _ = write!(stdout, "{}", separator);
}
let value = candidate.get_value().to_string_lossy();
match (help_sep, candidate.get_help()) {
(Some(sep), Some(help)) => {
let _ = write!(stdout, "{}{}{}", value, sep, help);
}
_ => {
let _ = write!(stdout, "{}", value);
}
}
}
CONTEXT.with(|ctx| ctx.borrow_mut().take());
true
}
pub(crate) fn branch_value_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: false,
exclude_remote_only: false,
worktree_only: false,
})
}
pub(crate) fn worktree_branch_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: true,
exclude_remote_only: false,
worktree_only: false,
})
}
pub(crate) fn local_branches_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: false,
exclude_remote_only: true,
worktree_only: false,
})
}
pub(crate) fn worktree_only_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: false,
exclude_remote_only: false,
worktree_only: true,
})
}
pub(crate) fn hook_command_name_completer() -> ArgValueCompleter {
ArgValueCompleter::new(HookCommandCompleter)
}
#[derive(Clone, Copy)]
struct HookCommandCompleter;
impl ValueCompleter for HookCommandCompleter {
fn complete(&self, current: &OsStr) -> Vec<CompletionCandidate> {
if current.to_str().is_some_and(|s| s.starts_with('-')) {
return Vec::new();
}
let hook_type = CONTEXT.with(|ctx| {
ctx.borrow().as_ref().and_then(|ctx| {
for hook in &[
"pre-start",
"post-start",
"pre-commit",
"post-commit",
"pre-merge",
"post-merge",
"pre-remove",
] {
if ctx.contains(hook) {
return Some(*hook);
}
}
if ctx.contains("post-create") {
return Some("pre-start");
}
None
})
});
let Some(hook_type_str) = hook_type else {
return Vec::new();
};
let Ok(hook_type) = hook_type_str.parse::<HookType>() else {
return Vec::new();
};
let mut candidates = Vec::new();
let add_named_commands =
|candidates: &mut Vec<_>, config: &worktrunk::config::CommandConfig| {
candidates.extend(
config
.commands()
.filter_map(|cmd| cmd.name.as_ref())
.map(|name| CompletionCandidate::new(name.clone())),
);
};
if let Ok(user_config) = UserConfig::load()
&& let Some(config) = user_config.configs.hooks.get(hook_type)
{
add_named_commands(&mut candidates, config);
}
if let Ok(repo) = Repository::current()
&& let Ok(Some(project_config)) = ProjectConfig::load(&repo, false)
&& let Some(config) = project_config.hooks.get(hook_type)
{
add_named_commands(&mut candidates, config);
}
candidates
}
}
#[derive(Clone, Copy)]
struct BranchCompleter {
suppress_with_create: bool,
exclude_remote_only: bool,
worktree_only: bool,
}
impl ValueCompleter for BranchCompleter {
fn complete(&self, current: &OsStr) -> Vec<CompletionCandidate> {
if current.to_str().is_some_and(|s| s.starts_with('-')) {
return Vec::new();
}
if self.suppress_with_create && suppress_switch_branch_completion() {
return Vec::new();
}
let branches = match Repository::current().and_then(|repo| repo.branches_for_completion()) {
Ok(b) => b,
Err(_) => return Vec::new(),
};
if branches.is_empty() {
return Vec::new();
}
let exclude_remote_only = self.exclude_remote_only
|| (!self.worktree_only
&& branches.len() > 100
&& branches
.iter()
.any(|b| matches!(b.category, BranchCategory::Remote(_))));
branches
.into_iter()
.filter(|branch| {
if self.worktree_only {
matches!(branch.category, BranchCategory::Worktree)
} else if exclude_remote_only {
!matches!(branch.category, BranchCategory::Remote(_))
} else {
true
}
})
.map(|branch| {
let time_str = format_relative_time_short(branch.timestamp);
let help = match branch.category {
BranchCategory::Worktree => format!("+ {}", time_str),
BranchCategory::Local => format!("/ {}", time_str),
BranchCategory::Remote(remotes) => {
format!("⇣ {} {}", time_str, remotes.join(", "))
}
};
CompletionCandidate::new(branch.name).help(Some(help.into()))
})
.collect()
}
}
fn suppress_switch_branch_completion() -> bool {
CONTEXT.with(|ctx| {
ctx.borrow()
.as_ref()
.is_some_and(|ctx| ctx.contains("--create") || ctx.contains("-c"))
})
}
struct CompletionContext {
args: Vec<OsString>,
}
impl CompletionContext {
fn contains(&self, needle: &str) -> bool {
self.args
.iter()
.any(|arg| arg.to_string_lossy().as_ref() == needle)
}
}
thread_local! {
static CONTEXT: RefCell<Option<CompletionContext>> = const { RefCell::new(None) };
}
fn completion_command() -> Command {
let cmd = cli::build_command();
let cmd = inject_alias_subcommands(cmd);
hide_non_positional_options_for_completion(cmd)
}
fn inject_alias_subcommands(cmd: Command) -> Command {
let aliases = load_aliases_for_completion();
if aliases.is_empty() {
return cmd;
}
cmd.mut_subcommand("step", |mut step| {
for (name, cmd_config) in aliases {
if step
.get_subcommands()
.any(|s| s.get_name() == name.as_str())
{
continue;
}
let first_template = cmd_config
.commands()
.next()
.map(|c| c.template.as_str())
.unwrap_or("");
let help = truncate_template(first_template);
let name: &'static str = Box::leak(name.into_boxed_str());
let about: &'static str = Box::leak(format!("alias: {help}").into_boxed_str());
let sub = Command::new(name)
.about(about)
.arg(clap::Arg::new("dry-run").long("dry-run"))
.arg(clap::Arg::new("yes").short('y').long("yes"))
.arg(
clap::Arg::new("var")
.long("var")
.num_args(1)
.action(clap::ArgAction::Append),
);
step = step.subcommand(sub);
}
step
})
}
fn load_aliases_for_completion() -> BTreeMap<String, CommandConfig> {
let mut aliases = BTreeMap::new();
if let Ok(repo) = Repository::current() {
if let Ok(user_config) = UserConfig::load() {
let project_id = repo.project_identifier().ok();
aliases.extend(user_config.aliases(project_id.as_deref()));
}
if let Ok(Some(project_config)) = ProjectConfig::load(&repo, false)
&& let Some(ref project_aliases) = project_config.aliases
{
append_aliases(&mut aliases, project_aliases);
}
} else if let Ok(user_config) = UserConfig::load() {
aliases.extend(user_config.aliases(None));
}
aliases
}
fn truncate_template(template: &str) -> &str {
let s = template.trim();
let first_line = s.lines().next().unwrap_or(s);
if first_line.len() > 60 {
let mut end = 57;
while end > 0 && !first_line.is_char_boundary(end) {
end -= 1;
}
&first_line[..end]
} else {
first_line
}
}
fn hide_non_positional_options_for_completion(cmd: Command) -> Command {
fn process_command(cmd: Command, is_root: bool) -> Command {
let cmd = cmd.disable_help_flag(true).arg(
clap::Arg::new("help")
.short('h')
.long("help")
.action(clap::ArgAction::Help)
.help("Print help (see more with '--help')"),
);
let cmd = if is_root {
cmd.disable_version_flag(true).arg(
clap::Arg::new("version")
.short('V')
.long("version")
.action(clap::ArgAction::Version)
.help("Print version"),
)
} else {
cmd
};
let cmd = cmd.mut_args(|arg| {
if arg.is_positional() || arg.is_hide_set() {
arg
} else {
arg.hide(true)
}
});
cmd.mut_subcommands(|sub| process_command(sub, false))
}
process_command(cmd, true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_truncate_template() {
assert_eq!(truncate_template("echo hello"), "echo hello");
assert_eq!(truncate_template("line one\nline two"), "line one");
assert_eq!(truncate_template(" spaced \n"), "spaced");
let s60 = "a".repeat(60);
assert_eq!(truncate_template(&s60), s60.as_str());
let s61 = "b".repeat(61);
assert_eq!(truncate_template(&s61), &"b".repeat(57));
let multi = "a".repeat(56) + "€€";
let result = truncate_template(&multi);
assert_eq!(result.len(), 56);
assert_eq!(result, "a".repeat(56));
}
}