1use anyhow::{Result, bail};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct PullRequestTemplateContext {
12 pub path: String,
14 pub body: String,
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
23#[serde(tag = "mode", rename_all = "snake_case")]
24pub enum TaskContext {
25 Staged {
27 include_unstaged: bool,
29 },
30
31 Commit {
33 commit_id: String,
35 },
36
37 Range {
39 from: String,
41 to: String,
43 },
44
45 PullRequest {
47 from: String,
49 to: String,
51 #[serde(skip)]
53 existing_body: Option<String>,
54 #[serde(skip)]
56 template: Option<PullRequestTemplateContext>,
57 },
58
59 Changelog {
61 from: String,
63 to: String,
65 version_name: Option<String>,
67 date: String,
69 },
70
71 Amend {
74 original_message: String,
76 },
77
78 #[default]
80 Discover,
81}
82
83impl TaskContext {
84 #[must_use]
87 pub fn for_gen() -> Self {
88 Self::Staged {
89 include_unstaged: false,
90 }
91 }
92
93 #[must_use]
96 pub fn for_amend(original_message: String) -> Self {
97 Self::Amend { original_message }
98 }
99
100 pub fn for_review(
112 commit: Option<String>,
113 from: Option<String>,
114 to: Option<String>,
115 include_unstaged: bool,
116 ) -> Result<Self> {
117 Self::for_review_with_base(commit, from, to, include_unstaged, "main")
118 }
119
120 pub fn for_review_with_base(
130 commit: Option<String>,
131 from: Option<String>,
132 to: Option<String>,
133 include_unstaged: bool,
134 default_base: &str,
135 ) -> Result<Self> {
136 if from.is_some() && to.is_none() {
138 bail!("When using --from, you must also specify --to for branch comparison reviews");
139 }
140
141 if commit.is_some() && (from.is_some() || to.is_some()) {
143 bail!("Cannot use --commit with --from/--to. These are mutually exclusive options");
144 }
145
146 if include_unstaged && (from.is_some() || to.is_some()) {
148 bail!(
149 "Cannot use --include-unstaged with --from/--to. Branch comparison reviews don't include working directory changes"
150 );
151 }
152
153 Ok(match (commit, from, to) {
155 (Some(id), _, _) => Self::Commit { commit_id: id },
156 (_, Some(f), Some(t)) => Self::Range { from: f, to: t },
157 (None, None, Some(t)) => Self::Range {
158 from: default_base.to_string(),
159 to: t,
160 },
161 _ => Self::Staged { include_unstaged },
162 })
163 }
164
165 #[must_use]
173 pub fn for_pr(from: Option<String>, to: Option<String>) -> Self {
174 Self::for_pr_with_base(from, to, "main")
175 }
176
177 #[must_use]
181 pub fn for_pr_with_base(from: Option<String>, to: Option<String>, default_base: &str) -> Self {
182 Self::for_pr_update_with_base(from, to, default_base, None, None)
183 }
184
185 #[must_use]
189 pub fn for_pr_update_with_base(
190 from: Option<String>,
191 to: Option<String>,
192 default_base: &str,
193 existing_body: Option<String>,
194 template: Option<PullRequestTemplateContext>,
195 ) -> Self {
196 let (from, to) = match (from, to) {
197 (Some(f), Some(t)) => (f, t),
198 (Some(f), None) => (f, "HEAD".to_string()),
199 (None, Some(t)) => (default_base.to_string(), t),
200 (None, None) => (default_base.to_string(), "HEAD".to_string()),
201 };
202
203 Self::PullRequest {
204 from,
205 to,
206 existing_body,
207 template,
208 }
209 }
210
211 #[must_use]
216 pub fn for_changelog(
217 from: String,
218 to: Option<String>,
219 version_name: Option<String>,
220 date: Option<String>,
221 ) -> Self {
222 Self::Changelog {
223 from,
224 to: to.unwrap_or_else(|| "HEAD".to_string()),
225 version_name,
226 date: date.unwrap_or_else(|| chrono::Local::now().format("%Y-%m-%d").to_string()),
227 }
228 }
229
230 #[must_use]
232 pub fn to_prompt_context(&self) -> String {
233 serde_json::to_string_pretty(self).unwrap_or_else(|_| format!("{self:?}"))
234 }
235
236 #[must_use]
238 pub fn diff_hint(&self) -> String {
239 match self {
240 Self::Staged { include_unstaged } => {
241 if *include_unstaged {
242 "git_diff() for staged changes, then check unstaged files".to_string()
243 } else {
244 "git_diff() for staged changes".to_string()
245 }
246 }
247 Self::Commit { commit_id } => {
248 format!("git_diff(from=\"{commit_id}^1\", to=\"{commit_id}\")")
249 }
250 Self::Range { from, to }
251 | Self::PullRequest { from, to, .. }
252 | Self::Changelog { from, to, .. } => {
253 format!("git_diff(from=\"{from}\", to=\"{to}\")")
254 }
255 Self::Amend { .. } => {
256 "git_diff(from=\"HEAD^1\") for combined amend diff (original commit + new staged changes)".to_string()
257 }
258 Self::Discover => "git_diff() to discover current changes".to_string(),
259 }
260 }
261
262 #[must_use]
264 pub fn is_range(&self) -> bool {
265 matches!(self, Self::Range { .. } | Self::PullRequest { .. })
266 }
267
268 #[must_use]
270 pub fn includes_unstaged(&self) -> bool {
271 matches!(
272 self,
273 Self::Staged {
274 include_unstaged: true
275 }
276 )
277 }
278
279 #[must_use]
281 pub fn is_amend(&self) -> bool {
282 matches!(self, Self::Amend { .. })
283 }
284
285 #[must_use]
287 pub fn original_message(&self) -> Option<&str> {
288 match self {
289 Self::Amend { original_message } => Some(original_message),
290 _ => None,
291 }
292 }
293
294 #[must_use]
296 pub fn existing_pull_request_body(&self) -> Option<&str> {
297 match self {
298 Self::PullRequest {
299 existing_body: Some(body),
300 ..
301 } => Some(body),
302 _ => None,
303 }
304 }
305
306 #[must_use]
308 pub fn pull_request_template(&self) -> Option<&PullRequestTemplateContext> {
309 match self {
310 Self::PullRequest {
311 template: Some(template),
312 ..
313 } => Some(template),
314 _ => None,
315 }
316 }
317}
318
319impl std::fmt::Display for TaskContext {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 match self {
322 Self::Staged { include_unstaged } => {
323 if *include_unstaged {
324 write!(f, "staged and unstaged changes")
325 } else {
326 write!(f, "staged changes")
327 }
328 }
329 Self::Commit { commit_id } => write!(f, "commit {commit_id}"),
330 Self::Range { from, to } => write!(f, "changes from {from} to {to}"),
331 Self::PullRequest { from, to, .. } => {
332 write!(f, "pull request changes from {from} to {to}")
333 }
334 Self::Changelog {
335 from,
336 to,
337 version_name,
338 date,
339 } => {
340 let version_str = version_name
341 .as_ref()
342 .map_or_else(|| "unreleased".to_string(), |v| format!("v{v}"));
343 write!(f, "changelog {version_str} ({date}) from {from} to {to}")
344 }
345 Self::Amend { .. } => write!(f, "amending previous commit"),
346 Self::Discover => write!(f, "auto-discovered changes"),
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_for_gen() {
357 let ctx = TaskContext::for_gen();
358 assert!(matches!(
359 ctx,
360 TaskContext::Staged {
361 include_unstaged: false
362 }
363 ));
364 }
365
366 #[test]
367 fn test_review_staged_only() {
368 let ctx = TaskContext::for_review(None, None, None, false).expect("should succeed");
369 assert!(matches!(
370 ctx,
371 TaskContext::Staged {
372 include_unstaged: false
373 }
374 ));
375 }
376
377 #[test]
378 fn test_review_with_unstaged() {
379 let ctx = TaskContext::for_review(None, None, None, true).expect("should succeed");
380 assert!(matches!(
381 ctx,
382 TaskContext::Staged {
383 include_unstaged: true
384 }
385 ));
386 }
387
388 #[test]
389 fn test_review_single_commit() {
390 let ctx = TaskContext::for_review(Some("abc123".to_string()), None, None, false)
391 .expect("should succeed");
392 assert!(matches!(ctx, TaskContext::Commit { commit_id } if commit_id == "abc123"));
393 }
394
395 #[test]
396 fn test_review_range() {
397 let ctx = TaskContext::for_review(
398 None,
399 Some("main".to_string()),
400 Some("feature".to_string()),
401 false,
402 )
403 .expect("should succeed");
404 assert!(
405 matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "feature")
406 );
407 }
408
409 #[test]
410 fn test_review_to_only_defaults_from_explicit_base() {
411 let ctx = TaskContext::for_review_with_base(
412 None,
413 None,
414 Some("feature".to_string()),
415 false,
416 "trunk",
417 )
418 .expect("should succeed");
419 assert!(
420 matches!(ctx, TaskContext::Range { from, to } if from == "trunk" && to == "feature")
421 );
422 }
423
424 #[test]
425 fn test_review_from_without_to_fails() {
426 let result = TaskContext::for_review(None, Some("main".to_string()), None, false);
427 assert!(result.is_err());
428 assert!(
429 result
430 .expect_err("should be err")
431 .to_string()
432 .contains("--to")
433 );
434 }
435
436 #[test]
437 fn test_review_commit_with_range_fails() {
438 let result = TaskContext::for_review(
440 Some("abc123".to_string()),
441 Some("main".to_string()),
442 Some("feature".to_string()),
443 false,
444 );
445 assert!(result.is_err());
446 assert!(
447 result
448 .expect_err("should be err")
449 .to_string()
450 .contains("mutually exclusive")
451 );
452 }
453
454 #[test]
455 fn test_review_unstaged_with_range_fails() {
456 let result = TaskContext::for_review(
457 None,
458 Some("main".to_string()),
459 Some("feature".to_string()),
460 true,
461 );
462 assert!(result.is_err());
463 assert!(
464 result
465 .expect_err("should be err")
466 .to_string()
467 .contains("include-unstaged")
468 );
469 }
470
471 #[test]
472 fn test_pr_defaults() {
473 let ctx = TaskContext::for_pr_with_base(None, None, "trunk");
474 assert!(
475 matches!(ctx, TaskContext::PullRequest { from, to, existing_body, template } if from == "trunk" && to == "HEAD" && existing_body.is_none() && template.is_none())
476 );
477 }
478
479 #[test]
480 fn test_pr_from_only() {
481 let ctx = TaskContext::for_pr(Some("develop".to_string()), None);
482 assert!(
483 matches!(ctx, TaskContext::PullRequest { from, to, existing_body, template } if from == "develop" && to == "HEAD" && existing_body.is_none() && template.is_none())
484 );
485 }
486
487 #[test]
488 fn test_pr_existing_body() {
489 let ctx = TaskContext::for_pr_update_with_base(
490 Some("main".to_string()),
491 Some("feature".to_string()),
492 "trunk",
493 Some("Existing body".to_string()),
494 None,
495 );
496 assert!(
497 matches!(ctx, TaskContext::PullRequest { from, to, existing_body, .. } if from == "main" && to == "feature" && existing_body == Some("Existing body".to_string()))
498 );
499 }
500
501 #[test]
502 fn test_changelog() {
503 let ctx = TaskContext::for_changelog(
504 "v1.0.0".to_string(),
505 None,
506 Some("1.1.0".to_string()),
507 Some("2025-01-15".to_string()),
508 );
509 assert!(matches!(
510 ctx,
511 TaskContext::Changelog { from, to, version_name, date }
512 if from == "v1.0.0" && to == "HEAD"
513 && version_name == Some("1.1.0".to_string())
514 && date == "2025-01-15"
515 ));
516 }
517
518 #[test]
519 fn test_changelog_default_date() {
520 let ctx = TaskContext::for_changelog("v1.0.0".to_string(), None, None, None);
521 if let TaskContext::Changelog { date, .. } = ctx {
523 assert!(!date.is_empty());
524 assert!(date.contains('-')); } else {
526 panic!("Expected Changelog variant");
527 }
528 }
529
530 #[test]
531 fn test_diff_hint() {
532 let staged = TaskContext::for_gen();
533 assert!(staged.diff_hint().contains("staged"));
534
535 let commit = TaskContext::Commit {
536 commit_id: "abc".to_string(),
537 };
538 assert!(commit.diff_hint().contains("abc^1"));
539
540 let range = TaskContext::Range {
541 from: "main".to_string(),
542 to: "dev".to_string(),
543 };
544 assert!(range.diff_hint().contains("main"));
545 assert!(range.diff_hint().contains("dev"));
546
547 let amend = TaskContext::for_amend("Fix bug".to_string());
548 assert!(amend.diff_hint().contains("HEAD^1"));
549 }
550
551 #[test]
552 fn test_amend_context() {
553 let ctx = TaskContext::for_amend("Initial commit message".to_string());
554 assert!(ctx.is_amend());
555 assert_eq!(ctx.original_message(), Some("Initial commit message"));
556 assert!(!ctx.is_range());
557 assert!(!ctx.includes_unstaged());
558 }
559
560 #[test]
561 fn test_amend_display() {
562 let ctx = TaskContext::for_amend("Fix bug".to_string());
563 assert_eq!(format!("{ctx}"), "amending previous commit");
564 }
565}