1use std::{path::Path, sync::OnceLock, time::Duration};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 api::{AnalysisContext, generate_conventional_analysis},
7 config::CommitConfig,
8 diff::smart_truncate_diff,
9 error::{CommitGenError, Result},
10 git::{get_git_diff, get_git_stat, get_head_hash, git_commit},
11 normalization::{format_commit_message, post_process_commit_message},
12 patch::{reset_staging, stage_group_changes},
13 style,
14 tokens::create_token_counter,
15 types::{Args, ChangeGroup, CommitType, ComposeAnalysis, ConventionalCommit, Mode},
16 validation::validate_commit_message,
17};
18
19static CLIENT: OnceLock<reqwest::blocking::Client> = OnceLock::new();
20
21fn get_client() -> &'static reqwest::blocking::Client {
22 CLIENT.get_or_init(|| {
23 reqwest::blocking::Client::builder()
24 .timeout(Duration::from_secs(120))
25 .connect_timeout(Duration::from_secs(30))
26 .build()
27 .expect("Failed to build HTTP client")
28 })
29}
30
31#[derive(Debug, Serialize)]
32struct Message {
33 role: String,
34 content: String,
35}
36
37#[derive(Debug, Serialize, Deserialize)]
38struct FunctionParameters {
39 #[serde(rename = "type")]
40 param_type: String,
41 properties: serde_json::Value,
42 required: Vec<String>,
43}
44
45#[derive(Debug, Serialize, Deserialize)]
46struct Function {
47 name: String,
48 description: String,
49 parameters: FunctionParameters,
50}
51
52#[derive(Debug, Serialize, Deserialize)]
53struct Tool {
54 #[serde(rename = "type")]
55 tool_type: String,
56 function: Function,
57}
58
59#[derive(Debug, Serialize)]
60struct ApiRequest {
61 model: String,
62 max_tokens: u32,
63 temperature: f32,
64 tools: Vec<Tool>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 tool_choice: Option<serde_json::Value>,
67 messages: Vec<Message>,
68}
69
70#[derive(Debug, Deserialize, Serialize)]
71struct ToolCall {
72 function: FunctionCall,
73}
74
75#[derive(Debug, Deserialize, Serialize)]
76struct FunctionCall {
77 name: String,
78 arguments: String,
79}
80
81#[derive(Debug, Deserialize, Serialize)]
82struct Choice {
83 message: ResponseMessage,
84}
85
86#[derive(Debug, Deserialize, Serialize)]
87struct ResponseMessage {
88 #[serde(default)]
89 tool_calls: Vec<ToolCall>,
90 #[serde(default)]
91 content: Option<String>,
92 #[serde(default)]
93 function_call: Option<FunctionCall>,
94}
95
96#[derive(Debug, Deserialize, Serialize)]
97struct ApiResponse {
98 choices: Vec<Choice>,
99}
100
101const COMPOSE_PROMPT: &str = r#"Split this git diff into 1-{MAX_COMMITS} logical, atomic commit groups.
102
103## Git Stat
104{STAT}
105
106## Git Diff
107{DIFF}
108
109## Rules (CRITICAL)
1101. **EXHAUSTIVENESS**: You MUST account for 100% of changes. Every file and hunk in the diff above must appear in exactly one group.
1112. **Atomicity**: Each group represents ONE logical change (feat/fix/refactor/etc.) that leaves codebase working.
1123. **Prefer fewer groups**: Default to 1-3 commits. Only split when changes are truly independent/separable.
1134. **Group related**: Implementation + tests go together. Refactoring + usage updates go together.
1145. **Dependencies**: Use indices. Group 2 depending on Group 1 means: dependencies: [0].
1156. **Hunk selection** (IMPORTANT - Use line numbers, NOT hunk headers):
116 - If entire file → hunks: ["ALL"]
117 - If partial → specify line ranges: hunks: [{start: 10, end: 25}, {start: 50, end: 60}]
118 - Line numbers are 1-indexed from the ORIGINAL file (look at "-" lines in diff)
119 - You can specify multiple ranges for discontinuous changes in one file
120
121## Good Example (2 independent changes)
122groups: [
123 {
124 changes: [
125 {path: "src/api.rs", hunks: ["ALL"]},
126 {path: "tests/api_test.rs", hunks: [{start: 15, end: 23}]}
127 ],
128 type: "feat", scope: "api", rationale: "add user endpoint with test",
129 dependencies: []
130 },
131 {
132 changes: [
133 {path: "src/utils.rs", hunks: [{start: 42, end: 48}, {start: 100, end: 105}]}
134 ],
135 type: "fix", scope: "utils", rationale: "fix string parsing bug in two locations",
136 dependencies: []
137 }
138]
139
140## Bad Example (over-splitting)
141❌ DON'T create 6 commits for: function rename + call sites. That's ONE refactor group.
142❌ DON'T split tests from implementation unless they test something from a prior group.
143
144## Bad Example (incomplete)
145❌ DON'T forget files. If diff shows 5 files, groups must cover all 5.
146
147Return groups in dependency order."#;
148
149#[derive(Deserialize)]
150struct ComposeResult {
151 groups: Vec<ChangeGroup>,
152}
153
154fn parse_compose_groups_from_content(content: &str) -> Result<Vec<ChangeGroup>> {
155 fn try_parse(input: &str) -> Option<Vec<ChangeGroup>> {
156 let trimmed = input.trim();
157 if trimmed.is_empty() {
158 return None;
159 }
160
161 serde_json::from_str::<ComposeResult>(trimmed)
162 .map(|r| r.groups)
163 .ok()
164 }
165
166 let trimmed = content.trim();
167 if trimmed.is_empty() {
168 return Err(CommitGenError::Other(
169 "Model returned an empty compose analysis response".to_string(),
170 ));
171 }
172
173 if let Some(groups) = try_parse(trimmed) {
174 return Ok(groups);
175 }
176
177 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}'))
178 && end >= start
179 {
180 let candidate = &trimmed[start..=end];
181 if let Some(groups) = try_parse(candidate) {
182 return Ok(groups);
183 }
184 }
185
186 let segments: Vec<&str> = trimmed.split("```").collect();
187 for (idx, segment) in segments.iter().enumerate() {
188 if idx % 2 == 1 {
189 let block = segment.trim();
190 let mut lines = block.lines();
191 let first_line = lines.next().unwrap_or_default();
192
193 let mut owned_candidate: Option<String> = None;
194 let json_candidate = if first_line.trim_start().starts_with('{') {
195 block
196 } else {
197 let rest: String = lines.collect::<Vec<_>>().join("\n");
198 let trimmed_rest = rest.trim();
199 if trimmed_rest.is_empty() {
200 block
201 } else {
202 owned_candidate = Some(trimmed_rest.to_string());
203 owned_candidate.as_deref().unwrap()
204 }
205 };
206
207 if let Some(groups) = try_parse(json_candidate) {
208 return Ok(groups);
209 }
210 }
211 }
212
213 Err(CommitGenError::Other("Failed to parse compose analysis from model response".to_string()))
214}
215
216fn parse_compose_groups_from_json(
217 raw: &str,
218) -> std::result::Result<Vec<ChangeGroup>, serde_json::Error> {
219 let trimmed = raw.trim();
220 if trimmed.starts_with('[') {
221 serde_json::from_str::<Vec<ChangeGroup>>(trimmed)
222 } else {
223 serde_json::from_str::<ComposeResult>(trimmed).map(|r| r.groups)
224 }
225}
226
227fn debug_failed_payload(source: &str, payload: &str, err: &serde_json::Error) {
228 let preview = payload.trim();
229 let preview = if preview.len() > 2000 {
230 format!("{}…", &preview[..2000])
231 } else {
232 preview.to_string()
233 };
234 eprintln!("Compose debug: failed to parse {source} payload ({err}); preview: {preview}");
235}
236
237fn group_affects_only_dependency_files(group: &ChangeGroup) -> bool {
238 group
239 .changes
240 .iter()
241 .all(|change| is_dependency_manifest(&change.path))
242}
243
244fn is_dependency_manifest(path: &str) -> bool {
245 const DEP_MANIFESTS: &[&str] = &[
246 "Cargo.toml",
247 "Cargo.lock",
248 "package.json",
249 "package-lock.json",
250 "pnpm-lock.yaml",
251 "yarn.lock",
252 "bun.lock",
253 "bun.lockb",
254 "go.mod",
255 "go.sum",
256 "requirements.txt",
257 "Pipfile",
258 "Pipfile.lock",
259 "pyproject.toml",
260 "Gemfile",
261 "Gemfile.lock",
262 "composer.json",
263 "composer.lock",
264 "build.gradle",
265 "build.gradle.kts",
266 "gradle.properties",
267 "pom.xml",
268 ];
269
270 let path = Path::new(path);
271 let Some(file_name) = path.file_name().and_then(|s| s.to_str()) else {
272 return false;
273 };
274
275 if DEP_MANIFESTS.contains(&file_name) {
276 return true;
277 }
278
279 Path::new(file_name)
280 .extension()
281 .is_some_and(|ext| ext.eq_ignore_ascii_case("lock") || ext.eq_ignore_ascii_case("lockb"))
282}
283
284pub fn analyze_for_compose(
286 diff: &str,
287 stat: &str,
288 config: &CommitConfig,
289 max_commits: usize,
290) -> Result<ComposeAnalysis> {
291 let client = get_client();
292
293 let tool = Tool {
294 tool_type: "function".to_string(),
295 function: Function {
296 name: "create_compose_analysis".to_string(),
297 description: "Split changes into logical commit groups with dependencies".to_string(),
298 parameters: FunctionParameters {
299 param_type: "object".to_string(),
300 properties: serde_json::json!({
301 "groups": {
302 "type": "array",
303 "description": "Array of change groups in dependency order",
304 "items": {
305 "type": "object",
306 "properties": {
307 "changes": {
308 "type": "array",
309 "description": "File changes with specific hunks",
310 "items": {
311 "type": "object",
312 "properties": {
313 "path": {
314 "type": "string",
315 "description": "File path"
316 },
317 "hunks": {
318 "type": "array",
319 "description": "Either ['ALL'] for entire file, or line range objects: [{start: 10, end: 25}]. Line numbers are 1-indexed from ORIGINAL file.",
320 "items": {
321 "oneOf": [
322 { "type": "string", "const": "ALL" },
323 {
324 "type": "object",
325 "properties": {
326 "start": { "type": "integer", "minimum": 1 },
327 "end": { "type": "integer", "minimum": 1 }
328 },
329 "required": ["start", "end"]
330 }
331 ]
332 }
333 }
334 },
335 "required": ["path", "hunks"]
336 }
337 },
338 "type": {
339 "type": "string",
340 "enum": ["feat", "fix", "refactor", "docs", "test", "chore", "style", "perf", "build", "ci", "revert"],
341 "description": "Commit type for this group"
342 },
343 "scope": {
344 "type": "string",
345 "description": "Optional scope (module/component). Omit if broad."
346 },
347 "rationale": {
348 "type": "string",
349 "description": "Brief explanation of why these changes belong together"
350 },
351 "dependencies": {
352 "type": "array",
353 "description": "Indices of groups this depends on (e.g., [0, 1])",
354 "items": { "type": "integer" }
355 }
356 },
357 "required": ["changes", "type", "rationale", "dependencies"]
358 }
359 }
360 }),
361 required: vec!["groups".to_string()],
362 },
363 },
364 };
365
366 let prompt = COMPOSE_PROMPT
367 .replace("{STAT}", stat)
368 .replace("{DIFF}", diff)
369 .replace("{MAX_COMMITS}", &max_commits.to_string());
370
371 let request = ApiRequest {
372 model: config.analysis_model.clone(),
373 max_tokens: 8000,
374 temperature: config.temperature,
375 tools: vec![tool],
376 tool_choice: Some(
377 serde_json::json!({ "type": "function", "function": { "name": "create_compose_analysis" } }),
378 ),
379 messages: vec![Message { role: "user".to_string(), content: prompt }],
380 };
381
382 let response = client
383 .post(format!("{}/chat/completions", config.api_base_url))
384 .header("content-type", "application/json")
385 .json(&request)
386 .send()
387 .map_err(CommitGenError::HttpError)?;
388
389 let status = response.status();
390 if !status.is_success() {
391 let error_text = response
392 .text()
393 .unwrap_or_else(|_| "Unknown error".to_string());
394 return Err(CommitGenError::ApiError { status: status.as_u16(), body: error_text });
395 }
396
397 let api_response: ApiResponse = response.json().map_err(CommitGenError::HttpError)?;
398
399 if api_response.choices.is_empty() {
400 return Err(CommitGenError::Other(
401 "API returned empty response for compose analysis".to_string(),
402 ));
403 }
404
405 let mut last_parse_error: Option<CommitGenError> = None;
406
407 for choice in &api_response.choices {
408 let message = &choice.message;
409
410 if let Some(tool_call) = message.tool_calls.first()
411 && tool_call.function.name == "create_compose_analysis"
412 {
413 let args = &tool_call.function.arguments;
414 match parse_compose_groups_from_json(args) {
415 Ok(groups) => {
416 let dependency_order = compute_dependency_order(&groups)?;
417 return Ok(ComposeAnalysis { groups, dependency_order });
418 },
419 Err(err) => {
420 debug_failed_payload("tool_call", args, &err);
421 last_parse_error =
422 Some(CommitGenError::Other(format!("Failed to parse compose analysis: {err}")));
423 },
424 }
425 }
426
427 if let Some(function_call) = &message.function_call
428 && function_call.name == "create_compose_analysis"
429 {
430 let args = &function_call.arguments;
431 match parse_compose_groups_from_json(args) {
432 Ok(groups) => {
433 let dependency_order = compute_dependency_order(&groups)?;
434 return Ok(ComposeAnalysis { groups, dependency_order });
435 },
436 Err(err) => {
437 debug_failed_payload("function_call", args, &err);
438 last_parse_error =
439 Some(CommitGenError::Other(format!("Failed to parse compose analysis: {err}")));
440 },
441 }
442 }
443
444 if let Some(content) = &message.content {
445 match parse_compose_groups_from_content(content) {
446 Ok(groups) => {
447 let dependency_order = compute_dependency_order(&groups)?;
448 return Ok(ComposeAnalysis { groups, dependency_order });
449 },
450 Err(err) => last_parse_error = Some(err),
451 }
452 }
453 }
454
455 if let Some(err) = last_parse_error {
456 debug_compose_response(&api_response);
457 return Err(err);
458 }
459
460 debug_compose_response(&api_response);
461 Err(CommitGenError::Other("No compose analysis found in API response".to_string()))
462}
463
464fn debug_compose_response(response: &ApiResponse) {
465 let raw_preview = serde_json::to_string(response).map_or_else(
466 |_| "<failed to serialize response>".to_string(),
467 |json| {
468 if json.len() > 4000 {
469 format!("{}…", &json[..4000])
470 } else {
471 json
472 }
473 },
474 );
475
476 eprintln!(
477 "Compose debug: received {} choice(s) from analysis model\n raw: {}",
478 response.choices.len(),
479 raw_preview
480 );
481
482 for (idx, choice) in response.choices.iter().enumerate() {
483 let message = &choice.message;
484 let tool_call = message.tool_calls.first();
485 let tool_name = tool_call.map_or("<none>", |tc| tc.function.name.as_str());
486 let tool_args_len = tool_call.map_or(0, |tc| tc.function.arguments.len());
487
488 let function_call_name = message
489 .function_call
490 .as_ref()
491 .map_or("<none>", |fc| fc.name.as_str());
492 let function_call_args_len = message
493 .function_call
494 .as_ref()
495 .map_or(0, |fc| fc.arguments.len());
496
497 let content_preview = message.content.as_deref().map_or_else(
498 || "<none>".to_string(),
499 |c| {
500 let trimmed = c.trim();
501 if trimmed.len() > 200 {
502 format!("{}…", &trimmed[..200])
503 } else {
504 trimmed.to_string()
505 }
506 },
507 );
508
509 eprintln!(
510 "Choice #{idx}: tool_call={tool_name} (args {tool_args_len} chars), \
511 function_call={function_call_name} (args {function_call_args_len} chars), \
512 content_preview={content_preview}"
513 );
514 }
515}
516
517fn compute_dependency_order(groups: &[ChangeGroup]) -> Result<Vec<usize>> {
519 let n = groups.len();
520 let mut in_degree = vec![0; n];
521 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
522
523 for (i, group) in groups.iter().enumerate() {
525 for &dep in &group.dependencies {
526 if dep >= n {
527 return Err(CommitGenError::Other(format!(
528 "Invalid dependency index {dep} (max: {n})"
529 )));
530 }
531 adjacency[dep].push(i);
532 in_degree[i] += 1;
533 }
534 }
535
536 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
538 let mut order = Vec::new();
539
540 while let Some(node) = queue.pop() {
541 order.push(node);
542 for &neighbor in &adjacency[node] {
543 in_degree[neighbor] -= 1;
544 if in_degree[neighbor] == 0 {
545 queue.push(neighbor);
546 }
547 }
548 }
549
550 if order.len() != n {
551 return Err(CommitGenError::Other(
552 "Circular dependency detected in commit groups".to_string(),
553 ));
554 }
555
556 Ok(order)
557}
558
559fn validate_compose_groups(groups: &[ChangeGroup], full_diff: &str) -> Result<()> {
561 use std::collections::{HashMap, HashSet};
562
563 let mut diff_files: HashSet<String> = HashSet::new();
565 for line in full_diff.lines() {
566 if line.starts_with("diff --git")
567 && let Some(b_part) = line.split_whitespace().nth(3)
568 && let Some(path) = b_part.strip_prefix("b/")
569 {
570 diff_files.insert(path.to_string());
571 }
572 }
573
574 let mut covered_files: HashSet<String> = HashSet::new();
576 let mut file_coverage: HashMap<String, usize> = HashMap::new();
577
578 for (idx, group) in groups.iter().enumerate() {
579 for change in &group.changes {
580 covered_files.insert(change.path.clone());
581 *file_coverage.entry(change.path.clone()).or_insert(0) += 1;
582
583 for selector in &change.hunks {
585 match selector {
586 crate::types::HunkSelector::All => {},
587 crate::types::HunkSelector::Lines { start, end } => {
588 if start > end {
589 eprintln!(
590 "{}",
591 style::warning(&format!(
592 "{} Warning: Group {idx} has invalid line range {start}-{end} in {}",
593 style::icons::WARNING,
594 change.path
595 ))
596 );
597 }
598 if *start == 0 {
599 eprintln!(
600 "{}",
601 style::warning(&format!(
602 "{} Warning: Group {idx} has line range starting at 0 (should be \
603 1-indexed) in {}",
604 style::icons::WARNING,
605 change.path
606 ))
607 );
608 }
609 },
610 crate::types::HunkSelector::Search { pattern } => {
611 if pattern.is_empty() {
612 eprintln!(
613 "{}",
614 style::warning(&format!(
615 "{} Warning: Group {idx} has empty search pattern in {}",
616 style::icons::WARNING,
617 change.path
618 ))
619 );
620 }
621 },
622 }
623 }
624 }
625
626 for &dep in &group.dependencies {
628 if dep >= groups.len() {
629 return Err(CommitGenError::Other(format!(
630 "Group {idx} has invalid dependency {dep} (only {} groups total)",
631 groups.len()
632 )));
633 }
634 if dep == idx {
635 return Err(CommitGenError::Other(format!("Group {idx} depends on itself (circular)")));
636 }
637 }
638 }
639
640 let missing_files: Vec<&String> = diff_files.difference(&covered_files).collect();
642 if !missing_files.is_empty() {
643 eprintln!(
644 "{}",
645 style::warning(&format!(
646 "{} Warning: Groups don't cover all files. Missing:",
647 style::icons::WARNING
648 ))
649 );
650 for file in &missing_files {
651 eprintln!(" - {file}");
652 }
653 return Err(CommitGenError::Other(format!(
654 "Non-exhaustive groups: {} file(s) not covered",
655 missing_files.len()
656 )));
657 }
658
659 let duplicates: Vec<_> = file_coverage
661 .iter()
662 .filter(|&(_, count)| *count > 1)
663 .collect();
664
665 if !duplicates.is_empty() {
666 eprintln!(
667 "{}",
668 style::warning(&format!(
669 "{} Warning: Some files appear in multiple groups:",
670 style::icons::WARNING
671 ))
672 );
673 for (file, count) in duplicates {
674 eprintln!(" - {file} ({count} times)");
675 }
676 }
677
678 for (idx, group) in groups.iter().enumerate() {
680 if group.changes.is_empty() {
681 return Err(CommitGenError::Other(format!("Group {idx} has no changes")));
682 }
683 }
684
685 Ok(())
686}
687
688pub fn execute_compose(
690 analysis: &ComposeAnalysis,
691 config: &CommitConfig,
692 args: &Args,
693) -> Result<Vec<String>> {
694 let dir = &args.dir;
695 let token_counter = create_token_counter(config);
696
697 println!("{}", style::info("Resetting staging area..."));
699 reset_staging(dir)?;
700
701 let baseline_diff_output = std::process::Command::new("git")
704 .args(["diff", "HEAD"])
705 .current_dir(dir)
706 .output()
707 .map_err(|e| CommitGenError::GitError(format!("Failed to get baseline diff: {e}")))?;
708
709 if !baseline_diff_output.status.success() {
710 let stderr = String::from_utf8_lossy(&baseline_diff_output.stderr);
711 return Err(CommitGenError::GitError(format!("git diff HEAD failed: {stderr}")));
712 }
713
714 let baseline_diff = String::from_utf8_lossy(&baseline_diff_output.stdout).to_string();
715
716 let mut commit_hashes = Vec::new();
717
718 for (idx, &group_idx) in analysis.dependency_order.iter().enumerate() {
719 let mut group = analysis.groups[group_idx].clone();
720 let dependency_only = group_affects_only_dependency_files(&group);
721
722 if dependency_only && group.commit_type.as_str() != "build" {
723 group.commit_type = CommitType::new("build")?;
724 }
725
726 println!(
727 "\n[{}/{}] Creating commit for group: {}",
728 idx + 1,
729 analysis.dependency_order.len(),
730 group.rationale
731 );
732 println!(" Type: {}", style::commit_type(&group.commit_type.to_string()));
733 if let Some(scope) = &group.scope {
734 println!(" Scope: {}", style::scope(&scope.to_string()));
735 }
736 let files: Vec<String> = group.changes.iter().map(|c| c.path.clone()).collect();
737 println!(" Files: {}", files.join(", "));
738
739 stage_group_changes(&group, dir, &baseline_diff)?;
741
742 let diff = get_git_diff(&Mode::Staged, None, dir, config)?;
744 let stat = get_git_stat(&Mode::Staged, None, dir, config)?;
745
746 let diff = if diff.len() > config.max_diff_length {
748 smart_truncate_diff(&diff, config.max_diff_length, config, &token_counter)
749 } else {
750 diff
751 };
752
753 println!(" {}", style::info("Generating commit message..."));
755 let ctx = AnalysisContext {
756 user_context: Some(&group.rationale),
757 recent_commits: None, common_scopes: None, project_context: None, };
761 let message_analysis =
762 generate_conventional_analysis(&stat, &diff, &config.analysis_model, "", &ctx, config)?;
763
764 let analysis_body = message_analysis.body_texts();
765
766 let summary = crate::api::generate_summary_from_analysis(
767 &stat,
768 group.commit_type.as_str(),
769 group.scope.as_ref().map(|s| s.as_str()),
770 &analysis_body,
771 Some(&group.rationale),
772 config,
773 )?;
774
775 let final_commit_type = if dependency_only {
776 CommitType::new("build")?
777 } else {
778 message_analysis.commit_type
779 };
780
781 let mut commit = ConventionalCommit {
782 commit_type: final_commit_type,
783 scope: message_analysis.scope,
784 summary,
785 body: analysis_body,
786 footers: vec![],
787 };
788
789 post_process_commit_message(&mut commit, config);
790
791 if let Err(e) = validate_commit_message(&commit, config) {
792 eprintln!(
793 " {}",
794 style::warning(&format!("{} Warning: Validation failed: {e}", style::icons::WARNING))
795 );
796 }
797
798 let formatted_message = format_commit_message(&commit);
799
800 println!(
801 " Message:\n{}",
802 formatted_message
803 .lines()
804 .take(3)
805 .collect::<Vec<_>>()
806 .join("\n")
807 );
808
809 if !args.compose_preview {
811 let sign = args.sign || config.gpg_sign;
812 git_commit(&formatted_message, false, dir, sign, args.skip_hooks)?;
813 let hash = get_head_hash(dir)?;
814 commit_hashes.push(hash);
815
816 if args.compose_test_after_each {
818 println!(" {}", style::info("Running tests..."));
819 let test_result = std::process::Command::new("cargo")
820 .arg("test")
821 .current_dir(dir)
822 .status();
823
824 if let Ok(status) = test_result {
825 if !status.success() {
826 return Err(CommitGenError::Other(format!(
827 "Tests failed after commit {idx}. Aborting."
828 )));
829 }
830 println!(" {}", style::success(&format!("{} Tests passed", style::icons::SUCCESS)));
831 }
832 }
833 }
834 }
835
836 Ok(commit_hashes)
837}
838
839pub fn run_compose_mode(args: &Args, config: &CommitConfig) -> Result<()> {
841 let max_rounds = config.compose_max_rounds;
842
843 for round in 1..=max_rounds {
844 if round > 1 {
845 println!(
846 "\n{}",
847 style::section_header(&format!("Compose Round {round}/{max_rounds}"), 80)
848 );
849 } else {
850 println!("{}", style::section_header("Compose Mode", 80));
851 }
852 println!("{}\n", style::info("Analyzing all changes for intelligent splitting..."));
853
854 run_compose_round(args, config, round)?;
855
856 if args.compose_preview {
858 break;
859 }
860
861 let remaining_diff_output = std::process::Command::new("git")
862 .args(["diff", "HEAD"])
863 .current_dir(&args.dir)
864 .output()
865 .map_err(|e| CommitGenError::GitError(format!("Failed to check remaining diff: {e}")))?;
866
867 if !remaining_diff_output.status.success() {
868 continue;
869 }
870
871 let remaining_diff = String::from_utf8_lossy(&remaining_diff_output.stdout);
872 if remaining_diff.trim().is_empty() {
873 println!(
874 "\n{}",
875 style::success(&format!(
876 "{} All changes committed successfully",
877 style::icons::SUCCESS
878 ))
879 );
880 break;
881 }
882
883 eprintln!(
884 "\n{}",
885 style::warning(&format!(
886 "{} Uncommitted changes remain after round {round}",
887 style::icons::WARNING
888 ))
889 );
890
891 let stat_output = std::process::Command::new("git")
892 .args(["diff", "HEAD", "--stat"])
893 .current_dir(&args.dir)
894 .output()
895 .ok();
896
897 if let Some(output) = stat_output
898 && output.status.success()
899 {
900 let stat = String::from_utf8_lossy(&output.stdout);
901 eprintln!("{stat}");
902 }
903
904 if round < max_rounds {
905 eprintln!("{}", style::info("Starting another compose round..."));
906 continue;
907 }
908 eprintln!(
909 "{}",
910 style::warning(&format!(
911 "Reached max rounds ({max_rounds}). Remaining changes need manual commit."
912 ))
913 );
914 }
915
916 Ok(())
917}
918
919fn run_compose_round(args: &Args, config: &CommitConfig, round: usize) -> Result<()> {
921 let token_counter = create_token_counter(config);
922
923 let diff_staged = get_git_diff(&Mode::Staged, None, &args.dir, config).unwrap_or_default();
925 let diff_unstaged = get_git_diff(&Mode::Unstaged, None, &args.dir, config).unwrap_or_default();
926
927 let combined_diff = if diff_staged.is_empty() {
928 diff_unstaged
929 } else if diff_unstaged.is_empty() {
930 diff_staged
931 } else {
932 format!("{diff_staged}\n{diff_unstaged}")
933 };
934
935 if combined_diff.is_empty() {
936 return Err(CommitGenError::NoChanges { mode: "working directory".to_string() });
937 }
938
939 let stat_staged = get_git_stat(&Mode::Staged, None, &args.dir, config).unwrap_or_default();
940 let stat_unstaged = get_git_stat(&Mode::Unstaged, None, &args.dir, config).unwrap_or_default();
941
942 let combined_stat = if stat_staged.is_empty() {
943 stat_unstaged
944 } else if stat_unstaged.is_empty() {
945 stat_staged
946 } else {
947 format!("{stat_staged}\n{stat_unstaged}")
948 };
949
950 let original_diff = combined_diff.clone();
952
953 let diff = if combined_diff.len() > config.max_diff_length {
955 println!(
956 "{}",
957 style::warning(&format!(
958 "{} Applying smart truncation (diff size: {} characters)",
959 style::icons::WARNING,
960 combined_diff.len()
961 ))
962 );
963 smart_truncate_diff(&combined_diff, config.max_diff_length, config, &token_counter)
964 } else {
965 combined_diff
966 };
967
968 let max_commits = args.compose_max_commits.unwrap_or(3);
969
970 println!("{}", style::info(&format!("Analyzing changes (max {max_commits} commits)...")));
971 let analysis = analyze_for_compose(&diff, &combined_stat, config, max_commits)?;
972
973 println!("{}", style::info("Validating groups..."));
975 validate_compose_groups(&analysis.groups, &original_diff)?;
976
977 println!("\n{}", style::section_header("Proposed Commit Groups", 80));
978 for (idx, &group_idx) in analysis.dependency_order.iter().enumerate() {
979 let mut group = analysis.groups[group_idx].clone();
980 if group_affects_only_dependency_files(&group) && group.commit_type.as_str() != "build" {
981 group.commit_type = CommitType::new("build")?;
982 }
983 println!(
984 "\n{}. [{}{}] {}",
985 idx + 1,
986 style::commit_type(&group.commit_type.to_string()),
987 group
988 .scope
989 .as_ref()
990 .map(|s| format!("({})", style::scope(&s.to_string())))
991 .unwrap_or_default(),
992 group.rationale
993 );
994 println!(" Changes:");
995 for change in &group.changes {
996 let is_all =
997 change.hunks.len() == 1 && matches!(&change.hunks[0], crate::types::HunkSelector::All);
998
999 if is_all {
1000 println!(" - {} (all changes)", change.path);
1001 } else {
1002 let summary: Vec<String> = change
1004 .hunks
1005 .iter()
1006 .map(|s| match s {
1007 crate::types::HunkSelector::All => "all".to_string(),
1008 crate::types::HunkSelector::Lines { start, end } => {
1009 format!("lines {start}-{end}")
1010 },
1011 crate::types::HunkSelector::Search { pattern } => {
1012 if pattern.len() > 20 {
1013 format!("search '{}'...", &pattern[..20])
1014 } else {
1015 format!("search '{pattern}'")
1016 }
1017 },
1018 })
1019 .collect();
1020 println!(" - {} ({})", change.path, summary.join(", "));
1021 }
1022 }
1023 if !group.dependencies.is_empty() {
1024 println!(" Depends on: {:?}", group.dependencies);
1025 }
1026 }
1027
1028 if args.compose_preview {
1029 println!(
1030 "\n{}",
1031 style::success(&format!(
1032 "{} Preview complete (use --compose without --compose-preview to execute)",
1033 style::icons::SUCCESS
1034 ))
1035 );
1036 return Ok(());
1037 }
1038
1039 println!("\n{}", style::info(&format!("Executing compose (round {round})...")));
1040 let hashes = execute_compose(&analysis, config, args)?;
1041
1042 println!(
1043 "{}",
1044 style::success(&format!(
1045 "{} Round {round}: Created {} commit(s)",
1046 style::icons::SUCCESS,
1047 hashes.len()
1048 ))
1049 );
1050 Ok(())
1051}