1use std::path::Path;
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 api::{AnalysisContext, generate_conventional_analysis, openai_prompt_cache_key},
7 config::CommitConfig,
8 diff::smart_truncate_diff,
9 error::{CommitGenError, Result},
10 git::{get_git_diff, get_git_stat, get_head_hash, git_commit},
11 normalization::{format_commit_message, post_process_commit_message},
12 patch::{reset_staging, stage_group_changes},
13 style,
14 tokens::create_token_counter,
15 types::{Args, ChangeGroup, CommitType, ComposeAnalysis, ConventionalCommit, Mode},
16 validation::validate_commit_message,
17};
18
19#[derive(Debug, Serialize)]
20struct Message {
21 role: String,
22 content: String,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26struct FunctionParameters {
27 #[serde(rename = "type")]
28 param_type: String,
29 properties: serde_json::Value,
30 required: Vec<String>,
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34struct Function {
35 name: String,
36 description: String,
37 parameters: FunctionParameters,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
41struct Tool {
42 #[serde(rename = "type")]
43 tool_type: String,
44 function: Function,
45}
46
47#[derive(Debug, Serialize)]
48struct ApiRequest {
49 model: String,
50 max_tokens: u32,
51 temperature: f32,
52 tools: Vec<Tool>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 tool_choice: Option<serde_json::Value>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 prompt_cache_key: Option<String>,
57 messages: Vec<Message>,
58}
59
60#[derive(Debug, Deserialize, Serialize)]
61struct ToolCall {
62 function: FunctionCall,
63}
64
65#[derive(Debug, Deserialize, Serialize)]
66struct FunctionCall {
67 name: String,
68 arguments: String,
69}
70
71#[derive(Debug, Deserialize, Serialize)]
72struct Choice {
73 message: ResponseMessage,
74}
75
76#[derive(Debug, Deserialize, Serialize)]
77struct ResponseMessage {
78 #[serde(default)]
79 tool_calls: Vec<ToolCall>,
80 #[serde(default)]
81 content: Option<String>,
82 #[serde(default)]
83 function_call: Option<FunctionCall>,
84}
85
86#[derive(Debug, Deserialize, Serialize)]
87struct ApiResponse {
88 choices: Vec<Choice>,
89}
90
91const COMPOSE_SYSTEM_PROMPT: &str = r#"You split git diffs into logical, atomic commit groups for compose mode.
92
93## Rules (CRITICAL)
941. **EXHAUSTIVENESS**: You MUST account for 100% of changes. Every file and hunk in the provided diff must appear in exactly one group.
952. **Atomicity**: Each group represents ONE logical change (feat/fix/refactor/etc.) that leaves codebase working.
963. **Prefer fewer groups**: Default to 1-3 commits. Only split when changes are truly independent/separable.
974. **Group related**: Implementation + tests go together. Refactoring + usage updates go together.
985. **Dependencies**: Use indices. Group 2 depending on Group 1 means: dependencies: [0].
996. **Hunk selection** (IMPORTANT - Use line numbers, NOT hunk headers):
100 - If entire file → hunks: ["ALL"]
101 - If partial → specify line ranges: hunks: [{start: 10, end: 25}, {start: 50, end: 60}]
102 - Line numbers are 1-indexed from the ORIGINAL file (look at "-" lines in diff)
103 - You can specify multiple ranges for discontinuous changes in one file
104
105## Good Example (2 independent changes)
106groups: [
107 {
108 changes: [
109 {path: "src/api.rs", hunks: ["ALL"]},
110 {path: "tests/api_test.rs", hunks: [{start: 15, end: 23}]}
111 ],
112 type: "feat", scope: "api", rationale: "add user endpoint with test",
113 dependencies: []
114 },
115 {
116 changes: [
117 {path: "src/utils.rs", hunks: [{start: 42, end: 48}, {start: 100, end: 105}]}
118 ],
119 type: "fix", scope: "utils", rationale: "fix string parsing bug in two locations",
120 dependencies: []
121 }
122]
123
124## Bad Example (over-splitting)
125❌ DON'T create 6 commits for: function rename + call sites. That's ONE refactor group.
126❌ DON'T split tests from implementation unless they test something from a prior group.
127
128## Bad Example (incomplete)
129❌ DON'T forget files. If diff shows 5 files, groups must cover all 5.
130
131Return groups in dependency order."#;
132
133const COMPOSE_USER_PROMPT: &str = r"Split this git diff into 1-{MAX_COMMITS} logical, atomic commit groups.
134
135## Git Stat
136{STAT}
137
138## Git Diff
139{DIFF}";
140
141#[derive(Deserialize)]
142struct ComposeResult {
143 groups: Vec<ChangeGroup>,
144}
145
146fn parse_compose_groups_from_content(content: &str) -> Result<Vec<ChangeGroup>> {
147 fn try_parse(input: &str) -> Option<Vec<ChangeGroup>> {
148 let trimmed = input.trim();
149 if trimmed.is_empty() {
150 return None;
151 }
152
153 serde_json::from_str::<ComposeResult>(trimmed)
154 .map(|r| r.groups)
155 .ok()
156 }
157
158 let trimmed = content.trim();
159 if trimmed.is_empty() {
160 return Err(CommitGenError::Other(
161 "Model returned an empty compose analysis response".to_string(),
162 ));
163 }
164
165 if let Some(groups) = try_parse(trimmed) {
166 return Ok(groups);
167 }
168
169 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}'))
170 && end >= start
171 {
172 let candidate = &trimmed[start..=end];
173 if let Some(groups) = try_parse(candidate) {
174 return Ok(groups);
175 }
176 }
177
178 let segments: Vec<&str> = trimmed.split("```").collect();
179 for (idx, segment) in segments.iter().enumerate() {
180 if idx % 2 == 1 {
181 let block = segment.trim();
182 let mut lines = block.lines();
183 let first_line = lines.next().unwrap_or_default();
184
185 let rest_owned: String;
186 let json_candidate = if first_line.trim_start().starts_with('{') {
187 block
188 } else {
189 let rest: String = lines.collect::<Vec<_>>().join("\n");
190 let trimmed_rest = rest.trim();
191 if trimmed_rest.is_empty() {
192 block
193 } else {
194 rest_owned = trimmed_rest.to_string();
195 &rest_owned
196 }
197 };
198
199 if let Some(groups) = try_parse(json_candidate) {
200 return Ok(groups);
201 }
202 }
203 }
204
205 Err(CommitGenError::Other("Failed to parse compose analysis from model response".to_string()))
206}
207
208fn parse_compose_groups_from_json(
209 raw: &str,
210) -> std::result::Result<Vec<ChangeGroup>, serde_json::Error> {
211 let trimmed = raw.trim();
212 if trimmed.starts_with('[') {
213 serde_json::from_str::<Vec<ChangeGroup>>(trimmed)
214 } else {
215 serde_json::from_str::<ComposeResult>(trimmed).map(|r| r.groups)
216 }
217}
218
219fn debug_failed_payload(source: &str, payload: &str, err: &serde_json::Error) {
220 let preview = payload.trim();
221 let preview = if preview.len() > 2000 {
222 format!("{}…", &preview[..2000])
223 } else {
224 preview.to_string()
225 };
226 eprintln!("Compose debug: failed to parse {source} payload ({err}); preview: {preview}");
227}
228
229fn group_affects_only_dependency_files(group: &ChangeGroup) -> bool {
230 group
231 .changes
232 .iter()
233 .all(|change| is_dependency_manifest(&change.path))
234}
235
236fn is_dependency_manifest(path: &str) -> bool {
237 const DEP_MANIFESTS: &[&str] = &[
238 "Cargo.toml",
239 "Cargo.lock",
240 "package.json",
241 "package-lock.json",
242 "pnpm-lock.yaml",
243 "yarn.lock",
244 "bun.lock",
245 "bun.lockb",
246 "go.mod",
247 "go.sum",
248 "requirements.txt",
249 "Pipfile",
250 "Pipfile.lock",
251 "pyproject.toml",
252 "Gemfile",
253 "Gemfile.lock",
254 "composer.json",
255 "composer.lock",
256 "build.gradle",
257 "build.gradle.kts",
258 "gradle.properties",
259 "pom.xml",
260 ];
261
262 let path = Path::new(path);
263 let Some(file_name) = path.file_name().and_then(|s| s.to_str()) else {
264 return false;
265 };
266
267 if DEP_MANIFESTS.contains(&file_name) {
268 return true;
269 }
270
271 Path::new(file_name)
272 .extension()
273 .is_some_and(|ext| ext.eq_ignore_ascii_case("lock") || ext.eq_ignore_ascii_case("lockb"))
274}
275
276pub async fn analyze_for_compose(
278 diff: &str,
279 stat: &str,
280 config: &CommitConfig,
281 max_commits: usize,
282) -> Result<ComposeAnalysis> {
283 let client = crate::api::get_client(config);
284
285 let tool = Tool {
286 tool_type: "function".to_string(),
287 function: Function {
288 name: "create_compose_analysis".to_string(),
289 description: "Split changes into logical commit groups with dependencies".to_string(),
290 parameters: FunctionParameters {
291 param_type: "object".to_string(),
292 properties: serde_json::json!({
293 "groups": {
294 "type": "array",
295 "description": "Array of change groups in dependency order",
296 "items": {
297 "type": "object",
298 "properties": {
299 "changes": {
300 "type": "array",
301 "description": "File changes with specific hunks",
302 "items": {
303 "type": "object",
304 "properties": {
305 "path": {
306 "type": "string",
307 "description": "File path"
308 },
309 "hunks": {
310 "type": "array",
311 "description": "Either ['ALL'] for entire file, or line range objects: [{start: 10, end: 25}]. Line numbers are 1-indexed from ORIGINAL file.",
312 "items": {
313 "oneOf": [
314 { "type": "string", "const": "ALL" },
315 {
316 "type": "object",
317 "properties": {
318 "start": { "type": "integer", "minimum": 1 },
319 "end": { "type": "integer", "minimum": 1 }
320 },
321 "required": ["start", "end"]
322 }
323 ]
324 }
325 }
326 },
327 "required": ["path", "hunks"]
328 }
329 },
330 "type": {
331 "type": "string",
332 "enum": ["feat", "fix", "refactor", "docs", "test", "chore", "style", "perf", "build", "ci", "revert"],
333 "description": "Commit type for this group"
334 },
335 "scope": {
336 "type": "string",
337 "description": "Optional scope (module/component). Omit if broad."
338 },
339 "rationale": {
340 "type": "string",
341 "description": "Brief explanation of why these changes belong together"
342 },
343 "dependencies": {
344 "type": "array",
345 "description": "Indices of groups this depends on (e.g., [0, 1])",
346 "items": { "type": "integer" }
347 }
348 },
349 "required": ["changes", "type", "rationale", "dependencies"]
350 }
351 }
352 }),
353 required: vec!["groups".to_string()],
354 },
355 },
356 };
357
358 let user_prompt = COMPOSE_USER_PROMPT
359 .replace("{STAT}", stat)
360 .replace("{DIFF}", diff)
361 .replace("{MAX_COMMITS}", &max_commits.to_string());
362 let prompt_cache_key =
363 openai_prompt_cache_key(config, &config.model, "compose", "default", COMPOSE_SYSTEM_PROMPT);
364
365 let request = ApiRequest {
366 model: config.model.clone(),
367 max_tokens: 8000,
368 temperature: config.temperature,
369 tools: vec![tool],
370 tool_choice: Some(
371 serde_json::json!({ "type": "function", "function": { "name": "create_compose_analysis" } }),
372 ),
373 prompt_cache_key,
374 messages: vec![
375 Message { role: "system".to_string(), content: COMPOSE_SYSTEM_PROMPT.to_string() },
376 Message { role: "user".to_string(), content: user_prompt },
377 ],
378 };
379
380 let mut request_builder = client
381 .post(format!("{}/chat/completions", config.api_base_url))
382 .header("content-type", "application/json");
383
384 if let Some(api_key) = &config.api_key {
385 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
386 }
387
388 let (status, response_text) =
389 crate::api::timed_send(request_builder.json(&request), "compose", &config.model).await?;
390 if !status.is_success() {
391 return Err(CommitGenError::ApiError { status: status.as_u16(), body: response_text });
392 }
393
394 let api_response: ApiResponse =
395 serde_json::from_str(&response_text).map_err(CommitGenError::JsonError)?;
396
397 if api_response.choices.is_empty() {
398 return Err(CommitGenError::Other(
399 "API returned empty response for compose analysis".to_string(),
400 ));
401 }
402
403 let mut last_parse_error: Option<CommitGenError> = None;
404
405 for choice in &api_response.choices {
406 let message = &choice.message;
407
408 if let Some(tool_call) = message.tool_calls.first()
409 && tool_call.function.name.ends_with("create_compose_analysis")
410 {
411 let args = &tool_call.function.arguments;
412 match parse_compose_groups_from_json(args) {
413 Ok(groups) => {
414 let dependency_order = compute_dependency_order(&groups)?;
415 return Ok(ComposeAnalysis { groups, dependency_order });
416 },
417 Err(err) => {
418 debug_failed_payload("tool_call", args, &err);
419 last_parse_error =
420 Some(CommitGenError::Other(format!("Failed to parse compose analysis: {err}")));
421 },
422 }
423 }
424
425 if let Some(function_call) = &message.function_call
426 && function_call.name == "create_compose_analysis"
427 {
428 let args = &function_call.arguments;
429 match parse_compose_groups_from_json(args) {
430 Ok(groups) => {
431 let dependency_order = compute_dependency_order(&groups)?;
432 return Ok(ComposeAnalysis { groups, dependency_order });
433 },
434 Err(err) => {
435 debug_failed_payload("function_call", args, &err);
436 last_parse_error =
437 Some(CommitGenError::Other(format!("Failed to parse compose analysis: {err}")));
438 },
439 }
440 }
441
442 if let Some(content) = &message.content {
443 match parse_compose_groups_from_content(content) {
444 Ok(groups) => {
445 let dependency_order = compute_dependency_order(&groups)?;
446 return Ok(ComposeAnalysis { groups, dependency_order });
447 },
448 Err(err) => last_parse_error = Some(err),
449 }
450 }
451 }
452
453 if let Some(err) = last_parse_error {
454 debug_compose_response(&api_response);
455 return Err(err);
456 }
457
458 debug_compose_response(&api_response);
459 Err(CommitGenError::Other("No compose analysis found in API response".to_string()))
460}
461
462fn debug_compose_response(response: &ApiResponse) {
463 let raw_preview = serde_json::to_string(response).map_or_else(
464 |_| "<failed to serialize response>".to_string(),
465 |json| {
466 if json.len() > 4000 {
467 format!("{}…", &json[..4000])
468 } else {
469 json
470 }
471 },
472 );
473
474 eprintln!(
475 "Compose debug: received {} choice(s) from analysis model\n raw: {}",
476 response.choices.len(),
477 raw_preview
478 );
479
480 for (idx, choice) in response.choices.iter().enumerate() {
481 let message = &choice.message;
482 let tool_call = message.tool_calls.first();
483 let tool_name = tool_call.map_or("<none>", |tc| tc.function.name.as_str());
484 let tool_args_len = tool_call.map_or(0, |tc| tc.function.arguments.len());
485
486 let function_call_name = message
487 .function_call
488 .as_ref()
489 .map_or("<none>", |fc| fc.name.as_str());
490 let function_call_args_len = message
491 .function_call
492 .as_ref()
493 .map_or(0, |fc| fc.arguments.len());
494
495 let content_preview = message.content.as_deref().map_or_else(
496 || "<none>".to_string(),
497 |c| {
498 let trimmed = c.trim();
499 if trimmed.len() > 200 {
500 format!("{}…", &trimmed[..200])
501 } else {
502 trimmed.to_string()
503 }
504 },
505 );
506
507 eprintln!(
508 "Choice #{idx}: tool_call={tool_name} (args {tool_args_len} chars), \
509 function_call={function_call_name} (args {function_call_args_len} chars), \
510 content_preview={content_preview}"
511 );
512 }
513}
514
515fn compute_dependency_order(groups: &[ChangeGroup]) -> Result<Vec<usize>> {
517 let n = groups.len();
518 let mut in_degree = vec![0; n];
519 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
520
521 for (i, group) in groups.iter().enumerate() {
523 for &dep in &group.dependencies {
524 if dep >= n {
525 return Err(CommitGenError::Other(format!(
526 "Invalid dependency index {dep} (max: {n})"
527 )));
528 }
529 adjacency[dep].push(i);
530 in_degree[i] += 1;
531 }
532 }
533
534 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
536 let mut order = Vec::new();
537
538 while let Some(node) = queue.pop() {
539 order.push(node);
540 for &neighbor in &adjacency[node] {
541 in_degree[neighbor] -= 1;
542 if in_degree[neighbor] == 0 {
543 queue.push(neighbor);
544 }
545 }
546 }
547
548 if order.len() != n {
549 return Err(CommitGenError::Other(
550 "Circular dependency detected in commit groups".to_string(),
551 ));
552 }
553
554 Ok(order)
555}
556
557fn validate_compose_groups(groups: &[ChangeGroup], full_diff: &str) -> Result<()> {
559 use std::collections::{HashMap, HashSet};
560
561 let mut diff_files: HashSet<String> = HashSet::new();
563 for line in full_diff.lines() {
564 if line.starts_with("diff --git")
565 && let Some(b_part) = line.split_whitespace().nth(3)
566 && let Some(path) = b_part.strip_prefix("b/")
567 {
568 diff_files.insert(path.to_string());
569 }
570 }
571
572 let mut covered_files: HashSet<String> = HashSet::new();
574 let mut file_coverage: HashMap<String, usize> = HashMap::new();
575
576 for (idx, group) in groups.iter().enumerate() {
577 for change in &group.changes {
578 covered_files.insert(change.path.clone());
579 *file_coverage.entry(change.path.clone()).or_insert(0) += 1;
580
581 for selector in &change.hunks {
583 match selector {
584 crate::types::HunkSelector::All => {},
585 crate::types::HunkSelector::Lines { start, end } => {
586 if start > end {
587 eprintln!(
588 "{}",
589 style::warning(&format!(
590 "{} Warning: Group {idx} has invalid line range {start}-{end} in {}",
591 style::icons::WARNING,
592 change.path
593 ))
594 );
595 }
596 if *start == 0 {
597 eprintln!(
598 "{}",
599 style::warning(&format!(
600 "{} Warning: Group {idx} has line range starting at 0 (should be \
601 1-indexed) in {}",
602 style::icons::WARNING,
603 change.path
604 ))
605 );
606 }
607 },
608 crate::types::HunkSelector::Search { pattern } => {
609 if pattern.is_empty() {
610 eprintln!(
611 "{}",
612 style::warning(&format!(
613 "{} Warning: Group {idx} has empty search pattern in {}",
614 style::icons::WARNING,
615 change.path
616 ))
617 );
618 }
619 },
620 }
621 }
622 }
623
624 for &dep in &group.dependencies {
626 if dep >= groups.len() {
627 return Err(CommitGenError::Other(format!(
628 "Group {idx} has invalid dependency {dep} (only {} groups total)",
629 groups.len()
630 )));
631 }
632 if dep == idx {
633 return Err(CommitGenError::Other(format!("Group {idx} depends on itself (circular)")));
634 }
635 }
636 }
637
638 let missing_files: Vec<&String> = diff_files.difference(&covered_files).collect();
640 if !missing_files.is_empty() {
641 eprintln!(
642 "{}",
643 style::warning(&format!(
644 "{} Warning: Groups don't cover all files. Missing:",
645 style::icons::WARNING
646 ))
647 );
648 for file in &missing_files {
649 eprintln!(" - {file}");
650 }
651 return Err(CommitGenError::Other(format!(
652 "Non-exhaustive groups: {} file(s) not covered",
653 missing_files.len()
654 )));
655 }
656
657 let duplicates: Vec<_> = file_coverage
659 .iter()
660 .filter(|&(_, count)| *count > 1)
661 .collect();
662
663 if !duplicates.is_empty() {
664 eprintln!(
665 "{}",
666 style::warning(&format!(
667 "{} Warning: Some files appear in multiple groups:",
668 style::icons::WARNING
669 ))
670 );
671 for (file, count) in duplicates {
672 eprintln!(" - {file} ({count} times)");
673 }
674 }
675
676 for (idx, group) in groups.iter().enumerate() {
678 if group.changes.is_empty() {
679 return Err(CommitGenError::Other(format!("Group {idx} has no changes")));
680 }
681 }
682
683 Ok(())
684}
685
686pub async fn execute_compose(
688 analysis: &ComposeAnalysis,
689 config: &CommitConfig,
690 args: &Args,
691) -> Result<Vec<String>> {
692 let dir = &args.dir;
693 let token_counter = create_token_counter(config);
694
695 println!("{}", style::info("Resetting staging area..."));
697 reset_staging(dir)?;
698
699 let baseline_diff_output = std::process::Command::new("git")
702 .args(["diff", "HEAD"])
703 .current_dir(dir)
704 .output()
705 .map_err(|e| CommitGenError::git(format!("Failed to get baseline diff: {e}")))?;
706
707 if !baseline_diff_output.status.success() {
708 let stderr = String::from_utf8_lossy(&baseline_diff_output.stderr);
709 return Err(CommitGenError::git(format!("git diff HEAD failed: {stderr}")));
710 }
711
712 let baseline_diff = String::from_utf8_lossy(&baseline_diff_output.stdout).to_string();
713
714 let mut commit_hashes = Vec::new();
715
716 for (idx, &group_idx) in analysis.dependency_order.iter().enumerate() {
717 let mut group = analysis.groups[group_idx].clone();
718 let dependency_only = group_affects_only_dependency_files(&group);
719
720 if dependency_only && group.commit_type.as_str() != "build" {
721 group.commit_type = CommitType::new("build")?;
722 }
723
724 println!(
725 "\n[{}/{}] Creating commit for group: {}",
726 idx + 1,
727 analysis.dependency_order.len(),
728 group.rationale
729 );
730 println!(" Type: {}", style::commit_type(&group.commit_type.to_string()));
731 if let Some(scope) = &group.scope {
732 println!(" Scope: {}", style::scope(&scope.to_string()));
733 }
734 let files: Vec<String> = group.changes.iter().map(|c| c.path.clone()).collect();
735 println!(" Files: {}", files.join(", "));
736
737 stage_group_changes(&group, dir, &baseline_diff)?;
739
740 let diff = get_git_diff(&Mode::Staged, None, dir, config)?;
742 let stat = get_git_stat(&Mode::Staged, None, dir, config)?;
743
744 let diff = if diff.len() > config.max_diff_length {
746 smart_truncate_diff(&diff, config.max_diff_length, config, &token_counter)
747 } else {
748 diff
749 };
750
751 println!(" {}", style::info("Generating commit message..."));
753 let debug_prefix = format!("compose-{}", idx + 1);
754 let ctx = AnalysisContext {
755 user_context: Some(&group.rationale),
756 recent_commits: None, common_scopes: None, project_context: None, debug_output: args.debug_output.as_deref(),
760 debug_prefix: Some(&debug_prefix),
761 };
762 let message_analysis =
763 generate_conventional_analysis(&stat, &diff, &config.model, "", &ctx, config).await?;
764
765 let analysis_body = message_analysis.body_texts();
766
767 let summary = crate::api::generate_summary_from_analysis(
768 &stat,
769 group.commit_type.as_str(),
770 group.scope.as_ref().map(|s| s.as_str()),
771 &analysis_body,
772 Some(&group.rationale),
773 config,
774 args.debug_output.as_deref(),
775 Some(&debug_prefix),
776 )
777 .await?;
778
779 let final_commit_type = if dependency_only {
780 CommitType::new("build")?
781 } else {
782 message_analysis.commit_type
783 };
784
785 let mut commit = ConventionalCommit {
786 commit_type: final_commit_type,
787 scope: message_analysis.scope,
788 summary,
789 body: analysis_body,
790 footers: vec![],
791 };
792
793 post_process_commit_message(&mut commit, config);
794
795 if let Err(e) = validate_commit_message(&commit, config) {
796 eprintln!(
797 " {}",
798 style::warning(&format!("{} Warning: Validation failed: {e}", style::icons::WARNING))
799 );
800 }
801
802 let formatted_message = format_commit_message(&commit);
803
804 println!(
805 " Message:\n{}",
806 formatted_message
807 .lines()
808 .take(3)
809 .collect::<Vec<_>>()
810 .join("\n")
811 );
812
813 if !args.compose_preview {
815 let sign = args.sign || config.gpg_sign;
816 let signoff = args.signoff || config.signoff;
817 git_commit(&formatted_message, false, dir, sign, signoff, args.skip_hooks, false)?;
818 let hash = get_head_hash(dir)?;
819 commit_hashes.push(hash);
820
821 if args.compose_test_after_each {
823 println!(" {}", style::info("Running tests..."));
824 let test_result = std::process::Command::new("cargo")
825 .arg("test")
826 .current_dir(dir)
827 .status();
828
829 if let Ok(status) = test_result {
830 if !status.success() {
831 return Err(CommitGenError::Other(format!(
832 "Tests failed after commit {idx}. Aborting."
833 )));
834 }
835 println!(" {}", style::success(&format!("{} Tests passed", style::icons::SUCCESS)));
836 }
837 }
838 }
839 }
840
841 Ok(commit_hashes)
842}
843
844pub async fn run_compose_mode(args: &Args, config: &CommitConfig) -> Result<()> {
846 let max_rounds = config.compose_max_rounds;
847
848 for round in 1..=max_rounds {
849 if round > 1 {
850 println!(
851 "\n{}",
852 style::section_header(&format!("Compose Round {round}/{max_rounds}"), 80)
853 );
854 } else {
855 println!("{}", style::section_header("Compose Mode", 80));
856 }
857 println!("{}\n", style::info("Analyzing all changes for intelligent splitting..."));
858
859 run_compose_round(args, config, round).await?;
860
861 if args.compose_preview {
863 break;
864 }
865
866 let remaining_diff_output = std::process::Command::new("git")
867 .args(["diff", "HEAD"])
868 .current_dir(&args.dir)
869 .output()
870 .map_err(|e| CommitGenError::git(format!("Failed to check remaining diff: {e}")))?;
871
872 if !remaining_diff_output.status.success() {
873 continue;
874 }
875
876 let remaining_diff = String::from_utf8_lossy(&remaining_diff_output.stdout);
877 if remaining_diff.trim().is_empty() {
878 println!(
879 "\n{}",
880 style::success(&format!(
881 "{} All changes committed successfully",
882 style::icons::SUCCESS
883 ))
884 );
885 break;
886 }
887
888 eprintln!(
889 "\n{}",
890 style::warning(&format!(
891 "{} Uncommitted changes remain after round {round}",
892 style::icons::WARNING
893 ))
894 );
895
896 let stat_output = std::process::Command::new("git")
897 .args(["diff", "HEAD", "--stat"])
898 .current_dir(&args.dir)
899 .output()
900 .ok();
901
902 if let Some(output) = stat_output
903 && output.status.success()
904 {
905 let stat = String::from_utf8_lossy(&output.stdout);
906 eprintln!("{stat}");
907 }
908
909 if round < max_rounds {
910 eprintln!("{}", style::info("Starting another compose round..."));
911 continue;
912 }
913 eprintln!(
914 "{}",
915 style::warning(&format!(
916 "Reached max rounds ({max_rounds}). Remaining changes need manual commit."
917 ))
918 );
919 }
920
921 Ok(())
922}
923
924async fn run_compose_round(args: &Args, config: &CommitConfig, round: usize) -> Result<()> {
926 let token_counter = create_token_counter(config);
927
928 let diff_staged = get_git_diff(&Mode::Staged, None, &args.dir, config).unwrap_or_default();
930 let diff_unstaged = get_git_diff(&Mode::Unstaged, None, &args.dir, config).unwrap_or_default();
931
932 let combined_diff = if diff_staged.is_empty() {
933 diff_unstaged
934 } else if diff_unstaged.is_empty() {
935 diff_staged
936 } else {
937 format!("{diff_staged}\n{diff_unstaged}")
938 };
939
940 if combined_diff.is_empty() {
941 return Err(CommitGenError::NoChanges { mode: "working directory".to_string() });
942 }
943
944 let stat_staged = get_git_stat(&Mode::Staged, None, &args.dir, config).unwrap_or_default();
945 let stat_unstaged = get_git_stat(&Mode::Unstaged, None, &args.dir, config).unwrap_or_default();
946
947 let combined_stat = if stat_staged.is_empty() {
948 stat_unstaged
949 } else if stat_unstaged.is_empty() {
950 stat_staged
951 } else {
952 format!("{stat_staged}\n{stat_unstaged}")
953 };
954
955 let original_diff = combined_diff.clone();
957
958 let diff = if combined_diff.len() > config.max_diff_length {
960 println!(
961 "{}",
962 style::warning(&format!(
963 "{} Applying smart truncation (diff size: {} characters)",
964 style::icons::WARNING,
965 combined_diff.len()
966 ))
967 );
968 smart_truncate_diff(&combined_diff, config.max_diff_length, config, &token_counter)
969 } else {
970 combined_diff
971 };
972
973 let max_commits = args.compose_max_commits.unwrap_or(3);
974
975 println!("{}", style::info(&format!("Analyzing changes (max {max_commits} commits)...")));
976 let analysis = analyze_for_compose(&diff, &combined_stat, config, max_commits).await?;
977
978 println!("{}", style::info("Validating groups..."));
980 validate_compose_groups(&analysis.groups, &original_diff)?;
981
982 println!("\n{}", style::section_header("Proposed Commit Groups", 80));
983 for (idx, &group_idx) in analysis.dependency_order.iter().enumerate() {
984 let mut group = analysis.groups[group_idx].clone();
985 if group_affects_only_dependency_files(&group) && group.commit_type.as_str() != "build" {
986 group.commit_type = CommitType::new("build")?;
987 }
988 println!(
989 "\n{}. [{}{}] {}",
990 idx + 1,
991 style::commit_type(&group.commit_type.to_string()),
992 group
993 .scope
994 .as_ref()
995 .map(|s| format!("({})", style::scope(&s.to_string())))
996 .unwrap_or_default(),
997 group.rationale
998 );
999 println!(" Changes:");
1000 for change in &group.changes {
1001 let is_all =
1002 change.hunks.len() == 1 && matches!(&change.hunks[0], crate::types::HunkSelector::All);
1003
1004 if is_all {
1005 println!(" - {} (all changes)", change.path);
1006 } else {
1007 let summary: Vec<String> = change
1009 .hunks
1010 .iter()
1011 .map(|s| match s {
1012 crate::types::HunkSelector::All => "all".to_string(),
1013 crate::types::HunkSelector::Lines { start, end } => {
1014 format!("lines {start}-{end}")
1015 },
1016 crate::types::HunkSelector::Search { pattern } => {
1017 if pattern.len() > 20 {
1018 format!("search '{}'...", &pattern[..20])
1019 } else {
1020 format!("search '{pattern}'")
1021 }
1022 },
1023 })
1024 .collect();
1025 println!(" - {} ({})", change.path, summary.join(", "));
1026 }
1027 }
1028 if !group.dependencies.is_empty() {
1029 println!(" Depends on: {:?}", group.dependencies);
1030 }
1031 }
1032
1033 if args.compose_preview {
1034 println!(
1035 "\n{}",
1036 style::success(&format!(
1037 "{} Preview complete (use --compose without --compose-preview to execute)",
1038 style::icons::SUCCESS
1039 ))
1040 );
1041 return Ok(());
1042 }
1043
1044 println!("\n{}", style::info(&format!("Executing compose (round {round})...")));
1045 let hashes = execute_compose(&analysis, config, args).await?;
1046
1047 println!(
1048 "{}",
1049 style::success(&format!(
1050 "{} Round {round}: Created {} commit(s)",
1051 style::icons::SUCCESS,
1052 hashes.len()
1053 ))
1054 );
1055 Ok(())
1056}