llm_git/
map_reduce.rs

1//! Map-reduce pattern for large diff analysis
2//!
3//! When diffs exceed the token threshold, this module splits analysis across
4//! files, then synthesizes results for accurate classification.
5
6use std::path::Path;
7
8use rayon::prelude::*;
9use serde::{Deserialize, Serialize};
10
11use crate::{
12   api::retry_api_call,
13   config::CommitConfig,
14   diff::{FileDiff, parse_diff, reconstruct_diff},
15   error::{CommitGenError, Result},
16   templates,
17   tokens::TokenCounter,
18   types::ConventionalAnalysis,
19};
20
21/// Observation from a single file during map phase
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FileObservation {
24   pub file:         String,
25   pub observations: Vec<String>,
26   pub additions:    usize,
27   pub deletions:    usize,
28}
29
30/// Minimum files to justify map-reduce overhead (below this, unified is fine)
31const MIN_FILES_FOR_MAP_REDUCE: usize = 4;
32
33/// Maximum tokens per file in map phase (leave headroom for prompt template +
34/// context)
35const MAX_FILE_TOKENS: usize = 50_000;
36
37/// Check if map-reduce should be used
38/// Always use map-reduce except for:
39/// 1. Explicitly disabled in config
40/// 2. Very small diffs (≤3 files) where overhead isn't worth it
41pub fn should_use_map_reduce(diff: &str, config: &CommitConfig, counter: &TokenCounter) -> bool {
42   if !config.map_reduce_enabled {
43      return false;
44   }
45
46   let files = parse_diff(diff);
47   let file_count = files
48      .iter()
49      .filter(|f| {
50         !config
51            .excluded_files
52            .iter()
53            .any(|ex| f.filename.ends_with(ex))
54      })
55      .count();
56
57   // Use map-reduce for 4+ files, or if any single file would need truncation
58   file_count >= MIN_FILES_FOR_MAP_REDUCE
59      || files
60         .iter()
61         .any(|f| f.token_estimate(counter) > MAX_FILE_TOKENS)
62}
63
64/// Maximum files to include in context header (prevent token explosion)
65const MAX_CONTEXT_FILES: usize = 20;
66
67/// Generate context header summarizing other files for cross-file awareness
68fn generate_context_header(files: &[FileDiff], current_file: &str) -> String {
69   // Skip context header for very large commits (diminishing returns)
70   if files.len() > 100 {
71      return format!("(Large commit with {} total files)", files.len());
72   }
73
74   let mut lines = vec!["OTHER FILES IN THIS CHANGE:".to_string()];
75
76   let other_files: Vec<_> = files
77      .iter()
78      .filter(|f| f.filename != current_file)
79      .collect();
80
81   let total_other = other_files.len();
82
83   // Only show top files by change size if too many
84   let to_show: Vec<&FileDiff> = if total_other > MAX_CONTEXT_FILES {
85      let mut sorted = other_files;
86      sorted.sort_by_key(|f| std::cmp::Reverse(f.additions + f.deletions));
87      sorted.truncate(MAX_CONTEXT_FILES);
88      sorted
89   } else {
90      other_files
91   };
92
93   for file in &to_show {
94      let line_count = file.additions + file.deletions;
95      let description = infer_file_description(&file.filename, &file.content);
96      lines.push(format!("- {} ({} lines): {}", file.filename, line_count, description));
97   }
98
99   if to_show.len() < total_other {
100      lines.push(format!("... and {} more files", total_other - to_show.len()));
101   }
102
103   if lines.len() == 1 {
104      return String::new(); // No other files
105   }
106
107   lines.join("\n")
108}
109
110/// Infer a brief description of what a file likely contains based on
111/// name/content
112fn infer_file_description(filename: &str, content: &str) -> &'static str {
113   let filename_lower = filename.to_lowercase();
114
115   // Check filename patterns
116   if filename_lower.contains("test") {
117      return "test file";
118   }
119   if Path::new(filename)
120      .extension()
121      .is_some_and(|e| e.eq_ignore_ascii_case("md"))
122   {
123      return "documentation";
124   }
125   let ext = Path::new(filename).extension();
126   if filename_lower.contains("config")
127      || ext.is_some_and(|e| e.eq_ignore_ascii_case("toml"))
128      || ext.is_some_and(|e| e.eq_ignore_ascii_case("yaml"))
129      || ext.is_some_and(|e| e.eq_ignore_ascii_case("yml"))
130   {
131      return "configuration";
132   }
133   if filename_lower.contains("error") {
134      return "error definitions";
135   }
136   if filename_lower.contains("type") {
137      return "type definitions";
138   }
139   if filename_lower.ends_with("mod.rs") || filename_lower.ends_with("lib.rs") {
140      return "module exports";
141   }
142   if filename_lower.ends_with("main.rs")
143      || filename_lower.ends_with("main.go")
144      || filename_lower.ends_with("main.py")
145   {
146      return "entry point";
147   }
148
149   // Check content patterns
150   if content.contains("impl ") || content.contains("fn ") {
151      return "implementation";
152   }
153   if content.contains("struct ") || content.contains("enum ") {
154      return "type definitions";
155   }
156   if content.contains("async ") || content.contains("await") {
157      return "async code";
158   }
159
160   "source code"
161}
162
163/// Map phase: analyze each file individually and extract observations
164fn map_phase(
165   files: &[FileDiff],
166   model_name: &str,
167   config: &CommitConfig,
168   counter: &TokenCounter,
169) -> Result<Vec<FileObservation>> {
170   // Process files in parallel using rayon
171   let observations: Vec<Result<FileObservation>> = files
172      .par_iter()
173      .map(|file| {
174         if file.is_binary {
175            return Ok(FileObservation {
176               file:         file.filename.clone(),
177               observations: vec!["Binary file changed.".to_string()],
178               additions:    0,
179               deletions:    0,
180            });
181         }
182
183         let context_header = generate_context_header(files, &file.filename);
184
185         // Truncate large files to fit API limits
186         let mut file_clone = file.clone();
187         let file_tokens = file_clone.token_estimate(counter);
188         if file_tokens > MAX_FILE_TOKENS {
189            let target_size = MAX_FILE_TOKENS * 4; // Convert tokens to chars
190            file_clone.truncate(target_size);
191            eprintln!(
192               "  {} truncated {} ({} → {} tokens)",
193               crate::style::icons::WARNING,
194               file.filename,
195               file_tokens,
196               file_clone.token_estimate(counter)
197            );
198         }
199
200         let file_diff = reconstruct_diff(&[file_clone]);
201
202         map_single_file(&file.filename, &file_diff, &context_header, model_name, config)
203      })
204      .collect();
205
206   // Collect results, failing fast on first error
207   observations.into_iter().collect()
208}
209
210/// Analyze a single file and extract observations
211fn map_single_file(
212   filename: &str,
213   file_diff: &str,
214   context_header: &str,
215   model_name: &str,
216   config: &CommitConfig,
217) -> Result<FileObservation> {
218   retry_api_call(config, || {
219      let client = build_client(config);
220
221      let tool = build_observation_tool();
222
223      let prompt = templates::render_map_prompt("default", filename, file_diff, context_header)?;
224
225      let request = build_api_request(model_name, config.temperature, vec![tool], &prompt);
226
227      let mut request_builder = client
228         .post(format!("{}/chat/completions", config.api_base_url))
229         .header("content-type", "application/json");
230
231      if let Some(api_key) = &config.api_key {
232         request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
233      }
234
235      let response = request_builder
236         .json(&request)
237         .send()
238         .map_err(CommitGenError::HttpError)?;
239
240      let status = response.status();
241
242      if status.is_server_error() {
243         let error_text = response
244            .text()
245            .unwrap_or_else(|_| "Unknown error".to_string());
246         eprintln!("{}", crate::style::error(&format!("Server error {status}: {error_text}")));
247         return Ok((true, None)); // Retry
248      }
249
250      if !status.is_success() {
251         let error_text = response
252            .text()
253            .unwrap_or_else(|_| "Unknown error".to_string());
254         return Err(CommitGenError::ApiError { status: status.as_u16(), body: error_text });
255      }
256
257      let api_response: ApiResponse = response.json().map_err(CommitGenError::HttpError)?;
258
259      if api_response.choices.is_empty() {
260         return Err(CommitGenError::Other(
261            "API returned empty response for file observation".to_string(),
262         ));
263      }
264
265      let message = &api_response.choices[0].message;
266
267      if !message.tool_calls.is_empty() {
268         let tool_call = &message.tool_calls[0];
269         if tool_call.function.name == "create_file_observation" {
270            let args = &tool_call.function.arguments;
271            if args.is_empty() {
272               return Err(CommitGenError::Other(
273                  "Model returned empty function arguments for observation".to_string(),
274               ));
275            }
276
277            let obs: FileObservationResponse = serde_json::from_str(args).map_err(|e| {
278               CommitGenError::Other(format!("Failed to parse observation response: {e}"))
279            })?;
280
281            return Ok((
282               false,
283               Some(FileObservation {
284                  file:         filename.to_string(),
285                  observations: obs.observations,
286                  additions:    0, // Will be filled from FileDiff
287                  deletions:    0,
288               }),
289            ));
290         }
291      }
292
293      // Fallback: try to parse content
294      if let Some(content) = &message.content {
295         let obs: FileObservationResponse =
296            serde_json::from_str(content.trim()).map_err(CommitGenError::JsonError)?;
297         return Ok((
298            false,
299            Some(FileObservation {
300               file:         filename.to_string(),
301               observations: obs.observations,
302               additions:    0,
303               deletions:    0,
304            }),
305         ));
306      }
307
308      Err(CommitGenError::Other("No observation found in API response".to_string()))
309   })
310}
311
312/// Reduce phase: synthesize all observations into final analysis
313pub fn reduce_phase(
314   observations: &[FileObservation],
315   stat: &str,
316   scope_candidates: &str,
317   model_name: &str,
318   config: &CommitConfig,
319) -> Result<ConventionalAnalysis> {
320   retry_api_call(config, || {
321      let client = build_client(config);
322
323      // Build type enum from config
324      let type_enum: Vec<&str> = config.types.keys().map(|s| s.as_str()).collect();
325
326      let tool = build_analysis_tool(&type_enum);
327
328      let observations_json =
329         serde_json::to_string_pretty(observations).unwrap_or_else(|_| "[]".to_string());
330
331      let types_description = crate::api::format_types_description(config);
332      let prompt = templates::render_reduce_prompt(
333         "default",
334         &observations_json,
335         stat,
336         scope_candidates,
337         Some(&types_description),
338      )?;
339
340      let request = build_api_request(model_name, config.temperature, vec![tool], &prompt);
341
342      let mut request_builder = client
343         .post(format!("{}/chat/completions", config.api_base_url))
344         .header("content-type", "application/json");
345
346      if let Some(api_key) = &config.api_key {
347         request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"));
348      }
349
350      let response = request_builder
351         .json(&request)
352         .send()
353         .map_err(CommitGenError::HttpError)?;
354
355      let status = response.status();
356
357      if status.is_server_error() {
358         let error_text = response
359            .text()
360            .unwrap_or_else(|_| "Unknown error".to_string());
361         eprintln!("{}", crate::style::error(&format!("Server error {status}: {error_text}")));
362         return Ok((true, None)); // Retry
363      }
364
365      if !status.is_success() {
366         let error_text = response
367            .text()
368            .unwrap_or_else(|_| "Unknown error".to_string());
369         return Err(CommitGenError::ApiError { status: status.as_u16(), body: error_text });
370      }
371
372      let api_response: ApiResponse = response.json().map_err(CommitGenError::HttpError)?;
373
374      if api_response.choices.is_empty() {
375         return Err(CommitGenError::Other(
376            "API returned empty response for synthesis".to_string(),
377         ));
378      }
379
380      let message = &api_response.choices[0].message;
381
382      if !message.tool_calls.is_empty() {
383         let tool_call = &message.tool_calls[0];
384         if tool_call.function.name == "create_conventional_analysis" {
385            let args = &tool_call.function.arguments;
386            if args.is_empty() {
387               return Err(CommitGenError::Other(
388                  "Model returned empty function arguments for synthesis".to_string(),
389               ));
390            }
391
392            let analysis: ConventionalAnalysis = serde_json::from_str(args).map_err(|e| {
393               CommitGenError::Other(format!("Failed to parse synthesis response: {e}"))
394            })?;
395
396            return Ok((false, Some(analysis)));
397         }
398      }
399
400      // Fallback
401      if let Some(content) = &message.content {
402         let analysis: ConventionalAnalysis =
403            serde_json::from_str(content.trim()).map_err(CommitGenError::JsonError)?;
404         return Ok((false, Some(analysis)));
405      }
406
407      Err(CommitGenError::Other("No analysis found in synthesis response".to_string()))
408   })
409}
410
411/// Run full map-reduce pipeline for large diffs
412pub fn run_map_reduce(
413   diff: &str,
414   stat: &str,
415   scope_candidates: &str,
416   model_name: &str,
417   config: &CommitConfig,
418   counter: &TokenCounter,
419) -> Result<ConventionalAnalysis> {
420   let mut files = parse_diff(diff);
421
422   // Filter excluded files
423   files.retain(|f| {
424      !config
425         .excluded_files
426         .iter()
427         .any(|excluded| f.filename.ends_with(excluded))
428   });
429
430   if files.is_empty() {
431      return Err(CommitGenError::Other(
432         "No relevant files to analyze after filtering".to_string(),
433      ));
434   }
435
436   let file_count = files.len();
437   crate::style::print_info(&format!("Running map-reduce on {file_count} files..."));
438
439   // Map phase
440   let observations = map_phase(&files, model_name, config, counter)?;
441
442   // Reduce phase
443   reduce_phase(&observations, stat, scope_candidates, model_name, config)
444}
445
446// ============================================================================
447// API types (duplicated from api.rs to avoid circular deps)
448// ============================================================================
449
450use std::time::Duration;
451
452fn build_client(config: &CommitConfig) -> reqwest::blocking::Client {
453   reqwest::blocking::Client::builder()
454      .timeout(Duration::from_secs(config.request_timeout_secs))
455      .connect_timeout(Duration::from_secs(config.connect_timeout_secs))
456      .build()
457      .expect("Failed to build HTTP client")
458}
459
460#[derive(Debug, Serialize)]
461struct Message {
462   role:    String,
463   content: String,
464}
465
466#[derive(Debug, Serialize, Deserialize)]
467struct FunctionParameters {
468   #[serde(rename = "type")]
469   param_type: String,
470   properties: serde_json::Value,
471   required:   Vec<String>,
472}
473
474#[derive(Debug, Serialize, Deserialize)]
475struct Function {
476   name:        String,
477   description: String,
478   parameters:  FunctionParameters,
479}
480
481#[derive(Debug, Serialize, Deserialize)]
482struct Tool {
483   #[serde(rename = "type")]
484   tool_type: String,
485   function:  Function,
486}
487
488#[derive(Debug, Serialize)]
489struct ApiRequest {
490   model:       String,
491   max_tokens:  u32,
492   temperature: f32,
493   tools:       Vec<Tool>,
494   #[serde(skip_serializing_if = "Option::is_none")]
495   tool_choice: Option<serde_json::Value>,
496   messages:    Vec<Message>,
497}
498
499#[derive(Debug, Deserialize)]
500struct ToolCall {
501   function: FunctionCall,
502}
503
504#[derive(Debug, Deserialize)]
505struct FunctionCall {
506   name:      String,
507   arguments: String,
508}
509
510#[derive(Debug, Deserialize)]
511struct Choice {
512   message: ResponseMessage,
513}
514
515#[derive(Debug, Deserialize)]
516struct ResponseMessage {
517   #[serde(default)]
518   tool_calls: Vec<ToolCall>,
519   #[serde(default)]
520   content:    Option<String>,
521}
522
523#[derive(Debug, Deserialize)]
524struct ApiResponse {
525   choices: Vec<Choice>,
526}
527
528#[derive(Debug, Deserialize)]
529struct FileObservationResponse {
530   observations: Vec<String>,
531}
532
533fn build_observation_tool() -> Tool {
534   Tool {
535      tool_type: "function".to_string(),
536      function:  Function {
537         name:        "create_file_observation".to_string(),
538         description: "Extract observations from a single file's changes".to_string(),
539         parameters:  FunctionParameters {
540            param_type: "object".to_string(),
541            properties: serde_json::json!({
542               "observations": {
543                  "type": "array",
544                  "description": "List of factual observations about what changed in this file",
545                  "items": {
546                     "type": "string"
547                  }
548               }
549            }),
550            required:   vec!["observations".to_string()],
551         },
552      },
553   }
554}
555
556fn build_analysis_tool(type_enum: &[&str]) -> Tool {
557   Tool {
558      tool_type: "function".to_string(),
559      function:  Function {
560         name:        "create_conventional_analysis".to_string(),
561         description: "Synthesize observations into conventional commit analysis".to_string(),
562         parameters:  FunctionParameters {
563            param_type: "object".to_string(),
564            properties: serde_json::json!({
565               "type": {
566                  "type": "string",
567                  "enum": type_enum,
568                  "description": "Commit type based on combined changes"
569               },
570               "scope": {
571                  "type": "string",
572                  "description": "Optional scope (module/component). Omit if unclear or multi-component."
573               },
574               "details": {
575                  "type": "array",
576                  "description": "Array of 0-6 detail items with changelog metadata.",
577                  "items": {
578                     "type": "object",
579                     "properties": {
580                        "text": {
581                           "type": "string",
582                           "description": "Detail about change, starting with past-tense verb, ending with period"
583                        },
584                        "changelog_category": {
585                           "type": "string",
586                           "enum": ["Added", "Changed", "Fixed", "Deprecated", "Removed", "Security"],
587                           "description": "Changelog category if user-visible. Omit for internal changes."
588                        },
589                        "user_visible": {
590                           "type": "boolean",
591                           "description": "True if this change affects users/API and should appear in changelog"
592                        }
593                     },
594                     "required": ["text", "user_visible"]
595                  }
596               },
597               "issue_refs": {
598                  "type": "array",
599                  "description": "Issue numbers from context (e.g., ['#123', '#456']). Empty if none.",
600                  "items": {
601                     "type": "string"
602                  }
603               }
604            }),
605            required:   vec!["type".to_string(), "details".to_string(), "issue_refs".to_string()],
606         },
607      },
608   }
609}
610
611fn build_api_request(model: &str, temperature: f32, tools: Vec<Tool>, prompt: &str) -> ApiRequest {
612   let tool_name = tools.first().map(|t| t.function.name.clone());
613
614   ApiRequest {
615      model: model.to_string(),
616      max_tokens: 1000,
617      temperature,
618      tool_choice: tool_name
619         .map(|name| serde_json::json!({ "type": "function", "function": { "name": name } })),
620      tools,
621      messages: vec![Message { role: "user".to_string(), content: prompt.to_string() }],
622   }
623}
624
625#[cfg(test)]
626mod tests {
627   use super::*;
628   use crate::tokens::TokenCounter;
629
630   fn test_counter() -> TokenCounter {
631      TokenCounter::new("http://localhost:4000", None, "claude-sonnet-4.5")
632   }
633
634   #[test]
635   fn test_should_use_map_reduce_disabled() {
636      let config = CommitConfig { map_reduce_enabled: false, ..Default::default() };
637      let counter = test_counter();
638      // Even with many files, disabled means no map-reduce
639      let diff = r"diff --git a/a.rs b/a.rs
640@@ -0,0 +1 @@
641+a
642diff --git a/b.rs b/b.rs
643@@ -0,0 +1 @@
644+b
645diff --git a/c.rs b/c.rs
646@@ -0,0 +1 @@
647+c
648diff --git a/d.rs b/d.rs
649@@ -0,0 +1 @@
650+d";
651      assert!(!should_use_map_reduce(diff, &config, &counter));
652   }
653
654   #[test]
655   fn test_should_use_map_reduce_few_files() {
656      let config = CommitConfig::default();
657      let counter = test_counter();
658      // Only 2 files - below threshold
659      let diff = r"diff --git a/a.rs b/a.rs
660@@ -0,0 +1 @@
661+a
662diff --git a/b.rs b/b.rs
663@@ -0,0 +1 @@
664+b";
665      assert!(!should_use_map_reduce(diff, &config, &counter));
666   }
667
668   #[test]
669   fn test_should_use_map_reduce_many_files() {
670      let config = CommitConfig::default();
671      let counter = test_counter();
672      // 5 files - above threshold
673      let diff = r"diff --git a/a.rs b/a.rs
674@@ -0,0 +1 @@
675+a
676diff --git a/b.rs b/b.rs
677@@ -0,0 +1 @@
678+b
679diff --git a/c.rs b/c.rs
680@@ -0,0 +1 @@
681+c
682diff --git a/d.rs d/d.rs
683@@ -0,0 +1 @@
684+d
685diff --git a/e.rs b/e.rs
686@@ -0,0 +1 @@
687+e";
688      assert!(should_use_map_reduce(diff, &config, &counter));
689   }
690
691   #[test]
692   fn test_generate_context_header_empty() {
693      let files = vec![FileDiff {
694         filename:  "only.rs".to_string(),
695         header:    String::new(),
696         content:   String::new(),
697         additions: 10,
698         deletions: 5,
699         is_binary: false,
700      }];
701      let header = generate_context_header(&files, "only.rs");
702      assert!(header.is_empty());
703   }
704
705   #[test]
706   fn test_generate_context_header_multiple() {
707      let files = vec![
708         FileDiff {
709            filename:  "src/main.rs".to_string(),
710            header:    String::new(),
711            content:   "fn main() {}".to_string(),
712            additions: 10,
713            deletions: 5,
714            is_binary: false,
715         },
716         FileDiff {
717            filename:  "src/lib.rs".to_string(),
718            header:    String::new(),
719            content:   "mod test;".to_string(),
720            additions: 3,
721            deletions: 1,
722            is_binary: false,
723         },
724         FileDiff {
725            filename:  "tests/test.rs".to_string(),
726            header:    String::new(),
727            content:   "#[test]".to_string(),
728            additions: 20,
729            deletions: 0,
730            is_binary: false,
731         },
732      ];
733
734      let header = generate_context_header(&files, "src/main.rs");
735      assert!(header.contains("OTHER FILES IN THIS CHANGE:"));
736      assert!(header.contains("src/lib.rs"));
737      assert!(header.contains("tests/test.rs"));
738      assert!(!header.contains("src/main.rs")); // Current file excluded
739   }
740
741   #[test]
742   fn test_infer_file_description() {
743      assert_eq!(infer_file_description("src/test_utils.rs", ""), "test file");
744      assert_eq!(infer_file_description("README.md", ""), "documentation");
745      assert_eq!(infer_file_description("config.toml", ""), "configuration");
746      assert_eq!(infer_file_description("src/error.rs", ""), "error definitions");
747      assert_eq!(infer_file_description("src/types.rs", ""), "type definitions");
748      assert_eq!(infer_file_description("src/mod.rs", ""), "module exports");
749      assert_eq!(infer_file_description("src/main.rs", ""), "entry point");
750      assert_eq!(infer_file_description("src/api.rs", "fn call()"), "implementation");
751      assert_eq!(infer_file_description("src/models.rs", "struct Foo"), "type definitions");
752      assert_eq!(infer_file_description("src/unknown.xyz", ""), "source code");
753   }
754}