use std::collections::HashMap;
use std::path::{Path, PathBuf};
use crate::display::format_relative_time_short;
use anyhow::Context;
use color_print::cformat;
use dunce::canonicalize;
use serde::Serialize;
use worktrunk::HookType;
use worktrunk::config::{
UserConfig, ValidationScope, expand_template, template_references_var, validate_template,
};
use worktrunk::git::remote_ref::{
self, GitHubProvider, GitLabProvider, RemoteRefInfo, RemoteRefProvider,
};
use worktrunk::git::{
GitError, RefContext, RefType, Repository, SwitchSuggestionCtx, current_or_recover,
};
use worktrunk::styling::{
eprintln, format_with_gutter, hint_message, info_message, progress_message, suggest_command,
warning_message,
};
use super::resolve::{
compute_clobber_backup, compute_worktree_path, offer_bare_repo_worktree_path_fix, path_mismatch,
};
use super::types::{CreationMethod, SwitchBranchInfo, SwitchPlan, SwitchResult};
use crate::cli::SwitchFormat;
use crate::commands::command_approval::{approve_hooks, approve_or_skip};
use crate::commands::command_executor::FailureStrategy;
use crate::commands::command_executor::{CommandContext, build_hook_context};
use crate::commands::hooks::{HookAnnouncer, execute_hook};
use crate::commands::template_vars::TemplateVars;
use crate::output::{
execute_user_command, handle_switch_output, is_shell_integration_active,
prompt_shell_integration,
};
struct ResolvedTarget {
branch: String,
method: CreationMethod,
}
fn format_ref_context(ctx: &impl RefContext) -> String {
let mut status_parts = vec![format!("by @{}", ctx.author()), ctx.state().to_string()];
if ctx.draft() {
status_parts.push("draft".to_string());
}
status_parts.push(ctx.source_ref());
let status_line = status_parts.join(" · ");
cformat!(
"<bold>{}</> ({}{})\n{status_line} · <bright-black>{}</>",
ctx.title(),
ctx.ref_type().symbol(),
ctx.number(),
ctx.url()
)
}
fn resolve_remote_ref(
repo: &Repository,
provider: &dyn RemoteRefProvider,
number: u32,
create: bool,
base: Option<&str>,
) -> anyhow::Result<ResolvedTarget> {
let ref_type = provider.ref_type();
let symbol = ref_type.symbol();
if base.is_some() {
return Err(GitError::RefBaseConflict { ref_type, number }.into());
}
eprintln!(
"{}",
progress_message(cformat!("Fetching {} {symbol}{number}...", ref_type.name()))
);
let info = provider.fetch_info(number, repo)?;
eprintln!("{}", format_with_gutter(&format_ref_context(&info), None));
if create {
return Err(GitError::RefCreateConflict {
ref_type,
number,
branch: info.source_branch.clone(),
}
.into());
}
if info.is_cross_repo {
return resolve_fork_ref(repo, provider, number, &info);
}
resolve_same_repo_ref(repo, &info)
}
fn resolve_fork_ref(
repo: &Repository,
provider: &dyn RemoteRefProvider,
number: u32,
info: &RemoteRefInfo,
) -> anyhow::Result<ResolvedTarget> {
let ref_type = provider.ref_type();
let repo_root = repo.repo_path()?;
let local_branch = remote_ref::local_branch_name(info);
let expected_remote = match remote_ref::find_remote(repo, info) {
Ok(remote) => Some(remote),
Err(e) => {
log::debug!("Could not resolve remote for {}: {e:#}", ref_type.name());
None
}
};
if let Some(tracks_this) = remote_ref::branch_tracks_ref(
repo_root,
&local_branch,
provider,
number,
expected_remote.as_deref(),
) {
if tracks_this {
eprintln!(
"{}",
info_message(cformat!(
"Branch <bold>{local_branch}</> already configured for {}",
ref_type.display(number)
))
);
return Ok(ResolvedTarget {
branch: local_branch,
method: CreationMethod::Regular {
create_branch: false,
base_branch: None,
base_pr_upstream: None,
},
});
}
if let Some(prefixed) = info.prefixed_local_branch_name() {
if let Some(prefixed_tracks) = remote_ref::branch_tracks_ref(
repo_root,
&prefixed,
provider,
number,
expected_remote.as_deref(),
) {
if prefixed_tracks {
eprintln!(
"{}",
info_message(cformat!(
"Branch <bold>{prefixed}</> already configured for {}",
ref_type.display(number)
))
);
return Ok(ResolvedTarget {
branch: prefixed,
method: CreationMethod::Regular {
create_branch: false,
base_branch: None,
base_pr_upstream: None,
},
});
}
return Err(GitError::BranchTracksDifferentRef {
branch: prefixed,
ref_type,
number,
}
.into());
}
let remote = remote_ref::find_remote(repo, info)?;
return Ok(ResolvedTarget {
branch: prefixed,
method: CreationMethod::ForkRef {
ref_type,
number,
ref_path: provider.ref_path(number),
fork_push_url: None,
ref_url: info.url.clone(),
remote,
},
});
}
return Err(GitError::BranchTracksDifferentRef {
branch: local_branch,
ref_type,
number,
}
.into());
}
let (fork_push_url, remote) = match ref_type {
RefType::Pr => {
let remote = remote_ref::find_remote(repo, info)?;
(info.fork_push_url.clone(), remote)
}
RefType::Mr => {
let urls =
worktrunk::git::remote_ref::gitlab::fetch_gitlab_project_urls(info, repo_root)?;
let target_url = urls.target_url.ok_or_else(|| {
anyhow::anyhow!(
"{} is from a fork but glab didn't provide target project URL; \
upgrade glab or checkout the fork branch manually",
ref_type.display(number)
)
})?;
let remote = repo.find_remote_by_url(&target_url).ok_or_else(|| {
anyhow::anyhow!(
"No remote found for target project; \
add a remote pointing to {} (e.g., `git remote add upstream {}`)",
target_url,
target_url
)
})?;
if urls.fork_push_url.is_none() {
anyhow::bail!(
"{} is from a fork but glab didn't provide source project URL; \
upgrade glab or checkout the fork branch manually",
ref_type.display(number)
);
}
(urls.fork_push_url, remote)
}
};
Ok(ResolvedTarget {
branch: local_branch,
method: CreationMethod::ForkRef {
ref_type,
number,
ref_path: provider.ref_path(number),
fork_push_url,
ref_url: info.url.clone(),
remote,
},
})
}
fn resolve_same_repo_ref(
repo: &Repository,
info: &RemoteRefInfo,
) -> anyhow::Result<ResolvedTarget> {
fetch_same_repo_branch(repo, info)?;
Ok(ResolvedTarget {
branch: info.source_branch.clone(),
method: CreationMethod::Regular {
create_branch: false,
base_branch: None,
base_pr_upstream: None,
},
})
}
fn fetch_same_repo_branch(repo: &Repository, info: &RemoteRefInfo) -> anyhow::Result<()> {
let remote = remote_ref::find_remote(repo, info)?;
let branch = &info.source_branch;
eprintln!(
"{}",
progress_message(cformat!("Fetching <bold>{branch}</> from {remote}..."))
);
let refspec = format!("+refs/heads/{branch}:refs/remotes/{remote}/{branch}");
repo.run_command(&["fetch", "--", &remote, &refspec])
.with_context(|| cformat!("Failed to fetch branch <bold>{}</> from {}", branch, remote))?;
Ok(())
}
fn resolve_base_ref(
repo: &Repository,
base: &str,
) -> anyhow::Result<(String, Option<(String, String)>)> {
if let Some(suffix) = base.strip_prefix("pr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_remote_ref_as_base(repo, &GitHubProvider, number);
}
if let Some(suffix) = base.strip_prefix("mr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_remote_ref_as_base(repo, &GitLabProvider, number);
}
let resolved = repo.resolve_worktree_name(base)?;
if !repo.ref_exists(&resolved)? {
let remotes = repo.branch(&resolved).remotes()?;
if remotes.len() == 1 {
return Ok((format!("{}/{}", remotes[0], resolved), None));
}
}
Ok((resolved, None))
}
fn resolve_remote_ref_as_base(
repo: &Repository,
provider: &dyn RemoteRefProvider,
number: u32,
) -> anyhow::Result<(String, Option<(String, String)>)> {
let ref_type = provider.ref_type();
let symbol = ref_type.symbol();
eprintln!(
"{}",
progress_message(cformat!(
"Fetching base {} {symbol}{number}...",
ref_type.name()
))
);
let info = provider.fetch_info(number, repo)?;
eprintln!("{}", format_with_gutter(&format_ref_context(&info), None));
if !info.is_cross_repo {
fetch_same_repo_branch(repo, &info)?;
let remote = remote_ref::find_remote(repo, &info)?;
return Ok((
info.source_branch.clone(),
Some((remote, info.source_branch.clone())),
));
}
let remote = remote_ref::find_remote(repo, &info)?;
let display = ref_type.display(number);
repo.run_command(&["fetch", "--", &remote, &provider.tracking_ref(number)])
.with_context(|| cformat!("Failed to fetch <bold>{display}</> from {remote}"))?;
let sha = repo
.run_command(&["rev-parse", "FETCH_HEAD"])
.context("Failed to resolve FETCH_HEAD to a commit SHA")?
.trim()
.to_string();
Ok((sha, None))
}
fn resolve_switch_target(
repo: &Repository,
branch: &str,
create: bool,
base: Option<&str>,
) -> anyhow::Result<ResolvedTarget> {
if let Some(suffix) = branch.strip_prefix("pr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_remote_ref(repo, &GitHubProvider, number, create, base);
}
if let Some(suffix) = branch.strip_prefix("mr:")
&& let Ok(number) = suffix.parse::<u32>()
{
return resolve_remote_ref(repo, &GitLabProvider, number, create, base);
}
let mut resolved_branch = repo
.resolve_worktree_name(branch)
.context("Failed to resolve branch name")?;
if !create && let Some(local_name) = repo.strip_remote_prefix(&resolved_branch) {
resolved_branch = local_name;
}
let (resolved_base, base_pr_upstream) = if let Some(base_str) = base {
if !create {
eprintln!(
"{}",
warning_message("--base flag is only used with --create, ignoring")
);
(None, None)
} else {
let (resolved, upstream) = resolve_base_ref(repo, base_str)?;
if !repo.ref_exists(&resolved)? {
return Err(GitError::ReferenceNotFound {
reference: resolved,
}
.into());
}
(Some(resolved), upstream)
}
} else {
(None, None)
};
if create {
let branch_handle = repo.branch(&resolved_branch);
if branch_handle.exists_locally()? {
return Err(GitError::BranchAlreadyExists {
branch: resolved_branch,
}
.into());
}
let remotes = branch_handle.remotes()?;
if !remotes.is_empty() {
let remote_ref = format!("{}/{}", remotes[0], resolved_branch);
eprintln!(
"{}",
warning_message(cformat!(
"Branch <bold>{resolved_branch}</> exists on remote ({remote_ref}); creating new branch from base instead"
))
);
let remove_cmd = suggest_command("remove", &[&resolved_branch], &["--foreground"]);
let switch_cmd = suggest_command("switch", &[&resolved_branch], &[]);
eprintln!(
"{}",
hint_message(cformat!(
"To switch to the remote branch, delete this branch and run without <underline>--create</>: <underline>{remove_cmd} && {switch_cmd}</>"
))
);
}
}
let base_branch = if create {
resolved_base.or_else(|| {
repo.resolve_target_branch(None)
.ok()
.filter(|b| repo.branch(b).exists_locally().unwrap_or(false))
})
} else {
None
};
Ok(ResolvedTarget {
branch: resolved_branch,
method: CreationMethod::Regular {
create_branch: create,
base_branch,
base_pr_upstream,
},
})
}
fn validate_worktree_creation(
repo: &Repository,
branch: &str,
path: &Path,
clobber: bool,
method: &CreationMethod,
) -> anyhow::Result<Option<std::path::PathBuf>> {
if let CreationMethod::Regular {
create_branch: false,
..
} = method
&& !repo.branch(branch).exists()?
{
return Err(GitError::BranchNotFound {
branch: branch.to_string(),
show_create_hint: true,
last_fetch_ago: format_last_fetch_ago(repo),
pr_mr_platform: repo.detect_ref_type(),
}
.into());
}
if let Some((existing_path, occupant)) = repo.worktree_at_path(path)? {
if !existing_path.exists() {
let occupant_branch = occupant.unwrap_or_else(|| branch.to_string());
return Err(GitError::WorktreeMissing {
branch: occupant_branch,
}
.into());
}
return Err(GitError::WorktreePathOccupied {
branch: branch.to_string(),
path: path.to_path_buf(),
occupant,
}
.into());
}
let is_create = matches!(
method,
CreationMethod::Regular {
create_branch: true,
..
}
);
compute_clobber_backup(path, branch, clobber, is_create)
}
fn setup_fork_branch(
repo: &Repository,
branch: &str,
remote: &str,
remote_ref: &str,
fork_push_url: Option<&str>,
worktree_path: &Path,
label: &str,
) -> anyhow::Result<()> {
repo.run_command(&["branch", "--", branch, "FETCH_HEAD"])
.with_context(|| {
cformat!(
"Failed to create local branch <bold>{}</> from {}",
branch,
label
)
})?;
let branch_remote_key = format!("branch.{}.remote", branch);
let branch_merge_key = format!("branch.{}.merge", branch);
let merge_ref = format!("refs/{}", remote_ref);
repo.set_config(&branch_remote_key, remote)
.with_context(|| format!("Failed to configure branch.{}.remote", branch))?;
repo.set_config(&branch_merge_key, &merge_ref)
.with_context(|| format!("Failed to configure branch.{}.merge", branch))?;
if let Some(url) = fork_push_url {
let branch_push_remote_key = format!("branch.{}.pushRemote", branch);
repo.set_config(&branch_push_remote_key, url)
.with_context(|| format!("Failed to configure branch.{}.pushRemote", branch))?;
}
let worktree_path_str = worktree_path.to_string_lossy();
let git_args = ["worktree", "add", "--", worktree_path_str.as_ref(), branch];
repo.run_command_delayed_stream(
&git_args,
Repository::SLOW_OPERATION_DELAY_MS,
Some(
progress_message(cformat!("Creating worktree for <bold>{}</>...", branch)).to_string(),
),
)
.map_err(|e| worktree_creation_error(&e, branch.to_string(), None))?;
Ok(())
}
pub fn plan_switch(
repo: &Repository,
branch: &str,
create: bool,
base: Option<&str>,
clobber: bool,
config: &UserConfig,
) -> anyhow::Result<SwitchPlan> {
let new_previous = repo.current_worktree().branch().ok().flatten();
let target = resolve_switch_target(repo, branch, create, base)?;
match repo.worktree_for_branch(&target.branch)? {
Some(existing_path) if existing_path.exists() => {
return Ok(SwitchPlan::Existing {
path: canonicalize(&existing_path).unwrap_or(existing_path),
branch: Some(target.branch),
new_previous,
});
}
Some(_) => {
return Err(GitError::WorktreeMissing {
branch: target.branch,
}
.into());
}
None => {}
}
if !create {
let candidate = Path::new(branch);
let abs_path = if candidate.is_absolute() {
Some(candidate.to_path_buf())
} else if candidate.components().count() > 1 {
std::env::current_dir().ok().map(|cwd| cwd.join(candidate))
} else {
None
};
if let Some(abs_path) = abs_path
&& let Some((path, wt_branch)) = repo.worktree_at_path(&abs_path)?
{
let canonical = canonicalize(&path).unwrap_or_else(|_| path.clone());
return Ok(SwitchPlan::Existing {
path: canonical,
branch: wt_branch,
new_previous,
});
}
}
let expected_path = compute_worktree_path(repo, &target.branch, config)?;
let clobber_backup = validate_worktree_creation(
repo,
&target.branch,
&expected_path,
clobber,
&target.method,
)?;
Ok(SwitchPlan::Create {
branch: target.branch,
worktree_path: expected_path,
method: target.method,
clobber_backup,
new_previous,
})
}
pub fn execute_switch(
repo: &Repository,
plan: SwitchPlan,
config: &UserConfig,
force: bool,
run_hooks: bool,
) -> anyhow::Result<(SwitchResult, SwitchBranchInfo)> {
match plan {
SwitchPlan::Existing {
path,
branch,
new_previous,
} => {
let current_dir = std::env::current_dir()
.ok()
.and_then(|p| canonicalize(&p).ok());
let already_at_worktree = current_dir
.as_ref()
.map(|cur| cur == &path)
.unwrap_or(false);
if !already_at_worktree {
let _ = repo.set_switch_previous(new_previous.as_deref());
}
let result = if already_at_worktree {
SwitchResult::AlreadyAt(path)
} else {
SwitchResult::Existing { path }
};
Ok((
result,
SwitchBranchInfo {
branch,
expected_path: None,
},
))
}
SwitchPlan::Create {
branch,
worktree_path,
method,
clobber_backup,
new_previous,
} => {
if let Some(backup_path) = &clobber_backup {
let path_display = worktrunk::path::format_path_for_display(&worktree_path);
let backup_display = worktrunk::path::format_path_for_display(backup_path);
eprintln!(
"{}",
warning_message(cformat!(
"Moving <bold>{path_display}</> to <bold>{backup_display}</> (--clobber)"
))
);
std::fs::rename(&worktree_path, backup_path).with_context(|| {
format!("Failed to move {path_display} to {backup_display}")
})?;
}
let (created_branch, base_branch, from_remote) = match &method {
CreationMethod::Regular {
create_branch,
base_branch,
base_pr_upstream,
} => {
let branch_handle = repo.branch(&branch);
let local_branch_existed =
!create_branch && branch_handle.exists_locally().unwrap_or(false);
let worktree_path_str = worktree_path.to_string_lossy();
let mut args = vec!["worktree", "add", worktree_path_str.as_ref()];
let tracking_ref;
if *create_branch {
args.push("-b");
args.push(&branch);
if let Some(base) = base_branch {
args.push(base);
}
} else if !local_branch_existed {
let remotes = branch_handle.remotes().unwrap_or_default();
if remotes.len() == 1 {
tracking_ref = format!("{}/{}", remotes[0], branch);
args.extend(["-b", &branch, tracking_ref.as_str()]);
} else {
args.push(&branch);
}
} else {
args.push(&branch);
}
let progress_msg = Some(
progress_message(cformat!("Creating worktree for <bold>{}</>...", branch))
.to_string(),
);
if let Err(e) = repo.run_command_delayed_stream(
&args,
Repository::SLOW_OPERATION_DELAY_MS,
progress_msg,
) {
return Err(worktree_creation_error(
&e,
branch.clone(),
base_branch.clone(),
)
.into());
}
if *create_branch
&& let Some(base) = base_branch
&& repo.is_remote_tracking_branch(base)
{
branch_handle.unset_upstream()?;
}
if *create_branch
&& let Some((upstream_remote, upstream_branch)) = base_pr_upstream
{
repo.set_config(&format!("branch.{branch}.remote"), upstream_remote)?;
repo.set_config(
&format!("branch.{branch}.merge"),
&format!("refs/heads/{upstream_branch}"),
)?;
}
let from_remote = if !create_branch && !local_branch_existed {
branch_handle.upstream()?
} else {
None
};
(*create_branch, base_branch.clone(), from_remote)
}
CreationMethod::ForkRef {
ref_type,
number,
ref_path,
fork_push_url,
ref_url: _,
remote,
} => {
let label = ref_type.display(*number);
repo.run_command(&["fetch", "--", remote, ref_path])
.with_context(|| format!("Failed to fetch {} from {}", label, remote))?;
let setup_result = setup_fork_branch(
repo,
&branch,
remote,
ref_path,
fork_push_url.as_deref(),
&worktree_path,
&label,
);
if let Err(e) = setup_result {
let _ = repo.run_command(&["branch", "-D", "--", &branch]);
return Err(e);
}
if let Some(url) = fork_push_url {
eprintln!(
"{}",
info_message(cformat!("Push configured to fork: <underline>{url}</>"))
);
} else {
eprintln!(
"{}",
warning_message(cformat!(
"Using prefixed branch name <bold>{branch}</> due to name conflict"
))
);
eprintln!(
"{}",
hint_message(
"Push to fork is not supported with prefixed branches; feedback welcome at https://github.com/max-sixty/worktrunk/issues/714",
)
);
}
(false, None, Some(label))
}
};
let base_worktree_path = base_branch
.as_ref()
.and_then(|b| repo.worktree_for_branch(b).ok().flatten())
.map(|p| worktrunk::path::to_posix_path(&p.to_string_lossy()));
let (pr_number, pr_url) = match &method {
CreationMethod::ForkRef {
number, ref_url, ..
} => (Some(*number), Some(ref_url.clone())),
CreationMethod::Regular { .. } => (None, None),
};
if run_hooks {
let ctx = CommandContext::new(repo, config, Some(&branch), &worktree_path, force);
let mut vars = TemplateVars::new()
.with_target(&branch)
.with_target_worktree_path(&worktree_path);
match &method {
CreationMethod::Regular { base_branch, .. } => {
vars = vars
.with_base_strs(base_branch.as_deref(), base_worktree_path.as_deref());
}
CreationMethod::ForkRef {
number, ref_url, ..
} => {
vars = vars.with_pr(Some(*number), Some(ref_url));
}
}
ctx.execute_pre_start_commands(&vars.as_extra_vars())?;
}
let _ = repo.set_switch_previous(new_previous.as_deref());
Ok((
SwitchResult::Created {
path: worktree_path,
created_branch,
base_branch,
base_worktree_path,
from_remote,
pr_number,
pr_url,
},
SwitchBranchInfo {
branch: Some(branch),
expected_path: None,
},
))
}
}
}
fn worktree_creation_error(
err: &anyhow::Error,
branch: String,
base_branch: Option<String>,
) -> GitError {
let (output, command) = Repository::extract_failed_command(err);
GitError::WorktreeCreationFailed {
branch,
base_branch,
error: output,
command,
}
}
fn format_last_fetch_ago(repo: &Repository) -> Option<String> {
let epoch = repo.last_fetch_epoch()?;
let relative = format_relative_time_short(epoch as i64);
if relative == "now" || relative == "future" {
Some("last fetched just now".to_string())
} else {
Some(format!("last fetched {relative} ago"))
}
}
#[derive(Serialize)]
struct SwitchJsonOutput {
action: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
branch: Option<String>,
path: PathBuf,
#[serde(skip_serializing_if = "Option::is_none")]
created_branch: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
base_branch: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
from_remote: Option<String>,
}
impl SwitchJsonOutput {
fn from_result(result: &SwitchResult, branch_info: &SwitchBranchInfo) -> Self {
let (action, path, created_branch, base_branch, from_remote) = match result {
SwitchResult::AlreadyAt(path) => ("already_at", path, None, None, None),
SwitchResult::Existing { path } => ("existing", path, None, None, None),
SwitchResult::Created {
path,
created_branch,
base_branch,
from_remote,
..
} => (
"created",
path,
Some(*created_branch),
base_branch.clone(),
from_remote.clone(),
),
};
Self {
action,
branch: branch_info.branch.clone(),
path: path.clone(),
created_branch,
base_branch,
from_remote,
}
}
}
pub struct SwitchOptions<'a> {
pub branch: &'a str,
pub create: bool,
pub base: Option<&'a str>,
pub execute: Option<&'a str>,
pub execute_args: &'a [String],
pub yes: bool,
pub clobber: bool,
pub change_dir: Option<bool>,
pub verify: bool,
pub format: crate::cli::SwitchFormat,
}
pub(crate) fn run_pre_switch_hooks(
repo: &Repository,
config: &UserConfig,
target_branch: &str,
yes: bool,
) -> anyhow::Result<()> {
let current_wt = repo.current_worktree();
let current_path = current_wt.path().to_path_buf();
let resolved_target = repo
.resolve_worktree_name(target_branch)
.unwrap_or_else(|_| target_branch.to_string());
let pre_ctx = CommandContext::new(repo, config, Some(&resolved_target), ¤t_path, yes);
let pre_switch_approved = approve_hooks(&pre_ctx, &[HookType::PreSwitch])?;
if pre_switch_approved {
let base_branch = current_wt.branch().ok().flatten().unwrap_or_default();
let dest_path = repo.worktree_for_branch(&resolved_target).ok().flatten();
let mut vars = TemplateVars::new()
.with_base(&base_branch, ¤t_path)
.with_target(&resolved_target);
if let Some(p) = dest_path.as_deref() {
vars = vars.with_target_worktree_path(p).with_active_worktree(p);
}
let extra_vars = vars.as_extra_vars();
execute_hook(
&pre_ctx,
HookType::PreSwitch,
&extra_vars,
FailureStrategy::FailFast,
crate::output::pre_hook_display_path(pre_ctx.worktree_path),
)?;
}
Ok(())
}
fn switch_post_hook_types(is_create: bool) -> &'static [HookType] {
if is_create {
&[
HookType::PreStart,
HookType::PostStart,
HookType::PostSwitch,
]
} else {
&[HookType::PostSwitch]
}
}
pub(crate) fn approve_switch_hooks(
repo: &Repository,
config: &UserConfig,
plan: &SwitchPlan,
yes: bool,
verify: bool,
) -> anyhow::Result<bool> {
if !verify {
return Ok(false);
}
let ctx = CommandContext::new(repo, config, plan.branch(), plan.worktree_path(), yes);
let on_decline = if plan.is_create() {
"Commands declined, continuing worktree creation"
} else {
"Commands declined"
};
approve_or_skip(&ctx, switch_post_hook_types(plan.is_create()), on_decline)
}
pub(crate) fn spawn_switch_background_hooks(
repo: &Repository,
config: &UserConfig,
result: &SwitchResult,
branch: Option<&str>,
yes: bool,
extra_vars: &[(&str, &str)],
hooks_display_path: Option<&Path>,
) -> anyhow::Result<()> {
let ctx = CommandContext::new(repo, config, branch, result.path(), yes);
let mut announcer = HookAnnouncer::new(repo, config, false);
announcer.register(&ctx, HookType::PostSwitch, extra_vars, hooks_display_path)?;
if matches!(result, SwitchResult::Created { .. }) {
announcer.register(&ctx, HookType::PostStart, extra_vars, hooks_display_path)?;
}
announcer.flush()
}
pub fn run_switch(
opts: SwitchOptions<'_>,
config: &mut UserConfig,
binary_name: &str,
) -> anyhow::Result<()> {
let SwitchOptions {
branch,
create,
base,
execute,
execute_args,
yes,
clobber,
change_dir: change_dir_flag,
verify,
format,
} = opts;
let (repo, is_recovered) = current_or_recover().context("Failed to switch worktree")?;
let change_dir = change_dir_flag.unwrap_or_else(|| {
let project_id = repo.project_identifier().ok();
config.resolved(project_id.as_deref()).switch.cd()
});
let suggestion_ctx = execute.map(|exec| {
let escaped = shell_escape::escape(exec.into());
SwitchSuggestionCtx {
extra_flags: vec![format!("--execute={escaped}")],
trailing_args: execute_args.to_vec(),
}
});
if verify && !is_recovered {
run_pre_switch_hooks(&repo, config, branch, yes)?;
}
offer_bare_repo_worktree_path_fix(&repo, config)?;
let plan = plan_switch(&repo, branch, create, base, clobber, config).map_err(|err| {
match suggestion_ctx {
Some(ref ctx) => match err.downcast::<GitError>() {
Ok(git_err) => GitError::WithSwitchSuggestion {
source: Box::new(git_err),
ctx: ctx.clone(),
}
.into(),
Err(err) => err,
},
None => err,
}
})?;
let hooks_approved = approve_switch_hooks(&repo, config, &plan, yes, verify)?;
validate_switch_templates(&repo, config, &plan, execute, execute_args, hooks_approved)?;
let source_branch = repo
.current_worktree()
.branch()
.ok()
.flatten()
.unwrap_or_default();
let source_path = repo
.current_worktree()
.root()
.ok()
.map(|p| worktrunk::path::to_posix_path(&p.to_string_lossy()))
.unwrap_or_default();
let (result, branch_info) = execute_switch(&repo, plan, config, yes, hooks_approved)?;
if format == SwitchFormat::Json {
let json = SwitchJsonOutput::from_result(&result, &branch_info);
let json = serde_json::to_string(&json).context("Failed to serialize to JSON")?;
println!("{json}");
}
if std::env::var_os("WORKTRUNK_FIRST_OUTPUT").is_some() {
return Ok(());
}
let branch_info = match &result {
SwitchResult::Existing { path } | SwitchResult::AlreadyAt(path) => {
let expected_path = branch_info
.branch
.as_deref()
.and_then(|b| path_mismatch(&repo, b, path, config));
SwitchBranchInfo {
expected_path,
..branch_info
}
}
_ => branch_info,
};
let fallback_path = repo.repo_path()?.to_path_buf();
let cwd = std::env::current_dir().unwrap_or(fallback_path.clone());
let source_root = repo.current_worktree().root().unwrap_or(fallback_path);
let hooks_display_path =
handle_switch_output(&result, &branch_info, change_dir, Some(&source_root), &cwd)?;
if change_dir && !is_shell_integration_active() {
let skip_prompt = execute.is_some();
let _ = prompt_shell_integration(&repo, config, binary_name, skip_prompt);
}
let template_vars =
TemplateVars::for_post_switch(&result, &branch_info, &source_branch, &source_path);
let extra_vars = template_vars.as_extra_vars();
if hooks_approved {
spawn_switch_background_hooks(
&repo,
config,
&result,
branch_info.branch.as_deref(),
yes,
&extra_vars,
hooks_display_path.as_deref(),
)?;
}
if let Some(cmd) = execute {
let ctx = CommandContext::new(
&repo,
config,
branch_info.branch.as_deref(),
result.path(),
yes,
);
let template_vars = build_hook_context(&ctx, &extra_vars, None)?;
let vars: HashMap<&str, &str> = template_vars
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
let expanded_cmd = expand_template(cmd, &vars, true, &repo, "--execute command")?;
let full_cmd = if execute_args.is_empty() {
expanded_cmd
} else {
let expanded_args: Result<Vec<_>, _> = execute_args
.iter()
.map(|arg| expand_template(arg, &vars, false, &repo, "--execute argument"))
.collect();
let escaped_args: Vec<_> = expanded_args?
.iter()
.map(|arg| shell_escape::escape(arg.into()).into_owned())
.collect();
format!("{} {}", expanded_cmd, escaped_args.join(" "))
};
execute_user_command(&full_cmd, hooks_display_path.as_deref())?;
}
Ok(())
}
fn validate_switch_templates(
repo: &Repository,
config: &UserConfig,
plan: &SwitchPlan,
execute: Option<&str>,
execute_args: &[String],
hooks_approved: bool,
) -> anyhow::Result<()> {
if let Some(cmd) = execute {
validate_template(
cmd,
ValidationScope::SwitchExecute,
repo,
"--execute command",
)?;
for arg in execute_args {
validate_template(
arg,
ValidationScope::SwitchExecute,
repo,
"--execute argument",
)?;
}
}
if !hooks_approved {
return Ok(());
}
let project_config = repo.load_project_config()?;
let user_hooks = config.hooks(repo.project_identifier().ok().as_deref());
for &hook_type in switch_post_hook_types(plan.is_create()) {
let (user_cfg, proj_cfg) = crate::commands::hooks::lookup_hook_configs(
&user_hooks,
project_config.as_ref(),
hook_type,
);
for (source, cfg) in [("user", user_cfg), ("project", proj_cfg)] {
if let Some(cfg) = cfg {
for cmd in cfg.commands() {
if template_references_var(&cmd.template, "vars") {
continue;
}
let name = match &cmd.name {
Some(n) => format!("{source} {hook_type}:{n}"),
None => format!("{source} {hook_type} hook"),
};
validate_template(
&cmd.template,
ValidationScope::Hook(hook_type),
repo,
&name,
)?;
}
}
}
}
Ok(())
}