1use crate::ai::{AiEvent, AiRequest, BackendConfig, resolve_backend};
2use crate::cache::{CacheLookup, CacheManager};
3use crate::git::{GitRepo, SnapshotGuard};
4use crate::ui;
5use anyhow::{Context, Result, bail};
6use indicatif::ProgressBar;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tokio::sync::mpsc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct CommitPlan {
14 pub commits: Vec<PlannedCommit>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PlannedCommit {
19 pub order: Option<u32>,
20 pub message: String,
21 pub body: Option<String>,
22 pub footer: Option<String>,
23 pub files: Vec<String>,
24}
25
26#[derive(Debug, clap::Args)]
27pub struct CommitArgs {
28 #[arg(short, long)]
30 pub staged: bool,
31
32 #[arg(short = 'M', long)]
34 pub message: Option<String>,
35
36 #[arg(short = 'n', long)]
38 pub dry_run: bool,
39
40 #[arg(short, long)]
42 pub yes: bool,
43
44 #[arg(long)]
46 pub no_cache: bool,
47}
48
49const COMMIT_SCHEMA: &str = r#"{
50 "type": "object",
51 "properties": {
52 "commits": {
53 "type": "array",
54 "items": {
55 "type": "object",
56 "properties": {
57 "order": { "type": "integer" },
58 "message": { "type": "string", "description": "Header: type(scope): subject — imperative, lowercase, no period, max 72 chars" },
59 "body": { "type": "string", "description": "Body: explain WHY the change was made, wrap at 72 chars" },
60 "footer": { "type": "string", "description": "Footer: BREAKING CHANGE notes, Closes/Fixes/Refs #issue, etc." },
61 "files": { "type": "array", "items": { "type": "string" } }
62 },
63 "required": ["order", "message", "body", "files"]
64 }
65 }
66 },
67 "required": ["commits"]
68}"#;
69
70fn build_system_prompt(commit_pattern: &str, type_names: &[&str]) -> String {
71 let types_list = type_names.join(", ");
72 format!(
73 r#"You are an expert at analyzing git diffs and creating atomic, well-organized commits following the Angular Conventional Commits standard.
74
75HEADER ("message" field):
76- Must match this regex: {commit_pattern}
77- Format: type(scope): subject
78- Valid types ONLY: {types_list}
79- NEVER invent types. Words like db, auth, api, etc. are scopes, not types. Use the semantically correct type for the change (e.g. feat(db): add user cache migration, fix(auth): resolve token expiry)
80- scope is optional but recommended when applicable
81- subject: imperative mood, lowercase first letter, no period at end, max 72 chars
82
83BODY ("body" field — required):
84- Explain WHY the change was made, not what changed (the diff shows that)
85- Use imperative tense ("add" not "added")
86- Wrap at 72 characters
87
88FOOTER ("footer" field — optional):
89- BREAKING CHANGE: description of what breaks and migration path
90- Closes #N, Fixes #N, Refs #N for issue references
91- Only include when relevant
92
93COMMIT ORGANIZATION:
94- Each commit must be atomic: one logical change per commit
95- Every changed file must appear in exactly one commit
96- CRITICAL: A file must NEVER appear in more than one commit. The execution engine stages entire files, not individual hunks. Splitting one file across commits will fail.
97- If one file contains multiple logical changes, place it in the most fitting commit and note the secondary changes in that commit's body.
98- Order: infrastructure/config -> core library -> features -> tests -> docs
99- File paths must be relative to the repository root and match exactly as git reports them"#
100 )
101}
102
103enum CacheStatus {
104 None,
106 Cached,
108 Incremental,
110}
111
112pub async fn run(args: &CommitArgs, backend_config: &BackendConfig) -> Result<()> {
113 ui::header("sr commit");
114
115 let repo = GitRepo::discover()?;
117 ui::phase_ok("Repository found", None);
118
119 let config = sr_core::config::ReleaseConfig::find_config(repo.root().as_path())
121 .map(|(path, _)| sr_core::config::ReleaseConfig::load(&path))
122 .transpose()?
123 .unwrap_or_default();
124 let type_names: Vec<&str> = config.types.iter().map(|t| t.name.as_str()).collect();
125 let system_prompt = build_system_prompt(&config.commit_pattern, &type_names);
126
127 let has_changes = if args.staged {
129 repo.has_staged_changes()?
130 } else {
131 repo.has_any_changes()?
132 };
133
134 if !has_changes {
135 bail!(crate::error::SrAiError::NoChanges);
136 }
137
138 let statuses = repo.file_statuses().unwrap_or_default();
139 let file_count = statuses.len();
140 ui::phase_ok(
141 "Changes detected",
142 Some(&format!(
143 "{file_count} file{}",
144 if file_count == 1 { "" } else { "s" }
145 )),
146 );
147
148 let backend = resolve_backend(backend_config).await?;
150 let backend_name = backend.name().to_string();
151 let model_name = backend_config
152 .model
153 .as_deref()
154 .unwrap_or("default")
155 .to_string();
156 ui::phase_ok(
157 "Backend resolved",
158 Some(&format!("{backend_name} ({model_name})")),
159 );
160
161 let cache = if args.no_cache {
163 None
164 } else {
165 CacheManager::new(
166 repo.root(),
167 args.staged,
168 args.message.as_deref(),
169 &backend_name,
170 &model_name,
171 )
172 };
173
174 let snapshot = SnapshotGuard::new(&repo)?;
178 ui::phase_ok("Working tree snapshot saved", None);
179
180 let (mut plan, cache_status) = match cache.as_ref().map(|c| c.lookup()) {
182 Some(CacheLookup::ExactHit(cached_plan)) => {
183 ui::phase_ok(
184 "Plan loaded",
185 Some(&format!("{} commits · cached", cached_plan.commits.len())),
186 );
187 (cached_plan, CacheStatus::Cached)
188 }
189 Some(CacheLookup::IncrementalHit {
190 previous_plan,
191 delta_summary,
192 }) => {
193 let spinner = ui::spinner(&format!(
194 "Analyzing changes with {backend_name} (incremental)..."
195 ));
196 let (tx, event_handler) = spawn_event_handler(&spinner);
197
198 let user_prompt =
199 build_incremental_prompt(args, &repo, &previous_plan, &delta_summary)?;
200
201 let request = AiRequest {
202 system_prompt: system_prompt.clone(),
203 user_prompt,
204 json_schema: Some(COMMIT_SCHEMA.to_string()),
205 working_dir: repo.root().to_string_lossy().to_string(),
206 };
207
208 let response = backend.request(&request, Some(tx)).await?;
209 let _ = event_handler.await;
210
211 let p: CommitPlan = parse_plan(&response.text)?;
212
213 let detail = format_done_detail(p.commits.len(), "incremental", &response.usage);
214 ui::spinner_done(&spinner, Some(&detail));
215
216 (p, CacheStatus::Incremental)
217 }
218 _ => {
219 let spinner = ui::spinner(&format!("Analyzing changes with {backend_name}..."));
220 let (tx, event_handler) = spawn_event_handler(&spinner);
221
222 let user_prompt = build_user_prompt(args, &repo)?;
223
224 let request = AiRequest {
225 system_prompt: system_prompt.clone(),
226 user_prompt,
227 json_schema: Some(COMMIT_SCHEMA.to_string()),
228 working_dir: repo.root().to_string_lossy().to_string(),
229 };
230
231 let response = backend.request(&request, Some(tx)).await?;
232 let _ = event_handler.await;
233
234 let p: CommitPlan = parse_plan(&response.text)?;
235
236 let detail = format_done_detail(p.commits.len(), "", &response.usage);
237 ui::spinner_done(&spinner, Some(&detail));
238
239 (p, CacheStatus::None)
240 }
241 };
242
243 if plan.commits.is_empty() {
244 bail!(crate::error::SrAiError::EmptyPlan);
245 }
246
247 let pre_validate_count = plan.commits.len();
249 plan = validate_plan(plan);
250 if plan.commits.len() < pre_validate_count {
251 ui::warn(&format!(
252 "Shared files detected — merged {} commits into 1",
253 pre_validate_count - plan.commits.len() + 1
254 ));
255 }
256
257 if let Some(cache) = &cache {
259 cache.store(&plan, &backend_name, &model_name);
260 }
261
262 let cache_label: Option<&str> = match &cache_status {
264 CacheStatus::Cached => Some("cached"),
265 CacheStatus::Incremental => Some("incremental"),
266 CacheStatus::None => None,
267 };
268 ui::display_plan(&plan, &statuses, cache_label);
269
270 if args.dry_run {
271 ui::info("Dry run — no commits created");
272 println!();
273 snapshot.success();
274 return Ok(());
275 }
276
277 if !args.yes && !ui::confirm("Execute plan? [y/N]")? {
279 bail!(crate::error::SrAiError::Cancelled);
280 }
281
282 let invalid = validate_messages(&plan, &config.commit_pattern);
284 if !invalid.is_empty() {
285 ui::invalid_messages(&invalid);
286 if !args.yes && !ui::confirm("Continue anyway? Invalid commits will likely fail. [y/N]")? {
287 bail!(crate::error::SrAiError::Cancelled);
288 }
289 }
290
291 execute_plan(&repo, &plan)?;
293
294 snapshot.success();
296
297 Ok(())
298}
299
300fn build_user_prompt(args: &CommitArgs, repo: &GitRepo) -> Result<String> {
301 let git_root = repo.root().to_string_lossy();
302
303 let mut prompt = if args.staged {
304 "Analyze the staged git changes and group them into atomic commits.\n\
305 Use `git diff --cached` and `git diff --cached --stat` to inspect what's staged."
306 .to_string()
307 } else {
308 "Analyze all git changes (staged, unstaged, and untracked) and group them into atomic commits.\n\
309 Use `git diff HEAD`, `git diff --cached`, `git diff`, `git status --porcelain`, and \
310 `git ls-files --others --exclude-standard` to inspect changes."
311 .to_string()
312 };
313
314 prompt.push_str(&format!("\nThe git repository root is: {git_root}"));
315
316 if let Some(msg) = &args.message {
317 prompt.push_str(&format!("\n\nAdditional context from the user:\n{msg}"));
318 }
319
320 Ok(prompt)
321}
322
323fn build_incremental_prompt(
324 args: &CommitArgs,
325 repo: &GitRepo,
326 previous_plan: &CommitPlan,
327 delta_summary: &str,
328) -> Result<String> {
329 let mut prompt = build_user_prompt(args, repo)?;
330
331 let previous_json =
332 serde_json::to_string_pretty(previous_plan).unwrap_or_else(|_| "{}".to_string());
333
334 prompt.push_str(&format!(
335 "\n\n--- INCREMENTAL HINTS ---\n\
336 A previous commit plan exists for a similar set of changes. \
337 Maintain the groupings for unchanged files where possible. \
338 Only re-analyze files that have changed.\n\n\
339 Previous plan:\n```json\n{previous_json}\n```\n\n\
340 File delta:\n{delta_summary}"
341 ));
342
343 Ok(prompt)
344}
345
346fn validate_plan(plan: CommitPlan) -> CommitPlan {
349 let mut file_counts: HashMap<String, usize> = HashMap::new();
351 for commit in &plan.commits {
352 for file in &commit.files {
353 *file_counts.entry(file.clone()).or_default() += 1;
354 }
355 }
356
357 let dupes: Vec<&String> = file_counts
358 .iter()
359 .filter(|(_, count)| **count > 1)
360 .map(|(file, _)| file)
361 .collect();
362
363 if dupes.is_empty() {
364 return plan;
365 }
366
367 let mut tainted = Vec::new();
369 let mut clean = Vec::new();
370
371 for commit in plan.commits {
372 let is_tainted = commit.files.iter().any(|f| dupes.contains(&f));
373 if is_tainted {
374 tainted.push(commit);
375 } else {
376 clean.push(commit);
377 }
378 }
379
380 let merged_message = tainted
382 .first()
383 .map(|c| c.message.clone())
384 .unwrap_or_default();
385
386 let merged_body = tainted
387 .iter()
388 .filter_map(|c| c.body.as_ref())
389 .filter(|b| !b.is_empty())
390 .cloned()
391 .collect::<Vec<_>>()
392 .join("\n\n");
393
394 let merged_footer = tainted
395 .iter()
396 .filter_map(|c| c.footer.as_ref())
397 .filter(|f| !f.is_empty())
398 .cloned()
399 .collect::<Vec<_>>()
400 .join("\n");
401
402 let mut merged_files: Vec<String> = tainted
403 .iter()
404 .flat_map(|c| c.files.iter().cloned())
405 .collect();
406 merged_files.sort();
407 merged_files.dedup();
408
409 let merged_commit = PlannedCommit {
410 order: Some(1),
411 message: merged_message,
412 body: if merged_body.is_empty() {
413 None
414 } else {
415 Some(merged_body)
416 },
417 footer: if merged_footer.is_empty() {
418 None
419 } else {
420 Some(merged_footer)
421 },
422 files: merged_files,
423 };
424
425 let mut result = vec![merged_commit];
427 for (i, mut commit) in clean.into_iter().enumerate() {
428 commit.order = Some(i as u32 + 2);
429 result.push(commit);
430 }
431
432 CommitPlan { commits: result }
433}
434
435fn parse_plan(text: &str) -> Result<CommitPlan> {
437 let value: serde_json::Value =
441 serde_json::from_str(text).context("failed to parse JSON from AI response")?;
442 serde_json::from_value(value).context("failed to parse commit plan from AI response")
443}
444
445fn spawn_event_handler(
447 spinner: &ProgressBar,
448) -> (mpsc::UnboundedSender<AiEvent>, tokio::task::JoinHandle<()>) {
449 let (tx, mut rx) = mpsc::unbounded_channel();
450 let pb = spinner.clone();
451 let handle = tokio::spawn(async move {
452 while let Some(event) = rx.recv().await {
453 match event {
454 AiEvent::ToolCall { input, .. } => ui::tool_call(&pb, &input),
455 }
456 }
457 });
458 (tx, handle)
459}
460
461fn format_done_detail(
463 commit_count: usize,
464 extra: &str,
465 usage: &Option<crate::ai::AiUsage>,
466) -> String {
467 let commits = format!(
468 "{commit_count} commit{}",
469 if commit_count == 1 { "" } else { "s" }
470 );
471 let extra_part = if extra.is_empty() {
472 String::new()
473 } else {
474 format!(" · {extra}")
475 };
476 let usage_part = match usage {
477 Some(u) => {
478 let cost = u
479 .cost_usd
480 .map(|c| format!(" · ${c:.4}"))
481 .unwrap_or_default();
482 format!(
483 " · {} in / {} out{}",
484 ui::format_tokens(u.input_tokens),
485 ui::format_tokens(u.output_tokens),
486 cost
487 )
488 }
489 None => String::new(),
490 };
491 format!("{commits}{extra_part}{usage_part}")
492}
493
494fn validate_messages(plan: &CommitPlan, commit_pattern: &str) -> Vec<(usize, String, String)> {
497 let re = match Regex::new(commit_pattern) {
498 Ok(re) => re,
499 Err(e) => {
500 return plan
502 .commits
503 .iter()
504 .enumerate()
505 .map(|(i, c)| (i + 1, c.message.clone(), format!("invalid pattern: {e}")))
506 .collect();
507 }
508 };
509
510 plan.commits
511 .iter()
512 .enumerate()
513 .filter(|(_, c)| !re.is_match(&c.message))
514 .map(|(i, c)| {
515 (
516 i + 1,
517 c.message.clone(),
518 format!("does not match pattern: {commit_pattern}"),
519 )
520 })
521 .collect()
522}
523
524fn execute_plan(repo: &GitRepo, plan: &CommitPlan) -> Result<()> {
525 repo.reset_head()?;
527
528 let total = plan.commits.len();
529 let mut created: Vec<(String, String)> = Vec::new();
530 let mut failed: Vec<(usize, String, String)> = Vec::new();
531
532 for (i, commit) in plan.commits.iter().enumerate() {
533 ui::commit_start(i + 1, total, &commit.message);
534
535 for file in &commit.files {
537 let ok = repo.stage_file(file)?;
538 ui::file_staged(file, ok);
539 }
540
541 let mut full_message = commit.message.clone();
543 if let Some(body) = &commit.body
544 && !body.is_empty()
545 {
546 full_message.push_str("\n\n");
547 full_message.push_str(body);
548 }
549 if let Some(footer) = &commit.footer
550 && !footer.is_empty()
551 {
552 full_message.push_str("\n\n");
553 full_message.push_str(footer);
554 }
555
556 if repo.has_staged_after_add()? {
558 match repo.commit(&full_message) {
559 Ok(()) => {
560 let sha = repo.head_short().unwrap_or_else(|_| "???????".to_string());
561 ui::commit_created(&sha);
562 created.push((sha, commit.message.clone()));
563 }
564 Err(e) => {
565 ui::commit_failed(&format!("{e:#}"));
566 failed.push((i + 1, commit.message.clone(), format!("{e:#}")));
567 repo.reset_head()?;
569 }
570 }
571 } else {
572 ui::commit_skipped();
573 }
574 }
575
576 ui::summary(&created);
577
578 if !failed.is_empty() {
579 ui::failed_commits(&failed);
580 if created.is_empty() {
581 bail!("all {} commits failed", failed.len());
582 }
583 }
584
585 Ok(())
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn validate_plan_no_dupes() {
594 let plan = CommitPlan {
595 commits: vec![
596 PlannedCommit {
597 order: Some(1),
598 message: "feat: add foo".into(),
599 body: Some("reason".into()),
600 footer: None,
601 files: vec!["a.rs".into()],
602 },
603 PlannedCommit {
604 order: Some(2),
605 message: "fix: fix bar".into(),
606 body: Some("reason".into()),
607 footer: None,
608 files: vec!["b.rs".into()],
609 },
610 ],
611 };
612
613 let result = validate_plan(plan);
614 assert_eq!(result.commits.len(), 2);
615 }
616
617 #[test]
618 fn validate_plan_merges_dupes() {
619 let plan = CommitPlan {
620 commits: vec![
621 PlannedCommit {
622 order: Some(1),
623 message: "feat: add foo".into(),
624 body: Some("reason 1".into()),
625 footer: None,
626 files: vec!["shared.rs".into(), "a.rs".into()],
627 },
628 PlannedCommit {
629 order: Some(2),
630 message: "fix: fix bar".into(),
631 body: Some("reason 2".into()),
632 footer: None,
633 files: vec!["shared.rs".into(), "b.rs".into()],
634 },
635 PlannedCommit {
636 order: Some(3),
637 message: "docs: update readme".into(),
638 body: Some("docs".into()),
639 footer: None,
640 files: vec!["README.md".into()],
641 },
642 ],
643 };
644
645 let result = validate_plan(plan);
646 assert_eq!(result.commits.len(), 2);
648 assert_eq!(result.commits[0].message, "feat: add foo");
649 assert!(result.commits[0].files.contains(&"shared.rs".to_string()));
650 assert!(result.commits[0].files.contains(&"a.rs".to_string()));
651 assert!(result.commits[0].files.contains(&"b.rs".to_string()));
652 assert_eq!(result.commits[1].message, "docs: update readme");
653 assert_eq!(result.commits[1].order, Some(2));
654 }
655
656 #[test]
657 fn validate_messages_all_valid() {
658 let plan = CommitPlan {
659 commits: vec![
660 PlannedCommit {
661 order: Some(1),
662 message: "feat: add foo".into(),
663 body: None,
664 footer: None,
665 files: vec![],
666 },
667 PlannedCommit {
668 order: Some(2),
669 message: "fix(core): null check".into(),
670 body: None,
671 footer: None,
672 files: vec![],
673 },
674 ],
675 };
676
677 let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
678 let invalid = validate_messages(&plan, pattern);
679 assert!(invalid.is_empty());
680 }
681
682 #[test]
683 fn validate_messages_catches_invalid() {
684 let plan = CommitPlan {
685 commits: vec![
686 PlannedCommit {
687 order: Some(1),
688 message: "feat: add foo".into(),
689 body: None,
690 footer: None,
691 files: vec![],
692 },
693 PlannedCommit {
694 order: Some(2),
695 message: "not a conventional commit".into(),
696 body: None,
697 footer: None,
698 files: vec![],
699 },
700 PlannedCommit {
701 order: Some(3),
702 message: "fix: valid one".into(),
703 body: None,
704 footer: None,
705 files: vec![],
706 },
707 ],
708 };
709
710 let pattern = sr_core::commit::DEFAULT_COMMIT_PATTERN;
711 let invalid = validate_messages(&plan, pattern);
712 assert_eq!(invalid.len(), 1);
713 assert_eq!(invalid[0].0, 2); assert_eq!(invalid[0].1, "not a conventional commit");
715 }
716
717 #[test]
718 fn validate_messages_invalid_pattern() {
719 let plan = CommitPlan {
720 commits: vec![PlannedCommit {
721 order: Some(1),
722 message: "feat: add foo".into(),
723 body: None,
724 footer: None,
725 files: vec![],
726 }],
727 };
728
729 let invalid = validate_messages(&plan, "[invalid regex");
730 assert_eq!(invalid.len(), 1);
731 assert!(invalid[0].2.contains("invalid pattern"));
732 }
733}