Skip to main content

sr_ai/commands/
commit.rs

1use crate::ai::{AiRequest, BackendConfig, resolve_backend};
2use crate::cache::{CacheLookup, CacheManager};
3use crate::git::GitRepo;
4use crate::ui;
5use anyhow::{Context, Result, bail};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct CommitPlan {
11    pub commits: Vec<PlannedCommit>,
12}
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PlannedCommit {
16    pub order: Option<u32>,
17    pub message: String,
18    pub body: Option<String>,
19    pub footer: Option<String>,
20    pub files: Vec<String>,
21}
22
23#[derive(Debug, clap::Args)]
24pub struct CommitArgs {
25    /// Only analyze staged changes
26    #[arg(short, long)]
27    pub staged: bool,
28
29    /// Additional context or instructions for commit generation
30    #[arg(short = 'M', long)]
31    pub message: Option<String>,
32
33    /// Display plan without executing
34    #[arg(short = 'n', long)]
35    pub dry_run: bool,
36
37    /// Skip confirmation prompt
38    #[arg(short, long)]
39    pub yes: bool,
40
41    /// Bypass cache (always call AI)
42    #[arg(long)]
43    pub no_cache: bool,
44}
45
46const COMMIT_SCHEMA: &str = r#"{
47    "type": "object",
48    "properties": {
49        "commits": {
50            "type": "array",
51            "items": {
52                "type": "object",
53                "properties": {
54                    "order": { "type": "integer" },
55                    "message": { "type": "string", "description": "Header: type(scope): subject — imperative, lowercase, no period, max 72 chars" },
56                    "body": { "type": "string", "description": "Body: explain WHY the change was made, wrap at 72 chars" },
57                    "footer": { "type": "string", "description": "Footer: BREAKING CHANGE notes, Closes/Fixes/Refs #issue, etc." },
58                    "files": { "type": "array", "items": { "type": "string" } }
59                },
60                "required": ["order", "message", "body", "files"]
61            }
62        }
63    },
64    "required": ["commits"]
65}"#;
66
67const SYSTEM_PROMPT: &str = r#"You are an expert at analyzing git diffs and creating atomic, well-organized commits following the Angular Conventional Commits standard.
68
69HEADER ("message" field):
70- Must match this regex: (?s)(build|bump|chore|ci|docs|feat|fix|perf|refactor|revert|style|test)(\(\S+\))?!?: ([^\n\r]+)((\n\n.*)|(\\s*))?$
71- Format: type(scope): subject
72- Valid types ONLY: build, bump, chore, ci, docs, feat, fix, perf, refactor, revert, style, test
73- NEVER invent types. Words like db, auth, api, etc. are scopes, not types. Use the semantically correct type for the change (e.g. feat(db): add user cache migration, fix(auth): resolve token expiry)
74- scope is optional but recommended when applicable
75- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
76
77BODY ("body" field — required):
78- Explain WHY the change was made, not what changed (the diff shows that)
79- Use imperative tense ("add" not "added")
80- Wrap at 72 characters
81
82FOOTER ("footer" field — optional):
83- BREAKING CHANGE: description of what breaks and migration path
84- Closes #N, Fixes #N, Refs #N for issue references
85- Only include when relevant
86
87COMMIT ORGANIZATION:
88- Each commit must be atomic: one logical change per commit
89- Every changed file must appear in exactly one commit
90- CRITICAL: A file must NEVER appear in more than one commit. The execution engine stages entire files, not individual hunks. Splitting one file across commits will fail.
91- If one file contains multiple logical changes, place it in the most fitting commit and note the secondary changes in that commit's body.
92- Order: infrastructure/config -> core library -> features -> tests -> docs
93- File paths must be relative to the repository root and match exactly as git reports them"#;
94
95enum CacheStatus {
96    /// No cache used (--no-cache, or cache unavailable)
97    None,
98    /// Exact cache hit
99    Cached,
100    /// Incremental hit with delta info
101    Incremental { cached: usize, reanalyzed: usize },
102}
103
104pub async fn run(args: &CommitArgs, backend_config: &BackendConfig) -> Result<()> {
105    let repo = GitRepo::discover()?;
106
107    // Check for changes
108    let has_changes = if args.staged {
109        repo.has_staged_changes()?
110    } else {
111        repo.has_any_changes()?
112    };
113
114    if !has_changes {
115        bail!(crate::error::SrAiError::NoChanges);
116    }
117
118    // Resolve AI backend
119    let backend = resolve_backend(backend_config).await?;
120    let backend_name = backend.name().to_string();
121    let model_name = backend_config
122        .model
123        .as_deref()
124        .unwrap_or("default")
125        .to_string();
126
127    // Build cache manager (may be None if cache dir unavailable)
128    let cache = if args.no_cache {
129        None
130    } else {
131        CacheManager::new(
132            repo.root(),
133            args.staged,
134            args.message.as_deref(),
135            &backend_name,
136            &model_name,
137        )
138    };
139
140    // Cache lookup
141    let (mut plan, cache_status) = match cache.as_ref().map(|c| c.lookup()) {
142        Some(CacheLookup::ExactHit(cached_plan)) => (cached_plan, CacheStatus::Cached),
143        Some(CacheLookup::IncrementalHit {
144            previous_plan,
145            delta_summary,
146        }) => {
147            let cached_count = previous_plan.commits.len();
148
149            let spinner = ui::spinner(&format!(
150                "Analyzing changes with {} (incremental)...",
151                backend_name
152            ));
153
154            let user_prompt =
155                build_incremental_prompt(args, &repo, &previous_plan, &delta_summary)?;
156
157            let request = AiRequest {
158                system_prompt: SYSTEM_PROMPT.to_string(),
159                user_prompt,
160                json_schema: Some(COMMIT_SCHEMA.to_string()),
161                working_dir: repo.root().to_string_lossy().to_string(),
162            };
163
164            let response = backend.request(&request).await?;
165            spinner.finish_and_clear();
166
167            let p: CommitPlan = serde_json::from_str(&response.text)
168                .context("failed to parse commit plan from AI response")?;
169
170            (
171                p,
172                CacheStatus::Incremental {
173                    cached: cached_count,
174                    reanalyzed: 1, // at least one AI call was made
175                },
176            )
177        }
178        _ => {
179            // Miss or no cache
180            let spinner = ui::spinner(&format!("Analyzing changes with {}...", backend_name));
181
182            let user_prompt = build_user_prompt(args, &repo)?;
183
184            let request = AiRequest {
185                system_prompt: SYSTEM_PROMPT.to_string(),
186                user_prompt,
187                json_schema: Some(COMMIT_SCHEMA.to_string()),
188                working_dir: repo.root().to_string_lossy().to_string(),
189            };
190
191            let response = backend.request(&request).await?;
192            spinner.finish_and_clear();
193
194            let p: CommitPlan = serde_json::from_str(&response.text)
195                .context("failed to parse commit plan from AI response")?;
196
197            (p, CacheStatus::None)
198        }
199    };
200
201    if plan.commits.is_empty() {
202        bail!(crate::error::SrAiError::EmptyPlan);
203    }
204
205    // Validate: merge commits with shared files
206    plan = validate_plan(plan);
207
208    // Store in cache (before display/execute so dry-runs populate cache too)
209    if let Some(cache) = &cache {
210        cache.store(&plan, &backend_name, &model_name);
211    }
212
213    // Display plan with cache status indicator
214    match &cache_status {
215        CacheStatus::Cached => println!("[cached]"),
216        CacheStatus::Incremental { cached, reanalyzed } => {
217            println!("[incremental: {cached} cached, {reanalyzed} re-analyzed]")
218        }
219        CacheStatus::None => {}
220    }
221    ui::display_plan(&plan);
222
223    if args.dry_run {
224        println!();
225        println!("(dry run - no commits created)");
226        return Ok(());
227    }
228
229    // Confirm
230    if !args.yes {
231        println!();
232        if !ui::confirm("Execute this plan? [y/N]")? {
233            bail!(crate::error::SrAiError::Cancelled);
234        }
235    }
236
237    // Execute
238    execute_plan(&repo, &plan)?;
239
240    Ok(())
241}
242
243fn build_user_prompt(args: &CommitArgs, repo: &GitRepo) -> Result<String> {
244    let git_root = repo.root().to_string_lossy();
245
246    let mut prompt = if args.staged {
247        "Analyze the staged git changes and group them into atomic commits.\n\
248         Use `git diff --cached` and `git diff --cached --stat` to inspect what's staged."
249            .to_string()
250    } else {
251        "Analyze all git changes (staged, unstaged, and untracked) and group them into atomic commits.\n\
252         Use `git diff HEAD`, `git diff --cached`, `git diff`, `git status --porcelain`, and \
253         `git ls-files --others --exclude-standard` to inspect changes."
254            .to_string()
255    };
256
257    prompt.push_str(&format!("\nThe git repository root is: {git_root}"));
258
259    if let Some(msg) = &args.message {
260        prompt.push_str(&format!("\n\nAdditional context from the user:\n{msg}"));
261    }
262
263    Ok(prompt)
264}
265
266fn build_incremental_prompt(
267    args: &CommitArgs,
268    repo: &GitRepo,
269    previous_plan: &CommitPlan,
270    delta_summary: &str,
271) -> Result<String> {
272    let mut prompt = build_user_prompt(args, repo)?;
273
274    let previous_json =
275        serde_json::to_string_pretty(previous_plan).unwrap_or_else(|_| "{}".to_string());
276
277    prompt.push_str(&format!(
278        "\n\n--- INCREMENTAL HINTS ---\n\
279         A previous commit plan exists for a similar set of changes. \
280         Maintain the groupings for unchanged files where possible. \
281         Only re-analyze files that have changed.\n\n\
282         Previous plan:\n```json\n{previous_json}\n```\n\n\
283         File delta:\n{delta_summary}"
284    ));
285
286    Ok(prompt)
287}
288
289/// Validate that no file appears in multiple commits. If duplicates are found,
290/// merge affected commits into one.
291fn validate_plan(plan: CommitPlan) -> CommitPlan {
292    // Count file occurrences
293    let mut file_counts: HashMap<String, usize> = HashMap::new();
294    for commit in &plan.commits {
295        for file in &commit.files {
296            *file_counts.entry(file.clone()).or_default() += 1;
297        }
298    }
299
300    let dupes: Vec<&String> = file_counts
301        .iter()
302        .filter(|(_, count)| **count > 1)
303        .map(|(file, _)| file)
304        .collect();
305
306    if dupes.is_empty() {
307        return plan;
308    }
309
310    eprintln!();
311    eprintln!("Notice: shared files detected across commits — merging affected commits.");
312    eprintln!(
313        "Shared files: {}",
314        dupes
315            .iter()
316            .map(|s| s.as_str())
317            .collect::<Vec<_>>()
318            .join(" ")
319    );
320
321    // Partition into tainted (has any dupe file) and clean
322    let mut tainted = Vec::new();
323    let mut clean = Vec::new();
324
325    for commit in plan.commits {
326        let is_tainted = commit.files.iter().any(|f| dupes.contains(&f));
327        if is_tainted {
328            tainted.push(commit);
329        } else {
330            clean.push(commit);
331        }
332    }
333
334    // Merge all tainted commits into one
335    let merged_message = tainted
336        .first()
337        .map(|c| c.message.clone())
338        .unwrap_or_default();
339
340    let merged_body = tainted
341        .iter()
342        .filter_map(|c| c.body.as_ref())
343        .filter(|b| !b.is_empty())
344        .cloned()
345        .collect::<Vec<_>>()
346        .join("\n\n");
347
348    let merged_footer = tainted
349        .iter()
350        .filter_map(|c| c.footer.as_ref())
351        .filter(|f| !f.is_empty())
352        .cloned()
353        .collect::<Vec<_>>()
354        .join("\n");
355
356    let mut merged_files: Vec<String> = tainted
357        .iter()
358        .flat_map(|c| c.files.iter().cloned())
359        .collect();
360    merged_files.sort();
361    merged_files.dedup();
362
363    let merged_commit = PlannedCommit {
364        order: Some(1),
365        message: merged_message,
366        body: if merged_body.is_empty() {
367            None
368        } else {
369            Some(merged_body)
370        },
371        footer: if merged_footer.is_empty() {
372            None
373        } else {
374            Some(merged_footer)
375        },
376        files: merged_files,
377    };
378
379    // Re-number: merged first, then clean commits
380    let mut result = vec![merged_commit];
381    for (i, mut commit) in clean.into_iter().enumerate() {
382        commit.order = Some(i as u32 + 2);
383        result.push(commit);
384    }
385
386    CommitPlan { commits: result }
387}
388
389fn execute_plan(repo: &GitRepo, plan: &CommitPlan) -> Result<()> {
390    // Unstage everything first
391    repo.reset_head()?;
392
393    let total = plan.commits.len();
394
395    for (i, commit) in plan.commits.iter().enumerate() {
396        println!();
397        println!("Creating commit {}/{total}: {}", i + 1, commit.message);
398
399        // Stage files for this commit
400        for file in &commit.files {
401            if !repo.stage_file(file)? {
402                eprintln!("  Warning: file not found: {file}");
403            }
404        }
405
406        // Build full commit message
407        let mut full_message = commit.message.clone();
408        if let Some(body) = &commit.body
409            && !body.is_empty()
410        {
411            full_message.push_str("\n\n");
412            full_message.push_str(body);
413        }
414        if let Some(footer) = &commit.footer
415            && !footer.is_empty()
416        {
417            full_message.push_str("\n\n");
418            full_message.push_str(footer);
419        }
420
421        // Create commit (only if there are staged files)
422        if repo.has_staged_after_add()? {
423            repo.commit(&full_message)?;
424        } else {
425            eprintln!(
426                "  Warning: no files staged for this commit (may already be committed or missing)"
427            );
428        }
429    }
430
431    println!();
432    println!("Done! Recent commits:");
433    println!("{}", repo.recent_commits(total)?);
434
435    Ok(())
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn validate_plan_no_dupes() {
444        let plan = CommitPlan {
445            commits: vec![
446                PlannedCommit {
447                    order: Some(1),
448                    message: "feat: add foo".into(),
449                    body: Some("reason".into()),
450                    footer: None,
451                    files: vec!["a.rs".into()],
452                },
453                PlannedCommit {
454                    order: Some(2),
455                    message: "fix: fix bar".into(),
456                    body: Some("reason".into()),
457                    footer: None,
458                    files: vec!["b.rs".into()],
459                },
460            ],
461        };
462
463        let result = validate_plan(plan);
464        assert_eq!(result.commits.len(), 2);
465    }
466
467    #[test]
468    fn validate_plan_merges_dupes() {
469        let plan = CommitPlan {
470            commits: vec![
471                PlannedCommit {
472                    order: Some(1),
473                    message: "feat: add foo".into(),
474                    body: Some("reason 1".into()),
475                    footer: None,
476                    files: vec!["shared.rs".into(), "a.rs".into()],
477                },
478                PlannedCommit {
479                    order: Some(2),
480                    message: "fix: fix bar".into(),
481                    body: Some("reason 2".into()),
482                    footer: None,
483                    files: vec!["shared.rs".into(), "b.rs".into()],
484                },
485                PlannedCommit {
486                    order: Some(3),
487                    message: "docs: update readme".into(),
488                    body: Some("docs".into()),
489                    footer: None,
490                    files: vec!["README.md".into()],
491                },
492            ],
493        };
494
495        let result = validate_plan(plan);
496        // Two tainted merged into one + one clean = 2
497        assert_eq!(result.commits.len(), 2);
498        assert_eq!(result.commits[0].message, "feat: add foo");
499        assert!(result.commits[0].files.contains(&"shared.rs".to_string()));
500        assert!(result.commits[0].files.contains(&"a.rs".to_string()));
501        assert!(result.commits[0].files.contains(&"b.rs".to_string()));
502        assert_eq!(result.commits[1].message, "docs: update readme");
503        assert_eq!(result.commits[1].order, Some(2));
504    }
505}