1use anyhow::{Result, bail};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(tag = "mode", rename_all = "snake_case")]
15pub enum TaskContext {
16 Staged {
18 include_unstaged: bool,
20 },
21
22 Commit {
24 commit_id: String,
26 },
27
28 Range {
30 from: String,
32 to: String,
34 },
35
36 Changelog {
38 from: String,
40 to: String,
42 version_name: Option<String>,
44 date: String,
46 },
47
48 Amend {
51 original_message: String,
53 },
54
55 #[default]
57 Discover,
58}
59
60impl TaskContext {
61 pub fn for_gen() -> Self {
64 Self::Staged {
65 include_unstaged: false,
66 }
67 }
68
69 pub fn for_amend(original_message: String) -> Self {
72 Self::Amend { original_message }
73 }
74
75 pub fn for_review(
82 commit: Option<String>,
83 from: Option<String>,
84 to: Option<String>,
85 include_unstaged: bool,
86 ) -> Result<Self> {
87 if from.is_some() && to.is_none() {
89 bail!("When using --from, you must also specify --to for branch comparison reviews");
90 }
91
92 if commit.is_some() && (from.is_some() || to.is_some()) {
94 bail!("Cannot use --commit with --from/--to. These are mutually exclusive options");
95 }
96
97 if include_unstaged && (from.is_some() || to.is_some()) {
99 bail!(
100 "Cannot use --include-unstaged with --from/--to. Branch comparison reviews don't include working directory changes"
101 );
102 }
103
104 Ok(match (commit, from, to) {
106 (Some(id), _, _) => Self::Commit { commit_id: id },
107 (_, Some(f), Some(t)) => Self::Range { from: f, to: t },
108 _ => Self::Staged { include_unstaged },
109 })
110 }
111
112 pub fn for_pr(from: Option<String>, to: Option<String>) -> Self {
120 match (from, to) {
121 (Some(f), Some(t)) => Self::Range { from: f, to: t },
122 (Some(f), None) => Self::Range {
123 from: f,
124 to: "HEAD".to_string(),
125 },
126 (None, Some(t)) => Self::Range {
127 from: "main".to_string(),
128 to: t,
129 },
130 (None, None) => Self::Range {
131 from: "main".to_string(),
132 to: "HEAD".to_string(),
133 },
134 }
135 }
136
137 pub fn for_changelog(
142 from: String,
143 to: Option<String>,
144 version_name: Option<String>,
145 date: Option<String>,
146 ) -> Self {
147 Self::Changelog {
148 from,
149 to: to.unwrap_or_else(|| "HEAD".to_string()),
150 version_name,
151 date: date.unwrap_or_else(|| chrono::Local::now().format("%Y-%m-%d").to_string()),
152 }
153 }
154
155 pub fn to_prompt_context(&self) -> String {
157 serde_json::to_string_pretty(self).unwrap_or_else(|_| format!("{self:?}"))
158 }
159
160 pub fn diff_hint(&self) -> String {
162 match self {
163 Self::Staged { include_unstaged } => {
164 if *include_unstaged {
165 "git_diff() for staged changes, then check unstaged files".to_string()
166 } else {
167 "git_diff() for staged changes".to_string()
168 }
169 }
170 Self::Commit { commit_id } => {
171 format!("git_diff(from=\"{commit_id}^1\", to=\"{commit_id}\")")
172 }
173 Self::Range { from, to } | Self::Changelog { from, to, .. } => {
174 format!("git_diff(from=\"{from}\", to=\"{to}\")")
175 }
176 Self::Amend { .. } => {
177 "git_diff(from=\"HEAD^1\") for combined amend diff (original commit + new staged changes)".to_string()
178 }
179 Self::Discover => "git_diff() to discover current changes".to_string(),
180 }
181 }
182
183 pub fn is_range(&self) -> bool {
185 matches!(self, Self::Range { .. })
186 }
187
188 pub fn includes_unstaged(&self) -> bool {
190 matches!(
191 self,
192 Self::Staged {
193 include_unstaged: true
194 }
195 )
196 }
197
198 pub fn is_amend(&self) -> bool {
200 matches!(self, Self::Amend { .. })
201 }
202
203 pub fn original_message(&self) -> Option<&str> {
205 match self {
206 Self::Amend { original_message } => Some(original_message),
207 _ => None,
208 }
209 }
210}
211
212impl std::fmt::Display for TaskContext {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 match self {
215 Self::Staged { include_unstaged } => {
216 if *include_unstaged {
217 write!(f, "staged and unstaged changes")
218 } else {
219 write!(f, "staged changes")
220 }
221 }
222 Self::Commit { commit_id } => write!(f, "commit {commit_id}"),
223 Self::Range { from, to } => write!(f, "changes from {from} to {to}"),
224 Self::Changelog {
225 from,
226 to,
227 version_name,
228 date,
229 } => {
230 let version_str = version_name
231 .as_ref()
232 .map_or_else(|| "unreleased".to_string(), |v| format!("v{v}"));
233 write!(f, "changelog {version_str} ({date}) from {from} to {to}")
234 }
235 Self::Amend { .. } => write!(f, "amending previous commit"),
236 Self::Discover => write!(f, "auto-discovered changes"),
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_for_gen() {
247 let ctx = TaskContext::for_gen();
248 assert!(matches!(
249 ctx,
250 TaskContext::Staged {
251 include_unstaged: false
252 }
253 ));
254 }
255
256 #[test]
257 fn test_review_staged_only() {
258 let ctx = TaskContext::for_review(None, None, None, false).expect("should succeed");
259 assert!(matches!(
260 ctx,
261 TaskContext::Staged {
262 include_unstaged: false
263 }
264 ));
265 }
266
267 #[test]
268 fn test_review_with_unstaged() {
269 let ctx = TaskContext::for_review(None, None, None, true).expect("should succeed");
270 assert!(matches!(
271 ctx,
272 TaskContext::Staged {
273 include_unstaged: true
274 }
275 ));
276 }
277
278 #[test]
279 fn test_review_single_commit() {
280 let ctx = TaskContext::for_review(Some("abc123".to_string()), None, None, false)
281 .expect("should succeed");
282 assert!(matches!(ctx, TaskContext::Commit { commit_id } if commit_id == "abc123"));
283 }
284
285 #[test]
286 fn test_review_range() {
287 let ctx = TaskContext::for_review(
288 None,
289 Some("main".to_string()),
290 Some("feature".to_string()),
291 false,
292 )
293 .expect("should succeed");
294 assert!(
295 matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "feature")
296 );
297 }
298
299 #[test]
300 fn test_review_from_without_to_fails() {
301 let result = TaskContext::for_review(None, Some("main".to_string()), None, false);
302 assert!(result.is_err());
303 assert!(result.unwrap_err().to_string().contains("--to"));
304 }
305
306 #[test]
307 fn test_review_commit_with_range_fails() {
308 let result = TaskContext::for_review(
310 Some("abc123".to_string()),
311 Some("main".to_string()),
312 Some("feature".to_string()),
313 false,
314 );
315 assert!(result.is_err());
316 assert!(
317 result
318 .unwrap_err()
319 .to_string()
320 .contains("mutually exclusive")
321 );
322 }
323
324 #[test]
325 fn test_review_unstaged_with_range_fails() {
326 let result = TaskContext::for_review(
327 None,
328 Some("main".to_string()),
329 Some("feature".to_string()),
330 true,
331 );
332 assert!(result.is_err());
333 assert!(result.unwrap_err().to_string().contains("include-unstaged"));
334 }
335
336 #[test]
337 fn test_pr_defaults() {
338 let ctx = TaskContext::for_pr(None, None);
339 assert!(matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "HEAD"));
340 }
341
342 #[test]
343 fn test_pr_from_only() {
344 let ctx = TaskContext::for_pr(Some("develop".to_string()), None);
345 assert!(
346 matches!(ctx, TaskContext::Range { from, to } if from == "develop" && to == "HEAD")
347 );
348 }
349
350 #[test]
351 fn test_changelog() {
352 let ctx = TaskContext::for_changelog(
353 "v1.0.0".to_string(),
354 None,
355 Some("1.1.0".to_string()),
356 Some("2025-01-15".to_string()),
357 );
358 assert!(matches!(
359 ctx,
360 TaskContext::Changelog { from, to, version_name, date }
361 if from == "v1.0.0" && to == "HEAD"
362 && version_name == Some("1.1.0".to_string())
363 && date == "2025-01-15"
364 ));
365 }
366
367 #[test]
368 fn test_changelog_default_date() {
369 let ctx = TaskContext::for_changelog("v1.0.0".to_string(), None, None, None);
370 if let TaskContext::Changelog { date, .. } = ctx {
372 assert!(!date.is_empty());
373 assert!(date.contains('-')); } else {
375 panic!("Expected Changelog variant");
376 }
377 }
378
379 #[test]
380 fn test_diff_hint() {
381 let staged = TaskContext::for_gen();
382 assert!(staged.diff_hint().contains("staged"));
383
384 let commit = TaskContext::Commit {
385 commit_id: "abc".to_string(),
386 };
387 assert!(commit.diff_hint().contains("abc^1"));
388
389 let range = TaskContext::Range {
390 from: "main".to_string(),
391 to: "dev".to_string(),
392 };
393 assert!(range.diff_hint().contains("main"));
394 assert!(range.diff_hint().contains("dev"));
395
396 let amend = TaskContext::for_amend("Fix bug".to_string());
397 assert!(amend.diff_hint().contains("HEAD^1"));
398 }
399
400 #[test]
401 fn test_amend_context() {
402 let ctx = TaskContext::for_amend("Initial commit message".to_string());
403 assert!(ctx.is_amend());
404 assert_eq!(ctx.original_message(), Some("Initial commit message"));
405 assert!(!ctx.is_range());
406 assert!(!ctx.includes_unstaged());
407 }
408
409 #[test]
410 fn test_amend_display() {
411 let ctx = TaskContext::for_amend("Fix bug".to_string());
412 assert_eq!(format!("{ctx}"), "amending previous commit");
413 }
414}