Skip to main content

sr_ai/commands/
commit.rs

1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::cache::{CacheLookup, CacheManager};
3use crate::git::{GitRepo, SnapshotGuard};
4use crate::ui;
5use anyhow::{Context, Result, bail};
6use indicatif::ProgressBar;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct CommitPlan {
13    pub commits: Vec<PlannedCommit>,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct PlannedCommit {
18    pub order: Option<u32>,
19    pub message: String,
20    pub body: Option<String>,
21    pub footer: Option<String>,
22    pub files: Vec<String>,
23}
24
25#[derive(Debug, clap::Args)]
26pub struct CommitArgs {
27    /// Only analyze staged changes
28    #[arg(short, long)]
29    pub staged: bool,
30
31    /// Additional context or instructions for commit generation
32    #[arg(short = 'M', long)]
33    pub message: Option<String>,
34
35    /// Display plan without executing
36    #[arg(short = 'n', long)]
37    pub dry_run: bool,
38
39    /// Skip confirmation prompt
40    #[arg(short, long)]
41    pub yes: bool,
42
43    /// Bypass cache (always call AI)
44    #[arg(long)]
45    pub no_cache: bool,
46}
47
48const COMMIT_SCHEMA: &str = r#"{
49    "type": "object",
50    "properties": {
51        "commits": {
52            "type": "array",
53            "items": {
54                "type": "object",
55                "properties": {
56                    "order": { "type": "integer" },
57                    "message": { "type": "string", "description": "Header: type(scope): subject — imperative, lowercase, no period, max 72 chars" },
58                    "body": { "type": "string", "description": "Body: explain WHY the change was made, wrap at 72 chars" },
59                    "footer": { "type": "string", "description": "Footer: BREAKING CHANGE notes, Closes/Fixes/Refs #issue, etc." },
60                    "files": { "type": "array", "items": { "type": "string" } }
61                },
62                "required": ["order", "message", "body", "files"]
63            }
64        }
65    },
66    "required": ["commits"]
67}"#;
68
69fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
70    let types_list = type_names.join(", ");
71    format!(
72        r#"You are an expert at analyzing git diffs and creating atomic, well-organized commits following the Angular Conventional Commits standard.
73
74HEADER ("message" field):
75- Must match this regex: {commit_pattern}
76- Format: type(scope): subject
77- Valid types ONLY: {types_list}
78- NEVER invent types. Words like db, auth, api, etc. are scopes, not types. Use the semantically correct type for the change (e.g. feat(db): add user cache migration, fix(auth): resolve token expiry)
79- scope is optional but recommended when applicable
80- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
81
82BODY ("body" field — required):
83- Explain WHY the change was made, not what changed (the diff shows that)
84- Use imperative tense ("add" not "added")
85- Wrap at 72 characters
86
87FOOTER ("footer" field — optional):
88- BREAKING CHANGE: description of what breaks and migration path
89- Closes #N, Fixes #N, Refs #N for issue references
90- Only include when relevant
91
92COMMIT ORGANIZATION:
93- Each commit must be atomic: one logical change per commit
94- Every changed file must appear in exactly one commit
95- CRITICAL: A file must NEVER appear in more than one commit. The execution engine stages entire files, not individual hunks. Splitting one file across commits will fail.
96- If one file contains multiple logical changes, place it in the most fitting commit and note the secondary changes in that commit's body.
97- Order: infrastructure/config -> core library -> features -> tests -> docs
98- File paths must be relative to the repository root and match exactly as git reports them"#
99    )
100}
101
102enum CacheStatus {
103    /// No cache used (--no-cache, or cache unavailable)
104    None,
105    /// Exact cache hit
106    Cached,
107    /// Incremental hit
108    Incremental,
109}
110
111pub async fn run(args: &CommitArgs, backend_config: &BackendConfig) -> Result<()> {
112    ui::header("sr commit");
113
114    // Phase 1: Discover repository
115    let repo = GitRepo::discover()?;
116    ui::phase_ok("Repository found", None);
117
118    // Load project config for commit types and pattern
119    let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
120        .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
121        .transpose()?
122        .unwrap_or_default();
123    let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
124    let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
125
126    // Phase 2: Check for changes
127    let has_changes = if args.staged {
128        repo.has_staged_changes()?
129    } else {
130        repo.has_any_changes()?
131    };
132
133    if !has_changes {
134        bail!(crate::error::SrAiError::NoChanges);
135    }
136
137    let statuses = repo.file_statuses().unwrap_or_default();
138    let file_count = statuses.len();
139    ui::phase_ok(
140        "Changes detected",
141        Some(&format!(
142            "{file_count} file{}",
143            if file_count == 1 { "" } else { "s" }
144        )),
145    );
146
147    // Phase 3: Resolve AI backend
148    let backend = resolve_backend(backend_config).await?;
149    let backend_name = backend.name().to_string();
150    let model_name = backend_config
151        .model
152        .as_deref()
153        .unwrap_or("default")
154        .to_string();
155    ui::phase_ok(
156        "Backend resolved",
157        Some(&format!("{backend_name} ({model_name})")),
158    );
159
160    // Build cache manager (may be None if cache dir unavailable)
161    let cache = if args.no_cache {
162        None
163    } else {
164        CacheManager::new(
165            repo.root(),
166            args.staged,
167            args.message.as_deref(),
168            &backend_name,
169            &model_name,
170        )
171    };
172
173    // Snapshot the working tree before the agent runs.
174    // If anything goes wrong (agent failure, unexpected mutations),
175    // the guard restores the working tree from the snapshot on drop.
176    let snapshot = SnapshotGuard::new(&repo)?;
177    ui::phase_ok("Working tree snapshot saved", None);
178
179    // Phase 4: Generate plan (cache or AI)
180    let (mut plan, cache_status) = match cache.as_ref().map(|c| c.lookup()) {
181        Some(CacheLookup::ExactHit(cached_plan)) => {
182            ui::phase_ok(
183                "Plan loaded",
184                Some(&format!("{} commits · cached", cached_plan.commits.len())),
185            );
186            (cached_plan, CacheStatus::Cached)
187        }
188        Some(CacheLookup::IncrementalHit {
189            previous_plan,
190            delta_summary,
191        }) => {
192            let spinner = ui::spinner(&format!(
193                "Analyzing changes with {backend_name} (incremental)..."
194            ));
195            let (tx, event_handler) = spawn_event_handler(&spinner);
196
197            let user_prompt =
198                build_incremental_prompt(args, &repo, &previous_plan, &delta_summary)?;
199
200            let request = AiRequest {
201                system_prompt: system_prompt.clone(),
202                user_prompt,
203                json_schema: Some(COMMIT_SCHEMA.to_string()),
204                working_dir: repo.root().to_string_lossy().to_string(),
205            };
206
207            let response = backend.request(&request, Some(tx)).await?;
208            let _ = event_handler.await;
209
210            let p: CommitPlan = parse_plan(&response.text)?;
211
212            let detail = format_done_detail(p.commits.len(), "incremental", &response.usage);
213            ui::spinner_done(&spinner, Some(&detail));
214
215            (p, CacheStatus::Incremental)
216        }
217        _ => {
218            let spinner = ui::spinner(&format!("Analyzing changes with {backend_name}..."));
219            let (tx, event_handler) = spawn_event_handler(&spinner);
220
221            let user_prompt = build_user_prompt(args, &repo)?;
222
223            let request = AiRequest {
224                system_prompt: system_prompt.clone(),
225                user_prompt,
226                json_schema: Some(COMMIT_SCHEMA.to_string()),
227                working_dir: repo.root().to_string_lossy().to_string(),
228            };
229
230            let response = backend.request(&request, Some(tx)).await?;
231            let _ = event_handler.await;
232
233            let p: CommitPlan = parse_plan(&response.text)?;
234
235            let detail = format_done_detail(p.commits.len(), "", &response.usage);
236            ui::spinner_done(&spinner, Some(&detail));
237
238            (p, CacheStatus::None)
239        }
240    };
241
242    if plan.commits.is_empty() {
243        bail!(crate::error::SrAiError::EmptyPlan);
244    }
245
246    // Validate: merge commits with shared files
247    let pre_validate_count = plan.commits.len();
248    plan = validate_plan(plan);
249    if plan.commits.len() < pre_validate_count {
250        ui::warn(&format!(
251            "Shared files detected — merged {} commits into 1",
252            pre_validate_count - plan.commits.len() + 1
253        ));
254    }
255
256    // Store in cache (before display/execute so dry-runs populate cache too)
257    if let Some(cache) = &cache {
258        cache.store(&plan, &backend_name, &model_name);
259    }
260
261    // Display plan
262    let cache_label: Option<&str> = match &cache_status {
263        CacheStatus::Cached => Some("cached"),
264        CacheStatus::Incremental => Some("incremental"),
265        CacheStatus::None => None,
266    };
267    ui::display_plan(&plan, &statuses, cache_label);
268
269    if args.dry_run {
270        ui::info("Dry run — no commits created");
271        println!();
272        return Ok(());
273    }
274
275    // Confirm
276    if !args.yes && !ui::confirm("Execute plan? [y/N]")? {
277        bail!(crate::error::SrAiError::Cancelled);
278    }
279
280    // Execute
281    execute_plan(&repo, &plan)?;
282
283    // All commits succeeded — clear the snapshot
284    snapshot.success();
285
286    Ok(())
287}
288
289fn build_user_prompt(args: &CommitArgs, repo: &GitRepo) -> Result<String> {
290    let git_root = repo.root().to_string_lossy();
291
292    let mut prompt = if args.staged {
293        "Analyze the staged git changes and group them into atomic commits.\n\
294         Use `git diff --cached` and `git diff --cached --stat` to inspect what's staged."
295            .to_string()
296    } else {
297        "Analyze all git changes (staged, unstaged, and untracked) and group them into atomic commits.\n\
298         Use `git diff HEAD`, `git diff --cached`, `git diff`, `git status --porcelain`, and \
299         `git ls-files --others --exclude-standard` to inspect changes."
300            .to_string()
301    };
302
303    prompt.push_str(&format!("\nThe git repository root is: {git_root}"));
304
305    if let Some(msg) = &args.message {
306        prompt.push_str(&format!("\n\nAdditional context from the user:\n{msg}"));
307    }
308
309    Ok(prompt)
310}
311
312fn build_incremental_prompt(
313    args: &CommitArgs,
314    repo: &GitRepo,
315    previous_plan: &CommitPlan,
316    delta_summary: &str,
317) -> Result<String> {
318    let mut prompt = build_user_prompt(args, repo)?;
319
320    let previous_json =
321        serde_json::to_string_pretty(previous_plan).unwrap_or_else(|_| "{}".to_string());
322
323    prompt.push_str(&format!(
324        "\n\n--- INCREMENTAL HINTS ---\n\
325         A previous commit plan exists for a similar set of changes. \
326         Maintain the groupings for unchanged files where possible. \
327         Only re-analyze files that have changed.\n\n\
328         Previous plan:\n```json\n{previous_json}\n```\n\n\
329         File delta:\n{delta_summary}"
330    ));
331
332    Ok(prompt)
333}
334
335/// Validate that no file appears in multiple commits. If duplicates are found,
336/// merge affected commits into one.
337fn validate_plan(plan: CommitPlan) -> CommitPlan {
338    // Count file occurrences
339    let mut file_counts: HashMap<String, usize> = HashMap::new();
340    for commit in &plan.commits {
341        for file in &commit.files {
342            *file_counts.entry(file.clone()).or_default() += 1;
343        }
344    }
345
346    let dupes: Vec<&String> = file_counts
347        .iter()
348        .filter(|(_, count)| **count > 1)
349        .map(|(file, _)| file)
350        .collect();
351
352    if dupes.is_empty() {
353        return plan;
354    }
355
356    // Partition into tainted (has any dupe file) and clean
357    let mut tainted = Vec::new();
358    let mut clean = Vec::new();
359
360    for commit in plan.commits {
361        let is_tainted = commit.files.iter().any(|f| dupes.contains(&f));
362        if is_tainted {
363            tainted.push(commit);
364        } else {
365            clean.push(commit);
366        }
367    }
368
369    // Merge all tainted commits into one
370    let merged_message = tainted
371        .first()
372        .map(|c| c.message.clone())
373        .unwrap_or_default();
374
375    let merged_body = tainted
376        .iter()
377        .filter_map(|c| c.body.as_ref())
378        .filter(|b| !b.is_empty())
379        .cloned()
380        .collect::<Vec<_>>()
381        .join("\n\n");
382
383    let merged_footer = tainted
384        .iter()
385        .filter_map(|c| c.footer.as_ref())
386        .filter(|f| !f.is_empty())
387        .cloned()
388        .collect::<Vec<_>>()
389        .join("\n");
390
391    let mut merged_files: Vec<String> = tainted
392        .iter()
393        .flat_map(|c| c.files.iter().cloned())
394        .collect();
395    merged_files.sort();
396    merged_files.dedup();
397
398    let merged_commit = PlannedCommit {
399        order: Some(1),
400        message: merged_message,
401        body: if merged_body.is_empty() {
402            None
403        } else {
404            Some(merged_body)
405        },
406        footer: if merged_footer.is_empty() {
407            None
408        } else {
409            Some(merged_footer)
410        },
411        files: merged_files,
412    };
413
414    // Re-number: merged first, then clean commits
415    let mut result = vec![merged_commit];
416    for (i, mut commit) in clean.into_iter().enumerate() {
417        commit.order = Some(i as u32 + 2);
418        result.push(commit);
419    }
420
421    CommitPlan { commits: result }
422}
423
424/// Parse a commit plan from JSON text, tolerating duplicate fields.
425fn parse_plan(text: &str) -> Result<CommitPlan> {
426    // Parse to Value first — serde_json::Value keeps the last value for duplicate keys,
427    // while #[derive(Deserialize)] rejects them. This handles AI responses that
428    // occasionally produce duplicate fields when schema is embedded in the prompt.
429    let value: serde_json::Value =
430        serde_json::from_str(text).context("failed to parse JSON from AI response")?;
431    serde_json::from_value(value).context("failed to parse commit plan from AI response")
432}
433
434/// Spawn a background task that renders AI events (tool calls) above a spinner.
435fn spawn_event_handler(
436    spinner: &ProgressBar,
437) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
438    let (tx, mut rx) = mpsc::unbounded_channel();
439    let pb = spinner.clone();
440    let handle = tokio::spawn(async move {
441        while let Some(event) = rx.recv().await {
442            match event {
443                AiEvent::ToolCall { input, .. } => ui::tool_call(&pb, &input),
444            }
445        }
446    });
447    (tx, handle)
448}
449
450/// Format the detail string for spinner_done, including usage if available.
451fn format_done_detail(
452    commit_count: usize,
453    extra: &str,
454    usage: &Option<crate::ai::AiUsage>,
455) -> String {
456    let commits = format!(
457        "{commit_count} commit{}",
458        if commit_count == 1 { "" } else { "s" }
459    );
460    let extra_part = if extra.is_empty() {
461        String::new()
462    } else {
463        format!(" · {extra}")
464    };
465    let usage_part = match usage {
466        Some(u) => {
467            let cost = u
468                .cost_usd
469                .map(|c| format!(" · ${c:.4}"))
470                .unwrap_or_default();
471            format!(
472                " · {} in / {} out{}",
473                ui::format_tokens(u.input_tokens),
474                ui::format_tokens(u.output_tokens),
475                cost
476            )
477        }
478        None => String::new(),
479    };
480    format!("{commits}{extra_part}{usage_part}")
481}
482
483fn execute_plan(repo: &GitRepo, plan: &CommitPlan) -> Result<()> {
484    // Unstage everything first
485    repo.reset_head()?;
486
487    // Collect all files belonging to future commits so we can hide them from
488    // pre-commit hooks. Hooks like `golangci-lint run ./...` or `go test ./...`
489    // scan the entire working tree, not just staged files. If future-commit
490    // files are visible on disk, the project may be in an inconsistent state
491    // (e.g. go.mod updated but new source files not yet committed) and the
492    // hook will fail.
493    let all_plan_files: Vec<Vec<String>> = plan.commits.iter().map(|c| c.files.clone()).collect();
494
495    let total = plan.commits.len();
496    let mut created: Vec<(String, String)> = Vec::new();
497
498    // Build a temp dir outside the repo to stash future-commit files
499    let stash_dir = tempfile::tempdir().context("failed to create temp dir for commit stash")?;
500
501    for (i, commit) in plan.commits.iter().enumerate() {
502        ui::commit_start(i + 1, total, &commit.message);
503
504        // Stage files for this commit
505        for file in &commit.files {
506            let ok = repo.stage_file(file)?;
507            ui::file_staged(file, ok);
508        }
509
510        // Hide files belonging to future commits so hooks see a consistent tree
511        let future_files: Vec<&str> = all_plan_files[i + 1..]
512            .iter()
513            .flatten()
514            .map(|s| s.as_str())
515            .collect();
516        let hidden = hide_files(repo.root(), &future_files, stash_dir.path());
517
518        // Build full commit message
519        let mut full_message = commit.message.clone();
520        if let Some(body) = &commit.body
521            && !body.is_empty()
522        {
523            full_message.push_str("\n\n");
524            full_message.push_str(body);
525        }
526        if let Some(footer) = &commit.footer
527            && !footer.is_empty()
528        {
529            full_message.push_str("\n\n");
530            full_message.push_str(footer);
531        }
532
533        // Create commit (only if there are staged files)
534        let result = if repo.has_staged_after_add()? {
535            repo.commit(&full_message)
536        } else {
537            ui::commit_skipped();
538            Ok(())
539        };
540
541        // Restore hidden files before checking for errors
542        restore_files(repo.root(), &hidden, stash_dir.path());
543
544        result?;
545
546        let sha = repo.head_short().unwrap_or_else(|_| "???????".to_string());
547        ui::commit_created(&sha);
548        created.push((sha, commit.message.clone()));
549    }
550
551    ui::summary(&created);
552
553    Ok(())
554}
555
556/// Move files out of the working tree into a temp directory.
557/// Returns the list of files that were actually moved.
558fn hide_files(
559    repo_root: &std::path::Path,
560    files: &[&str],
561    stash_dir: &std::path::Path,
562) -> Vec<String> {
563    let mut hidden = Vec::new();
564    for &file in files {
565        let src = repo_root.join(file);
566        if !src.exists() {
567            continue;
568        }
569        let dest = stash_dir.join(file);
570        if let Some(parent) = dest.parent() {
571            std::fs::create_dir_all(parent).ok();
572        }
573        if std::fs::rename(&src, &dest).is_ok() {
574            hidden.push(file.to_string());
575        }
576    }
577    hidden
578}
579
580/// Move files back from the temp directory into the working tree.
581fn restore_files(repo_root: &std::path::Path, files: &[String], stash_dir: &std::path::Path) {
582    for file in files {
583        let src = stash_dir.join(file);
584        let dest = repo_root.join(file);
585        if src.exists() {
586            if let Some(parent) = dest.parent() {
587                std::fs::create_dir_all(parent).ok();
588            }
589            std::fs::rename(&src, &dest).ok();
590        }
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597
598    #[test]
599    fn validate_plan_no_dupes() {
600        let plan = CommitPlan {
601            commits: vec![
602                PlannedCommit {
603                    order: Some(1),
604                    message: "feat: add foo".into(),
605                    body: Some("reason".into()),
606                    footer: None,
607                    files: vec!["a.rs".into()],
608                },
609                PlannedCommit {
610                    order: Some(2),
611                    message: "fix: fix bar".into(),
612                    body: Some("reason".into()),
613                    footer: None,
614                    files: vec!["b.rs".into()],
615                },
616            ],
617        };
618
619        let result = validate_plan(plan);
620        assert_eq!(result.commits.len(), 2);
621    }
622
623    #[test]
624    fn validate_plan_merges_dupes() {
625        let plan = CommitPlan {
626            commits: vec![
627                PlannedCommit {
628                    order: Some(1),
629                    message: "feat: add foo".into(),
630                    body: Some("reason 1".into()),
631                    footer: None,
632                    files: vec!["shared.rs".into(), "a.rs".into()],
633                },
634                PlannedCommit {
635                    order: Some(2),
636                    message: "fix: fix bar".into(),
637                    body: Some("reason 2".into()),
638                    footer: None,
639                    files: vec!["shared.rs".into(), "b.rs".into()],
640                },
641                PlannedCommit {
642                    order: Some(3),
643                    message: "docs: update readme".into(),
644                    body: Some("docs".into()),
645                    footer: None,
646                    files: vec!["README.md".into()],
647                },
648            ],
649        };
650
651        let result = validate_plan(plan);
652        // Two tainted merged into one + one clean = 2
653        assert_eq!(result.commits.len(), 2);
654        assert_eq!(result.commits[0].message, "feat: add foo");
655        assert!(result.commits[0].files.contains(&"shared.rs".to_string()));
656        assert!(result.commits[0].files.contains(&"a.rs".to_string()));
657        assert!(result.commits[0].files.contains(&"b.rs".to_string()));
658        assert_eq!(result.commits[1].message, "docs: update readme");
659        assert_eq!(result.commits[1].order, Some(2));
660    }
661}