1use std::path::Path;
7
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 api::retry_api_call,
13 config::CommitConfig,
14 diff::{FileDiff, parse_diff, reconstruct_diff},
15 error::{CommitGenError, Result},
16 templates,
17 tokens::TokenCounter,
18 types::ConventionalAnalysis,
19};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FileObservation {
24 pub file: String,
25 pub observations: Vec<String>,
26 pub additions: usize,
27 pub deletions: usize,
28}
29
30const MIN_FILES_FOR_MAP_REDUCE: usize = 4;
32
33const MAX_FILE_TOKENS: usize = 50_000;
36
37pub fn should_use_map_reduce(diff: &str, config: &CommitConfig, counter: &TokenCounter) -> bool {
42 if !config.map_reduce_enabled {
43 return false;
44 }
45
46 let files = parse_diff(diff);
47 let file_count = files
48 .iter()
49 .filter(|f| {
50 !config
51 .excluded_files
52 .iter()
53 .any(|ex| f.filename.ends_with(ex))
54 })
55 .count();
56
57 file_count >= MIN_FILES_FOR_MAP_REDUCE
59 || files
60 .iter()
61 .any(|f| f.token_estimate(counter) > MAX_FILE_TOKENS)
62}
63
64const MAX_CONTEXT_FILES: usize = 20;
66
67fn generate_context_header(files: &[FileDiff], current_file: &str) -> String {
69 if files.len() > 100 {
71 return format!("(Large commit with {} total files)", files.len());
72 }
73
74 let mut lines = vec!["OTHER FILES IN THIS CHANGE:".to_string()];
75
76 let other_files: Vec<_> = files
77 .iter()
78 .filter(|f| f.filename != current_file)
79 .collect();
80
81 let total_other = other_files.len();
82
83 let to_show: Vec<&FileDiff> = if total_other > MAX_CONTEXT_FILES {
85 let mut sorted = other_files;
86 sorted.sort_by_key(|f| std::cmp::Reverse(f.additions + f.deletions));
87 sorted.truncate(MAX_CONTEXT_FILES);
88 sorted
89 } else {
90 other_files
91 };
92
93 for file in &to_show {
94 let line_count = file.additions + file.deletions;
95 let description = infer_file_description(&file.filename, &file.content);
96 lines.push(format!("- {} ({} lines): {}", file.filename, line_count, description));
97 }
98
99 if to_show.len() < total_other {
100 lines.push(format!("... and {} more files", total_other - to_show.len()));
101 }
102
103 if lines.len() == 1 {
104 return String::new(); }
106
107 lines.join("\n")
108}
109
110fn infer_file_description(filename: &str, content: &str) -> &'static str {
113 let filename_lower = filename.to_lowercase();
114
115 if filename_lower.contains("test") {
117 return "test file";
118 }
119 if Path::new(filename)
120 .extension()
121 .is_some_and(|e| e.eq_ignore_ascii_case("md"))
122 {
123 return "documentation";
124 }
125 let ext = Path::new(filename).extension();
126 if filename_lower.contains("config")
127 || ext.is_some_and(|e| e.eq_ignore_ascii_case("toml"))
128 || ext.is_some_and(|e| e.eq_ignore_ascii_case("yaml"))
129 || ext.is_some_and(|e| e.eq_ignore_ascii_case("yml"))
130 {
131 return "configuration";
132 }
133 if filename_lower.contains("error") {
134 return "error definitions";
135 }
136 if filename_lower.contains("type") {
137 return "type definitions";
138 }
139 if filename_lower.ends_with("mod.rs") || filename_lower.ends_with("lib.rs") {
140 return "module exports";
141 }
142 if filename_lower.ends_with("main.rs")
143 || filename_lower.ends_with("main.go")
144 || filename_lower.ends_with("main.py")
145 {
146 return "entry point";
147 }
148
149 if content.contains("impl ") || content.contains("fn ") {
151 return "implementation";
152 }
153 if content.contains("struct ") || content.contains("enum ") {
154 return "type definitions";
155 }
156 if content.contains("async ") || content.contains("await") {
157 return "async code";
158 }
159
160 "source code"
161}
162
163fn map_phase(
165 files: &[FileDiff],
166 model_name: &str,
167 config: &CommitConfig,
168 counter: &TokenCounter,
169) -> Result<Vec<FileObservation>> {
170 let observations: Vec<Result<FileObservation>> = files
172 .par_iter()
173 .map(|file| {
174 if file.is_binary {
175 return Ok(FileObservation {
176 file: file.filename.clone(),
177 observations: vec!["Binary file changed.".to_string()],
178 additions: 0,
179 deletions: 0,
180 });
181 }
182
183 let context_header = generate_context_header(files, &file.filename);
184
185 let mut file_clone = file.clone();
187 let file_tokens = file_clone.token_estimate(counter);
188 if file_tokens > MAX_FILE_TOKENS {
189 let target_size = MAX_FILE_TOKENS * 4; file_clone.truncate(target_size);
191 eprintln!(
192 " {} truncated {} ({} → {} tokens)",
193 crate::style::icons::WARNING,
194 file.filename,
195 file_tokens,
196 file_clone.token_estimate(counter)
197 );
198 }
199
200 let file_diff = reconstruct_diff(&[file_clone]);
201
202 map_single_file(&file.filename, &file_diff, &context_header, model_name, config)
203 })
204 .collect();
205
206 observations.into_iter().collect()
208}
209
210fn map_single_file(
212 filename: &str,
213 file_diff: &str,
214 context_header: &str,
215 model_name: &str,
216 config: &CommitConfig,
217) -> Result<FileObservation> {
218 retry_api_call(config, || {
219 let client = build_client(config);
220
221 let tool = build_observation_tool();
222
223 let prompt = templates::render_map_prompt("default", filename, file_diff, context_header)?;
224
225 let request = build_api_request(model_name, config.temperature, vec![tool], &prompt);
226
227 let mut request_builder = client
228 .post(format!("{}/chat/completions", config.api_base_url))
229 .header("content-type", "application/json");
230
231 if let Some(api_key) = &config.api_key {
232 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
233 }
234
235 let response = request_builder
236 .json(&request)
237 .send()
238 .map_err(CommitGenError::HttpError)?;
239
240 let status = response.status();
241
242 if status.is_server_error() {
243 let error_text = response
244 .text()
245 .unwrap_or_else(|_| "Unknown error".to_string());
246 eprintln!("{}", crate::style::error(&format!("Server error {status}: {error_text}")));
247 return Ok((true, None)); }
249
250 if !status.is_success() {
251 let error_text = response
252 .text()
253 .unwrap_or_else(|_| "Unknown error".to_string());
254 return Err(CommitGenError::ApiError { status: status.as_u16(), body: error_text });
255 }
256
257 let api_response: ApiResponse = response.json().map_err(CommitGenError::HttpError)?;
258
259 if api_response.choices.is_empty() {
260 return Err(CommitGenError::Other(
261 "API returned empty response for file observation".to_string(),
262 ));
263 }
264
265 let message = &api_response.choices[0].message;
266
267 if !message.tool_calls.is_empty() {
268 let tool_call = &message.tool_calls[0];
269 if tool_call.function.name == "create_file_observation" {
270 let args = &tool_call.function.arguments;
271 if args.is_empty() {
272 return Err(CommitGenError::Other(
273 "Model returned empty function arguments for observation".to_string(),
274 ));
275 }
276
277 let obs: FileObservationResponse = serde_json::from_str(args).map_err(|e| {
278 CommitGenError::Other(format!("Failed to parse observation response: {e}"))
279 })?;
280
281 return Ok((
282 false,
283 Some(FileObservation {
284 file: filename.to_string(),
285 observations: obs.observations,
286 additions: 0, deletions: 0,
288 }),
289 ));
290 }
291 }
292
293 if let Some(content) = &message.content {
295 let obs: FileObservationResponse =
296 serde_json::from_str(content.trim()).map_err(CommitGenError::JsonError)?;
297 return Ok((
298 false,
299 Some(FileObservation {
300 file: filename.to_string(),
301 observations: obs.observations,
302 additions: 0,
303 deletions: 0,
304 }),
305 ));
306 }
307
308 Err(CommitGenError::Other("No observation found in API response".to_string()))
309 })
310}
311
312pub fn reduce_phase(
314 observations: &[FileObservation],
315 stat: &str,
316 scope_candidates: &str,
317 model_name: &str,
318 config: &CommitConfig,
319) -> Result<ConventionalAnalysis> {
320 retry_api_call(config, || {
321 let client = build_client(config);
322
323 let type_enum: Vec<&str> = config.types.keys().map(|s| s.as_str()).collect();
325
326 let tool = build_analysis_tool(&type_enum);
327
328 let observations_json =
329 serde_json::to_string_pretty(observations).unwrap_or_else(|_| "[]".to_string());
330
331 let types_description = crate::api::format_types_description(config);
332 let prompt = templates::render_reduce_prompt(
333 "default",
334 &observations_json,
335 stat,
336 scope_candidates,
337 Some(&types_description),
338 )?;
339
340 let request = build_api_request(model_name, config.temperature, vec![tool], &prompt);
341
342 let mut request_builder = client
343 .post(format!("{}/chat/completions", config.api_base_url))
344 .header("content-type", "application/json");
345
346 if let Some(api_key) = &config.api_key {
347 request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
348 }
349
350 let response = request_builder
351 .json(&request)
352 .send()
353 .map_err(CommitGenError::HttpError)?;
354
355 let status = response.status();
356
357 if status.is_server_error() {
358 let error_text = response
359 .text()
360 .unwrap_or_else(|_| "Unknown error".to_string());
361 eprintln!("{}", crate::style::error(&format!("Server error {status}: {error_text}")));
362 return Ok((true, None)); }
364
365 if !status.is_success() {
366 let error_text = response
367 .text()
368 .unwrap_or_else(|_| "Unknown error".to_string());
369 return Err(CommitGenError::ApiError { status: status.as_u16(), body: error_text });
370 }
371
372 let api_response: ApiResponse = response.json().map_err(CommitGenError::HttpError)?;
373
374 if api_response.choices.is_empty() {
375 return Err(CommitGenError::Other(
376 "API returned empty response for synthesis".to_string(),
377 ));
378 }
379
380 let message = &api_response.choices[0].message;
381
382 if !message.tool_calls.is_empty() {
383 let tool_call = &message.tool_calls[0];
384 if tool_call.function.name == "create_conventional_analysis" {
385 let args = &tool_call.function.arguments;
386 if args.is_empty() {
387 return Err(CommitGenError::Other(
388 "Model returned empty function arguments for synthesis".to_string(),
389 ));
390 }
391
392 let analysis: ConventionalAnalysis = serde_json::from_str(args).map_err(|e| {
393 CommitGenError::Other(format!("Failed to parse synthesis response: {e}"))
394 })?;
395
396 return Ok((false, Some(analysis)));
397 }
398 }
399
400 if let Some(content) = &message.content {
402 let analysis: ConventionalAnalysis =
403 serde_json::from_str(content.trim()).map_err(CommitGenError::JsonError)?;
404 return Ok((false, Some(analysis)));
405 }
406
407 Err(CommitGenError::Other("No analysis found in synthesis response".to_string()))
408 })
409}
410
411pub fn run_map_reduce(
413 diff: &str,
414 stat: &str,
415 scope_candidates: &str,
416 model_name: &str,
417 config: &CommitConfig,
418 counter: &TokenCounter,
419) -> Result<ConventionalAnalysis> {
420 let mut files = parse_diff(diff);
421
422 files.retain(|f| {
424 !config
425 .excluded_files
426 .iter()
427 .any(|excluded| f.filename.ends_with(excluded))
428 });
429
430 if files.is_empty() {
431 return Err(CommitGenError::Other(
432 "No relevant files to analyze after filtering".to_string(),
433 ));
434 }
435
436 let file_count = files.len();
437 crate::style::print_info(&format!("Running map-reduce on {file_count} files..."));
438
439 let observations = map_phase(&files, model_name, config, counter)?;
441
442 reduce_phase(&observations, stat, scope_candidates, model_name, config)
444}
445
446use std::time::Duration;
451
452fn build_client(config: &CommitConfig) -> reqwest::blocking::Client {
453 reqwest::blocking::Client::builder()
454 .timeout(Duration::from_secs(config.request_timeout_secs))
455 .connect_timeout(Duration::from_secs(config.connect_timeout_secs))
456 .build()
457 .expect("Failed to build HTTP client")
458}
459
460#[derive(Debug, Serialize)]
461struct Message {
462 role: String,
463 content: String,
464}
465
466#[derive(Debug, Serialize, Deserialize)]
467struct FunctionParameters {
468 #[serde(rename = "type")]
469 param_type: String,
470 properties: serde_json::Value,
471 required: Vec<String>,
472}
473
474#[derive(Debug, Serialize, Deserialize)]
475struct Function {
476 name: String,
477 description: String,
478 parameters: FunctionParameters,
479}
480
481#[derive(Debug, Serialize, Deserialize)]
482struct Tool {
483 #[serde(rename = "type")]
484 tool_type: String,
485 function: Function,
486}
487
488#[derive(Debug, Serialize)]
489struct ApiRequest {
490 model: String,
491 max_tokens: u32,
492 temperature: f32,
493 tools: Vec<Tool>,
494 #[serde(skip_serializing_if = "Option::is_none")]
495 tool_choice: Option<serde_json::Value>,
496 messages: Vec<Message>,
497}
498
499#[derive(Debug, Deserialize)]
500struct ToolCall {
501 function: FunctionCall,
502}
503
504#[derive(Debug, Deserialize)]
505struct FunctionCall {
506 name: String,
507 arguments: String,
508}
509
510#[derive(Debug, Deserialize)]
511struct Choice {
512 message: ResponseMessage,
513}
514
515#[derive(Debug, Deserialize)]
516struct ResponseMessage {
517 #[serde(default)]
518 tool_calls: Vec<ToolCall>,
519 #[serde(default)]
520 content: Option<String>,
521}
522
523#[derive(Debug, Deserialize)]
524struct ApiResponse {
525 choices: Vec<Choice>,
526}
527
528#[derive(Debug, Deserialize)]
529struct FileObservationResponse {
530 observations: Vec<String>,
531}
532
533fn build_observation_tool() -> Tool {
534 Tool {
535 tool_type: "function".to_string(),
536 function: Function {
537 name: "create_file_observation".to_string(),
538 description: "Extract observations from a single file's changes".to_string(),
539 parameters: FunctionParameters {
540 param_type: "object".to_string(),
541 properties: serde_json::json!({
542 "observations": {
543 "type": "array",
544 "description": "List of factual observations about what changed in this file",
545 "items": {
546 "type": "string"
547 }
548 }
549 }),
550 required: vec!["observations".to_string()],
551 },
552 },
553 }
554}
555
556fn build_analysis_tool(type_enum: &[&str]) -> Tool {
557 Tool {
558 tool_type: "function".to_string(),
559 function: Function {
560 name: "create_conventional_analysis".to_string(),
561 description: "Synthesize observations into conventional commit analysis".to_string(),
562 parameters: FunctionParameters {
563 param_type: "object".to_string(),
564 properties: serde_json::json!({
565 "type": {
566 "type": "string",
567 "enum": type_enum,
568 "description": "Commit type based on combined changes"
569 },
570 "scope": {
571 "type": "string",
572 "description": "Optional scope (module/component). Omit if unclear or multi-component."
573 },
574 "details": {
575 "type": "array",
576 "description": "Array of 0-6 detail items with changelog metadata.",
577 "items": {
578 "type": "object",
579 "properties": {
580 "text": {
581 "type": "string",
582 "description": "Detail about change, starting with past-tense verb, ending with period"
583 },
584 "changelog_category": {
585 "type": "string",
586 "enum": ["Added", "Changed", "Fixed", "Deprecated", "Removed", "Security"],
587 "description": "Changelog category if user-visible. Omit for internal changes."
588 },
589 "user_visible": {
590 "type": "boolean",
591 "description": "True if this change affects users/API and should appear in changelog"
592 }
593 },
594 "required": ["text", "user_visible"]
595 }
596 },
597 "issue_refs": {
598 "type": "array",
599 "description": "Issue numbers from context (e.g., ['#123', '#456']). Empty if none.",
600 "items": {
601 "type": "string"
602 }
603 }
604 }),
605 required: vec!["type".to_string(), "details".to_string(), "issue_refs".to_string()],
606 },
607 },
608 }
609}
610
611fn build_api_request(model: &str, temperature: f32, tools: Vec<Tool>, prompt: &str) -> ApiRequest {
612 let tool_name = tools.first().map(|t| t.function.name.clone());
613
614 ApiRequest {
615 model: model.to_string(),
616 max_tokens: 1000,
617 temperature,
618 tool_choice: tool_name
619 .map(|name| serde_json::json!({ "type": "function", "function": { "name": name } })),
620 tools,
621 messages: vec![Message { role: "user".to_string(), content: prompt.to_string() }],
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 use crate::tokens::TokenCounter;
629
630 fn test_counter() -> TokenCounter {
631 TokenCounter::new("http://localhost:4000", None, "claude-sonnet-4.5")
632 }
633
634 #[test]
635 fn test_should_use_map_reduce_disabled() {
636 let config = CommitConfig { map_reduce_enabled: false, ..Default::default() };
637 let counter = test_counter();
638 let diff = r"diff --git a/a.rs b/a.rs
640@@ -0,0 +1 @@
641+a
642diff --git a/b.rs b/b.rs
643@@ -0,0 +1 @@
644+b
645diff --git a/c.rs b/c.rs
646@@ -0,0 +1 @@
647+c
648diff --git a/d.rs b/d.rs
649@@ -0,0 +1 @@
650+d";
651 assert!(!should_use_map_reduce(diff, &config, &counter));
652 }
653
654 #[test]
655 fn test_should_use_map_reduce_few_files() {
656 let config = CommitConfig::default();
657 let counter = test_counter();
658 let diff = r"diff --git a/a.rs b/a.rs
660@@ -0,0 +1 @@
661+a
662diff --git a/b.rs b/b.rs
663@@ -0,0 +1 @@
664+b";
665 assert!(!should_use_map_reduce(diff, &config, &counter));
666 }
667
668 #[test]
669 fn test_should_use_map_reduce_many_files() {
670 let config = CommitConfig::default();
671 let counter = test_counter();
672 let diff = r"diff --git a/a.rs b/a.rs
674@@ -0,0 +1 @@
675+a
676diff --git a/b.rs b/b.rs
677@@ -0,0 +1 @@
678+b
679diff --git a/c.rs b/c.rs
680@@ -0,0 +1 @@
681+c
682diff --git a/d.rs d/d.rs
683@@ -0,0 +1 @@
684+d
685diff --git a/e.rs b/e.rs
686@@ -0,0 +1 @@
687+e";
688 assert!(should_use_map_reduce(diff, &config, &counter));
689 }
690
691 #[test]
692 fn test_generate_context_header_empty() {
693 let files = vec![FileDiff {
694 filename: "only.rs".to_string(),
695 header: String::new(),
696 content: String::new(),
697 additions: 10,
698 deletions: 5,
699 is_binary: false,
700 }];
701 let header = generate_context_header(&files, "only.rs");
702 assert!(header.is_empty());
703 }
704
705 #[test]
706 fn test_generate_context_header_multiple() {
707 let files = vec![
708 FileDiff {
709 filename: "src/main.rs".to_string(),
710 header: String::new(),
711 content: "fn main() {}".to_string(),
712 additions: 10,
713 deletions: 5,
714 is_binary: false,
715 },
716 FileDiff {
717 filename: "src/lib.rs".to_string(),
718 header: String::new(),
719 content: "mod test;".to_string(),
720 additions: 3,
721 deletions: 1,
722 is_binary: false,
723 },
724 FileDiff {
725 filename: "tests/test.rs".to_string(),
726 header: String::new(),
727 content: "#[test]".to_string(),
728 additions: 20,
729 deletions: 0,
730 is_binary: false,
731 },
732 ];
733
734 let header = generate_context_header(&files, "src/main.rs");
735 assert!(header.contains("OTHER FILES IN THIS CHANGE:"));
736 assert!(header.contains("src/lib.rs"));
737 assert!(header.contains("tests/test.rs"));
738 assert!(!header.contains("src/main.rs")); }
740
741 #[test]
742 fn test_infer_file_description() {
743 assert_eq!(infer_file_description("src/test_utils.rs", ""), "test file");
744 assert_eq!(infer_file_description("README.md", ""), "documentation");
745 assert_eq!(infer_file_description("config.toml", ""), "configuration");
746 assert_eq!(infer_file_description("src/error.rs", ""), "error definitions");
747 assert_eq!(infer_file_description("src/types.rs", ""), "type definitions");
748 assert_eq!(infer_file_description("src/mod.rs", ""), "module exports");
749 assert_eq!(infer_file_description("src/main.rs", ""), "entry point");
750 assert_eq!(infer_file_description("src/api.rs", "fn call()"), "implementation");
751 assert_eq!(infer_file_description("src/models.rs", "struct Foo"), "type definitions");
752 assert_eq!(infer_file_description("src/unknown.xyz", ""), "source code");
753 }
754}