Skip to main content

git_iris/agents/
context.rs

1//! Task context for agent operations
2//!
3//! This module provides structured, validated context for agent tasks,
4//! replacing fragile string-based parameter passing.
5
6use anyhow::{Result, bail};
7use serde::{Deserialize, Serialize};
8
9/// Validated, structured context for agent tasks.
10///
11/// This enum represents the different modes of operation for code analysis,
12/// with validation built into the constructors.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(tag = "mode", rename_all = "snake_case")]
15pub enum TaskContext {
16    /// Analyze staged changes (optionally including unstaged)
17    Staged {
18        /// Whether to include unstaged changes in analysis
19        include_unstaged: bool,
20    },
21
22    /// Analyze a single commit
23    Commit {
24        /// The commit ID (hash, branch name, or commitish like HEAD~1)
25        commit_id: String,
26    },
27
28    /// Analyze a range of commits or branch comparison
29    Range {
30        /// Starting reference (exclusive)
31        from: String,
32        /// Ending reference (inclusive)
33        to: String,
34    },
35
36    /// Generate changelog or release notes with version metadata
37    Changelog {
38        /// Starting reference (exclusive)
39        from: String,
40        /// Ending reference (inclusive)
41        to: String,
42        /// Explicit version name (e.g., "1.2.0")
43        version_name: Option<String>,
44        /// Release date in YYYY-MM-DD format
45        date: String,
46    },
47
48    /// Amend the previous commit with staged changes
49    /// The agent sees the combined diff from HEAD^1 to staged state
50    Amend {
51        /// The original commit message being amended
52        original_message: String,
53    },
54
55    /// Let the agent discover context via tools (default for gen command)
56    #[default]
57    Discover,
58}
59
60impl TaskContext {
61    /// Create context for the gen (commit message) command.
62    /// Always uses staged changes only.
63    pub fn for_gen() -> Self {
64        Self::Staged {
65            include_unstaged: false,
66        }
67    }
68
69    /// Create context for amending the previous commit.
70    /// The agent will see the combined diff from HEAD^1 to staged state.
71    pub fn for_amend(original_message: String) -> Self {
72        Self::Amend { original_message }
73    }
74
75    /// Create context for the review command with full parameter validation.
76    ///
77    /// Validates:
78    /// - `--from` requires `--to` for range comparison
79    /// - `--commit` is mutually exclusive with `--from/--to`
80    /// - `--include-unstaged` is incompatible with range comparisons
81    pub fn for_review(
82        commit: Option<String>,
83        from: Option<String>,
84        to: Option<String>,
85        include_unstaged: bool,
86    ) -> Result<Self> {
87        // Validate: --from requires --to
88        if from.is_some() && to.is_none() {
89            bail!("When using --from, you must also specify --to for branch comparison reviews");
90        }
91
92        // Validate: --commit is mutually exclusive with --from/--to
93        if commit.is_some() && (from.is_some() || to.is_some()) {
94            bail!("Cannot use --commit with --from/--to. These are mutually exclusive options");
95        }
96
97        // Validate: --include-unstaged incompatible with range comparisons
98        if include_unstaged && (from.is_some() || to.is_some()) {
99            bail!(
100                "Cannot use --include-unstaged with --from/--to. Branch comparison reviews don't include working directory changes"
101            );
102        }
103
104        // Route to correct variant based on parameters
105        Ok(match (commit, from, to) {
106            (Some(id), _, _) => Self::Commit { commit_id: id },
107            (_, Some(f), Some(t)) => Self::Range { from: f, to: t },
108            _ => Self::Staged { include_unstaged },
109        })
110    }
111
112    /// Create context for the PR command.
113    ///
114    /// PR command is more flexible - all parameter combinations are valid:
115    /// - `from` + `to`: Explicit range/branch comparison
116    /// - `from` only: Compare `from..HEAD`
117    /// - `to` only: Compare `main..to`
118    /// - Neither: Compare `main..HEAD`
119    pub fn for_pr(from: Option<String>, to: Option<String>) -> Self {
120        match (from, to) {
121            (Some(f), Some(t)) => Self::Range { from: f, to: t },
122            (Some(f), None) => Self::Range {
123                from: f,
124                to: "HEAD".to_string(),
125            },
126            (None, Some(t)) => Self::Range {
127                from: "main".to_string(),
128                to: t,
129            },
130            (None, None) => Self::Range {
131                from: "main".to_string(),
132                to: "HEAD".to_string(),
133            },
134        }
135    }
136
137    /// Create context for changelog/release-notes commands.
138    ///
139    /// These always require a `from` reference; `to` defaults to HEAD.
140    /// Automatically sets today's date if not provided.
141    pub fn for_changelog(
142        from: String,
143        to: Option<String>,
144        version_name: Option<String>,
145        date: Option<String>,
146    ) -> Self {
147        Self::Changelog {
148            from,
149            to: to.unwrap_or_else(|| "HEAD".to_string()),
150            version_name,
151            date: date.unwrap_or_else(|| chrono::Local::now().format("%Y-%m-%d").to_string()),
152        }
153    }
154
155    /// Generate a human-readable prompt context string for the agent.
156    pub fn to_prompt_context(&self) -> String {
157        serde_json::to_string_pretty(self).unwrap_or_else(|_| format!("{self:?}"))
158    }
159
160    /// Generate a hint for which `git_diff` call the agent should make.
161    pub fn diff_hint(&self) -> String {
162        match self {
163            Self::Staged { include_unstaged } => {
164                if *include_unstaged {
165                    "git_diff() for staged changes, then check unstaged files".to_string()
166                } else {
167                    "git_diff() for staged changes".to_string()
168                }
169            }
170            Self::Commit { commit_id } => {
171                format!("git_diff(from=\"{commit_id}^1\", to=\"{commit_id}\")")
172            }
173            Self::Range { from, to } | Self::Changelog { from, to, .. } => {
174                format!("git_diff(from=\"{from}\", to=\"{to}\")")
175            }
176            Self::Amend { .. } => {
177                "git_diff(from=\"HEAD^1\") for combined amend diff (original commit + new staged changes)".to_string()
178            }
179            Self::Discover => "git_diff() to discover current changes".to_string(),
180        }
181    }
182
183    /// Check if this context represents a range comparison (vs staged/single commit)
184    pub fn is_range(&self) -> bool {
185        matches!(self, Self::Range { .. })
186    }
187
188    /// Check if this context involves unstaged changes
189    pub fn includes_unstaged(&self) -> bool {
190        matches!(
191            self,
192            Self::Staged {
193                include_unstaged: true
194            }
195        )
196    }
197
198    /// Check if this is an amend operation
199    pub fn is_amend(&self) -> bool {
200        matches!(self, Self::Amend { .. })
201    }
202
203    /// Get the original commit message if this is an amend context
204    pub fn original_message(&self) -> Option<&str> {
205        match self {
206            Self::Amend { original_message } => Some(original_message),
207            _ => None,
208        }
209    }
210}
211
212impl std::fmt::Display for TaskContext {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            Self::Staged { include_unstaged } => {
216                if *include_unstaged {
217                    write!(f, "staged and unstaged changes")
218                } else {
219                    write!(f, "staged changes")
220                }
221            }
222            Self::Commit { commit_id } => write!(f, "commit {commit_id}"),
223            Self::Range { from, to } => write!(f, "changes from {from} to {to}"),
224            Self::Changelog {
225                from,
226                to,
227                version_name,
228                date,
229            } => {
230                let version_str = version_name
231                    .as_ref()
232                    .map_or_else(|| "unreleased".to_string(), |v| format!("v{v}"));
233                write!(f, "changelog {version_str} ({date}) from {from} to {to}")
234            }
235            Self::Amend { .. } => write!(f, "amending previous commit"),
236            Self::Discover => write!(f, "auto-discovered changes"),
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_for_gen() {
247        let ctx = TaskContext::for_gen();
248        assert!(matches!(
249            ctx,
250            TaskContext::Staged {
251                include_unstaged: false
252            }
253        ));
254    }
255
256    #[test]
257    fn test_review_staged_only() {
258        let ctx = TaskContext::for_review(None, None, None, false).expect("should succeed");
259        assert!(matches!(
260            ctx,
261            TaskContext::Staged {
262                include_unstaged: false
263            }
264        ));
265    }
266
267    #[test]
268    fn test_review_with_unstaged() {
269        let ctx = TaskContext::for_review(None, None, None, true).expect("should succeed");
270        assert!(matches!(
271            ctx,
272            TaskContext::Staged {
273                include_unstaged: true
274            }
275        ));
276    }
277
278    #[test]
279    fn test_review_single_commit() {
280        let ctx = TaskContext::for_review(Some("abc123".to_string()), None, None, false)
281            .expect("should succeed");
282        assert!(matches!(ctx, TaskContext::Commit { commit_id } if commit_id == "abc123"));
283    }
284
285    #[test]
286    fn test_review_range() {
287        let ctx = TaskContext::for_review(
288            None,
289            Some("main".to_string()),
290            Some("feature".to_string()),
291            false,
292        )
293        .expect("should succeed");
294        assert!(
295            matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "feature")
296        );
297    }
298
299    #[test]
300    fn test_review_from_without_to_fails() {
301        let result = TaskContext::for_review(None, Some("main".to_string()), None, false);
302        assert!(result.is_err());
303        assert!(
304            result
305                .expect_err("should be err")
306                .to_string()
307                .contains("--to")
308        );
309    }
310
311    #[test]
312    fn test_review_commit_with_range_fails() {
313        // commit + from + to should fail as mutually exclusive
314        let result = TaskContext::for_review(
315            Some("abc123".to_string()),
316            Some("main".to_string()),
317            Some("feature".to_string()),
318            false,
319        );
320        assert!(result.is_err());
321        assert!(
322            result
323                .expect_err("should be err")
324                .to_string()
325                .contains("mutually exclusive")
326        );
327    }
328
329    #[test]
330    fn test_review_unstaged_with_range_fails() {
331        let result = TaskContext::for_review(
332            None,
333            Some("main".to_string()),
334            Some("feature".to_string()),
335            true,
336        );
337        assert!(result.is_err());
338        assert!(
339            result
340                .expect_err("should be err")
341                .to_string()
342                .contains("include-unstaged")
343        );
344    }
345
346    #[test]
347    fn test_pr_defaults() {
348        let ctx = TaskContext::for_pr(None, None);
349        assert!(matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "HEAD"));
350    }
351
352    #[test]
353    fn test_pr_from_only() {
354        let ctx = TaskContext::for_pr(Some("develop".to_string()), None);
355        assert!(
356            matches!(ctx, TaskContext::Range { from, to } if from == "develop" && to == "HEAD")
357        );
358    }
359
360    #[test]
361    fn test_changelog() {
362        let ctx = TaskContext::for_changelog(
363            "v1.0.0".to_string(),
364            None,
365            Some("1.1.0".to_string()),
366            Some("2025-01-15".to_string()),
367        );
368        assert!(matches!(
369            ctx,
370            TaskContext::Changelog { from, to, version_name, date }
371                if from == "v1.0.0" && to == "HEAD"
372                && version_name == Some("1.1.0".to_string())
373                && date == "2025-01-15"
374        ));
375    }
376
377    #[test]
378    fn test_changelog_default_date() {
379        let ctx = TaskContext::for_changelog("v1.0.0".to_string(), None, None, None);
380        // Should use today's date
381        if let TaskContext::Changelog { date, .. } = ctx {
382            assert!(!date.is_empty());
383            assert!(date.contains('-')); // YYYY-MM-DD format
384        } else {
385            panic!("Expected Changelog variant");
386        }
387    }
388
389    #[test]
390    fn test_diff_hint() {
391        let staged = TaskContext::for_gen();
392        assert!(staged.diff_hint().contains("staged"));
393
394        let commit = TaskContext::Commit {
395            commit_id: "abc".to_string(),
396        };
397        assert!(commit.diff_hint().contains("abc^1"));
398
399        let range = TaskContext::Range {
400            from: "main".to_string(),
401            to: "dev".to_string(),
402        };
403        assert!(range.diff_hint().contains("main"));
404        assert!(range.diff_hint().contains("dev"));
405
406        let amend = TaskContext::for_amend("Fix bug".to_string());
407        assert!(amend.diff_hint().contains("HEAD^1"));
408    }
409
410    #[test]
411    fn test_amend_context() {
412        let ctx = TaskContext::for_amend("Initial commit message".to_string());
413        assert!(ctx.is_amend());
414        assert_eq!(ctx.original_message(), Some("Initial commit message"));
415        assert!(!ctx.is_range());
416        assert!(!ctx.includes_unstaged());
417    }
418
419    #[test]
420    fn test_amend_display() {
421        let ctx = TaskContext::for_amend("Fix bug".to_string());
422        assert_eq!(format!("{ctx}"), "amending previous commit");
423    }
424}