Skip to main content

git_iris/agents/
context.rs

1//! Task context for agent operations
2//!
3//! This module provides structured, validated context for agent tasks,
4//! replacing fragile string-based parameter passing.
5
6use anyhow::{Result, bail};
7use serde::{Deserialize, Serialize};
8
9/// Validated, structured context for agent tasks.
10///
11/// This enum represents the different modes of operation for code analysis,
12/// with validation built into the constructors.
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(tag = "mode", rename_all = "snake_case")]
15pub enum TaskContext {
16    /// Analyze staged changes (optionally including unstaged)
17    Staged {
18        /// Whether to include unstaged changes in analysis
19        include_unstaged: bool,
20    },
21
22    /// Analyze a single commit
23    Commit {
24        /// The commit ID (hash, branch name, or commitish like HEAD~1)
25        commit_id: String,
26    },
27
28    /// Analyze a range of commits or branch comparison
29    Range {
30        /// Starting reference (exclusive)
31        from: String,
32        /// Ending reference (inclusive)
33        to: String,
34    },
35
36    /// Generate changelog or release notes with version metadata
37    Changelog {
38        /// Starting reference (exclusive)
39        from: String,
40        /// Ending reference (inclusive)
41        to: String,
42        /// Explicit version name (e.g., "1.2.0")
43        version_name: Option<String>,
44        /// Release date in YYYY-MM-DD format
45        date: String,
46    },
47
48    /// Amend the previous commit with staged changes
49    /// The agent sees the combined diff from HEAD^1 to staged state
50    Amend {
51        /// The original commit message being amended
52        original_message: String,
53    },
54
55    /// Let the agent discover context via tools (default for gen command)
56    #[default]
57    Discover,
58}
59
60impl TaskContext {
61    /// Create context for the gen (commit message) command.
62    /// Always uses staged changes only.
63    #[must_use]
64    pub fn for_gen() -> Self {
65        Self::Staged {
66            include_unstaged: false,
67        }
68    }
69
70    /// Create context for amending the previous commit.
71    /// The agent will see the combined diff from HEAD^1 to staged state.
72    #[must_use]
73    pub fn for_amend(original_message: String) -> Self {
74        Self::Amend { original_message }
75    }
76
77    /// Create context for the review command with full parameter validation.
78    ///
79    /// Validates:
80    /// - `--from` requires `--to` for explicit range comparison
81    /// - `--to` on its own compares `<fallback-base>..to`
82    /// - `--commit` is mutually exclusive with `--from/--to`
83    /// - `--include-unstaged` is incompatible with range comparisons
84    ///
85    /// # Errors
86    ///
87    /// Returns an error when the provided flag combination is invalid.
88    pub fn for_review(
89        commit: Option<String>,
90        from: Option<String>,
91        to: Option<String>,
92        include_unstaged: bool,
93    ) -> Result<Self> {
94        Self::for_review_with_base(commit, from, to, include_unstaged, "main")
95    }
96
97    /// Create review context with an explicit default base for `--to`-only comparisons.
98    ///
99    /// CLI and Studio should prefer a repo-aware base from `GitRepo::get_default_base_ref()`.
100    /// This lets branch comparisons follow the repository's actual primary branch
101    /// instead of relying on the legacy `"main"` fallback.
102    ///
103    /// # Errors
104    ///
105    /// Returns an error when the provided flag combination is invalid.
106    pub fn for_review_with_base(
107        commit: Option<String>,
108        from: Option<String>,
109        to: Option<String>,
110        include_unstaged: bool,
111        default_base: &str,
112    ) -> Result<Self> {
113        // Validate: --from requires --to
114        if from.is_some() && to.is_none() {
115            bail!("When using --from, you must also specify --to for branch comparison reviews");
116        }
117
118        // Validate: --commit is mutually exclusive with --from/--to
119        if commit.is_some() && (from.is_some() || to.is_some()) {
120            bail!("Cannot use --commit with --from/--to. These are mutually exclusive options");
121        }
122
123        // Validate: --include-unstaged incompatible with range comparisons
124        if include_unstaged && (from.is_some() || to.is_some()) {
125            bail!(
126                "Cannot use --include-unstaged with --from/--to. Branch comparison reviews don't include working directory changes"
127            );
128        }
129
130        // Route to correct variant based on parameters
131        Ok(match (commit, from, to) {
132            (Some(id), _, _) => Self::Commit { commit_id: id },
133            (_, Some(f), Some(t)) => Self::Range { from: f, to: t },
134            (None, None, Some(t)) => Self::Range {
135                from: default_base.to_string(),
136                to: t,
137            },
138            _ => Self::Staged { include_unstaged },
139        })
140    }
141
142    /// Create context for the PR command.
143    ///
144    /// PR command is more flexible - all parameter combinations are valid:
145    /// - `from` + `to`: Explicit range/branch comparison
146    /// - `from` only: Compare `from..HEAD`
147    /// - `to` only: Compare `<fallback-base>..to`
148    /// - Neither: Compare `<fallback-base>..HEAD`
149    #[must_use]
150    pub fn for_pr(from: Option<String>, to: Option<String>) -> Self {
151        Self::for_pr_with_base(from, to, "main")
152    }
153
154    /// Create PR context with an explicit default comparison base.
155    ///
156    /// CLI and Studio should prefer a repo-aware base from `GitRepo::get_default_base_ref()`.
157    #[must_use]
158    pub fn for_pr_with_base(from: Option<String>, to: Option<String>, default_base: &str) -> Self {
159        match (from, to) {
160            (Some(f), Some(t)) => Self::Range { from: f, to: t },
161            (Some(f), None) => Self::Range {
162                from: f,
163                to: "HEAD".to_string(),
164            },
165            (None, Some(t)) => Self::Range {
166                from: default_base.to_string(),
167                to: t,
168            },
169            (None, None) => Self::Range {
170                from: default_base.to_string(),
171                to: "HEAD".to_string(),
172            },
173        }
174    }
175
176    /// Create context for changelog/release-notes commands.
177    ///
178    /// These always require a `from` reference; `to` defaults to HEAD.
179    /// Automatically sets today's date if not provided.
180    #[must_use]
181    pub fn for_changelog(
182        from: String,
183        to: Option<String>,
184        version_name: Option<String>,
185        date: Option<String>,
186    ) -> Self {
187        Self::Changelog {
188            from,
189            to: to.unwrap_or_else(|| "HEAD".to_string()),
190            version_name,
191            date: date.unwrap_or_else(|| chrono::Local::now().format("%Y-%m-%d").to_string()),
192        }
193    }
194
195    /// Generate a human-readable prompt context string for the agent.
196    #[must_use]
197    pub fn to_prompt_context(&self) -> String {
198        serde_json::to_string_pretty(self).unwrap_or_else(|_| format!("{self:?}"))
199    }
200
201    /// Generate a hint for which `git_diff` call the agent should make.
202    #[must_use]
203    pub fn diff_hint(&self) -> String {
204        match self {
205            Self::Staged { include_unstaged } => {
206                if *include_unstaged {
207                    "git_diff() for staged changes, then check unstaged files".to_string()
208                } else {
209                    "git_diff() for staged changes".to_string()
210                }
211            }
212            Self::Commit { commit_id } => {
213                format!("git_diff(from=\"{commit_id}^1\", to=\"{commit_id}\")")
214            }
215            Self::Range { from, to } | Self::Changelog { from, to, .. } => {
216                format!("git_diff(from=\"{from}\", to=\"{to}\")")
217            }
218            Self::Amend { .. } => {
219                "git_diff(from=\"HEAD^1\") for combined amend diff (original commit + new staged changes)".to_string()
220            }
221            Self::Discover => "git_diff() to discover current changes".to_string(),
222        }
223    }
224
225    /// Check if this context represents a range comparison (vs staged/single commit)
226    #[must_use]
227    pub fn is_range(&self) -> bool {
228        matches!(self, Self::Range { .. })
229    }
230
231    /// Check if this context involves unstaged changes
232    #[must_use]
233    pub fn includes_unstaged(&self) -> bool {
234        matches!(
235            self,
236            Self::Staged {
237                include_unstaged: true
238            }
239        )
240    }
241
242    /// Check if this is an amend operation
243    #[must_use]
244    pub fn is_amend(&self) -> bool {
245        matches!(self, Self::Amend { .. })
246    }
247
248    /// Get the original commit message if this is an amend context
249    #[must_use]
250    pub fn original_message(&self) -> Option<&str> {
251        match self {
252            Self::Amend { original_message } => Some(original_message),
253            _ => None,
254        }
255    }
256}
257
258impl std::fmt::Display for TaskContext {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        match self {
261            Self::Staged { include_unstaged } => {
262                if *include_unstaged {
263                    write!(f, "staged and unstaged changes")
264                } else {
265                    write!(f, "staged changes")
266                }
267            }
268            Self::Commit { commit_id } => write!(f, "commit {commit_id}"),
269            Self::Range { from, to } => write!(f, "changes from {from} to {to}"),
270            Self::Changelog {
271                from,
272                to,
273                version_name,
274                date,
275            } => {
276                let version_str = version_name
277                    .as_ref()
278                    .map_or_else(|| "unreleased".to_string(), |v| format!("v{v}"));
279                write!(f, "changelog {version_str} ({date}) from {from} to {to}")
280            }
281            Self::Amend { .. } => write!(f, "amending previous commit"),
282            Self::Discover => write!(f, "auto-discovered changes"),
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_for_gen() {
293        let ctx = TaskContext::for_gen();
294        assert!(matches!(
295            ctx,
296            TaskContext::Staged {
297                include_unstaged: false
298            }
299        ));
300    }
301
302    #[test]
303    fn test_review_staged_only() {
304        let ctx = TaskContext::for_review(None, None, None, false).expect("should succeed");
305        assert!(matches!(
306            ctx,
307            TaskContext::Staged {
308                include_unstaged: false
309            }
310        ));
311    }
312
313    #[test]
314    fn test_review_with_unstaged() {
315        let ctx = TaskContext::for_review(None, None, None, true).expect("should succeed");
316        assert!(matches!(
317            ctx,
318            TaskContext::Staged {
319                include_unstaged: true
320            }
321        ));
322    }
323
324    #[test]
325    fn test_review_single_commit() {
326        let ctx = TaskContext::for_review(Some("abc123".to_string()), None, None, false)
327            .expect("should succeed");
328        assert!(matches!(ctx, TaskContext::Commit { commit_id } if commit_id == "abc123"));
329    }
330
331    #[test]
332    fn test_review_range() {
333        let ctx = TaskContext::for_review(
334            None,
335            Some("main".to_string()),
336            Some("feature".to_string()),
337            false,
338        )
339        .expect("should succeed");
340        assert!(
341            matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "feature")
342        );
343    }
344
345    #[test]
346    fn test_review_to_only_defaults_from_explicit_base() {
347        let ctx = TaskContext::for_review_with_base(
348            None,
349            None,
350            Some("feature".to_string()),
351            false,
352            "trunk",
353        )
354        .expect("should succeed");
355        assert!(
356            matches!(ctx, TaskContext::Range { from, to } if from == "trunk" && to == "feature")
357        );
358    }
359
360    #[test]
361    fn test_review_from_without_to_fails() {
362        let result = TaskContext::for_review(None, Some("main".to_string()), None, false);
363        assert!(result.is_err());
364        assert!(
365            result
366                .expect_err("should be err")
367                .to_string()
368                .contains("--to")
369        );
370    }
371
372    #[test]
373    fn test_review_commit_with_range_fails() {
374        // commit + from + to should fail as mutually exclusive
375        let result = TaskContext::for_review(
376            Some("abc123".to_string()),
377            Some("main".to_string()),
378            Some("feature".to_string()),
379            false,
380        );
381        assert!(result.is_err());
382        assert!(
383            result
384                .expect_err("should be err")
385                .to_string()
386                .contains("mutually exclusive")
387        );
388    }
389
390    #[test]
391    fn test_review_unstaged_with_range_fails() {
392        let result = TaskContext::for_review(
393            None,
394            Some("main".to_string()),
395            Some("feature".to_string()),
396            true,
397        );
398        assert!(result.is_err());
399        assert!(
400            result
401                .expect_err("should be err")
402                .to_string()
403                .contains("include-unstaged")
404        );
405    }
406
407    #[test]
408    fn test_pr_defaults() {
409        let ctx = TaskContext::for_pr_with_base(None, None, "trunk");
410        assert!(matches!(ctx, TaskContext::Range { from, to } if from == "trunk" && to == "HEAD"));
411    }
412
413    #[test]
414    fn test_pr_from_only() {
415        let ctx = TaskContext::for_pr(Some("develop".to_string()), None);
416        assert!(
417            matches!(ctx, TaskContext::Range { from, to } if from == "develop" && to == "HEAD")
418        );
419    }
420
421    #[test]
422    fn test_changelog() {
423        let ctx = TaskContext::for_changelog(
424            "v1.0.0".to_string(),
425            None,
426            Some("1.1.0".to_string()),
427            Some("2025-01-15".to_string()),
428        );
429        assert!(matches!(
430            ctx,
431            TaskContext::Changelog { from, to, version_name, date }
432                if from == "v1.0.0" && to == "HEAD"
433                && version_name == Some("1.1.0".to_string())
434                && date == "2025-01-15"
435        ));
436    }
437
438    #[test]
439    fn test_changelog_default_date() {
440        let ctx = TaskContext::for_changelog("v1.0.0".to_string(), None, None, None);
441        // Should use today's date
442        if let TaskContext::Changelog { date, .. } = ctx {
443            assert!(!date.is_empty());
444            assert!(date.contains('-')); // YYYY-MM-DD format
445        } else {
446            panic!("Expected Changelog variant");
447        }
448    }
449
450    #[test]
451    fn test_diff_hint() {
452        let staged = TaskContext::for_gen();
453        assert!(staged.diff_hint().contains("staged"));
454
455        let commit = TaskContext::Commit {
456            commit_id: "abc".to_string(),
457        };
458        assert!(commit.diff_hint().contains("abc^1"));
459
460        let range = TaskContext::Range {
461            from: "main".to_string(),
462            to: "dev".to_string(),
463        };
464        assert!(range.diff_hint().contains("main"));
465        assert!(range.diff_hint().contains("dev"));
466
467        let amend = TaskContext::for_amend("Fix bug".to_string());
468        assert!(amend.diff_hint().contains("HEAD^1"));
469    }
470
471    #[test]
472    fn test_amend_context() {
473        let ctx = TaskContext::for_amend("Initial commit message".to_string());
474        assert!(ctx.is_amend());
475        assert_eq!(ctx.original_message(), Some("Initial commit message"));
476        assert!(!ctx.is_range());
477        assert!(!ctx.includes_unstaged());
478    }
479
480    #[test]
481    fn test_amend_display() {
482        let ctx = TaskContext::for_amend("Fix bug".to_string());
483        assert_eq!(format!("{ctx}"), "amending previous commit");
484    }
485}