use std::collections::HashMap;
use std::path::{Path, PathBuf};
use anyhow::Context;
use serde::Serialize;
use worktrunk::HookType;
use worktrunk::config::{
UserConfig, ValidationScope, expand_template, template_references_var, validate_template,
};
use worktrunk::git::{GitError, Repository, SwitchSuggestionCtx, current_or_recover};
use crate::cli::SwitchFormat;
use super::command_approval::{approve_hooks, approve_or_skip};
use super::command_executor::FailureStrategy;
use super::command_executor::{CommandContext, build_hook_context};
use super::hooks::{execute_hook, prepare_background_pipelines, run_hooks_background};
use super::template_vars::TemplateVars;
use super::worktree::{
SwitchBranchInfo, SwitchPlan, SwitchResult, execute_switch, offer_bare_repo_worktree_path_fix,
path_mismatch, plan_switch,
};
use crate::output::{
execute_user_command, handle_switch_output, is_shell_integration_active,
prompt_shell_integration,
};
#[derive(Serialize)]
struct SwitchJsonOutput {
action: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
branch: Option<String>,
path: PathBuf,
#[serde(skip_serializing_if = "Option::is_none")]
created_branch: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
base_branch: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
from_remote: Option<String>,
}
impl SwitchJsonOutput {
fn from_result(result: &SwitchResult, branch_info: &SwitchBranchInfo) -> Self {
let (action, path, created_branch, base_branch, from_remote) = match result {
SwitchResult::AlreadyAt(path) => ("already_at", path, None, None, None),
SwitchResult::Existing { path } => ("existing", path, None, None, None),
SwitchResult::Created {
path,
created_branch,
base_branch,
from_remote,
..
} => (
"created",
path,
Some(*created_branch),
base_branch.clone(),
from_remote.clone(),
),
};
Self {
action,
branch: branch_info.branch.clone(),
path: path.clone(),
created_branch,
base_branch,
from_remote,
}
}
}
pub struct SwitchOptions<'a> {
pub branch: &'a str,
pub create: bool,
pub base: Option<&'a str>,
pub execute: Option<&'a str>,
pub execute_args: &'a [String],
pub yes: bool,
pub clobber: bool,
pub change_dir: Option<bool>,
pub verify: bool,
pub format: crate::cli::SwitchFormat,
}
pub(crate) fn run_pre_switch_hooks(
repo: &Repository,
config: &UserConfig,
target_branch: &str,
yes: bool,
) -> anyhow::Result<()> {
let current_wt = repo.current_worktree();
let current_path = current_wt.path().to_path_buf();
let resolved_target = repo
.resolve_worktree_name(target_branch)
.unwrap_or_else(|_| target_branch.to_string());
let pre_ctx = CommandContext::new(repo, config, Some(&resolved_target), ¤t_path, yes);
let pre_switch_approved = approve_hooks(&pre_ctx, &[HookType::PreSwitch])?;
if pre_switch_approved {
let base_branch = current_wt.branch().ok().flatten().unwrap_or_default();
let dest_path = repo.worktree_for_branch(&resolved_target).ok().flatten();
let mut vars = TemplateVars::new()
.with_base(&base_branch, ¤t_path)
.with_target(&resolved_target);
if let Some(p) = dest_path.as_deref() {
vars = vars.with_target_worktree_path(p).with_active_worktree(p);
}
let extra_vars = vars.as_extra_vars();
execute_hook(
&pre_ctx,
HookType::PreSwitch,
&extra_vars,
FailureStrategy::FailFast,
crate::output::pre_hook_display_path(pre_ctx.worktree_path),
)?;
}
Ok(())
}
fn switch_post_hook_types(is_create: bool) -> &'static [HookType] {
if is_create {
&[
HookType::PreStart,
HookType::PostStart,
HookType::PostSwitch,
]
} else {
&[HookType::PostSwitch]
}
}
pub(crate) fn approve_switch_hooks(
repo: &Repository,
config: &UserConfig,
plan: &SwitchPlan,
yes: bool,
verify: bool,
) -> anyhow::Result<bool> {
if !verify {
return Ok(false);
}
let ctx = CommandContext::new(repo, config, plan.branch(), plan.worktree_path(), yes);
let on_decline = if plan.is_create() {
"Commands declined, continuing worktree creation"
} else {
"Commands declined"
};
approve_or_skip(&ctx, switch_post_hook_types(plan.is_create()), on_decline)
}
pub(crate) fn spawn_switch_background_hooks(
repo: &Repository,
config: &UserConfig,
result: &SwitchResult,
branch: Option<&str>,
yes: bool,
extra_vars: &[(&str, &str)],
hooks_display_path: Option<&Path>,
) -> anyhow::Result<()> {
let ctx = CommandContext::new(repo, config, branch, result.path(), yes);
let mut pipelines =
prepare_background_pipelines(&ctx, HookType::PostSwitch, extra_vars, hooks_display_path)?;
if matches!(result, SwitchResult::Created { .. }) {
pipelines.extend(prepare_background_pipelines(
&ctx,
HookType::PostStart,
extra_vars,
hooks_display_path,
)?);
}
run_hooks_background(pipelines, false)
}
pub fn handle_switch(
opts: SwitchOptions<'_>,
config: &mut UserConfig,
binary_name: &str,
) -> anyhow::Result<()> {
let SwitchOptions {
branch,
create,
base,
execute,
execute_args,
yes,
clobber,
change_dir: change_dir_flag,
verify,
format,
} = opts;
let (repo, is_recovered) = current_or_recover().context("Failed to switch worktree")?;
let change_dir = change_dir_flag.unwrap_or_else(|| {
let project_id = repo.project_identifier().ok();
config.resolved(project_id.as_deref()).switch.cd()
});
let suggestion_ctx = execute.map(|exec| {
let escaped = shell_escape::escape(exec.into());
SwitchSuggestionCtx {
extra_flags: vec![format!("--execute={escaped}")],
trailing_args: execute_args.to_vec(),
}
});
if verify && !is_recovered {
run_pre_switch_hooks(&repo, config, branch, yes)?;
}
offer_bare_repo_worktree_path_fix(&repo, config)?;
let plan = plan_switch(&repo, branch, create, base, clobber, config).map_err(|err| {
match suggestion_ctx {
Some(ref ctx) => match err.downcast::<GitError>() {
Ok(git_err) => GitError::WithSwitchSuggestion {
source: Box::new(git_err),
ctx: ctx.clone(),
}
.into(),
Err(err) => err,
},
None => err,
}
})?;
let hooks_approved = approve_switch_hooks(&repo, config, &plan, yes, verify)?;
validate_switch_templates(&repo, config, &plan, execute, execute_args, hooks_approved)?;
let source_branch = repo
.current_worktree()
.branch()
.ok()
.flatten()
.unwrap_or_default();
let source_path = repo
.current_worktree()
.root()
.ok()
.map(|p| worktrunk::path::to_posix_path(&p.to_string_lossy()))
.unwrap_or_default();
let (result, branch_info) = execute_switch(&repo, plan, config, yes, hooks_approved)?;
if format == SwitchFormat::Json {
let json = SwitchJsonOutput::from_result(&result, &branch_info);
let json = serde_json::to_string(&json).context("Failed to serialize to JSON")?;
println!("{json}");
}
if std::env::var_os("WORKTRUNK_FIRST_OUTPUT").is_some() {
return Ok(());
}
let branch_info = match &result {
SwitchResult::Existing { path } | SwitchResult::AlreadyAt(path) => {
let expected_path = branch_info
.branch
.as_deref()
.and_then(|b| path_mismatch(&repo, b, path, config));
SwitchBranchInfo {
expected_path,
..branch_info
}
}
_ => branch_info,
};
let fallback_path = repo.repo_path()?.to_path_buf();
let cwd = std::env::current_dir().unwrap_or(fallback_path.clone());
let source_root = repo.current_worktree().root().unwrap_or(fallback_path);
let hooks_display_path =
handle_switch_output(&result, &branch_info, change_dir, Some(&source_root), &cwd)?;
if change_dir && !is_shell_integration_active() {
let skip_prompt = execute.is_some();
let _ = prompt_shell_integration(config, binary_name, skip_prompt);
}
let template_vars =
TemplateVars::for_post_switch(&result, &branch_info, &source_branch, &source_path);
let extra_vars = template_vars.as_extra_vars();
if hooks_approved {
spawn_switch_background_hooks(
&repo,
config,
&result,
branch_info.branch.as_deref(),
yes,
&extra_vars,
hooks_display_path.as_deref(),
)?;
}
if let Some(cmd) = execute {
let ctx = CommandContext::new(
&repo,
config,
branch_info.branch.as_deref(),
result.path(),
yes,
);
let template_vars = build_hook_context(&ctx, &extra_vars)?;
let vars: HashMap<&str, &str> = template_vars
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
let expanded_cmd = expand_template(cmd, &vars, true, &repo, "--execute command")?;
let full_cmd = if execute_args.is_empty() {
expanded_cmd
} else {
let expanded_args: Result<Vec<_>, _> = execute_args
.iter()
.map(|arg| expand_template(arg, &vars, false, &repo, "--execute argument"))
.collect();
let escaped_args: Vec<_> = expanded_args?
.iter()
.map(|arg| shell_escape::escape(arg.into()).into_owned())
.collect();
format!("{} {}", expanded_cmd, escaped_args.join(" "))
};
execute_user_command(&full_cmd, hooks_display_path.as_deref())?;
}
Ok(())
}
fn validate_switch_templates(
repo: &Repository,
config: &UserConfig,
plan: &SwitchPlan,
execute: Option<&str>,
execute_args: &[String],
hooks_approved: bool,
) -> anyhow::Result<()> {
if let Some(cmd) = execute {
validate_template(
cmd,
ValidationScope::SwitchExecute,
repo,
"--execute command",
)?;
for arg in execute_args {
validate_template(
arg,
ValidationScope::SwitchExecute,
repo,
"--execute argument",
)?;
}
}
if !hooks_approved {
return Ok(());
}
let project_config = repo.load_project_config()?;
let user_hooks = config.hooks(repo.project_identifier().ok().as_deref());
for &hook_type in switch_post_hook_types(plan.is_create()) {
let (user_cfg, proj_cfg) =
super::hooks::lookup_hook_configs(&user_hooks, project_config.as_ref(), hook_type);
for (source, cfg) in [("user", user_cfg), ("project", proj_cfg)] {
if let Some(cfg) = cfg {
for cmd in cfg.commands() {
if template_references_var(&cmd.template, "vars") {
continue;
}
let name = match &cmd.name {
Some(n) => format!("{source} {hook_type}:{n}"),
None => format!("{source} {hook_type} hook"),
};
validate_template(
&cmd.template,
ValidationScope::Hook(hook_type),
repo,
&name,
)?;
}
}
}
}
Ok(())
}