use std::cell::RefCell;
use std::collections::BTreeMap;
use std::ffi::{OsStr, OsString};
use std::io::Write;
use clap::Command;
use clap_complete::engine::{ArgValueCompleter, CompletionCandidate, ValueCompleter};
use clap_complete::env::CompleteEnv;
use crate::cli;
use crate::display::format_relative_time_short;
use worktrunk::config::{CommandConfig, ProjectConfig, UserConfig, append_aliases};
use worktrunk::git::{BranchCategory, HookType, Repository};
pub(crate) fn maybe_handle_env_completion() -> bool {
let Some(shell_name) = std::env::var_os("COMPLETE") else {
return false;
};
if shell_name.is_empty() || shell_name == "0" {
return false;
}
worktrunk::config::suppress_warnings();
let mut args: Vec<OsString> = std::env::args_os().collect();
CONTEXT.with(|ctx| *ctx.borrow_mut() = Some(CompletionContext { args: args.clone() }));
args.remove(0);
let escape_index = args
.iter()
.position(|a| *a == "--")
.map(|i| i + 1)
.unwrap_or(args.len());
args.drain(0..escape_index);
let current_dir = std::env::current_dir().ok();
if args.is_empty() {
let all_args: Vec<OsString> = std::env::args_os().collect();
let _ = CompleteEnv::with_factory(completion_command)
.try_complete(all_args, current_dir.as_deref());
CONTEXT.with(|ctx| ctx.borrow_mut().take());
return true;
}
if args.len() >= 3 {
let subcommand = args[1].to_string_lossy();
let binary = format!("wt-{subcommand}");
if cli::build_command().find_subcommand(&*subcommand).is_none()
&& which::which(&binary).is_ok()
{
if let Some(forwarded) = forward_completion_to_custom(&binary, &args[1..], &shell_name)
{
let _ = std::io::stdout().write_all(forwarded.as_bytes());
CONTEXT.with(|ctx| ctx.borrow_mut().take());
return true;
}
}
}
let mut cmd = completion_command();
cmd.build();
let index: usize = std::env::var("_CLAP_COMPLETE_INDEX")
.ok()
.and_then(|i| i.parse().ok())
.unwrap_or_else(|| args.len() - 1);
let current_word = args.get(index).map(|s| s.to_string_lossy().into_owned());
let include_long_flags = current_word.as_deref() == Some("-");
let completions = match clap_complete::engine::complete(
&mut cmd,
args.clone(),
index,
current_dir.as_deref(),
) {
Ok(c) => c,
Err(_) => {
CONTEXT.with(|ctx| ctx.borrow_mut().take());
return true;
}
};
let completions = if include_long_flags {
let mut merged = completions;
let mut args_with_double_dash = args;
if let Some(word) = args_with_double_dash.get_mut(index) {
*word = OsString::from("--");
}
let mut cmd2 = completion_command();
cmd2.build();
if let Ok(long_completions) = clap_complete::engine::complete(
&mut cmd2,
args_with_double_dash,
index,
current_dir.as_deref(),
) {
for candidate in long_completions {
let value = candidate.get_value();
if !merged.iter().any(|c| c.get_value() == value) {
merged.push(candidate);
}
}
}
merged
} else {
completions
};
let shell_name = shell_name.to_string_lossy();
let completions = if shell_name.as_ref() == "bash" {
let prefix = current_word.as_deref().unwrap_or("").to_owned();
if prefix.is_empty() {
completions
} else {
completions
.into_iter()
.filter(|c| c.get_value().to_string_lossy().starts_with(&*prefix))
.collect()
}
} else {
completions
};
let ifs = std::env::var("_CLAP_IFS").ok();
let separator = ifs.as_deref().unwrap_or("\n");
let help_sep = match shell_name.as_ref() {
"zsh" => Some(":"),
"fish" | "nu" => Some("\t"),
_ => None,
};
let mut stdout = std::io::stdout();
for (i, candidate) in completions.iter().enumerate() {
if i != 0 {
let _ = write!(stdout, "{}", separator);
}
let value = candidate.get_value().to_string_lossy();
match (help_sep, candidate.get_help()) {
(Some(sep), Some(help)) => {
let _ = write!(stdout, "{}{}{}", value, sep, help);
}
_ => {
let _ = write!(stdout, "{}", value);
}
}
}
CONTEXT.with(|ctx| ctx.borrow_mut().take());
true
}
pub(crate) fn branch_value_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: false,
exclude_remote_only: false,
worktree_only: false,
})
}
pub(crate) fn worktree_branch_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: true,
exclude_remote_only: false,
worktree_only: false,
})
}
pub(crate) fn local_branches_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: false,
exclude_remote_only: true,
worktree_only: false,
})
}
pub(crate) fn worktree_only_completer() -> ArgValueCompleter {
ArgValueCompleter::new(BranchCompleter {
suppress_with_create: false,
exclude_remote_only: false,
worktree_only: true,
})
}
pub(crate) fn hook_command_name_completer() -> ArgValueCompleter {
ArgValueCompleter::new(HookCommandCompleter)
}
pub(crate) fn alias_name_completer() -> ArgValueCompleter {
ArgValueCompleter::new(AliasNameCompleter)
}
#[derive(Clone, Copy)]
struct AliasNameCompleter;
impl ValueCompleter for AliasNameCompleter {
fn complete(&self, current: &OsStr) -> Vec<CompletionCandidate> {
if current.to_str().is_some_and(|s| s.starts_with('-')) {
return Vec::new();
}
load_aliases_for_completion()
.into_keys()
.map(CompletionCandidate::new)
.collect()
}
}
#[derive(Clone, Copy)]
struct HookCommandCompleter;
impl ValueCompleter for HookCommandCompleter {
fn complete(&self, current: &OsStr) -> Vec<CompletionCandidate> {
if current.to_str().is_some_and(|s| s.starts_with('-')) {
return Vec::new();
}
let hook_type = CONTEXT.with(|ctx| {
ctx.borrow().as_ref().and_then(|ctx| {
for hook in &[
"pre-start",
"post-start",
"pre-commit",
"post-commit",
"pre-merge",
"post-merge",
"pre-remove",
] {
if ctx.contains(hook) {
return Some(*hook);
}
}
if ctx.contains("post-create") {
return Some("pre-start");
}
None
})
});
let Some(hook_type_str) = hook_type else {
return Vec::new();
};
let Ok(hook_type) = hook_type_str.parse::<HookType>() else {
return Vec::new();
};
let mut candidates = Vec::new();
let add_named_commands =
|candidates: &mut Vec<_>, config: &worktrunk::config::CommandConfig| {
candidates.extend(
config
.commands()
.filter_map(|cmd| cmd.name.as_ref())
.map(|name| CompletionCandidate::new(name.clone())),
);
};
if let Ok(user_config) = UserConfig::load()
&& let Some(config) = user_config.hooks.get(hook_type)
{
add_named_commands(&mut candidates, config);
}
if let Ok(repo) = Repository::current()
&& let Ok(Some(project_config)) = ProjectConfig::load(&repo, false)
&& let Some(config) = project_config.hooks.get(hook_type)
{
add_named_commands(&mut candidates, config);
}
candidates
}
}
#[derive(Clone, Copy)]
struct BranchCompleter {
suppress_with_create: bool,
exclude_remote_only: bool,
worktree_only: bool,
}
impl ValueCompleter for BranchCompleter {
fn complete(&self, current: &OsStr) -> Vec<CompletionCandidate> {
if current.to_str().is_some_and(|s| s.starts_with('-')) {
return Vec::new();
}
if self.suppress_with_create && suppress_switch_branch_completion() {
return Vec::new();
}
let branches = match Repository::current().and_then(|repo| repo.branches_for_completion()) {
Ok(b) => b,
Err(_) => return Vec::new(),
};
if branches.is_empty() {
return Vec::new();
}
let exclude_remote_only = self.exclude_remote_only
|| (!self.worktree_only
&& branches.len() > 100
&& branches
.iter()
.any(|b| matches!(b.category, BranchCategory::Remote(_))));
branches
.into_iter()
.filter(|branch| {
if self.worktree_only {
matches!(branch.category, BranchCategory::Worktree)
} else if exclude_remote_only {
!matches!(branch.category, BranchCategory::Remote(_))
} else {
true
}
})
.map(|branch| {
let time_str = format_relative_time_short(branch.timestamp);
let help = match branch.category {
BranchCategory::Worktree => format!("+ {}", time_str),
BranchCategory::Local => format!("/ {}", time_str),
BranchCategory::Remote(remotes) => {
format!("⇣ {} {}", time_str, remotes.join(", "))
}
};
CompletionCandidate::new(branch.name).help(Some(help.into()))
})
.collect()
}
}
fn suppress_switch_branch_completion() -> bool {
CONTEXT.with(|ctx| {
ctx.borrow()
.as_ref()
.is_some_and(|ctx| ctx.contains("--create") || ctx.contains("-c"))
})
}
struct CompletionContext {
args: Vec<OsString>,
}
impl CompletionContext {
fn contains(&self, needle: &str) -> bool {
self.args
.iter()
.any(|arg| arg.to_string_lossy().as_ref() == needle)
}
}
thread_local! {
static CONTEXT: RefCell<Option<CompletionContext>> = const { RefCell::new(None) };
}
fn completion_command() -> Command {
let cmd = cli::build_command();
let cmd = inject_alias_subcommands(cmd);
let cmd = inject_hook_subcommands(cmd);
let cmd = inject_custom_subcommands(cmd);
hide_non_positional_options_for_completion(cmd)
}
pub(crate) fn inject_hook_subcommands(cmd: Command) -> Command {
cmd.mut_subcommand("hook", |mut hook| {
for &name in cli::HOOK_TYPE_NAMES {
if hook.get_subcommands().any(|s| s.get_name() == name) {
continue;
}
hook = hook.subcommand(build_hook_completion_command(name));
}
hook
})
}
fn build_hook_completion_command(name: &'static str) -> Command {
let about: &'static str = Box::leak(format!("Run {name} hooks").into_boxed_str());
Command::new(name)
.about(about)
.arg(
clap::Arg::new("dry-run")
.long("dry-run")
.action(clap::ArgAction::SetTrue)
.help("Show what would run without executing"),
)
.arg(
clap::Arg::new("foreground")
.long("foreground")
.action(clap::ArgAction::SetTrue)
.help("Run in foreground (block until complete)"),
)
.arg(
clap::Arg::new("yes")
.short('y')
.long("yes")
.action(clap::ArgAction::SetTrue)
.help("Skip approval prompts for project hooks"),
)
.arg(
clap::Arg::new("var")
.long("var")
.value_name("KEY=VALUE")
.num_args(1)
.action(clap::ArgAction::Append)
.help("Set template variable (deprecated — prefer --KEY=VALUE)"),
)
.arg(
clap::Arg::new("name")
.num_args(0..)
.add(hook_command_name_completer())
.help("Filter by command name(s)"),
)
}
fn inject_alias_subcommands(cmd: Command) -> Command {
let aliases = load_aliases_for_completion();
if aliases.is_empty() {
return cmd;
}
let mut cmd = cmd;
for (name, cmd_config) in &aliases {
if cmd.get_subcommands().any(|s| s.get_name() == name.as_str()) {
continue;
}
cmd = cmd.subcommand(build_alias_completion_command(name, cmd_config));
}
cmd.mut_subcommand("step", |mut step| {
for (name, cmd_config) in aliases {
if step
.get_subcommands()
.any(|s| s.get_name() == name.as_str())
{
continue;
}
step = step.subcommand(build_alias_completion_command(&name, &cmd_config));
}
step
})
}
fn build_alias_completion_command(name: &str, cmd_config: &CommandConfig) -> Command {
let first_template = cmd_config
.commands()
.next()
.map(|c| c.template.as_str())
.unwrap_or("");
let help = truncate_template(first_template);
let name: &'static str = Box::leak(name.to_string().into_boxed_str());
let about: &'static str = Box::leak(format!("alias: {help}").into_boxed_str());
Command::new(name)
.about(about)
.arg(clap::Arg::new("dry-run").long("dry-run"))
.arg(clap::Arg::new("yes").short('y').long("yes"))
.arg(
clap::Arg::new("var")
.long("var")
.num_args(1)
.action(clap::ArgAction::Append),
)
}
fn load_aliases_for_completion() -> BTreeMap<String, CommandConfig> {
let mut aliases = BTreeMap::new();
if let Ok(repo) = Repository::current() {
if let Ok(user_config) = UserConfig::load() {
let project_id = repo.project_identifier().ok();
aliases.extend(user_config.aliases(project_id.as_deref()));
}
if let Ok(Some(project_config)) = ProjectConfig::load(&repo, false) {
append_aliases(&mut aliases, &project_config.aliases);
}
} else if let Ok(user_config) = UserConfig::load() {
aliases.extend(user_config.aliases(None));
}
aliases
}
fn truncate_template(template: &str) -> &str {
let s = template.trim();
let first_line = s.lines().next().unwrap_or(s);
if first_line.len() > 60 {
let mut end = 57;
while end > 0 && !first_line.is_char_boundary(end) {
end -= 1;
}
&first_line[..end]
} else {
first_line
}
}
fn forward_completion_to_custom(binary: &str, args: &[OsString], shell: &OsStr) -> Option<String> {
let mut child_args: Vec<OsString> = vec![OsString::from(binary)];
child_args.extend_from_slice(&args[1..]);
let index = std::env::var("_CLAP_COMPLETE_INDEX")
.ok()
.and_then(|i| i.parse::<usize>().ok())
.map(|i| i.saturating_sub(1));
let mut cmd = std::process::Command::new(binary);
cmd.arg("--");
cmd.args(&child_args);
cmd.env("COMPLETE", shell);
cmd.env(
"_CLAP_IFS",
std::env::var("_CLAP_IFS").unwrap_or_else(|_| "\n".to_string()),
);
if let Some(idx) = index {
cmd.env("_CLAP_COMPLETE_INDEX", idx.to_string());
}
let result = cmd
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null())
.spawn()
.ok()?
.wait_with_output()
.ok()?;
if result.status.success() {
String::from_utf8(result.stdout).ok()
} else {
None
}
}
fn inject_custom_subcommands(cmd: Command) -> Command {
inject_custom_subcommand_list(cmd, discover_custom_subcommands())
}
fn inject_custom_subcommand_list(mut cmd: Command, customs: Vec<String>) -> Command {
for name in customs {
if cmd.find_subcommand(&name).is_some() {
continue;
}
let name: &'static str = Box::leak(name.into_boxed_str());
let about: &'static str = Box::leak(format!("custom: wt-{name}").into_boxed_str());
let sub = Command::new(name)
.about(about)
.allow_external_subcommands(true);
cmd = cmd.subcommand(sub);
}
cmd
}
fn discover_custom_subcommands() -> Vec<String> {
let Some(path_var) = std::env::var_os("PATH") else {
return Vec::new();
};
discover_custom_subcommands_in(&path_var)
}
fn discover_custom_subcommands_in(path_var: &OsStr) -> Vec<String> {
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
for dir in std::env::split_paths(path_var) {
let Ok(entries) = std::fs::read_dir(&dir) else {
continue;
};
for entry in entries.flatten() {
let file_name = entry.file_name();
let Some(name) = file_name.to_str() else {
continue;
};
let Some(subcommand) = name.strip_prefix("wt-") else {
continue;
};
#[cfg(windows)]
let subcommand = std::path::Path::new(subcommand)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(subcommand);
if subcommand.is_empty() || !seen.insert(subcommand.to_string()) {
continue;
}
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Ok(meta) = entry.metadata()
&& meta.permissions().mode() & 0o111 == 0
{
continue;
}
}
result.push(subcommand.to_string());
}
}
result.sort();
result
}
fn hide_non_positional_options_for_completion(cmd: Command) -> Command {
fn process_command(cmd: Command, is_root: bool) -> Command {
let cmd = cmd.disable_help_flag(true).arg(
clap::Arg::new("help")
.short('h')
.long("help")
.action(clap::ArgAction::Help)
.help("Print help (see more with '--help')"),
);
let cmd = if is_root {
cmd.disable_version_flag(true).arg(
clap::Arg::new("version")
.short('V')
.long("version")
.action(clap::ArgAction::Version)
.help("Print version"),
)
} else {
cmd
};
let cmd = cmd.mut_args(|arg| {
if arg.is_positional() || arg.is_hide_set() {
arg
} else {
arg.hide(true)
}
});
cmd.mut_subcommands(|sub| process_command(sub, false))
}
process_command(cmd, true)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_truncate_template() {
assert_eq!(truncate_template("echo hello"), "echo hello");
assert_eq!(truncate_template("line one\nline two"), "line one");
assert_eq!(truncate_template(" spaced \n"), "spaced");
let s60 = "a".repeat(60);
assert_eq!(truncate_template(&s60), s60.as_str());
let s61 = "b".repeat(61);
assert_eq!(truncate_template(&s61), &"b".repeat(57));
let multi = "a".repeat(56) + "€€";
let result = truncate_template(&multi);
assert_eq!(result.len(), 56);
assert_eq!(result, "a".repeat(56));
}
#[test]
fn test_discover_empty_path() {
let result = discover_custom_subcommands_in(OsStr::new(""));
assert!(result.is_empty());
}
#[test]
fn test_discover_nonexistent_dir() {
let result =
discover_custom_subcommands_in(OsStr::new("/nonexistent/path/xxxxxxxx_wt_test"));
assert!(result.is_empty());
}
#[cfg(unix)]
#[test]
fn test_discover_finds_wt_executables() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
for name in ["wt-alpha", "wt-beta"] {
let path = dir.path().join(name);
std::fs::write(&path, "#!/bin/sh\n").unwrap();
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
}
let result = discover_custom_subcommands_in(dir.path().as_os_str());
assert_eq!(result, vec!["alpha", "beta"]);
}
#[cfg(unix)]
#[test]
fn test_discover_skips_non_executable() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let exec = dir.path().join("wt-exec");
std::fs::write(&exec, "#!/bin/sh\n").unwrap();
std::fs::set_permissions(&exec, std::fs::Permissions::from_mode(0o755)).unwrap();
let noexec = dir.path().join("wt-noexec");
std::fs::write(&noexec, "data").unwrap();
std::fs::set_permissions(&noexec, std::fs::Permissions::from_mode(0o644)).unwrap();
let result = discover_custom_subcommands_in(dir.path().as_os_str());
assert_eq!(result, vec!["exec"]);
}
#[cfg(unix)]
#[test]
fn test_discover_deduplicates_across_dirs() {
use std::os::unix::fs::PermissionsExt;
let dir1 = tempfile::tempdir().unwrap();
let dir2 = tempfile::tempdir().unwrap();
for dir in [dir1.path(), dir2.path()] {
let path = dir.join("wt-dup");
std::fs::write(&path, "#!/bin/sh\n").unwrap();
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
}
let path = std::env::join_paths([dir1.path(), dir2.path()]).unwrap();
let result = discover_custom_subcommands_in(&path);
assert_eq!(result, vec!["dup"]);
}
#[cfg(unix)]
#[test]
fn test_discover_skips_bare_prefix_and_non_matching() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let empty = dir.path().join("wt-");
std::fs::write(&empty, "#!/bin/sh\n").unwrap();
std::fs::set_permissions(&empty, std::fs::Permissions::from_mode(0o755)).unwrap();
let other = dir.path().join("other-tool");
std::fs::write(&other, "#!/bin/sh\n").unwrap();
std::fs::set_permissions(&other, std::fs::Permissions::from_mode(0o755)).unwrap();
let result = discover_custom_subcommands_in(dir.path().as_os_str());
assert!(result.is_empty());
}
#[cfg(unix)]
#[test]
fn test_discover_results_are_sorted() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
for name in ["wt-zebra", "wt-apple", "wt-mango"] {
let path = dir.path().join(name);
std::fs::write(&path, "#!/bin/sh\n").unwrap();
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
}
let result = discover_custom_subcommands_in(dir.path().as_os_str());
assert_eq!(result, vec!["apple", "mango", "zebra"]);
}
#[test]
fn test_inject_custom_adds_subcommands() {
let cmd = Command::new("wt")
.subcommand(Command::new("switch"))
.subcommand(Command::new("list"));
let cmd = inject_custom_subcommand_list(cmd, vec!["sync".into(), "deploy".into()]);
assert!(cmd.find_subcommand("sync").is_some());
assert!(cmd.find_subcommand("deploy").is_some());
assert!(cmd.find_subcommand("switch").is_some());
assert!(cmd.find_subcommand("list").is_some());
}
#[test]
fn test_inject_custom_skips_builtins() {
let cmd = Command::new("wt").subcommand(Command::new("switch").about("built-in switch"));
let cmd = inject_custom_subcommand_list(cmd, vec!["switch".into(), "sync".into()]);
let switch = cmd.find_subcommand("switch").unwrap();
assert_eq!(switch.get_about().unwrap().to_string(), "built-in switch");
let sync = cmd.find_subcommand("sync").unwrap();
assert!(sync.get_about().unwrap().to_string().contains("custom"));
}
#[test]
fn test_inject_custom_empty_list() {
let cmd = Command::new("wt").subcommand(Command::new("switch"));
let cmd = inject_custom_subcommand_list(cmd, vec![]);
assert_eq!(cmd.get_subcommands().count(), 1);
}
#[test]
fn test_inject_custom_allows_trailing_args() {
let cmd = Command::new("wt");
let cmd = inject_custom_subcommand_list(cmd, vec!["sync".into()]);
let sync = cmd.find_subcommand("sync").unwrap();
assert!(sync.is_allow_external_subcommands_set());
}
#[test]
fn test_forward_to_nonexistent_binary() {
let result = forward_completion_to_custom(
"/nonexistent/binary/xxxxxxxx_wt_test",
&[OsString::from("test")],
OsStr::new("bash"),
);
assert!(result.is_none());
}
#[cfg(unix)]
#[test]
fn test_forward_to_custom_binary() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let script = dir.path().join("wt-fake");
std::fs::write(&script, "#!/bin/sh\nprintf '%s\\n%s' '--all' '--verbose'\n").unwrap();
std::fs::set_permissions(&script, std::fs::Permissions::from_mode(0o755)).unwrap();
let result = forward_completion_to_custom(
script.to_str().unwrap(),
&[OsString::from("fake"), OsString::from("--")],
OsStr::new("bash"),
);
assert!(result.is_some());
let output = result.unwrap();
assert!(output.contains("--all"));
assert!(output.contains("--verbose"));
}
#[cfg(unix)]
#[test]
fn test_forward_to_failing_binary() {
use std::os::unix::fs::PermissionsExt;
let dir = tempfile::tempdir().unwrap();
let script = dir.path().join("wt-fail");
std::fs::write(&script, "#!/bin/sh\nexit 1\n").unwrap();
std::fs::set_permissions(&script, std::fs::Permissions::from_mode(0o755)).unwrap();
let result = forward_completion_to_custom(
script.to_str().unwrap(),
&[OsString::from("fail")],
OsStr::new("bash"),
);
assert!(result.is_none());
}
}