1use std::path::Path;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 api::{AnalysisContext, generate_conventional_analysis},
7 config::CommitConfig,
8 diff::smart_truncate_diff,
9 error::{CommitGenError, Result},
10 git::{get_git_diff, get_git_stat, get_head_hash, git_commit},
11 normalization::{format_commit_message, post_process_commit_message},
12 patch::{reset_staging, stage_group_changes},
13 style,
14 tokens::create_token_counter,
15 types::{Args, ChangeGroup, CommitType, ComposeAnalysis, ConventionalCommit, Mode},
16 validation::validate_commit_message,
17};
18
19#[derive(Debug, Serialize)]
20struct Message {
21 role: String,
22 content: String,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct FunctionParameters {
27 #[serde(rename = "type")]
28 param_type: String,
29 properties: serde_json::Value,
30 required: Vec<String>,
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34struct Function {
35 name: String,
36 description: String,
37 parameters: FunctionParameters,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41struct Tool {
42 #[serde(rename = "type")]
43 tool_type: String,
44 function: Function,
45}
46
47#[derive(Debug, Serialize)]
48struct ApiRequest {
49 model: String,
50 max_tokens: u32,
51 temperature: f32,
52 tools: Vec<Tool>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 tool_choice: Option<serde_json::Value>,
55 messages: Vec<Message>,
56}
57
58#[derive(Debug, Deserialize, Serialize)]
59struct ToolCall {
60 function: FunctionCall,
61}
62
63#[derive(Debug, Deserialize, Serialize)]
64struct FunctionCall {
65 name: String,
66 arguments: String,
67}
68
69#[derive(Debug, Deserialize, Serialize)]
70struct Choice {
71 message: ResponseMessage,
72}
73
74#[derive(Debug, Deserialize, Serialize)]
75struct ResponseMessage {
76 #[serde(default)]
77 tool_calls: Vec<ToolCall>,
78 #[serde(default)]
79 content: Option<String>,
80 #[serde(default)]
81 function_call: Option<FunctionCall>,
82}
83
84#[derive(Debug, Deserialize, Serialize)]
85struct ApiResponse {
86 choices: Vec<Choice>,
87}
88
89const COMPOSE_PROMPT: &str = r#"Split this git diff into 1-{MAX_COMMITS} logical, atomic commit groups.
90
91## Git Stat
92{STAT}
93
94## Git Diff
95{DIFF}
96
97## Rules (CRITICAL)
981. **EXHAUSTIVENESS**: You MUST account for 100% of changes. Every file and hunk in the diff above must appear in exactly one group.
992. **Atomicity**: Each group represents ONE logical change (feat/fix/refactor/etc.) that leaves codebase working.
1003. **Prefer fewer groups**: Default to 1-3 commits. Only split when changes are truly independent/separable.
1014. **Group related**: Implementation + tests go together. Refactoring + usage updates go together.
1025. **Dependencies**: Use indices. Group 2 depending on Group 1 means: dependencies: [0].
1036. **Hunk selection** (IMPORTANT - Use line numbers, NOT hunk headers):
104 - If entire file → hunks: ["ALL"]
105 - If partial → specify line ranges: hunks: [{start: 10, end: 25}, {start: 50, end: 60}]
106 - Line numbers are 1-indexed from the ORIGINAL file (look at "-" lines in diff)
107 - You can specify multiple ranges for discontinuous changes in one file
108
109## Good Example (2 independent changes)
110groups: [
111 {
112 changes: [
113 {path: "src/api.rs", hunks: ["ALL"]},
114 {path: "tests/api_test.rs", hunks: [{start: 15, end: 23}]}
115 ],
116 type: "feat", scope: "api", rationale: "add user endpoint with test",
117 dependencies: []
118 },
119 {
120 changes: [
121 {path: "src/utils.rs", hunks: [{start: 42, end: 48}, {start: 100, end: 105}]}
122 ],
123 type: "fix", scope: "utils", rationale: "fix string parsing bug in two locations",
124 dependencies: []
125 }
126]
127
128## Bad Example (over-splitting)
129❌ DON'T create 6 commits for: function rename + call sites. That's ONE refactor group.
130❌ DON'T split tests from implementation unless they test something from a prior group.
131
132## Bad Example (incomplete)
133❌ DON'T forget files. If diff shows 5 files, groups must cover all 5.
134
135Return groups in dependency order."#;
136
137#[derive(Deserialize)]
138struct ComposeResult {
139 groups: Vec<ChangeGroup>,
140}
141
142fn parse_compose_groups_from_content(content: &str) -> Result<Vec<ChangeGroup>> {
143 fn try_parse(input: &str) -> Option<Vec<ChangeGroup>> {
144 let trimmed = input.trim();
145 if trimmed.is_empty() {
146 return None;
147 }
148
149 serde_json::from_str::<ComposeResult>(trimmed)
150 .map(|r| r.groups)
151 .ok()
152 }
153
154 let trimmed = content.trim();
155 if trimmed.is_empty() {
156 return Err(CommitGenError::Other(
157 "Model returned an empty compose analysis response".to_string(),
158 ));
159 }
160
161 if let Some(groups) = try_parse(trimmed) {
162 return Ok(groups);
163 }
164
165 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}'))
166 && end >= start
167 {
168 let candidate = &trimmed[start..=end];
169 if let Some(groups) = try_parse(candidate) {
170 return Ok(groups);
171 }
172 }
173
174 let segments: Vec<&str> = trimmed.split("```").collect();
175 for (idx, segment) in segments.iter().enumerate() {
176 if idx % 2 == 1 {
177 let block = segment.trim();
178 let mut lines = block.lines();
179 let first_line = lines.next().unwrap_or_default();
180
181 let rest_owned: String;
182 let json_candidate = if first_line.trim_start().starts_with('{') {
183 block
184 } else {
185 let rest: String = lines.collect::<Vec<_>>().join("\n");
186 let trimmed_rest = rest.trim();
187 if trimmed_rest.is_empty() {
188 block
189 } else {
190 rest_owned = trimmed_rest.to_string();
191 &rest_owned
192 }
193 };
194
195 if let Some(groups) = try_parse(json_candidate) {
196 return Ok(groups);
197 }
198 }
199 }
200
201 Err(CommitGenError::Other("Failed to parse compose analysis from model response".to_string()))
202}
203
204fn parse_compose_groups_from_json(
205 raw: &str,
206) -> std::result::Result<Vec<ChangeGroup>, serde_json::Error> {
207 let trimmed = raw.trim();
208 if trimmed.starts_with('[') {
209 serde_json::from_str::<Vec<ChangeGroup>>(trimmed)
210 } else {
211 serde_json::from_str::<ComposeResult>(trimmed).map(|r| r.groups)
212 }
213}
214
215fn debug_failed_payload(source: &str, payload: &str, err: &serde_json::Error) {
216 let preview = payload.trim();
217 let preview = if preview.len() > 2000 {
218 format!("{}…", &preview[..2000])
219 } else {
220 preview.to_string()
221 };
222 eprintln!("Compose debug: failed to parse {source} payload ({err}); preview: {preview}");
223}
224
225fn group_affects_only_dependency_files(group: &ChangeGroup) -> bool {
226 group
227 .changes
228 .iter()
229 .all(|change| is_dependency_manifest(&change.path))
230}
231
232fn is_dependency_manifest(path: &str) -> bool {
233 const DEP_MANIFESTS: &[&str] = &[
234 "Cargo.toml",
235 "Cargo.lock",
236 "package.json",
237 "package-lock.json",
238 "pnpm-lock.yaml",
239 "yarn.lock",
240 "bun.lock",
241 "bun.lockb",
242 "go.mod",
243 "go.sum",
244 "requirements.txt",
245 "Pipfile",
246 "Pipfile.lock",
247 "pyproject.toml",
248 "Gemfile",
249 "Gemfile.lock",
250 "composer.json",
251 "composer.lock",
252 "build.gradle",
253 "build.gradle.kts",
254 "gradle.properties",
255 "pom.xml",
256 ];
257
258 let path = Path::new(path);
259 let Some(file_name) = path.file_name().and_then(|s| s.to_str()) else {
260 return false;
261 };
262
263 if DEP_MANIFESTS.contains(&file_name) {
264 return true;
265 }
266
267 Path::new(file_name)
268 .extension()
269 .is_some_and(|ext| ext.eq_ignore_ascii_case("lock") || ext.eq_ignore_ascii_case("lockb"))
270}
271
272pub async fn analyze_for_compose(
274 diff: &str,
275 stat: &str,
276 config: &CommitConfig,
277 max_commits: usize,
278) -> Result<ComposeAnalysis> {
279 let client = crate::api::get_client(config);
280
281 let tool = Tool {
282 tool_type: "function".to_string(),
283 function: Function {
284 name: "create_compose_analysis".to_string(),
285 description: "Split changes into logical commit groups with dependencies".to_string(),
286 parameters: FunctionParameters {
287 param_type: "object".to_string(),
288 properties: serde_json::json!({
289 "groups": {
290 "type": "array",
291 "description": "Array of change groups in dependency order",
292 "items": {
293 "type": "object",
294 "properties": {
295 "changes": {
296 "type": "array",
297 "description": "File changes with specific hunks",
298 "items": {
299 "type": "object",
300 "properties": {
301 "path": {
302 "type": "string",
303 "description": "File path"
304 },
305 "hunks": {
306 "type": "array",
307 "description": "Either ['ALL'] for entire file, or line range objects: [{start: 10, end: 25}]. Line numbers are 1-indexed from ORIGINAL file.",
308 "items": {
309 "oneOf": [
310 { "type": "string", "const": "ALL" },
311 {
312 "type": "object",
313 "properties": {
314 "start": { "type": "integer", "minimum": 1 },
315 "end": { "type": "integer", "minimum": 1 }
316 },
317 "required": ["start", "end"]
318 }
319 ]
320 }
321 }
322 },
323 "required": ["path", "hunks"]
324 }
325 },
326 "type": {
327 "type": "string",
328 "enum": ["feat", "fix", "refactor", "docs", "test", "chore", "style", "perf", "build", "ci", "revert"],
329 "description": "Commit type for this group"
330 },
331 "scope": {
332 "type": "string",
333 "description": "Optional scope (module/component). Omit if broad."
334 },
335 "rationale": {
336 "type": "string",
337 "description": "Brief explanation of why these changes belong together"
338 },
339 "dependencies": {
340 "type": "array",
341 "description": "Indices of groups this depends on (e.g., [0, 1])",
342 "items": { "type": "integer" }
343 }
344 },
345 "required": ["changes", "type", "rationale", "dependencies"]
346 }
347 }
348 }),
349 required: vec!["groups".to_string()],
350 },
351 },
352 };
353
354 let prompt = COMPOSE_PROMPT
355 .replace("{STAT}", stat)
356 .replace("{DIFF}", diff)
357 .replace("{MAX_COMMITS}", &max_commits.to_string());
358
359 let request = ApiRequest {
360 model: config.model.clone(),
361 max_tokens: 8000,
362 temperature: config.temperature,
363 tools: vec![tool],
364 tool_choice: Some(
365 serde_json::json!({ "type": "function", "function": { "name": "create_compose_analysis" } }),
366 ),
367 messages: vec![Message { role: "user".to_string(), content: prompt }],
368 };
369
370 let request_builder = client
371 .post(format!("{}/chat/completions", config.api_base_url))
372 .header("content-type", "application/json");
373
374 let (status, response_text) =
375 crate::api::timed_send(request_builder.json(&request), "compose", &config.model).await?;
376 if !status.is_success() {
377 return Err(CommitGenError::ApiError { status: status.as_u16(), body: response_text });
378 }
379
380 let api_response: ApiResponse =
381 serde_json::from_str(&response_text).map_err(CommitGenError::JsonError)?;
382
383 if api_response.choices.is_empty() {
384 return Err(CommitGenError::Other(
385 "API returned empty response for compose analysis".to_string(),
386 ));
387 }
388
389 let mut last_parse_error: Option<CommitGenError> = None;
390
391 for choice in &api_response.choices {
392 let message = &choice.message;
393
394 if let Some(tool_call) = message.tool_calls.first()
395 && tool_call.function.name.ends_with("create_compose_analysis")
396 {
397 let args = &tool_call.function.arguments;
398 match parse_compose_groups_from_json(args) {
399 Ok(groups) => {
400 let dependency_order = compute_dependency_order(&groups)?;
401 return Ok(ComposeAnalysis { groups, dependency_order });
402 },
403 Err(err) => {
404 debug_failed_payload("tool_call", args, &err);
405 last_parse_error =
406 Some(CommitGenError::Other(format!("Failed to parse compose analysis: {err}")));
407 },
408 }
409 }
410
411 if let Some(function_call) = &message.function_call
412 && function_call.name == "create_compose_analysis"
413 {
414 let args = &function_call.arguments;
415 match parse_compose_groups_from_json(args) {
416 Ok(groups) => {
417 let dependency_order = compute_dependency_order(&groups)?;
418 return Ok(ComposeAnalysis { groups, dependency_order });
419 },
420 Err(err) => {
421 debug_failed_payload("function_call", args, &err);
422 last_parse_error =
423 Some(CommitGenError::Other(format!("Failed to parse compose analysis: {err}")));
424 },
425 }
426 }
427
428 if let Some(content) = &message.content {
429 match parse_compose_groups_from_content(content) {
430 Ok(groups) => {
431 let dependency_order = compute_dependency_order(&groups)?;
432 return Ok(ComposeAnalysis { groups, dependency_order });
433 },
434 Err(err) => last_parse_error = Some(err),
435 }
436 }
437 }
438
439 if let Some(err) = last_parse_error {
440 debug_compose_response(&api_response);
441 return Err(err);
442 }
443
444 debug_compose_response(&api_response);
445 Err(CommitGenError::Other("No compose analysis found in API response".to_string()))
446}
447
448fn debug_compose_response(response: &ApiResponse) {
449 let raw_preview = serde_json::to_string(response).map_or_else(
450 |_| "<failed to serialize response>".to_string(),
451 |json| {
452 if json.len() > 4000 {
453 format!("{}…", &json[..4000])
454 } else {
455 json
456 }
457 },
458 );
459
460 eprintln!(
461 "Compose debug: received {} choice(s) from analysis model\n raw: {}",
462 response.choices.len(),
463 raw_preview
464 );
465
466 for (idx, choice) in response.choices.iter().enumerate() {
467 let message = &choice.message;
468 let tool_call = message.tool_calls.first();
469 let tool_name = tool_call.map_or("<none>", |tc| tc.function.name.as_str());
470 let tool_args_len = tool_call.map_or(0, |tc| tc.function.arguments.len());
471
472 let function_call_name = message
473 .function_call
474 .as_ref()
475 .map_or("<none>", |fc| fc.name.as_str());
476 let function_call_args_len = message
477 .function_call
478 .as_ref()
479 .map_or(0, |fc| fc.arguments.len());
480
481 let content_preview = message.content.as_deref().map_or_else(
482 || "<none>".to_string(),
483 |c| {
484 let trimmed = c.trim();
485 if trimmed.len() > 200 {
486 format!("{}…", &trimmed[..200])
487 } else {
488 trimmed.to_string()
489 }
490 },
491 );
492
493 eprintln!(
494 "Choice #{idx}: tool_call={tool_name} (args {tool_args_len} chars), \
495 function_call={function_call_name} (args {function_call_args_len} chars), \
496 content_preview={content_preview}"
497 );
498 }
499}
500
501fn compute_dependency_order(groups: &[ChangeGroup]) -> Result<Vec<usize>> {
503 let n = groups.len();
504 let mut in_degree = vec![0; n];
505 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
506
507 for (i, group) in groups.iter().enumerate() {
509 for &dep in &group.dependencies {
510 if dep >= n {
511 return Err(CommitGenError::Other(format!(
512 "Invalid dependency index {dep} (max: {n})"
513 )));
514 }
515 adjacency[dep].push(i);
516 in_degree[i] += 1;
517 }
518 }
519
520 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
522 let mut order = Vec::new();
523
524 while let Some(node) = queue.pop() {
525 order.push(node);
526 for &neighbor in &adjacency[node] {
527 in_degree[neighbor] -= 1;
528 if in_degree[neighbor] == 0 {
529 queue.push(neighbor);
530 }
531 }
532 }
533
534 if order.len() != n {
535 return Err(CommitGenError::Other(
536 "Circular dependency detected in commit groups".to_string(),
537 ));
538 }
539
540 Ok(order)
541}
542
543fn validate_compose_groups(groups: &[ChangeGroup], full_diff: &str) -> Result<()> {
545 use std::collections::{HashMap, HashSet};
546
547 let mut diff_files: HashSet<String> = HashSet::new();
549 for line in full_diff.lines() {
550 if line.starts_with("diff --git")
551 && let Some(b_part) = line.split_whitespace().nth(3)
552 && let Some(path) = b_part.strip_prefix("b/")
553 {
554 diff_files.insert(path.to_string());
555 }
556 }
557
558 let mut covered_files: HashSet<String> = HashSet::new();
560 let mut file_coverage: HashMap<String, usize> = HashMap::new();
561
562 for (idx, group) in groups.iter().enumerate() {
563 for change in &group.changes {
564 covered_files.insert(change.path.clone());
565 *file_coverage.entry(change.path.clone()).or_insert(0) += 1;
566
567 for selector in &change.hunks {
569 match selector {
570 crate::types::HunkSelector::All => {},
571 crate::types::HunkSelector::Lines { start, end } => {
572 if start > end {
573 eprintln!(
574 "{}",
575 style::warning(&format!(
576 "{} Warning: Group {idx} has invalid line range {start}-{end} in {}",
577 style::icons::WARNING,
578 change.path
579 ))
580 );
581 }
582 if *start == 0 {
583 eprintln!(
584 "{}",
585 style::warning(&format!(
586 "{} Warning: Group {idx} has line range starting at 0 (should be \
587 1-indexed) in {}",
588 style::icons::WARNING,
589 change.path
590 ))
591 );
592 }
593 },
594 crate::types::HunkSelector::Search { pattern } => {
595 if pattern.is_empty() {
596 eprintln!(
597 "{}",
598 style::warning(&format!(
599 "{} Warning: Group {idx} has empty search pattern in {}",
600 style::icons::WARNING,
601 change.path
602 ))
603 );
604 }
605 },
606 }
607 }
608 }
609
610 for &dep in &group.dependencies {
612 if dep >= groups.len() {
613 return Err(CommitGenError::Other(format!(
614 "Group {idx} has invalid dependency {dep} (only {} groups total)",
615 groups.len()
616 )));
617 }
618 if dep == idx {
619 return Err(CommitGenError::Other(format!("Group {idx} depends on itself (circular)")));
620 }
621 }
622 }
623
624 let missing_files: Vec<&String> = diff_files.difference(&covered_files).collect();
626 if !missing_files.is_empty() {
627 eprintln!(
628 "{}",
629 style::warning(&format!(
630 "{} Warning: Groups don't cover all files. Missing:",
631 style::icons::WARNING
632 ))
633 );
634 for file in &missing_files {
635 eprintln!(" - {file}");
636 }
637 return Err(CommitGenError::Other(format!(
638 "Non-exhaustive groups: {} file(s) not covered",
639 missing_files.len()
640 )));
641 }
642
643 let duplicates: Vec<_> = file_coverage
645 .iter()
646 .filter(|&(_, count)| *count > 1)
647 .collect();
648
649 if !duplicates.is_empty() {
650 eprintln!(
651 "{}",
652 style::warning(&format!(
653 "{} Warning: Some files appear in multiple groups:",
654 style::icons::WARNING
655 ))
656 );
657 for (file, count) in duplicates {
658 eprintln!(" - {file} ({count} times)");
659 }
660 }
661
662 for (idx, group) in groups.iter().enumerate() {
664 if group.changes.is_empty() {
665 return Err(CommitGenError::Other(format!("Group {idx} has no changes")));
666 }
667 }
668
669 Ok(())
670}
671
672pub async fn execute_compose(
674 analysis: &ComposeAnalysis,
675 config: &CommitConfig,
676 args: &Args,
677) -> Result<Vec<String>> {
678 let dir = &args.dir;
679 let token_counter = create_token_counter(config);
680
681 println!("{}", style::info("Resetting staging area..."));
683 reset_staging(dir)?;
684
685 let baseline_diff_output = std::process::Command::new("git")
688 .args(["diff", "HEAD"])
689 .current_dir(dir)
690 .output()
691 .map_err(|e| CommitGenError::git(format!("Failed to get baseline diff: {e}")))?;
692
693 if !baseline_diff_output.status.success() {
694 let stderr = String::from_utf8_lossy(&baseline_diff_output.stderr);
695 return Err(CommitGenError::git(format!("git diff HEAD failed: {stderr}")));
696 }
697
698 let baseline_diff = String::from_utf8_lossy(&baseline_diff_output.stdout).to_string();
699
700 let mut commit_hashes = Vec::new();
701
702 for (idx, &group_idx) in analysis.dependency_order.iter().enumerate() {
703 let mut group = analysis.groups[group_idx].clone();
704 let dependency_only = group_affects_only_dependency_files(&group);
705
706 if dependency_only && group.commit_type.as_str() != "build" {
707 group.commit_type = CommitType::new("build")?;
708 }
709
710 println!(
711 "\n[{}/{}] Creating commit for group: {}",
712 idx + 1,
713 analysis.dependency_order.len(),
714 group.rationale
715 );
716 println!(" Type: {}", style::commit_type(&group.commit_type.to_string()));
717 if let Some(scope) = &group.scope {
718 println!(" Scope: {}", style::scope(&scope.to_string()));
719 }
720 let files: Vec<String> = group.changes.iter().map(|c| c.path.clone()).collect();
721 println!(" Files: {}", files.join(", "));
722
723 stage_group_changes(&group, dir, &baseline_diff)?;
725
726 let diff = get_git_diff(&Mode::Staged, None, dir, config)?;
728 let stat = get_git_stat(&Mode::Staged, None, dir, config)?;
729
730 let diff = if diff.len() > config.max_diff_length {
732 smart_truncate_diff(&diff, config.max_diff_length, config, &token_counter)
733 } else {
734 diff
735 };
736
737 println!(" {}", style::info("Generating commit message..."));
739 let debug_prefix = format!("compose-{}", idx + 1);
740 let ctx = AnalysisContext {
741 user_context: Some(&group.rationale),
742 recent_commits: None, common_scopes: None, project_context: None, debug_output: args.debug_output.as_deref(),
746 debug_prefix: Some(&debug_prefix),
747 };
748 let message_analysis =
749 generate_conventional_analysis(&stat, &diff, &config.model, "", &ctx, config).await?;
750
751 let analysis_body = message_analysis.body_texts();
752
753 let summary = crate::api::generate_summary_from_analysis(
754 &stat,
755 group.commit_type.as_str(),
756 group.scope.as_ref().map(|s| s.as_str()),
757 &analysis_body,
758 Some(&group.rationale),
759 config,
760 args.debug_output.as_deref(),
761 Some(&debug_prefix),
762 )
763 .await?;
764
765 let final_commit_type = if dependency_only {
766 CommitType::new("build")?
767 } else {
768 message_analysis.commit_type
769 };
770
771 let mut commit = ConventionalCommit {
772 commit_type: final_commit_type,
773 scope: message_analysis.scope,
774 summary,
775 body: analysis_body,
776 footers: vec![],
777 };
778
779 post_process_commit_message(&mut commit, config);
780
781 if let Err(e) = validate_commit_message(&commit, config) {
782 eprintln!(
783 " {}",
784 style::warning(&format!("{} Warning: Validation failed: {e}", style::icons::WARNING))
785 );
786 }
787
788 let formatted_message = format_commit_message(&commit);
789
790 println!(
791 " Message:\n{}",
792 formatted_message
793 .lines()
794 .take(3)
795 .collect::<Vec<_>>()
796 .join("\n")
797 );
798
799 if !args.compose_preview {
801 let sign = args.sign || config.gpg_sign;
802 let signoff = args.signoff || config.signoff;
803 git_commit(&formatted_message, false, dir, sign, signoff, args.skip_hooks, false)?;
804 let hash = get_head_hash(dir)?;
805 commit_hashes.push(hash);
806
807 if args.compose_test_after_each {
809 println!(" {}", style::info("Running tests..."));
810 let test_result = std::process::Command::new("cargo")
811 .arg("test")
812 .current_dir(dir)
813 .status();
814
815 if let Ok(status) = test_result {
816 if !status.success() {
817 return Err(CommitGenError::Other(format!(
818 "Tests failed after commit {idx}. Aborting."
819 )));
820 }
821 println!(" {}", style::success(&format!("{} Tests passed", style::icons::SUCCESS)));
822 }
823 }
824 }
825 }
826
827 Ok(commit_hashes)
828}
829
830pub async fn run_compose_mode(args: &Args, config: &CommitConfig) -> Result<()> {
832 let max_rounds = config.compose_max_rounds;
833
834 for round in 1..=max_rounds {
835 if round > 1 {
836 println!(
837 "\n{}",
838 style::section_header(&format!("Compose Round {round}/{max_rounds}"), 80)
839 );
840 } else {
841 println!("{}", style::section_header("Compose Mode", 80));
842 }
843 println!("{}\n", style::info("Analyzing all changes for intelligent splitting..."));
844
845 run_compose_round(args, config, round).await?;
846
847 if args.compose_preview {
849 break;
850 }
851
852 let remaining_diff_output = std::process::Command::new("git")
853 .args(["diff", "HEAD"])
854 .current_dir(&args.dir)
855 .output()
856 .map_err(|e| CommitGenError::git(format!("Failed to check remaining diff: {e}")))?;
857
858 if !remaining_diff_output.status.success() {
859 continue;
860 }
861
862 let remaining_diff = String::from_utf8_lossy(&remaining_diff_output.stdout);
863 if remaining_diff.trim().is_empty() {
864 println!(
865 "\n{}",
866 style::success(&format!(
867 "{} All changes committed successfully",
868 style::icons::SUCCESS
869 ))
870 );
871 break;
872 }
873
874 eprintln!(
875 "\n{}",
876 style::warning(&format!(
877 "{} Uncommitted changes remain after round {round}",
878 style::icons::WARNING
879 ))
880 );
881
882 let stat_output = std::process::Command::new("git")
883 .args(["diff", "HEAD", "--stat"])
884 .current_dir(&args.dir)
885 .output()
886 .ok();
887
888 if let Some(output) = stat_output
889 && output.status.success()
890 {
891 let stat = String::from_utf8_lossy(&output.stdout);
892 eprintln!("{stat}");
893 }
894
895 if round < max_rounds {
896 eprintln!("{}", style::info("Starting another compose round..."));
897 continue;
898 }
899 eprintln!(
900 "{}",
901 style::warning(&format!(
902 "Reached max rounds ({max_rounds}). Remaining changes need manual commit."
903 ))
904 );
905 }
906
907 Ok(())
908}
909
910async fn run_compose_round(args: &Args, config: &CommitConfig, round: usize) -> Result<()> {
912 let token_counter = create_token_counter(config);
913
914 let diff_staged = get_git_diff(&Mode::Staged, None, &args.dir, config).unwrap_or_default();
916 let diff_unstaged = get_git_diff(&Mode::Unstaged, None, &args.dir, config).unwrap_or_default();
917
918 let combined_diff = if diff_staged.is_empty() {
919 diff_unstaged
920 } else if diff_unstaged.is_empty() {
921 diff_staged
922 } else {
923 format!("{diff_staged}\n{diff_unstaged}")
924 };
925
926 if combined_diff.is_empty() {
927 return Err(CommitGenError::NoChanges { mode: "working directory".to_string() });
928 }
929
930 let stat_staged = get_git_stat(&Mode::Staged, None, &args.dir, config).unwrap_or_default();
931 let stat_unstaged = get_git_stat(&Mode::Unstaged, None, &args.dir, config).unwrap_or_default();
932
933 let combined_stat = if stat_staged.is_empty() {
934 stat_unstaged
935 } else if stat_unstaged.is_empty() {
936 stat_staged
937 } else {
938 format!("{stat_staged}\n{stat_unstaged}")
939 };
940
941 let original_diff = combined_diff.clone();
943
944 let diff = if combined_diff.len() > config.max_diff_length {
946 println!(
947 "{}",
948 style::warning(&format!(
949 "{} Applying smart truncation (diff size: {} characters)",
950 style::icons::WARNING,
951 combined_diff.len()
952 ))
953 );
954 smart_truncate_diff(&combined_diff, config.max_diff_length, config, &token_counter)
955 } else {
956 combined_diff
957 };
958
959 let max_commits = args.compose_max_commits.unwrap_or(3);
960
961 println!("{}", style::info(&format!("Analyzing changes (max {max_commits} commits)...")));
962 let analysis = analyze_for_compose(&diff, &combined_stat, config, max_commits).await?;
963
964 println!("{}", style::info("Validating groups..."));
966 validate_compose_groups(&analysis.groups, &original_diff)?;
967
968 println!("\n{}", style::section_header("Proposed Commit Groups", 80));
969 for (idx, &group_idx) in analysis.dependency_order.iter().enumerate() {
970 let mut group = analysis.groups[group_idx].clone();
971 if group_affects_only_dependency_files(&group) && group.commit_type.as_str() != "build" {
972 group.commit_type = CommitType::new("build")?;
973 }
974 println!(
975 "\n{}. [{}{}] {}",
976 idx + 1,
977 style::commit_type(&group.commit_type.to_string()),
978 group
979 .scope
980 .as_ref()
981 .map(|s| format!("({})", style::scope(&s.to_string())))
982 .unwrap_or_default(),
983 group.rationale
984 );
985 println!(" Changes:");
986 for change in &group.changes {
987 let is_all =
988 change.hunks.len() == 1 && matches!(&change.hunks[0], crate::types::HunkSelector::All);
989
990 if is_all {
991 println!(" - {} (all changes)", change.path);
992 } else {
993 let summary: Vec<String> = change
995 .hunks
996 .iter()
997 .map(|s| match s {
998 crate::types::HunkSelector::All => "all".to_string(),
999 crate::types::HunkSelector::Lines { start, end } => {
1000 format!("lines {start}-{end}")
1001 },
1002 crate::types::HunkSelector::Search { pattern } => {
1003 if pattern.len() > 20 {
1004 format!("search '{}'...", &pattern[..20])
1005 } else {
1006 format!("search '{pattern}'")
1007 }
1008 },
1009 })
1010 .collect();
1011 println!(" - {} ({})", change.path, summary.join(", "));
1012 }
1013 }
1014 if !group.dependencies.is_empty() {
1015 println!(" Depends on: {:?}", group.dependencies);
1016 }
1017 }
1018
1019 if args.compose_preview {
1020 println!(
1021 "\n{}",
1022 style::success(&format!(
1023 "{} Preview complete (use --compose without --compose-preview to execute)",
1024 style::icons::SUCCESS
1025 ))
1026 );
1027 return Ok(());
1028 }
1029
1030 println!("\n{}", style::info(&format!("Executing compose (round {round})...")));
1031 let hashes = execute_compose(&analysis, config, args).await?;
1032
1033 println!(
1034 "{}",
1035 style::success(&format!(
1036 "{} Round {round}: Created {} commit(s)",
1037 style::icons::SUCCESS,
1038 hashes.len()
1039 ))
1040 );
1041 Ok(())
1042}