1use anyhow::{Result, bail};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(tag = "mode", rename_all = "snake_case")]
15pub enum TaskContext {
16 Staged {
18 include_unstaged: bool,
20 },
21
22 Commit {
24 commit_id: String,
26 },
27
28 Range {
30 from: String,
32 to: String,
34 },
35
36 Changelog {
38 from: String,
40 to: String,
42 version_name: Option<String>,
44 date: String,
46 },
47
48 Amend {
51 original_message: String,
53 },
54
55 #[default]
57 Discover,
58}
59
60impl TaskContext {
61 #[must_use]
64 pub fn for_gen() -> Self {
65 Self::Staged {
66 include_unstaged: false,
67 }
68 }
69
70 #[must_use]
73 pub fn for_amend(original_message: String) -> Self {
74 Self::Amend { original_message }
75 }
76
77 pub fn for_review(
89 commit: Option<String>,
90 from: Option<String>,
91 to: Option<String>,
92 include_unstaged: bool,
93 ) -> Result<Self> {
94 Self::for_review_with_base(commit, from, to, include_unstaged, "main")
95 }
96
97 pub fn for_review_with_base(
107 commit: Option<String>,
108 from: Option<String>,
109 to: Option<String>,
110 include_unstaged: bool,
111 default_base: &str,
112 ) -> Result<Self> {
113 if from.is_some() && to.is_none() {
115 bail!("When using --from, you must also specify --to for branch comparison reviews");
116 }
117
118 if commit.is_some() && (from.is_some() || to.is_some()) {
120 bail!("Cannot use --commit with --from/--to. These are mutually exclusive options");
121 }
122
123 if include_unstaged && (from.is_some() || to.is_some()) {
125 bail!(
126 "Cannot use --include-unstaged with --from/--to. Branch comparison reviews don't include working directory changes"
127 );
128 }
129
130 Ok(match (commit, from, to) {
132 (Some(id), _, _) => Self::Commit { commit_id: id },
133 (_, Some(f), Some(t)) => Self::Range { from: f, to: t },
134 (None, None, Some(t)) => Self::Range {
135 from: default_base.to_string(),
136 to: t,
137 },
138 _ => Self::Staged { include_unstaged },
139 })
140 }
141
142 #[must_use]
150 pub fn for_pr(from: Option<String>, to: Option<String>) -> Self {
151 Self::for_pr_with_base(from, to, "main")
152 }
153
154 #[must_use]
158 pub fn for_pr_with_base(from: Option<String>, to: Option<String>, default_base: &str) -> Self {
159 match (from, to) {
160 (Some(f), Some(t)) => Self::Range { from: f, to: t },
161 (Some(f), None) => Self::Range {
162 from: f,
163 to: "HEAD".to_string(),
164 },
165 (None, Some(t)) => Self::Range {
166 from: default_base.to_string(),
167 to: t,
168 },
169 (None, None) => Self::Range {
170 from: default_base.to_string(),
171 to: "HEAD".to_string(),
172 },
173 }
174 }
175
176 #[must_use]
181 pub fn for_changelog(
182 from: String,
183 to: Option<String>,
184 version_name: Option<String>,
185 date: Option<String>,
186 ) -> Self {
187 Self::Changelog {
188 from,
189 to: to.unwrap_or_else(|| "HEAD".to_string()),
190 version_name,
191 date: date.unwrap_or_else(|| chrono::Local::now().format("%Y-%m-%d").to_string()),
192 }
193 }
194
195 #[must_use]
197 pub fn to_prompt_context(&self) -> String {
198 serde_json::to_string_pretty(self).unwrap_or_else(|_| format!("{self:?}"))
199 }
200
201 #[must_use]
203 pub fn diff_hint(&self) -> String {
204 match self {
205 Self::Staged { include_unstaged } => {
206 if *include_unstaged {
207 "git_diff() for staged changes, then check unstaged files".to_string()
208 } else {
209 "git_diff() for staged changes".to_string()
210 }
211 }
212 Self::Commit { commit_id } => {
213 format!("git_diff(from=\"{commit_id}^1\", to=\"{commit_id}\")")
214 }
215 Self::Range { from, to } | Self::Changelog { from, to, .. } => {
216 format!("git_diff(from=\"{from}\", to=\"{to}\")")
217 }
218 Self::Amend { .. } => {
219 "git_diff(from=\"HEAD^1\") for combined amend diff (original commit + new staged changes)".to_string()
220 }
221 Self::Discover => "git_diff() to discover current changes".to_string(),
222 }
223 }
224
225 #[must_use]
227 pub fn is_range(&self) -> bool {
228 matches!(self, Self::Range { .. })
229 }
230
231 #[must_use]
233 pub fn includes_unstaged(&self) -> bool {
234 matches!(
235 self,
236 Self::Staged {
237 include_unstaged: true
238 }
239 )
240 }
241
242 #[must_use]
244 pub fn is_amend(&self) -> bool {
245 matches!(self, Self::Amend { .. })
246 }
247
248 #[must_use]
250 pub fn original_message(&self) -> Option<&str> {
251 match self {
252 Self::Amend { original_message } => Some(original_message),
253 _ => None,
254 }
255 }
256}
257
258impl std::fmt::Display for TaskContext {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 match self {
261 Self::Staged { include_unstaged } => {
262 if *include_unstaged {
263 write!(f, "staged and unstaged changes")
264 } else {
265 write!(f, "staged changes")
266 }
267 }
268 Self::Commit { commit_id } => write!(f, "commit {commit_id}"),
269 Self::Range { from, to } => write!(f, "changes from {from} to {to}"),
270 Self::Changelog {
271 from,
272 to,
273 version_name,
274 date,
275 } => {
276 let version_str = version_name
277 .as_ref()
278 .map_or_else(|| "unreleased".to_string(), |v| format!("v{v}"));
279 write!(f, "changelog {version_str} ({date}) from {from} to {to}")
280 }
281 Self::Amend { .. } => write!(f, "amending previous commit"),
282 Self::Discover => write!(f, "auto-discovered changes"),
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_for_gen() {
293 let ctx = TaskContext::for_gen();
294 assert!(matches!(
295 ctx,
296 TaskContext::Staged {
297 include_unstaged: false
298 }
299 ));
300 }
301
302 #[test]
303 fn test_review_staged_only() {
304 let ctx = TaskContext::for_review(None, None, None, false).expect("should succeed");
305 assert!(matches!(
306 ctx,
307 TaskContext::Staged {
308 include_unstaged: false
309 }
310 ));
311 }
312
313 #[test]
314 fn test_review_with_unstaged() {
315 let ctx = TaskContext::for_review(None, None, None, true).expect("should succeed");
316 assert!(matches!(
317 ctx,
318 TaskContext::Staged {
319 include_unstaged: true
320 }
321 ));
322 }
323
324 #[test]
325 fn test_review_single_commit() {
326 let ctx = TaskContext::for_review(Some("abc123".to_string()), None, None, false)
327 .expect("should succeed");
328 assert!(matches!(ctx, TaskContext::Commit { commit_id } if commit_id == "abc123"));
329 }
330
331 #[test]
332 fn test_review_range() {
333 let ctx = TaskContext::for_review(
334 None,
335 Some("main".to_string()),
336 Some("feature".to_string()),
337 false,
338 )
339 .expect("should succeed");
340 assert!(
341 matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "feature")
342 );
343 }
344
345 #[test]
346 fn test_review_to_only_defaults_from_explicit_base() {
347 let ctx = TaskContext::for_review_with_base(
348 None,
349 None,
350 Some("feature".to_string()),
351 false,
352 "trunk",
353 )
354 .expect("should succeed");
355 assert!(
356 matches!(ctx, TaskContext::Range { from, to } if from == "trunk" && to == "feature")
357 );
358 }
359
360 #[test]
361 fn test_review_from_without_to_fails() {
362 let result = TaskContext::for_review(None, Some("main".to_string()), None, false);
363 assert!(result.is_err());
364 assert!(
365 result
366 .expect_err("should be err")
367 .to_string()
368 .contains("--to")
369 );
370 }
371
372 #[test]
373 fn test_review_commit_with_range_fails() {
374 let result = TaskContext::for_review(
376 Some("abc123".to_string()),
377 Some("main".to_string()),
378 Some("feature".to_string()),
379 false,
380 );
381 assert!(result.is_err());
382 assert!(
383 result
384 .expect_err("should be err")
385 .to_string()
386 .contains("mutually exclusive")
387 );
388 }
389
390 #[test]
391 fn test_review_unstaged_with_range_fails() {
392 let result = TaskContext::for_review(
393 None,
394 Some("main".to_string()),
395 Some("feature".to_string()),
396 true,
397 );
398 assert!(result.is_err());
399 assert!(
400 result
401 .expect_err("should be err")
402 .to_string()
403 .contains("include-unstaged")
404 );
405 }
406
407 #[test]
408 fn test_pr_defaults() {
409 let ctx = TaskContext::for_pr_with_base(None, None, "trunk");
410 assert!(matches!(ctx, TaskContext::Range { from, to } if from == "trunk" && to == "HEAD"));
411 }
412
413 #[test]
414 fn test_pr_from_only() {
415 let ctx = TaskContext::for_pr(Some("develop".to_string()), None);
416 assert!(
417 matches!(ctx, TaskContext::Range { from, to } if from == "develop" && to == "HEAD")
418 );
419 }
420
421 #[test]
422 fn test_changelog() {
423 let ctx = TaskContext::for_changelog(
424 "v1.0.0".to_string(),
425 None,
426 Some("1.1.0".to_string()),
427 Some("2025-01-15".to_string()),
428 );
429 assert!(matches!(
430 ctx,
431 TaskContext::Changelog { from, to, version_name, date }
432 if from == "v1.0.0" && to == "HEAD"
433 && version_name == Some("1.1.0".to_string())
434 && date == "2025-01-15"
435 ));
436 }
437
438 #[test]
439 fn test_changelog_default_date() {
440 let ctx = TaskContext::for_changelog("v1.0.0".to_string(), None, None, None);
441 if let TaskContext::Changelog { date, .. } = ctx {
443 assert!(!date.is_empty());
444 assert!(date.contains('-')); } else {
446 panic!("Expected Changelog variant");
447 }
448 }
449
450 #[test]
451 fn test_diff_hint() {
452 let staged = TaskContext::for_gen();
453 assert!(staged.diff_hint().contains("staged"));
454
455 let commit = TaskContext::Commit {
456 commit_id: "abc".to_string(),
457 };
458 assert!(commit.diff_hint().contains("abc^1"));
459
460 let range = TaskContext::Range {
461 from: "main".to_string(),
462 to: "dev".to_string(),
463 };
464 assert!(range.diff_hint().contains("main"));
465 assert!(range.diff_hint().contains("dev"));
466
467 let amend = TaskContext::for_amend("Fix bug".to_string());
468 assert!(amend.diff_hint().contains("HEAD^1"));
469 }
470
471 #[test]
472 fn test_amend_context() {
473 let ctx = TaskContext::for_amend("Initial commit message".to_string());
474 assert!(ctx.is_amend());
475 assert_eq!(ctx.original_message(), Some("Initial commit message"));
476 assert!(!ctx.is_range());
477 assert!(!ctx.includes_unstaged());
478 }
479
480 #[test]
481 fn test_amend_display() {
482 let ctx = TaskContext::for_amend("Fix bug".to_string());
483 assert_eq!(format!("{ctx}"), "amending previous commit");
484 }
485}