Skip to main content

git_iris/agents/
context.rs

1//! Task context for agent operations
2//!
3//! This module provides structured, validated context for agent tasks,
4//! replacing fragile string-based parameter passing.
5
6use anyhow::{Result, bail};
7use serde::{Deserialize, Serialize};
8
9/// Pull request template context discovered from GitHub-supported locations.
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct PullRequestTemplateContext {
12    /// Repository-relative template path
13    pub path: String,
14    /// Markdown template body
15    pub body: String,
16}
17
18/// Validated, structured context for agent tasks.
19///
20/// This enum represents the different modes of operation for code analysis,
21/// with validation built into the constructors.
22#[derive(Debug, Clone, Default, Serialize, Deserialize)]
23#[serde(tag = "mode", rename_all = "snake_case")]
24pub enum TaskContext {
25    /// Analyze staged changes (optionally including unstaged)
26    Staged {
27        /// Whether to include unstaged changes in analysis
28        include_unstaged: bool,
29    },
30
31    /// Analyze a single commit
32    Commit {
33        /// The commit ID (hash, branch name, or commitish like HEAD~1)
34        commit_id: String,
35    },
36
37    /// Analyze a range of commits or branch comparison
38    Range {
39        /// Starting reference (exclusive)
40        from: String,
41        /// Ending reference (inclusive)
42        to: String,
43    },
44
45    /// Generate a pull request description, optionally revising an existing body
46    PullRequest {
47        /// Starting reference (exclusive)
48        from: String,
49        /// Ending reference (inclusive)
50        to: String,
51        /// Existing GitHub PR description to revise
52        #[serde(skip)]
53        existing_body: Option<String>,
54        /// Pull request template to adapt the description around
55        #[serde(skip)]
56        template: Option<PullRequestTemplateContext>,
57    },
58
59    /// Generate changelog or release notes with version metadata
60    Changelog {
61        /// Starting reference (exclusive)
62        from: String,
63        /// Ending reference (inclusive)
64        to: String,
65        /// Explicit version name (e.g., "1.2.0")
66        version_name: Option<String>,
67        /// Release date in YYYY-MM-DD format
68        date: String,
69    },
70
71    /// Amend the previous commit with staged changes
72    /// The agent sees the combined diff from HEAD^1 to staged state
73    Amend {
74        /// The original commit message being amended
75        original_message: String,
76    },
77
78    /// Let the agent discover context via tools (default for gen command)
79    #[default]
80    Discover,
81}
82
83impl TaskContext {
84    /// Create context for the gen (commit message) command.
85    /// Always uses staged changes only.
86    #[must_use]
87    pub fn for_gen() -> Self {
88        Self::Staged {
89            include_unstaged: false,
90        }
91    }
92
93    /// Create context for amending the previous commit.
94    /// The agent will see the combined diff from HEAD^1 to staged state.
95    #[must_use]
96    pub fn for_amend(original_message: String) -> Self {
97        Self::Amend { original_message }
98    }
99
100    /// Create context for the review command with full parameter validation.
101    ///
102    /// Validates:
103    /// - `--from` requires `--to` for explicit range comparison
104    /// - `--to` on its own compares `<fallback-base>..to`
105    /// - `--commit` is mutually exclusive with `--from/--to`
106    /// - `--include-unstaged` is incompatible with range comparisons
107    ///
108    /// # Errors
109    ///
110    /// Returns an error when the provided flag combination is invalid.
111    pub fn for_review(
112        commit: Option<String>,
113        from: Option<String>,
114        to: Option<String>,
115        include_unstaged: bool,
116    ) -> Result<Self> {
117        Self::for_review_with_base(commit, from, to, include_unstaged, "main")
118    }
119
120    /// Create review context with an explicit default base for `--to`-only comparisons.
121    ///
122    /// CLI and Studio should prefer a repo-aware base from `GitRepo::get_default_base_ref()`.
123    /// This lets branch comparisons follow the repository's actual primary branch
124    /// instead of relying on the legacy `"main"` fallback.
125    ///
126    /// # Errors
127    ///
128    /// Returns an error when the provided flag combination is invalid.
129    pub fn for_review_with_base(
130        commit: Option<String>,
131        from: Option<String>,
132        to: Option<String>,
133        include_unstaged: bool,
134        default_base: &str,
135    ) -> Result<Self> {
136        // Validate: --from requires --to
137        if from.is_some() && to.is_none() {
138            bail!("When using --from, you must also specify --to for branch comparison reviews");
139        }
140
141        // Validate: --commit is mutually exclusive with --from/--to
142        if commit.is_some() && (from.is_some() || to.is_some()) {
143            bail!("Cannot use --commit with --from/--to. These are mutually exclusive options");
144        }
145
146        // Validate: --include-unstaged incompatible with range comparisons
147        if include_unstaged && (from.is_some() || to.is_some()) {
148            bail!(
149                "Cannot use --include-unstaged with --from/--to. Branch comparison reviews don't include working directory changes"
150            );
151        }
152
153        // Route to correct variant based on parameters
154        Ok(match (commit, from, to) {
155            (Some(id), _, _) => Self::Commit { commit_id: id },
156            (_, Some(f), Some(t)) => Self::Range { from: f, to: t },
157            (None, None, Some(t)) => Self::Range {
158                from: default_base.to_string(),
159                to: t,
160            },
161            _ => Self::Staged { include_unstaged },
162        })
163    }
164
165    /// Create context for the PR command.
166    ///
167    /// PR command is more flexible - all parameter combinations are valid:
168    /// - `from` + `to`: Explicit range/branch comparison
169    /// - `from` only: Compare `from..HEAD`
170    /// - `to` only: Compare `<fallback-base>..to`
171    /// - Neither: Compare `<fallback-base>..HEAD`
172    #[must_use]
173    pub fn for_pr(from: Option<String>, to: Option<String>) -> Self {
174        Self::for_pr_with_base(from, to, "main")
175    }
176
177    /// Create PR context with an explicit default comparison base.
178    ///
179    /// CLI and Studio should prefer a repo-aware base from `GitRepo::get_default_base_ref()`.
180    #[must_use]
181    pub fn for_pr_with_base(from: Option<String>, to: Option<String>, default_base: &str) -> Self {
182        Self::for_pr_update_with_base(from, to, default_base, None, None)
183    }
184
185    /// Create PR context with an optional existing PR description.
186    ///
187    /// When present, the agent should revise the existing body instead of blindly replacing it.
188    #[must_use]
189    pub fn for_pr_update_with_base(
190        from: Option<String>,
191        to: Option<String>,
192        default_base: &str,
193        existing_body: Option<String>,
194        template: Option<PullRequestTemplateContext>,
195    ) -> Self {
196        let (from, to) = match (from, to) {
197            (Some(f), Some(t)) => (f, t),
198            (Some(f), None) => (f, "HEAD".to_string()),
199            (None, Some(t)) => (default_base.to_string(), t),
200            (None, None) => (default_base.to_string(), "HEAD".to_string()),
201        };
202
203        Self::PullRequest {
204            from,
205            to,
206            existing_body,
207            template,
208        }
209    }
210
211    /// Create context for changelog/release-notes commands.
212    ///
213    /// These always require a `from` reference; `to` defaults to HEAD.
214    /// Automatically sets today's date if not provided.
215    #[must_use]
216    pub fn for_changelog(
217        from: String,
218        to: Option<String>,
219        version_name: Option<String>,
220        date: Option<String>,
221    ) -> Self {
222        Self::Changelog {
223            from,
224            to: to.unwrap_or_else(|| "HEAD".to_string()),
225            version_name,
226            date: date.unwrap_or_else(|| chrono::Local::now().format("%Y-%m-%d").to_string()),
227        }
228    }
229
230    /// Generate a human-readable prompt context string for the agent.
231    #[must_use]
232    pub fn to_prompt_context(&self) -> String {
233        serde_json::to_string_pretty(self).unwrap_or_else(|_| format!("{self:?}"))
234    }
235
236    /// Generate a hint for which `git_diff` call the agent should make.
237    #[must_use]
238    pub fn diff_hint(&self) -> String {
239        match self {
240            Self::Staged { include_unstaged } => {
241                if *include_unstaged {
242                    "git_diff() for staged changes, then check unstaged files".to_string()
243                } else {
244                    "git_diff() for staged changes".to_string()
245                }
246            }
247            Self::Commit { commit_id } => {
248                format!("git_diff(from=\"{commit_id}^1\", to=\"{commit_id}\")")
249            }
250            Self::Range { from, to }
251            | Self::PullRequest { from, to, .. }
252            | Self::Changelog { from, to, .. } => {
253                format!("git_diff(from=\"{from}\", to=\"{to}\")")
254            }
255            Self::Amend { .. } => {
256                "git_diff(from=\"HEAD^1\") for combined amend diff (original commit + new staged changes)".to_string()
257            }
258            Self::Discover => "git_diff() to discover current changes".to_string(),
259        }
260    }
261
262    /// Check if this context represents a range comparison (vs staged/single commit)
263    #[must_use]
264    pub fn is_range(&self) -> bool {
265        matches!(self, Self::Range { .. } | Self::PullRequest { .. })
266    }
267
268    /// Check if this context involves unstaged changes
269    #[must_use]
270    pub fn includes_unstaged(&self) -> bool {
271        matches!(
272            self,
273            Self::Staged {
274                include_unstaged: true
275            }
276        )
277    }
278
279    /// Check if this is an amend operation
280    #[must_use]
281    pub fn is_amend(&self) -> bool {
282        matches!(self, Self::Amend { .. })
283    }
284
285    /// Get the original commit message if this is an amend context
286    #[must_use]
287    pub fn original_message(&self) -> Option<&str> {
288        match self {
289            Self::Amend { original_message } => Some(original_message),
290            _ => None,
291        }
292    }
293
294    /// Get the existing pull request body when PR generation is revising one.
295    #[must_use]
296    pub fn existing_pull_request_body(&self) -> Option<&str> {
297        match self {
298            Self::PullRequest {
299                existing_body: Some(body),
300                ..
301            } => Some(body),
302            _ => None,
303        }
304    }
305
306    /// Get the pull request template when one was discovered.
307    #[must_use]
308    pub fn pull_request_template(&self) -> Option<&PullRequestTemplateContext> {
309        match self {
310            Self::PullRequest {
311                template: Some(template),
312                ..
313            } => Some(template),
314            _ => None,
315        }
316    }
317}
318
319impl std::fmt::Display for TaskContext {
320    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321        match self {
322            Self::Staged { include_unstaged } => {
323                if *include_unstaged {
324                    write!(f, "staged and unstaged changes")
325                } else {
326                    write!(f, "staged changes")
327                }
328            }
329            Self::Commit { commit_id } => write!(f, "commit {commit_id}"),
330            Self::Range { from, to } => write!(f, "changes from {from} to {to}"),
331            Self::PullRequest { from, to, .. } => {
332                write!(f, "pull request changes from {from} to {to}")
333            }
334            Self::Changelog {
335                from,
336                to,
337                version_name,
338                date,
339            } => {
340                let version_str = version_name
341                    .as_ref()
342                    .map_or_else(|| "unreleased".to_string(), |v| format!("v{v}"));
343                write!(f, "changelog {version_str} ({date}) from {from} to {to}")
344            }
345            Self::Amend { .. } => write!(f, "amending previous commit"),
346            Self::Discover => write!(f, "auto-discovered changes"),
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_for_gen() {
357        let ctx = TaskContext::for_gen();
358        assert!(matches!(
359            ctx,
360            TaskContext::Staged {
361                include_unstaged: false
362            }
363        ));
364    }
365
366    #[test]
367    fn test_review_staged_only() {
368        let ctx = TaskContext::for_review(None, None, None, false).expect("should succeed");
369        assert!(matches!(
370            ctx,
371            TaskContext::Staged {
372                include_unstaged: false
373            }
374        ));
375    }
376
377    #[test]
378    fn test_review_with_unstaged() {
379        let ctx = TaskContext::for_review(None, None, None, true).expect("should succeed");
380        assert!(matches!(
381            ctx,
382            TaskContext::Staged {
383                include_unstaged: true
384            }
385        ));
386    }
387
388    #[test]
389    fn test_review_single_commit() {
390        let ctx = TaskContext::for_review(Some("abc123".to_string()), None, None, false)
391            .expect("should succeed");
392        assert!(matches!(ctx, TaskContext::Commit { commit_id } if commit_id == "abc123"));
393    }
394
395    #[test]
396    fn test_review_range() {
397        let ctx = TaskContext::for_review(
398            None,
399            Some("main".to_string()),
400            Some("feature".to_string()),
401            false,
402        )
403        .expect("should succeed");
404        assert!(
405            matches!(ctx, TaskContext::Range { from, to } if from == "main" && to == "feature")
406        );
407    }
408
409    #[test]
410    fn test_review_to_only_defaults_from_explicit_base() {
411        let ctx = TaskContext::for_review_with_base(
412            None,
413            None,
414            Some("feature".to_string()),
415            false,
416            "trunk",
417        )
418        .expect("should succeed");
419        assert!(
420            matches!(ctx, TaskContext::Range { from, to } if from == "trunk" && to == "feature")
421        );
422    }
423
424    #[test]
425    fn test_review_from_without_to_fails() {
426        let result = TaskContext::for_review(None, Some("main".to_string()), None, false);
427        assert!(result.is_err());
428        assert!(
429            result
430                .expect_err("should be err")
431                .to_string()
432                .contains("--to")
433        );
434    }
435
436    #[test]
437    fn test_review_commit_with_range_fails() {
438        // commit + from + to should fail as mutually exclusive
439        let result = TaskContext::for_review(
440            Some("abc123".to_string()),
441            Some("main".to_string()),
442            Some("feature".to_string()),
443            false,
444        );
445        assert!(result.is_err());
446        assert!(
447            result
448                .expect_err("should be err")
449                .to_string()
450                .contains("mutually exclusive")
451        );
452    }
453
454    #[test]
455    fn test_review_unstaged_with_range_fails() {
456        let result = TaskContext::for_review(
457            None,
458            Some("main".to_string()),
459            Some("feature".to_string()),
460            true,
461        );
462        assert!(result.is_err());
463        assert!(
464            result
465                .expect_err("should be err")
466                .to_string()
467                .contains("include-unstaged")
468        );
469    }
470
471    #[test]
472    fn test_pr_defaults() {
473        let ctx = TaskContext::for_pr_with_base(None, None, "trunk");
474        assert!(
475            matches!(ctx, TaskContext::PullRequest { from, to, existing_body, template } if from == "trunk" && to == "HEAD" && existing_body.is_none() && template.is_none())
476        );
477    }
478
479    #[test]
480    fn test_pr_from_only() {
481        let ctx = TaskContext::for_pr(Some("develop".to_string()), None);
482        assert!(
483            matches!(ctx, TaskContext::PullRequest { from, to, existing_body, template } if from == "develop" && to == "HEAD" && existing_body.is_none() && template.is_none())
484        );
485    }
486
487    #[test]
488    fn test_pr_existing_body() {
489        let ctx = TaskContext::for_pr_update_with_base(
490            Some("main".to_string()),
491            Some("feature".to_string()),
492            "trunk",
493            Some("Existing body".to_string()),
494            None,
495        );
496        assert!(
497            matches!(ctx, TaskContext::PullRequest { from, to, existing_body, .. } if from == "main" && to == "feature" && existing_body == Some("Existing body".to_string()))
498        );
499    }
500
501    #[test]
502    fn test_changelog() {
503        let ctx = TaskContext::for_changelog(
504            "v1.0.0".to_string(),
505            None,
506            Some("1.1.0".to_string()),
507            Some("2025-01-15".to_string()),
508        );
509        assert!(matches!(
510            ctx,
511            TaskContext::Changelog { from, to, version_name, date }
512                if from == "v1.0.0" && to == "HEAD"
513                && version_name == Some("1.1.0".to_string())
514                && date == "2025-01-15"
515        ));
516    }
517
518    #[test]
519    fn test_changelog_default_date() {
520        let ctx = TaskContext::for_changelog("v1.0.0".to_string(), None, None, None);
521        // Should use today's date
522        if let TaskContext::Changelog { date, .. } = ctx {
523            assert!(!date.is_empty());
524            assert!(date.contains('-')); // YYYY-MM-DD format
525        } else {
526            panic!("Expected Changelog variant");
527        }
528    }
529
530    #[test]
531    fn test_diff_hint() {
532        let staged = TaskContext::for_gen();
533        assert!(staged.diff_hint().contains("staged"));
534
535        let commit = TaskContext::Commit {
536            commit_id: "abc".to_string(),
537        };
538        assert!(commit.diff_hint().contains("abc^1"));
539
540        let range = TaskContext::Range {
541            from: "main".to_string(),
542            to: "dev".to_string(),
543        };
544        assert!(range.diff_hint().contains("main"));
545        assert!(range.diff_hint().contains("dev"));
546
547        let amend = TaskContext::for_amend("Fix bug".to_string());
548        assert!(amend.diff_hint().contains("HEAD^1"));
549    }
550
551    #[test]
552    fn test_amend_context() {
553        let ctx = TaskContext::for_amend("Initial commit message".to_string());
554        assert!(ctx.is_amend());
555        assert_eq!(ctx.original_message(), Some("Initial commit message"));
556        assert!(!ctx.is_range());
557        assert!(!ctx.includes_unstaged());
558    }
559
560    #[test]
561    fn test_amend_display() {
562        let ctx = TaskContext::for_amend("Fix bug".to_string());
563        assert_eq!(format!("{ctx}"), "amending previous commit");
564    }
565}