git_iris/agents/
context.rs

1//! Task context for agent operations
2//!
3//! This module provides structured, validated context for agent tasks,
4//! replacing fragile string-based parameter passing.
5
6use anyhow::{Result, bail};
7use serde::{Deserialize, Serialize};
8
9/// Validated, structured context for agent tasks.
10///
11/// This enum represents the different modes of operation for code analysis,
12/// with validation built into the constructors.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(tag = "mode", rename_all = "snake_case")]
15pub enum TaskContext {
16    /// Analyze staged changes (optionally including unstaged)
17    Staged {
18        /// Whether to include unstaged changes in analysis
19        include_unstaged: bool,
20    },
21
22    /// Analyze a single commit
23    Commit {
24        /// The commit ID (hash, branch name, or commitish like HEAD~1)
25        commit_id: String,
26    },
27
28    /// Analyze a range of commits or branch comparison
29    Range {
30        /// Starting reference (exclusive)
31        from: String,
32        /// Ending reference (inclusive)
33        to: String,
34    },
35
36    /// Generate changelog or release notes with version metadata
37    Changelog {
38        /// Starting reference (exclusive)
39        from: String,
40        /// Ending reference (inclusive)
41        to: String,
42        /// Explicit version name (e.g., "1.2.0")
43        version_name: Option<String>,
44        /// Release date in YYYY-MM-DD format
45        date: String,
46    },
47
48    /// Amend the previous commit with staged changes
49    /// The agent sees the combined diff from HEAD^1 to staged state
50    Amend {
51        /// The original commit message being amended
52        original_message: String,
53    },
54
55    /// Let the agent discover context via tools (default for gen command)
56    #[default]
57    Discover,
58}
59
60impl TaskContext {
61    /// Create context for the gen (commit message) command.
62    /// Always uses staged changes only.
63    pub fn for_gen() -> Self {
64        Self::Staged {
65            include_unstaged: false,
66        }
67    }
68
69    /// Create context for amending the previous commit.
70    /// The agent will see the combined diff from HEAD^1 to staged state.
71    pub fn for_amend(original_message: String) -> Self {
72        Self::Amend { original_message }
73    }
74
75    /// Create context for the review command with full parameter validation.
76    ///
77    /// Validates:
78    /// - `--from` requires `--to` for range comparison
79    /// - `--commit` is mutually exclusive with `--from/--to`
80    /// - `--include-unstaged` is incompatible with range comparisons
81    pub fn for_review(
82        commit: Option<String>,
83        from: Option<String>,
84        to: Option<String>,
85        include_unstaged: bool,
86    ) -> Result<Self> {
87        // Validate: --from requires --to
88        if from.is_some() && to.is_none() {
89            bail!("When using --from, you must also specify --to for branch comparison reviews");
90        }
91
92        // Validate: --commit is mutually exclusive with --from/--to
93        if commit.is_some() && (from.is_some() || to.is_some()) {
94            bail!("Cannot use --commit with --from/--to. These are mutually exclusive options");
95        }
96
97        // Validate: --include-unstaged incompatible with range comparisons
98        if include_unstaged && (from.is_some() || to.is_some()) {
99            bail!(
100                "Cannot use --include-unstaged with --from/--to. Branch comparison reviews don't include working directory changes"
101            );
102        }
103
104        // Route to correct variant based on parameters
105        Ok(match (commit, from, to) {
106            (Some(id), _, _) => Self::Commit { commit_id: id },
107            (_, Some(f), Some(t)) => Self::Range { from: f, to: t },
108            _ => Self::Staged { include_unstaged },
109        })
110    }
111
112    /// Create context for the PR command.
113    ///
114    /// PR command is more flexible - all parameter combinations are valid:
115    /// - `from` + `to`: Explicit range/branch comparison
116    /// - `from` only: Compare `from..HEAD`
117    /// - `to` only: Compare `main..to`
118    /// - Neither: Compare `main..HEAD`
119    pub fn for_pr(from: Option<String>, to: Option<String>) -> Self {
120        match (from, to) {
121            (Some(f), Some(t)) => Self::Range { from: f, to: t },
122            (Some(f), None) => Self::Range {
123                from: f,
124                to: "HEAD".to_string(),
125            },
126            (None, Some(t)) => Self::Range {
127                from: "main".to_string(),
128                to: t,
129            },
130            (None, None) => Self::Range {
131                from: "main".to_string(),
132                to: "HEAD".to_string(),
133            },
134        }
135    }
136
137    /// Create context for changelog/release-notes commands.
138    ///
139    /// These always require a `from` reference; `to` defaults to HEAD.
140    /// Automatically sets today's date if not provided.
141    pub fn for_changelog(
142        from: String,
143        to: Option<String>,
144        version_name: Option<String>,
145        date: Option<String>,
146    ) -> Self {
147        Self::Changelog {
148            from,
149            to: to.unwrap_or_else(|| "HEAD".to_string()),
150            version_name,
151            date: date.unwrap_or_else(|| chrono::Local::now().format("%Y-%m-%d").to_string()),
152        }
153    }
154
155    /// Generate a human-readable prompt context string for the agent.
156    pub fn to_prompt_context(&self) -> String {
157        serde_json::to_string_pretty(self).unwrap_or_else(|_| format!("{self:?}"))
158    }
159
160    /// Generate a hint for which `git_diff` call the agent should make.
161    pub fn diff_hint(&self) -> String {
162        match self {
163            Self::Staged { include_unstaged } => {
164                if *include_unstaged {
165                    "git_diff() for staged changes, then check unstaged files".to_string()
166                } else {
167                    "git_diff() for staged changes".to_string()
168                }
169            }
170            Self::Commit { commit_id } => {
171                format!("git_diff(from=\"{commit_id}^1\", to=\"{commit_id}\")")
172            }
173            Self::Range { from, to } | Self::Changelog { from, to, .. } => {
174                format!("git_diff(from=\"{from}\", to=\"{to}\")")
175            }
176            Self::Amend { .. } => {
177                "git_diff(from=\"HEAD^1\") for combined amend diff (original commit + new staged changes)".to_string()
178            }
179            Self::Discover => "git_diff() to discover current changes".to_string(),
180        }
181    }
182
183    /// Check if this context represents a range comparison (vs staged/single commit)
184    pub fn is_range(&self) -> bool {
185        matches!(self, Self::Range { .. })
186    }
187
188    /// Check if this context involves unstaged changes
189    pub fn includes_unstaged(&self) -> bool {
190        matches!(
191            self,
192            Self::Staged {
193                include_unstaged: true
194            }
195        )
196    }
197
198    /// Check if this is an amend operation
199    pub fn is_amend(&self) -> bool {
200        matches!(self, Self::Amend { .. })
201    }
202
203    /// Get the original commit message if this is an amend context
204    pub fn original_message(&self) -> Option<&str> {
205        match self {
206            Self::Amend { original_message } => Some(original_message),
207            _ => None,
208        }
209    }
210}
211
212impl std::fmt::Display for TaskContext {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            Self::Staged { include_unstaged } => {
216                if *include_unstaged {
217                    write!(f, "staged and unstaged changes")
218                } else {
219                    write!(f, "staged changes")
220                }
221            }
222            Self::Commit { commit_id } => write!(f, "commit {commit_id}"),
223            Self::Range { from, to } => write!(f, "changes from {from} to {to}"),
224            Self::Changelog {
225                from,
226                to,
227                version_name,
228                date,
229            } => {
230                let version_str = version_name
231                    .as_ref()
232                    .map_or_else(|| "unreleased".to_string(), |v| format!("v{v}"));
233                write!(f, "changelog {version_str} ({date}) from {from} to {to}")
234            }
235            Self::Amend { .. } => write!(f, "amending previous commit"),
236            Self::Discover => write!(f, "auto-discovered changes"),
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_for_gen() {
247        let ctx = TaskContext::for_gen();
248        assert!(matches!(
249            ctx,
250            TaskContext::Staged {
251                include_unstaged: false
252            }
253        ));
254    }
255
256    #[test]
257    fn test_review_staged_only() {
258        let ctx = TaskContext::for_review(None, None, None, false).expect("should succeed");
259        assert!(matches!(
260            ctx,
261            TaskContext::Staged {
262                include_unstaged: false
263            }
264        ));
265    }
266
267    #[test]
268    fn test_review_with_unstaged() {
269        let ctx = TaskContext::for_review(None, None, None, true).expect("should succeed");
270        assert!(matches!(
271            ctx,
272            TaskContext::Staged {
273                include_unstaged: true
274            }
275        ));
276    }
277
278    #[test]
279    fn test_review_single_commit() {
280        let ctx = TaskContext::for_review(Some("abc123".to_string()), None, None, false)
281            .expect("should succeed");
282        assert!(matches!(ctx, TaskContext::Commit { commit_id } if commit_id == "abc123"));
283    }
284
285    #[test]
286    fn test_review_range() {
287        let ctx = TaskContext::for_review(
288            None,
289            Some("main".to_string()),
290            Some("feature".to_string()),
291            false,
292        )
293        .expect("should succeed");
294        assert!(
295            matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "feature")
296        );
297    }
298
299    #[test]
300    fn test_review_from_without_to_fails() {
301        let result = TaskContext::for_review(None, Some("main".to_string()), None, false);
302        assert!(result.is_err());
303        assert!(result.unwrap_err().to_string().contains("--to"));
304    }
305
306    #[test]
307    fn test_review_commit_with_range_fails() {
308        // commit + from + to should fail as mutually exclusive
309        let result = TaskContext::for_review(
310            Some("abc123".to_string()),
311            Some("main".to_string()),
312            Some("feature".to_string()),
313            false,
314        );
315        assert!(result.is_err());
316        assert!(
317            result
318                .unwrap_err()
319                .to_string()
320                .contains("mutually exclusive")
321        );
322    }
323
324    #[test]
325    fn test_review_unstaged_with_range_fails() {
326        let result = TaskContext::for_review(
327            None,
328            Some("main".to_string()),
329            Some("feature".to_string()),
330            true,
331        );
332        assert!(result.is_err());
333        assert!(result.unwrap_err().to_string().contains("include-unstaged"));
334    }
335
336    #[test]
337    fn test_pr_defaults() {
338        let ctx = TaskContext::for_pr(None, None);
339        assert!(matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "HEAD"));
340    }
341
342    #[test]
343    fn test_pr_from_only() {
344        let ctx = TaskContext::for_pr(Some("develop".to_string()), None);
345        assert!(
346            matches!(ctx, TaskContext::Range { from, to } if from == "develop" && to == "HEAD")
347        );
348    }
349
350    #[test]
351    fn test_changelog() {
352        let ctx = TaskContext::for_changelog(
353            "v1.0.0".to_string(),
354            None,
355            Some("1.1.0".to_string()),
356            Some("2025-01-15".to_string()),
357        );
358        assert!(matches!(
359            ctx,
360            TaskContext::Changelog { from, to, version_name, date }
361                if from == "v1.0.0" && to == "HEAD"
362                && version_name == Some("1.1.0".to_string())
363                && date == "2025-01-15"
364        ));
365    }
366
367    #[test]
368    fn test_changelog_default_date() {
369        let ctx = TaskContext::for_changelog("v1.0.0".to_string(), None, None, None);
370        // Should use today's date
371        if let TaskContext::Changelog { date, .. } = ctx {
372            assert!(!date.is_empty());
373            assert!(date.contains('-')); // YYYY-MM-DD format
374        } else {
375            panic!("Expected Changelog variant");
376        }
377    }
378
379    #[test]
380    fn test_diff_hint() {
381        let staged = TaskContext::for_gen();
382        assert!(staged.diff_hint().contains("staged"));
383
384        let commit = TaskContext::Commit {
385            commit_id: "abc".to_string(),
386        };
387        assert!(commit.diff_hint().contains("abc^1"));
388
389        let range = TaskContext::Range {
390            from: "main".to_string(),
391            to: "dev".to_string(),
392        };
393        assert!(range.diff_hint().contains("main"));
394        assert!(range.diff_hint().contains("dev"));
395
396        let amend = TaskContext::for_amend("Fix bug".to_string());
397        assert!(amend.diff_hint().contains("HEAD^1"));
398    }
399
400    #[test]
401    fn test_amend_context() {
402        let ctx = TaskContext::for_amend("Initial commit message".to_string());
403        assert!(ctx.is_amend());
404        assert_eq!(ctx.original_message(), Some("Initial commit message"));
405        assert!(!ctx.is_range());
406        assert!(!ctx.includes_unstaged());
407    }
408
409    #[test]
410    fn test_amend_display() {
411        let ctx = TaskContext::for_amend("Fix bug".to_string());
412        assert_eq!(format!("{ctx}"), "amending previous commit");
413    }
414}