1use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7use std::process::Command;
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use console::style;
13
14use crate::error::{CwError, Result};
15
16pub const HOOK_EVENTS: &[&str] = &[
18 "worktree.pre_create",
19 "worktree.post_create",
20 "worktree.pre_delete",
21 "worktree.post_delete",
22 "merge.pre",
23 "merge.post",
24 "pr.pre",
25 "pr.post",
26 "resume.pre",
27 "resume.post",
28 "sync.pre",
29 "sync.post",
30];
31
32const LOCAL_CONFIG_FILE: &str = ".cwconfig.json";
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct HookEntry {
38 pub id: String,
39 pub command: String,
40 #[serde(default = "default_true")]
41 pub enabled: bool,
42 #[serde(default)]
43 pub description: String,
44}
45
46fn default_true() -> bool {
47 true
48}
49
50fn find_repo_root(start_path: Option<&Path>) -> Option<PathBuf> {
52 let start = start_path
53 .map(|p| p.to_path_buf())
54 .or_else(|| std::env::current_dir().ok())?;
55
56 let mut current = start.canonicalize().unwrap_or(start);
57 loop {
58 if current.join(".git").exists() {
59 return Some(current);
60 }
61 if !current.pop() {
62 break;
63 }
64 }
65 None
66}
67
68fn get_hooks_file_path(repo_root: Option<&Path>) -> Option<PathBuf> {
70 let root = if let Some(r) = repo_root {
71 r.to_path_buf()
72 } else {
73 find_repo_root(None)?
74 };
75 Some(root.join(LOCAL_CONFIG_FILE))
76}
77
78pub fn load_hooks_config(repo_root: Option<&Path>) -> HashMap<String, Vec<HookEntry>> {
80 let hooks_file = match get_hooks_file_path(repo_root) {
81 Some(p) if p.exists() => p,
82 _ => return HashMap::new(),
83 };
84
85 let content = match std::fs::read_to_string(&hooks_file) {
86 Ok(c) => c,
87 Err(_) => return HashMap::new(),
88 };
89
90 let data: Value = match serde_json::from_str(&content) {
91 Ok(v) => v,
92 Err(_) => return HashMap::new(),
93 };
94
95 let hooks_obj = match data.get("hooks") {
96 Some(Value::Object(m)) => m,
97 _ => return HashMap::new(),
98 };
99
100 let mut result = HashMap::new();
101 for (event, entries) in hooks_obj {
102 if let Ok(hooks) = serde_json::from_value::<Vec<HookEntry>>(entries.clone()) {
103 result.insert(event.clone(), hooks);
104 }
105 }
106 result
107}
108
109pub fn save_hooks_config(
111 hooks: &HashMap<String, Vec<HookEntry>>,
112 repo_root: Option<&Path>,
113) -> Result<()> {
114 let root = if let Some(r) = repo_root {
115 r.to_path_buf()
116 } else {
117 find_repo_root(None).ok_or_else(|| CwError::Hook("Not in a git repository".to_string()))?
118 };
119
120 let config_file = root.join(LOCAL_CONFIG_FILE);
121 let data = serde_json::json!({ "hooks": hooks });
122 let content = serde_json::to_string_pretty(&data)?;
123 std::fs::write(&config_file, content)?;
124 Ok(())
125}
126
127fn generate_hook_id(command: &str) -> String {
129 use std::collections::hash_map::DefaultHasher;
130 use std::hash::{Hash, Hasher};
131 let mut hasher = DefaultHasher::new();
132 command.hash(&mut hasher);
133 format!("hook-{:08x}", hasher.finish() as u32)
134}
135
136pub fn normalize_event_name(event: &str) -> String {
139 if HOOK_EVENTS.contains(&event) {
141 return event.to_string();
142 }
143
144 let normalized = event.replace('-', "_");
146 if HOOK_EVENTS.contains(&normalized.as_str()) {
147 return normalized;
148 }
149
150 let short_aliases = [
152 ("pre_create", "worktree.pre_create"),
153 ("post_create", "worktree.post_create"),
154 ("pre_delete", "worktree.pre_delete"),
155 ("post_delete", "worktree.post_delete"),
156 ("pre_merge", "merge.pre"),
157 ("post_merge", "merge.post"),
158 ("pre_pr", "pr.pre"),
159 ("post_pr", "pr.post"),
160 ("pre_resume", "resume.pre"),
161 ("post_resume", "resume.post"),
162 ("pre_sync", "sync.pre"),
163 ("post_sync", "sync.post"),
164 ];
165
166 let kebab_to_snake = event.replace('-', "_");
167 for (alias, canonical) in &short_aliases {
168 if kebab_to_snake == *alias {
169 return canonical.to_string();
170 }
171 }
172
173 event.to_string()
175}
176
177pub fn add_hook(
179 event: &str,
180 command: &str,
181 hook_id: Option<&str>,
182 description: Option<&str>,
183) -> Result<String> {
184 let event = normalize_event_name(event);
185 if !HOOK_EVENTS.contains(&event.as_str()) {
186 return Err(CwError::Hook(format!(
187 "Invalid hook event: {}.\n\nValid events:\n{}",
188 event,
189 HOOK_EVENTS
190 .iter()
191 .map(|e| format!(" {}", e))
192 .collect::<Vec<_>>()
193 .join("\n")
194 )));
195 }
196
197 let mut hooks = load_hooks_config(None);
198 let event_hooks = hooks.entry(event.clone()).or_default();
199
200 let id = hook_id
201 .map(|s| s.to_string())
202 .unwrap_or_else(|| generate_hook_id(command));
203
204 if event_hooks.iter().any(|h| h.id == id) {
206 return Err(CwError::Hook(format!(
207 "Hook with ID '{}' already exists for event '{}'",
208 id, event
209 )));
210 }
211
212 event_hooks.push(HookEntry {
213 id: id.clone(),
214 command: command.to_string(),
215 enabled: true,
216 description: description.unwrap_or("").to_string(),
217 });
218
219 save_hooks_config(&hooks, None)?;
220 Ok(id)
221}
222
223pub fn remove_hook(event: &str, hook_id: &str) -> Result<()> {
225 let mut hooks = load_hooks_config(None);
226 let event_hooks = hooks
227 .get_mut(event)
228 .ok_or_else(|| CwError::Hook(format!("No hooks found for event '{}'", event)))?;
229
230 let original_len = event_hooks.len();
231 event_hooks.retain(|h| h.id != hook_id);
232
233 if event_hooks.len() == original_len {
234 return Err(CwError::Hook(format!(
235 "Hook '{}' not found for event '{}'",
236 hook_id, event
237 )));
238 }
239
240 save_hooks_config(&hooks, None)?;
241 println!("* Removed hook '{}' from {}", hook_id, event);
242 Ok(())
243}
244
245pub fn set_hook_enabled(event: &str, hook_id: &str, enabled: bool) -> Result<()> {
247 let mut hooks = load_hooks_config(None);
248 let event_hooks = hooks
249 .get_mut(event)
250 .ok_or_else(|| CwError::Hook(format!("No hooks found for event '{}'", event)))?;
251
252 let hook = event_hooks
253 .iter_mut()
254 .find(|h| h.id == hook_id)
255 .ok_or_else(|| {
256 CwError::Hook(format!(
257 "Hook '{}' not found for event '{}'",
258 hook_id, event
259 ))
260 })?;
261
262 hook.enabled = enabled;
263 save_hooks_config(&hooks, None)?;
264
265 let action = if enabled { "Enabled" } else { "Disabled" };
266 println!("* {} hook '{}'", action, hook_id);
267 Ok(())
268}
269
270pub fn get_hooks(event: &str, repo_root: Option<&Path>) -> Vec<HookEntry> {
272 let hooks = load_hooks_config(repo_root);
273 hooks.get(event).cloned().unwrap_or_default()
274}
275
276pub fn run_hooks(
281 event: &str,
282 context: &HashMap<String, String>,
283 cwd: Option<&Path>,
284 repo_root: Option<&Path>,
285) -> Result<bool> {
286 let hooks = get_hooks(event, repo_root);
287 if hooks.is_empty() {
288 return Ok(true);
289 }
290
291 let enabled: Vec<&HookEntry> = hooks.iter().filter(|h| h.enabled).collect();
292 if enabled.is_empty() {
293 return Ok(true);
294 }
295
296 let is_pre_hook = event.contains(".pre");
297
298 eprintln!(
299 "{} Running {} hook(s) for {}...",
300 style("*").cyan().bold(),
301 enabled.len(),
302 style(event).yellow()
303 );
304
305 let mut env: HashMap<String, String> = std::env::vars().collect();
307 for (key, value) in context {
308 env.insert(format!("CW_{}", key.to_uppercase()), value.clone());
309 }
310
311 let mut all_succeeded = true;
312
313 for hook in enabled {
314 let desc_suffix = if hook.description.is_empty() {
315 String::new()
316 } else {
317 format!(" ({})", hook.description)
318 };
319 eprintln!(
320 " {} {}{}",
321 style("Running:").dim(),
322 style(&hook.id).bold(),
323 style(desc_suffix).dim()
324 );
325
326 let mut cmd = if cfg!(target_os = "windows") {
327 let mut c = Command::new("cmd");
328 c.args(["/C", &hook.command]);
329 c
330 } else {
331 let mut c = Command::new("sh");
332 c.args(["-c", &hook.command]);
333 c
334 };
335
336 cmd.envs(&env);
337 if let Some(dir) = cwd {
338 cmd.current_dir(dir);
339 }
340 cmd.stdout(std::process::Stdio::piped());
341 cmd.stderr(std::process::Stdio::piped());
342
343 match cmd.output() {
344 Ok(output) => {
345 if !output.status.success() {
346 all_succeeded = false;
347 let code = output.status.code().unwrap_or(-1);
348 eprintln!(
349 " {} Hook '{}' failed (exit code {})",
350 style("x").red().bold(),
351 style(&hook.id).bold(),
352 code
353 );
354
355 let stderr = String::from_utf8_lossy(&output.stderr);
356 for line in stderr.lines().take(5) {
357 eprintln!(" {}", style(line).dim());
358 }
359
360 if is_pre_hook {
361 return Err(CwError::Hook(format!(
362 "Pre-hook '{}' failed with exit code {}. Operation aborted.",
363 hook.id, code
364 )));
365 }
366 } else {
367 eprintln!(
368 " {} Hook '{}' completed",
369 style("*").green().bold(),
370 style(&hook.id).bold()
371 );
372 }
373 }
374 Err(e) => {
375 all_succeeded = false;
376 eprintln!(
377 " {} Hook '{}' failed: {}",
378 style("x").red().bold(),
379 style(&hook.id).bold(),
380 e
381 );
382 if is_pre_hook {
383 return Err(CwError::Hook(format!(
384 "Pre-hook '{}' failed to execute: {}",
385 hook.id, e
386 )));
387 }
388 }
389 }
390 }
391
392 if !all_succeeded && !is_pre_hook {
393 eprintln!(
394 "{} Some post-hooks failed. See output above.",
395 style("Warning:").yellow().bold()
396 );
397 }
398
399 Ok(all_succeeded)
400}