use std::collections::HashMap;
use std::path::{Path, PathBuf};
use crate::display::format_relative_time_short;
use anyhow::{Context, bail};
use color_print::cformat;
use dunce::canonicalize;
use serde::Serialize;
use worktrunk::HookType;
use worktrunk::config::{
UserConfig, ValidationScope, expand_template, template_references_var, validate_template,
};
use worktrunk::git::remote_ref::{
self, AzureDevOpsProvider, GitHubProvider, GitLabProvider, GiteaProvider, RemoteRefInfo,
RemoteRefProvider,
};
use worktrunk::git::{
GitError, GitRemoteUrl, RefContext, RefType, Repository, SwitchSuggestionCtx,
current_or_recover,
};
use worktrunk::shell_exec::{ShellEscapeMode, directive_shell_escape_mode, shell_escape_for};
use worktrunk::styling::{
eprintln, format_with_gutter, hint_message, info_message, progress_message, suggest_command,
warning_message,
};
use super::resolve::{compute_worktree_path, offer_bare_repo_worktree_path_fix, path_mismatch};
use super::types::{CreationMethod, SwitchBranchInfo, SwitchPlan, SwitchResult};
use crate::cli::SwitchFormat;
use crate::commands::backup::back_up_clobbered_path_now;
use crate::commands::command_approval::approve_hooks;
use crate::commands::command_executor::FailureStrategy;
use crate::commands::command_executor::{CommandContext, build_hook_context};
use crate::commands::hook_plan::{ApprovedHookPlan, HookPlanBuilder, register_planned};
use crate::commands::hooks::{HookAnnouncer, execute_hook};
use crate::commands::template_vars::TemplateVars;
use crate::output::{
execute_user_command, handle_switch_output, is_shell_integration_active,
prompt_shell_integration,
};
struct ResolvedTarget {
branch: String,
method: CreationMethod,
}
fn format_ref_context(ctx: &impl RefContext) -> String {
let mut status_parts = vec![format!("by @{}", ctx.author()), ctx.state().to_string()];
if ctx.draft() {
status_parts.push("draft".to_string());
}
status_parts.push(ctx.source_ref());
let status_line = status_parts.join(" · ");
cformat!(
"<bold>{}</> ({}{})\n{status_line} · <bright-black>{}</>",
ctx.title(),
ctx.ref_type().symbol(),
ctx.number(),
ctx.url()
)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PrProviderChoice {
GitHub,
Gitea,
AzureDevOps,
}
fn choose_pr_provider(repo: &Repository) -> anyhow::Result<PrProviderChoice> {
if let Some(platform_raw) = repo
.load_project_config()?
.and_then(|c| c.forge_platform().map(str::to_string))
{
let platform = platform_raw.to_ascii_lowercase();
match platform.as_str() {
"github" => return Ok(PrProviderChoice::GitHub),
"gitea" => return Ok(PrProviderChoice::Gitea),
"azure-devops" | "azuredevops" => return Ok(PrProviderChoice::AzureDevOps),
"gitlab" => {
bail!("forge.platform is set to gitlab; use mr:<number> instead of pr:<number>")
}
_ => bail!(
"Invalid forge.platform value `{platform_raw}` in .config/wt.toml; \
expected one of: github, gitlab, gitea, azure-devops"
),
}
}
let all_parsed: Vec<_> = repo
.all_remote_urls()
.into_iter()
.filter_map(|(_, url)| GitRemoteUrl::parse(&url))
.collect();
if all_parsed.iter().any(|u| u.is_github()) {
return Ok(PrProviderChoice::GitHub);
}
if all_parsed.iter().any(|u| u.is_gitea()) {
return Ok(PrProviderChoice::Gitea);
}
if all_parsed.iter().any(|u| u.is_azure_devops()) {
return Ok(PrProviderChoice::AzureDevOps);
}
if all_parsed.iter().any(|u| u.is_gitlab()) {
bail!("Detected GitLab remote; use mr:<number> instead of pr:<number>")
}
let Some(host) = repo
.primary_remote()
.ok()
.and_then(|remote| repo.remote_url(&remote))
.and_then(|url| GitRemoteUrl::parse(&url))
.map(|u| u.host().to_string())
else {
return Ok(PrProviderChoice::GitHub);
};
if remote_ref::gitea::is_authed_for(&host) && !remote_ref::github::is_authed_for(&host) {
Ok(PrProviderChoice::Gitea)
} else {
Ok(PrProviderChoice::GitHub)
}
}
fn resolve_pr_target(
repo: &Repository,
number: u32,
create: bool,
base: Option<&str>,
) -> anyhow::Result<ResolvedTarget> {
if base.is_some() {
return Err(GitError::RefBaseConflict {
ref_type: RefType::Pr,
number,
}
.into());
}
match choose_pr_provider(repo)? {
PrProviderChoice::GitHub => resolve_remote_ref(repo, &GitHubProvider, number, create, base),
PrProviderChoice::Gitea => resolve_remote_ref(repo, &GiteaProvider, number, create, base),
PrProviderChoice::AzureDevOps => {
resolve_remote_ref(repo, &AzureDevOpsProvider, number, create, base)
}
}
}
fn resolve_pr_base(
repo: &Repository,
number: u32,
) -> anyhow::Result<(String, Option<(String, String)>)> {
match choose_pr_provider(repo)? {
PrProviderChoice::GitHub => resolve_remote_ref_as_base(repo, &GitHubProvider, number),
PrProviderChoice::Gitea => resolve_remote_ref_as_base(repo, &GiteaProvider, number),
PrProviderChoice::AzureDevOps => {
resolve_remote_ref_as_base(repo, &AzureDevOpsProvider, number)
}
}
}
fn resolve_remote_ref(
repo: &Repository,
provider: &dyn RemoteRefProvider,
number: u32,
create: bool,
base: Option<&str>,
) -> anyhow::Result<ResolvedTarget> {
let ref_type = provider.ref_type();
let symbol = ref_type.symbol();
if base.is_some() {
return Err(GitError::RefBaseConflict { ref_type, number }.into());
}
eprintln!(
"{}",
progress_message(cformat!("Fetching {} {symbol}{number}...", ref_type.name()))
);
let info = provider.fetch_info(number, repo)?;
eprintln!("{}", format_with_gutter(&format_ref_context(&info), None));
if create {
return Err(GitError::RefCreateConflict {
ref_type,
number,
branch: info.source_branch.clone(),
}
.into());
}
if info.is_cross_repo {
return resolve_fork_ref(repo, provider, number, &info);
}
resolve_same_repo_ref(repo, &info)
}
fn resolve_fork_ref(
repo: &Repository,
provider: &dyn RemoteRefProvider,
number: u32,
info: &RemoteRefInfo,
) -> anyhow::Result<ResolvedTarget> {
let ref_type = provider.ref_type();
let repo_root = repo.repo_path()?;
let local_branch = remote_ref::local_branch_name(info);
let expected_remote = match remote_ref::find_remote(repo, info) {
Ok(remote) => Some(remote),
Err(e) => {
log::debug!("Could not resolve remote for {}: {e:#}", ref_type.name());
None
}
};
if let Some(tracks_this) = remote_ref::branch_tracks_ref(
repo_root,
&local_branch,
provider,
number,
expected_remote.as_deref(),
) {
if tracks_this {
eprintln!(
"{}",
info_message(cformat!(
"Branch <bold>{local_branch}</> already configured for {}",
ref_type.display(number)
))
);
return Ok(ResolvedTarget {
branch: local_branch,
method: CreationMethod::Regular {
create_branch: false,
base_branch: None,
base_pr_upstream: None,
},
});
}
if let Some(prefixed) = info.prefixed_local_branch_name() {
if let Some(prefixed_tracks) = remote_ref::branch_tracks_ref(
repo_root,
&prefixed,
provider,
number,
expected_remote.as_deref(),
) {
if prefixed_tracks {
eprintln!(
"{}",
info_message(cformat!(
"Branch <bold>{prefixed}</> already configured for {}",
ref_type.display(number)
))
);
return Ok(ResolvedTarget {
branch: prefixed,
method: CreationMethod::Regular {
create_branch: false,
base_branch: None,
base_pr_upstream: None,
},
});
}
return Err(GitError::BranchTracksDifferentRef {
branch: prefixed,
ref_type,
number,
}
.into());
}
let remote = remote_ref::find_remote(repo, info)?;
return Ok(ResolvedTarget {
branch: prefixed,
method: CreationMethod::ForkRef {
ref_type,
number,
ref_path: provider.ref_path(number),
fork_push_url: None,
ref_url: info.url.clone(),
remote,
},
});
}
return Err(GitError::BranchTracksDifferentRef {
branch: local_branch,
ref_type,
number,
}
.into());
}
let (fork_push_url, remote) = match ref_type {
RefType::Pr => {
let remote = remote_ref::find_remote(repo, info)?;
(info.fork_push_url.clone(), remote)
}
RefType::Mr => {
let urls =
worktrunk::git::remote_ref::gitlab::fetch_gitlab_project_urls(info, repo_root)?;
let target_url = urls.target_url.ok_or_else(|| {
anyhow::anyhow!(
"{} is from a fork but glab didn't provide target project URL; \
upgrade glab or checkout the fork branch manually",
ref_type.display(number)
)
})?;
let remote = repo.find_remote_by_url(&target_url).ok_or_else(|| {
anyhow::anyhow!(
"No remote found for target project; \
add a remote pointing to {} (e.g., `git remote add upstream {}`)",
target_url,
target_url
)
})?;
if urls.fork_push_url.is_none() {
anyhow::bail!(
"{} is from a fork but glab didn't provide source project URL; \
upgrade glab or checkout the fork branch manually",
ref_type.display(number)
);
}
(urls.fork_push_url, remote)
}
};
Ok(ResolvedTarget {
branch: local_branch,
method: CreationMethod::ForkRef {
ref_type,
number,
ref_path: provider.ref_path(number),
fork_push_url,
ref_url: info.url.clone(),
remote,
},
})
}
fn resolve_same_repo_ref(
repo: &Repository,
info: &RemoteRefInfo,
) -> anyhow::Result<ResolvedTarget> {
fetch_same_repo_branch(repo, info)?;
Ok(ResolvedTarget {
branch: info.source_branch.clone(),
method: CreationMethod::Regular {
create_branch: false,
base_branch: None,
base_pr_upstream: None,
},
})
}
fn fetch_same_repo_branch(repo: &Repository, info: &RemoteRefInfo) -> anyhow::Result<()> {
let remote = remote_ref::find_remote(repo, info)?;
let branch = &info.source_branch;
eprintln!(
"{}",
progress_message(cformat!("Fetching <bold>{branch}</> from {remote}..."))
);
let refspec = format!("+refs/heads/{branch}:refs/remotes/{remote}/{branch}");
repo.run_command(&["fetch", "--", &remote, &refspec])
.with_context(|| cformat!("Failed to fetch branch <bold>{}</> from {}", branch, remote))?;
Ok(())
}
fn resolve_base_ref(
repo: &Repository,
base: &str,
) -> anyhow::Result<(String, Option<(String, String)>)> {
if let Some(suffix) = base.strip_prefix("pr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_pr_base(repo, number);
}
if let Some(suffix) = base.strip_prefix("mr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_remote_ref_as_base(repo, &GitLabProvider, number);
}
let resolved = repo.resolve_worktree_name(base)?;
if !repo.ref_exists(&resolved)? {
let remotes = repo.branch(&resolved).remotes()?;
if remotes.len() == 1 {
return Ok((format!("{}/{}", remotes[0], resolved), None));
}
}
Ok((resolved, None))
}
fn resolve_remote_ref_as_base(
repo: &Repository,
provider: &dyn RemoteRefProvider,
number: u32,
) -> anyhow::Result<(String, Option<(String, String)>)> {
let ref_type = provider.ref_type();
let symbol = ref_type.symbol();
eprintln!(
"{}",
progress_message(cformat!(
"Fetching base {} {symbol}{number}...",
ref_type.name()
))
);
let info = provider.fetch_info(number, repo)?;
eprintln!("{}", format_with_gutter(&format_ref_context(&info), None));
if !info.is_cross_repo {
fetch_same_repo_branch(repo, &info)?;
let remote = remote_ref::find_remote(repo, &info)?;
return Ok((
info.source_branch.clone(),
Some((remote, info.source_branch.clone())),
));
}
let remote = remote_ref::find_remote(repo, &info)?;
let display = ref_type.display(number);
repo.run_command(&["fetch", "--", &remote, &provider.tracking_ref(number)])
.with_context(|| cformat!("Failed to fetch <bold>{display}</> from {remote}"))?;
let sha = repo
.run_command(&["rev-parse", "FETCH_HEAD"])
.context("Failed to resolve FETCH_HEAD to a commit SHA")?
.trim()
.to_string();
Ok((sha, None))
}
fn resolve_switch_target(
repo: &Repository,
branch: &str,
create: bool,
base: Option<&str>,
) -> anyhow::Result<ResolvedTarget> {
if let Some(suffix) = branch.strip_prefix("pr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_pr_target(repo, number, create, base);
}
if let Some(suffix) = branch.strip_prefix("mr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_remote_ref(repo, &GitLabProvider, number, create, base);
}
let mut resolved_branch = repo
.resolve_worktree_name(branch)
.context("Failed to resolve branch name")?;
if !create
&& repo.worktree_for_branch(&resolved_branch)?.is_none()
&& !repo.branch(&resolved_branch).exists_locally()?
&& let Some(local_name) = repo.strip_remote_prefix(&resolved_branch)
{
resolved_branch = local_name;
}
let (resolved_base, base_pr_upstream) = if let Some(base_str) = base {
if !create {
eprintln!(
"{}",
warning_message("--base flag is only used with --create, ignoring")
);
(None, None)
} else {
let (resolved, upstream) = resolve_base_ref(repo, base_str)?;
if !repo.ref_exists(&resolved)? {
return Err(GitError::ReferenceNotFound {
reference: resolved,
}
.into());
}
(Some(resolved), upstream)
}
} else {
(None, None)
};
if create {
let branch_handle = repo.branch(&resolved_branch);
if branch_handle.exists_locally()? {
return Err(GitError::BranchAlreadyExists {
branch: resolved_branch,
}
.into());
}
let remotes = branch_handle.remotes()?;
if !remotes.is_empty() {
let remote_ref = format!("{}/{}", remotes[0], resolved_branch);
eprintln!(
"{}",
warning_message(cformat!(
"Branch <bold>{resolved_branch}</> exists on remote ({remote_ref}); creating new branch from base instead"
))
);
let remove_cmd = suggest_command("remove", &[&resolved_branch], &["--foreground"]);
let switch_cmd = suggest_command("switch", &[&resolved_branch], &[]);
eprintln!(
"{}",
hint_message(cformat!(
"To switch to the remote branch, delete this branch and run without <underline>--create</>: <underline>{remove_cmd} && {switch_cmd}</>"
))
);
}
}
let base_branch = if create {
resolved_base.or_else(|| {
repo.resolve_target_branch(None)
.ok()
.filter(|b| repo.branch(b).exists_locally().unwrap_or(false))
})
} else {
None
};
Ok(ResolvedTarget {
branch: resolved_branch,
method: CreationMethod::Regular {
create_branch: create,
base_branch,
base_pr_upstream,
},
})
}
fn validate_worktree_creation(
repo: &Repository,
branch: &str,
path: &Path,
clobber: bool,
method: &CreationMethod,
) -> anyhow::Result<bool> {
if let CreationMethod::Regular {
create_branch: false,
..
} = method
&& !repo.branch(branch).exists()?
{
return Err(GitError::BranchNotFound {
branch: branch.to_string(),
show_create_hint: true,
last_fetch_ago: format_last_fetch_ago(repo),
pr_mr_platform: repo.detect_ref_type(),
}
.into());
}
if let Some((existing_path, occupant)) = repo.worktree_at_path(path)? {
if !existing_path.exists() {
let occupant_branch = occupant.unwrap_or_else(|| branch.to_string());
return Err(GitError::WorktreeMissing {
branch: occupant_branch,
}
.into());
}
return Err(GitError::WorktreePathOccupied {
branch: branch.to_string(),
path: path.to_path_buf(),
occupant,
}
.into());
}
if !path.exists() {
return Ok(false);
}
if clobber {
return Ok(true);
}
let is_create = matches!(
method,
CreationMethod::Regular {
create_branch: true,
..
}
);
Err(GitError::WorktreePathExists {
branch: branch.to_string(),
path: path.to_path_buf(),
create: is_create,
}
.into())
}
fn setup_fork_branch(
repo: &Repository,
branch: &str,
remote: &str,
remote_ref: &str,
fork_push_url: Option<&str>,
worktree_path: &Path,
label: &str,
) -> anyhow::Result<()> {
repo.run_command(&["branch", "--", branch, "FETCH_HEAD"])
.with_context(|| {
cformat!(
"Failed to create local branch <bold>{}</> from {}",
branch,
label
)
})?;
let branch_remote_key = format!("branch.{}.remote", branch);
let branch_merge_key = format!("branch.{}.merge", branch);
let merge_ref = format!("refs/{}", remote_ref);
repo.set_config(&branch_remote_key, remote)
.with_context(|| format!("Failed to configure branch.{}.remote", branch))?;
repo.set_config(&branch_merge_key, &merge_ref)
.with_context(|| format!("Failed to configure branch.{}.merge", branch))?;
if let Some(url) = fork_push_url {
let branch_push_remote_key = format!("branch.{}.pushRemote", branch);
repo.set_config(&branch_push_remote_key, url)
.with_context(|| format!("Failed to configure branch.{}.pushRemote", branch))?;
}
let worktree_path_str = worktree_path.to_string_lossy();
let git_args = ["worktree", "add", "--", worktree_path_str.as_ref(), branch];
repo.run_command_delayed_stream(
&git_args,
Repository::SLOW_OPERATION_DELAY_MS,
Some(
progress_message(cformat!("Creating worktree for <bold>{}</>...", branch)).to_string(),
),
)
.map_err(|e| worktree_creation_error(&e, branch.to_string(), None))?;
Ok(())
}
fn plan_switch(
repo: &Repository,
branch: &str,
create: bool,
base: Option<&str>,
clobber: bool,
config: &UserConfig,
) -> anyhow::Result<SwitchPlan> {
let new_previous = repo.current_worktree().branch().ok().flatten();
let target = resolve_switch_target(repo, branch, create, base)?;
match repo.worktree_for_branch(&target.branch)? {
Some(existing_path) if existing_path.exists() => {
return Ok(SwitchPlan::Existing {
path: canonicalize(&existing_path).unwrap_or(existing_path),
branch: Some(target.branch),
new_previous,
});
}
Some(_) => {
return Err(GitError::WorktreeMissing {
branch: target.branch,
}
.into());
}
None => {}
}
if !create {
let candidate = Path::new(branch);
let abs_path = if candidate.is_absolute() {
Some(candidate.to_path_buf())
} else if candidate.components().count() > 1 {
std::env::current_dir().ok().map(|cwd| cwd.join(candidate))
} else {
None
};
if let Some(abs_path) = abs_path
&& let Some((path, wt_branch)) = repo.worktree_at_path(&abs_path)?
{
let canonical = canonicalize(&path).unwrap_or_else(|_| path.clone());
return Ok(SwitchPlan::Existing {
path: canonical,
branch: wt_branch,
new_previous,
});
}
}
let expected_path = compute_worktree_path(repo, &target.branch, config)?;
let needs_clobber_backup = validate_worktree_creation(
repo,
&target.branch,
&expected_path,
clobber,
&target.method,
)?;
Ok(SwitchPlan::Create {
branch: target.branch,
worktree_path: expected_path,
method: target.method,
needs_clobber_backup,
new_previous,
})
}
fn execute_switch(
repo: &Repository,
plan: SwitchPlan,
config: &UserConfig,
force: bool,
run_hooks: bool,
hook_plan: &ApprovedHookPlan,
) -> anyhow::Result<(SwitchResult, SwitchBranchInfo)> {
match plan {
SwitchPlan::Existing {
path,
branch,
new_previous,
} => {
let current_dir = std::env::current_dir()
.ok()
.and_then(|p| canonicalize(&p).ok());
let already_at_worktree = current_dir
.as_ref()
.map(|cur| cur == &path)
.unwrap_or(false);
if !already_at_worktree {
let _ = repo.set_switch_previous(new_previous.as_deref());
}
let result = if already_at_worktree {
SwitchResult::AlreadyAt(path)
} else {
SwitchResult::Existing { path }
};
Ok((
result,
SwitchBranchInfo {
branch,
expected_path: None,
},
))
}
SwitchPlan::Create {
branch,
worktree_path,
method,
needs_clobber_backup,
new_previous,
} => {
if needs_clobber_backup {
let backup_path = back_up_clobbered_path_now(&worktree_path)?;
let path_display = worktrunk::path::format_path_for_display(&worktree_path);
let backup_display = worktrunk::path::format_path_for_display(&backup_path);
eprintln!(
"{}",
warning_message(cformat!(
"Moved <bold>{path_display}</> to <bold>{backup_display}</> (--clobber)"
))
);
}
let (created_branch, base_branch, from_remote) = match &method {
CreationMethod::Regular {
create_branch,
base_branch,
base_pr_upstream,
} => {
let branch_handle = repo.branch(&branch);
let local_branch_existed =
!create_branch && branch_handle.exists_locally().unwrap_or(false);
let worktree_path_str = worktree_path.to_string_lossy();
let mut args: Vec<&str> = vec!["worktree", "add"];
let tracking_ref;
let trailing_ref: Option<&str> = if *create_branch {
args.push("-b");
args.push(&branch);
base_branch.as_deref()
} else if !local_branch_existed {
let remotes = branch_handle.remotes().unwrap_or_default();
if remotes.len() == 1 {
tracking_ref = format!("{}/{}", remotes[0], branch);
args.extend(["-b", &branch]);
Some(tracking_ref.as_str())
} else {
Some(branch.as_str())
}
} else {
Some(branch.as_str())
};
args.push("--");
args.push(worktree_path_str.as_ref());
if let Some(r) = trailing_ref {
args.push(r);
}
let progress_msg = Some(
progress_message(cformat!("Creating worktree for <bold>{}</>...", branch))
.to_string(),
);
if let Err(e) = repo.run_command_delayed_stream(
&args,
Repository::SLOW_OPERATION_DELAY_MS,
progress_msg,
) {
return Err(worktree_creation_error(
&e,
branch.clone(),
base_branch.clone(),
)
.into());
}
if *create_branch
&& let Some(base) = base_branch
&& repo.is_remote_tracking_branch(base)
{
branch_handle.unset_upstream()?;
}
if *create_branch
&& let Some((upstream_remote, upstream_branch)) = base_pr_upstream
{
repo.set_config(&format!("branch.{branch}.remote"), upstream_remote)?;
repo.set_config(
&format!("branch.{branch}.merge"),
&format!("refs/heads/{upstream_branch}"),
)?;
}
let from_remote = if !create_branch && !local_branch_existed {
branch_handle.upstream()?
} else {
None
};
(*create_branch, base_branch.clone(), from_remote)
}
CreationMethod::ForkRef {
ref_type,
number,
ref_path,
fork_push_url,
ref_url: _,
remote,
} => {
let label = ref_type.display(*number);
repo.run_command(&["fetch", "--", remote, ref_path])
.with_context(|| format!("Failed to fetch {} from {}", label, remote))?;
let setup_result = setup_fork_branch(
repo,
&branch,
remote,
ref_path,
fork_push_url.as_deref(),
&worktree_path,
&label,
);
if let Err(e) = setup_result {
let _ = repo.run_command(&["branch", "-D", "--", &branch]);
return Err(e);
}
if let Some(url) = fork_push_url {
eprintln!(
"{}",
info_message(cformat!("Push configured to fork: <underline>{url}</>"))
);
} else {
eprintln!(
"{}",
warning_message(cformat!(
"Using prefixed branch name <bold>{branch}</> due to name conflict"
))
);
eprintln!(
"{}",
hint_message(
"Push to fork is not supported with prefixed branches; feedback welcome at https://github.com/max-sixty/worktrunk/issues/714",
)
);
}
(false, None, Some(label))
}
};
let base_worktree_path = base_branch
.as_ref()
.and_then(|b| {
Repository::at(repo.discovery_path())
.and_then(|fresh| fresh.worktree_for_branch(b))
.ok()
.flatten()
})
.map(|p| worktrunk::path::to_posix_path(&p.to_string_lossy()));
let (pr_number, pr_url) = match &method {
CreationMethod::ForkRef {
number, ref_url, ..
} => (Some(*number), Some(ref_url.clone())),
CreationMethod::Regular { .. } => (None, None),
};
if run_hooks {
let hook_repo = Repository::at(&worktree_path)?;
let ctx =
CommandContext::new(&hook_repo, config, Some(&branch), &worktree_path, force);
let mut vars = TemplateVars::new()
.with_target(&branch)
.with_target_worktree_path(&worktree_path);
match &method {
CreationMethod::Regular { base_branch, .. } => {
vars = vars
.with_base_strs(base_branch.as_deref(), base_worktree_path.as_deref());
}
CreationMethod::ForkRef {
number, ref_url, ..
} => {
vars = vars.with_pr(Some(*number), Some(ref_url));
}
}
ctx.execute_pre_create_commands(&vars.as_extra_vars(), hook_plan, &worktree_path)?;
}
let _ = repo.set_switch_previous(new_previous.as_deref());
Ok((
SwitchResult::Created {
path: worktree_path,
created_branch,
base_branch,
base_worktree_path,
from_remote,
pr_number,
pr_url,
},
SwitchBranchInfo {
branch: Some(branch),
expected_path: None,
},
))
}
}
}
fn worktree_creation_error(
err: &anyhow::Error,
branch: String,
base_branch: Option<String>,
) -> GitError {
let (output, command) = Repository::extract_failed_command(err);
GitError::WorktreeCreationFailed {
branch,
base_branch,
error: output,
command,
}
}
fn format_last_fetch_ago(repo: &Repository) -> Option<String> {
let epoch = repo.last_fetch_epoch()?;
let relative = format_relative_time_short(epoch as i64);
if relative == "now" || relative == "future" {
Some("last fetched just now".to_string())
} else {
Some(format!("last fetched {relative} ago"))
}
}
#[derive(Serialize)]
struct SwitchJsonOutput {
action: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
branch: Option<String>,
path: PathBuf,
#[serde(skip_serializing_if = "Option::is_none")]
created_branch: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
base_branch: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
from_remote: Option<String>,
}
impl SwitchJsonOutput {
fn from_result(result: &SwitchResult, branch_info: &SwitchBranchInfo) -> Self {
let (action, path, created_branch, base_branch, from_remote) = match result {
SwitchResult::AlreadyAt(path) => ("already_at", path, None, None, None),
SwitchResult::Existing { path } => ("existing", path, None, None, None),
SwitchResult::Created {
path,
created_branch,
base_branch,
from_remote,
..
} => (
"created",
path,
Some(*created_branch),
base_branch.clone(),
from_remote.clone(),
),
};
Self {
action,
branch: branch_info.branch.clone(),
path: path.clone(),
created_branch,
base_branch,
from_remote,
}
}
}
fn emit_switch_json(
format: SwitchFormat,
result: &SwitchResult,
branch_info: &SwitchBranchInfo,
) -> anyhow::Result<()> {
if format != SwitchFormat::Json {
return Ok(());
}
let json = SwitchJsonOutput::from_result(result, branch_info);
let json = serde_json::to_string(&json).context("Failed to serialize to JSON")?;
println!("{json}");
Ok(())
}
pub struct SwitchOptions<'a> {
pub branch: &'a str,
pub create: bool,
pub base: Option<&'a str>,
pub execute: Option<&'a str>,
pub execute_args: &'a [String],
pub yes: bool,
pub clobber: bool,
pub change_dir: Option<bool>,
pub verify: bool,
pub format: crate::cli::SwitchFormat,
}
fn run_pre_switch_hooks(
repo: &Repository,
config: &UserConfig,
target_branch: &str,
yes: bool,
) -> anyhow::Result<()> {
let current_wt = repo.current_worktree();
let current_path = current_wt.path().to_path_buf();
let resolved_target = repo
.resolve_worktree_name(target_branch)
.unwrap_or_else(|_| target_branch.to_string());
let pre_ctx = CommandContext::new(repo, config, Some(&resolved_target), ¤t_path, yes);
let pre_switch_approved = approve_hooks(&pre_ctx, &[HookType::PreSwitch])?;
if pre_switch_approved {
let base_branch = current_wt.branch().ok().flatten().unwrap_or_default();
let dest_path = repo.worktree_for_branch(&resolved_target).ok().flatten();
let mut vars = TemplateVars::new()
.with_base(&base_branch, ¤t_path)
.with_target(&resolved_target);
if let Some(p) = dest_path.as_deref() {
vars = vars.with_target_worktree_path(p).with_active_worktree(p);
}
let extra_vars = vars.as_extra_vars();
execute_hook(
&pre_ctx,
HookType::PreSwitch,
&extra_vars,
FailureStrategy::FailFast,
crate::output::pre_hook_display_path(pre_ctx.worktree_path),
)?;
}
Ok(())
}
fn switch_post_hook_types(is_create: bool) -> &'static [HookType] {
if is_create {
&[
HookType::PreCreate,
HookType::PostCreate,
HookType::PostSwitch,
]
} else {
&[HookType::PostSwitch]
}
}
fn approve_switch_hooks(
repo: &Repository,
config: &UserConfig,
plan: &SwitchPlan,
yes: bool,
verify: bool,
) -> anyhow::Result<(bool, ApprovedHookPlan)> {
if !verify {
return Ok((false, ApprovedHookPlan::empty()));
}
let project_id = repo.project_identifier().ok();
let pid = project_id.as_deref();
let project_config = repo.load_project_config()?;
let mut builder = HookPlanBuilder::new();
builder.add(
plan.worktree_path(),
switch_post_hook_types(plan.is_create()),
project_config.as_ref(),
config,
pid,
);
match builder.finish().approve(pid, yes)? {
Some(approved) => Ok((true, approved)),
None => {
let on_decline = if plan.is_create() {
"Commands declined, continuing worktree creation without hooks"
} else {
"Commands declined, switching without hooks"
};
eprintln!("{}", info_message(on_decline));
Ok((false, ApprovedHookPlan::empty()))
}
}
}
fn spawn_switch_background_hooks(
config: &UserConfig,
result: &SwitchResult,
branch: Option<&str>,
yes: bool,
extra_vars: &[(&str, &str)],
hooks_display_path: Option<&Path>,
hook_plan: &ApprovedHookPlan,
) -> anyhow::Result<()> {
let hook_repo = Repository::at(result.path())?;
let ctx = CommandContext::new(&hook_repo, config, branch, result.path(), yes);
let mut announcer = HookAnnouncer::new(&hook_repo, config, false);
register_planned(
&mut announcer,
hook_plan,
result.path(),
&ctx,
HookType::PostSwitch,
extra_vars,
hooks_display_path,
)?;
if matches!(result, SwitchResult::Created { .. }) {
register_planned(
&mut announcer,
hook_plan,
result.path(),
&ctx,
HookType::PostCreate,
extra_vars,
hooks_display_path,
)?;
}
announcer.flush()
}
fn capture_switch_source(repo: &Repository, is_recovered: bool) -> (String, String) {
if is_recovered {
return (String::new(), String::new());
}
let source_branch = repo
.current_worktree()
.branch()
.ok()
.flatten()
.unwrap_or_default();
let source_path = repo
.current_worktree()
.root()
.ok()
.map(|p| worktrunk::path::to_posix_path(&p.to_string_lossy()))
.unwrap_or_default();
(source_branch, source_path)
}
pub(crate) struct SwitchPipeline<'a> {
pub repo: &'a Repository,
pub config: &'a mut UserConfig,
pub identifier: &'a str,
pub create: bool,
pub base: Option<&'a str>,
pub clobber: bool,
pub verify: bool,
pub yes: bool,
pub change_dir: bool,
pub format: SwitchFormat,
pub is_recovered: bool,
pub suggestion_ctx: Option<SwitchSuggestionCtx>,
pub capture_source: bool,
pub execute: Option<&'a str>,
pub execute_args: &'a [String],
pub shell_integration_binary: Option<&'a str>,
}
impl SwitchPipeline<'_> {
pub(crate) fn run(self) -> anyhow::Result<()> {
let Self {
repo,
config,
identifier,
create,
base,
clobber,
verify,
yes,
change_dir,
format,
is_recovered,
suggestion_ctx,
capture_source,
execute,
execute_args,
shell_integration_binary,
} = self;
offer_bare_repo_worktree_path_fix(repo, config)?;
if verify && !is_recovered {
run_pre_switch_hooks(repo, config, identifier, yes)?;
}
let (source_branch, source_path) = if capture_source {
capture_switch_source(repo, is_recovered)
} else {
(String::new(), String::new())
};
let plan = plan_switch(repo, identifier, create, base, clobber, config).map_err(|err| {
match suggestion_ctx {
Some(ref ctx) => match err.downcast::<GitError>() {
Ok(git_err) => GitError::WithSwitchSuggestion {
source: Box::new(git_err),
ctx: ctx.clone(),
}
.into(),
Err(err) => err,
},
None => err,
}
})?;
let (hooks_approved, hook_plan) = approve_switch_hooks(repo, config, &plan, yes, verify)?;
validate_switch_templates(repo, config, &plan, execute, execute_args, hooks_approved)?;
let (result, branch_info) =
execute_switch(repo, plan, config, yes, hooks_approved, &hook_plan)?;
emit_switch_json(format, &result, &branch_info)?;
if std::env::var_os("WORKTRUNK_FIRST_OUTPUT").is_some() {
return Ok(());
}
let branch_info = match &result {
SwitchResult::Existing { path } | SwitchResult::AlreadyAt(path) => {
let expected_path = branch_info
.branch
.as_deref()
.and_then(|b| path_mismatch(repo, b, path, config));
SwitchBranchInfo {
expected_path,
..branch_info
}
}
_ => branch_info,
};
let fallback_path = repo.repo_path()?.to_path_buf();
let cwd = std::env::current_dir().unwrap_or(fallback_path.clone());
let source_root = repo.current_worktree().root().unwrap_or(fallback_path);
let hooks_display_path =
handle_switch_output(&result, &branch_info, change_dir, Some(&source_root), &cwd)?;
if let Some(binary_name) = shell_integration_binary
&& change_dir
&& !is_shell_integration_active()
{
let skip_prompt = execute.is_some();
let _ = prompt_shell_integration(repo, config, binary_name, skip_prompt);
}
let template_vars =
TemplateVars::for_post_switch(&result, &branch_info, &source_branch, &source_path);
let extra_vars = template_vars.as_extra_vars();
if hooks_approved {
spawn_switch_background_hooks(
config,
&result,
branch_info.branch.as_deref(),
yes,
&extra_vars,
hooks_display_path.as_deref(),
&hook_plan,
)?;
}
if let Some(cmd) = execute {
let ctx = CommandContext::new(
repo,
config,
branch_info.branch.as_deref(),
result.path(),
yes,
);
let template_vars = build_hook_context(&ctx, &extra_vars, None)?;
let vars: HashMap<&str, &str> = template_vars
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
let escape_mode = directive_shell_escape_mode();
let expanded_cmd = expand_template(cmd, &vars, escape_mode, repo, "--execute command")?;
let full_cmd = if execute_args.is_empty() {
expanded_cmd
} else {
let expanded_args: Result<Vec<_>, _> = execute_args
.iter()
.map(|arg| {
expand_template(
arg,
&vars,
ShellEscapeMode::Literal,
repo,
"--execute argument",
)
})
.collect();
let escaped_args: Vec<_> = expanded_args?
.iter()
.map(|arg| shell_escape_for(escape_mode, arg))
.collect();
format!("{} {}", expanded_cmd, escaped_args.join(" "))
};
execute_user_command(&full_cmd, hooks_display_path.as_deref())?;
}
Ok(())
}
}
pub fn run_switch(
opts: SwitchOptions<'_>,
config: &mut UserConfig,
binary_name: &str,
) -> anyhow::Result<()> {
let SwitchOptions {
branch,
create,
base,
execute,
execute_args,
yes,
clobber,
change_dir: change_dir_flag,
verify,
format,
} = opts;
let (repo, is_recovered) = current_or_recover().context("Failed to switch worktree")?;
let change_dir = change_dir_flag.unwrap_or_else(|| {
let project_id = repo.project_identifier().ok();
config.resolved(project_id.as_deref()).switch.cd()
});
let suggestion_ctx = execute.map(|exec| {
let escaped = shell_escape::unix::escape(exec.into());
SwitchSuggestionCtx {
extra_flags: vec![format!("--execute={escaped}")],
trailing_args: execute_args.to_vec(),
}
});
SwitchPipeline {
repo: &repo,
config,
identifier: branch,
create,
base,
clobber,
verify,
yes,
change_dir,
format,
is_recovered,
suggestion_ctx,
capture_source: true,
execute,
execute_args,
shell_integration_binary: Some(binary_name),
}
.run()
}
fn is_clean_program_token(value: &str) -> bool {
let mut chars = value.chars();
let Some(first) = chars.next() else {
return false;
};
let first_ok = first.is_ascii_alphanumeric()
|| matches!(first, '.' | '_' | '/' | '@')
|| (cfg!(windows) && first == '\\');
first_ok
&& chars.all(|c| {
c.is_ascii_alphanumeric()
|| matches!(c, '.' | '_' | '/' | '@' | '+' | '-')
|| (cfg!(windows) && matches!(c, '\\' | ':'))
})
}
fn warn_if_execute_form_deprecated(cmd: &str, execute_args: &[String]) {
if is_clean_program_token(cmd) {
return;
}
let mode = directive_shell_escape_mode();
let (shell, flag) = match mode {
ShellEscapeMode::PowerShell => ("pwsh", "-Command"),
ShellEscapeMode::Fish => ("fish", "-c"),
_ => ("sh", "-c"),
};
let command_line = if execute_args.is_empty() {
cmd.to_string()
} else {
let escaped: Vec<String> = execute_args
.iter()
.map(|arg| shell_escape_for(mode, arg))
.collect();
format!("{} {}", cmd, escaped.join(" "))
};
let suggested = shell_escape_for(mode, &command_line);
eprintln!(
"{}",
warning_message(cformat!(
"<bold>--execute</> will change in a future release: it will run a single program, with arguments after <bold>--</>, not a shell command line"
))
);
eprintln!(
"{}",
hint_message(cformat!(
"Comment at <underline>https://github.com/max-sixty/worktrunk/issues/2860</> if the new single-program form would make a workflow worse; to run this command line unchanged, pass it to a shell: <underline>--execute {shell} -- {flag} {suggested}</>"
))
);
}
fn validate_switch_templates(
repo: &Repository,
config: &UserConfig,
plan: &SwitchPlan,
execute: Option<&str>,
execute_args: &[String],
hooks_approved: bool,
) -> anyhow::Result<()> {
if let Some(cmd) = execute {
validate_template(
cmd,
ValidationScope::SwitchExecute,
repo,
"--execute command",
)?;
for arg in execute_args {
validate_template(
arg,
ValidationScope::SwitchExecute,
repo,
"--execute argument",
)?;
}
warn_if_execute_form_deprecated(cmd, execute_args);
}
if !hooks_approved {
return Ok(());
}
let project_config = repo.load_project_config()?;
let user_hooks = config.hooks(repo.project_identifier().ok().as_deref());
for &hook_type in switch_post_hook_types(plan.is_create()) {
let (user_cfg, proj_cfg) = crate::commands::hooks::lookup_hook_configs(
&user_hooks,
project_config.as_ref(),
hook_type,
);
for (source, cfg) in [("user", user_cfg), ("project", proj_cfg)] {
if let Some(cfg) = cfg {
for cmd in cfg.commands() {
if template_references_var(&cmd.template, "vars") {
continue;
}
let name = match &cmd.name {
Some(n) => format!("{source} {hook_type}:{n}"),
None => format!("{source} {hook_type} hook"),
};
validate_template(
&cmd.template,
ValidationScope::Hook(hook_type),
repo,
&name,
)?;
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use worktrunk::testing::TestRepo;
#[test]
fn is_clean_program_token_matches_only_bare_names() {
for ok in [
"git",
"claude",
"node18",
"my-tool",
"tool.sh",
"/usr/bin/env",
"./build",
"_x",
"@scope/pkg",
] {
assert!(is_clean_program_token(ok), "expected clean token: {ok:?}");
}
for bad in [
"",
"npm run dev",
"a && b",
"echo $HOME",
"code {{ worktree_path }}",
"a|b",
"-flag",
"+x",
] {
assert!(!is_clean_program_token(bad), "expected non-token: {bad:?}");
}
assert_eq!(
is_clean_program_token(r"C:\Tools\foo.exe"),
cfg!(windows),
"Windows path classification should follow the target OS"
);
}
#[test]
fn capture_switch_source_returns_empty_when_recovered() {
let test = TestRepo::with_initial_commit();
let (branch, path) = capture_switch_source(&test.repo, true);
assert_eq!(branch, "");
assert_eq!(path, "");
}
#[test]
fn capture_switch_source_returns_branch_and_path_normally() {
let test = TestRepo::with_initial_commit();
let (branch, path) = capture_switch_source(&test.repo, false);
assert_eq!(branch, "main");
assert!(!path.is_empty(), "source_path should be the worktree root");
}
#[test]
fn choose_pr_provider_prefers_github_over_azure() {
let test = TestRepo::with_initial_commit();
test.run_git(&["remote", "add", "origin", "https://github.com/myorg/myrepo"]);
test.run_git(&[
"remote",
"add",
"azure",
"https://dev.azure.com/myorg/proj/_git/myrepo",
]);
assert_eq!(
choose_pr_provider(&test.repo).unwrap(),
PrProviderChoice::GitHub
);
}
#[test]
fn choose_pr_provider_azure_only() {
let test = TestRepo::with_initial_commit();
test.run_git(&[
"remote",
"add",
"origin",
"https://dev.azure.com/myorg/proj/_git/myrepo",
]);
assert_eq!(
choose_pr_provider(&test.repo).unwrap(),
PrProviderChoice::AzureDevOps
);
}
#[test]
fn choose_pr_provider_no_recognised_remote() {
let test = TestRepo::with_initial_commit();
assert_eq!(
choose_pr_provider(&test.repo).unwrap(),
PrProviderChoice::GitHub
);
}
#[test]
fn choose_pr_provider_forge_platform_override_wins() {
let test = TestRepo::with_initial_commit();
test.run_git(&["remote", "add", "origin", "https://github.com/myorg/myrepo"]);
test.run_git(&[
"remote",
"add",
"azure",
"https://dev.azure.com/myorg/proj/_git/myrepo",
]);
test.write_project_config("[forge]\nplatform = \"azure-devops\"\n");
assert_eq!(
choose_pr_provider(&test.repo).unwrap(),
PrProviderChoice::AzureDevOps
);
}
#[test]
fn choose_pr_provider_forge_platform_github_in_azure_only_repo() {
let test = TestRepo::with_initial_commit();
test.run_git(&[
"remote",
"add",
"origin",
"https://dev.azure.com/myorg/proj/_git/myrepo",
]);
test.write_project_config("[forge]\nplatform = \"github\"\n");
assert_eq!(
choose_pr_provider(&test.repo).unwrap(),
PrProviderChoice::GitHub
);
}
}