Skip to main content

llm_git/
map_reduce.rs

1//! Map-reduce pattern for large diff analysis
2//!
3//! When diffs exceed the token threshold, this module splits analysis across
4//! files, then synthesizes results for accurate classification.
5
6use std::path::Path;
7
8use futures::stream::{self, StreamExt};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12   api::{OneShotSpec, run_oneshot, strict_json_schema},
13   config::CommitConfig,
14   diff::{FileDiff, parse_diff, reconstruct_diff},
15   error::{CommitGenError, Result},
16   templates,
17   tokens::TokenCounter,
18   types::ConventionalAnalysis,
19};
20
21/// Observation from a single file during map phase
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FileObservation {
24   pub file:         String,
25   pub observations: Vec<String>,
26   pub additions:    usize,
27   pub deletions:    usize,
28}
29
30/// Minimum files to justify map-reduce overhead (below this, unified is fine)
31const MIN_FILES_FOR_MAP_REDUCE: usize = 4;
32
33/// Maximum tokens per file in map phase (leave headroom for prompt template +
34/// context)
35const MAX_FILE_TOKENS: usize = 50_000;
36
37/// Check if map-reduce should be used
38/// Always use map-reduce except for:
39/// 1. Explicitly disabled in config
40/// 2. Very small diffs (≤3 files) where overhead isn't worth it
41pub fn should_use_map_reduce(diff: &str, config: &CommitConfig, counter: &TokenCounter) -> bool {
42   if !config.map_reduce_enabled {
43      return false;
44   }
45
46   let files = parse_diff(diff);
47   let file_count = files
48      .iter()
49      .filter(|f| {
50         !config
51            .excluded_files
52            .iter()
53            .any(|ex| f.filename.ends_with(ex))
54      })
55      .count();
56
57   // Use map-reduce for 4+ files, or if any single file would need truncation
58   file_count >= MIN_FILES_FOR_MAP_REDUCE
59      || files
60         .iter()
61         .any(|f| f.token_estimate(counter) > MAX_FILE_TOKENS)
62}
63
64/// Maximum files to include in context header (prevent token explosion)
65const MAX_CONTEXT_FILES: usize = 20;
66
67/// Generate context header summarizing other files for cross-file awareness
68fn generate_context_header(files: &[FileDiff], current_file: &str) -> String {
69   // Skip context header for very large commits (diminishing returns)
70   if files.len() > 100 {
71      return format!("(Large commit with {} total files)", files.len());
72   }
73
74   let mut lines = vec!["OTHER FILES IN THIS CHANGE:".to_string()];
75
76   let other_files: Vec<_> = files
77      .iter()
78      .filter(|f| f.filename != current_file)
79      .collect();
80
81   let total_other = other_files.len();
82
83   // Only show top files by change size if too many
84   let to_show: Vec<&FileDiff> = if total_other > MAX_CONTEXT_FILES {
85      let mut sorted = other_files;
86      sorted.sort_by_key(|f| std::cmp::Reverse(f.additions + f.deletions));
87      sorted.truncate(MAX_CONTEXT_FILES);
88      sorted
89   } else {
90      other_files
91   };
92
93   for file in &to_show {
94      let line_count = file.additions + file.deletions;
95      let description = infer_file_description(&file.filename, &file.content);
96      lines.push(format!("- {} ({} lines): {}", file.filename, line_count, description));
97   }
98
99   if to_show.len() < total_other {
100      lines.push(format!("... and {} more files", total_other - to_show.len()));
101   }
102
103   if lines.len() == 1 {
104      return String::new(); // No other files
105   }
106
107   lines.join("\n")
108}
109
110/// Infer a brief description of what a file likely contains based on
111/// name/content
112fn infer_file_description(filename: &str, content: &str) -> &'static str {
113   let filename_lower = filename.to_lowercase();
114
115   // Check filename patterns
116   if filename_lower.contains("test") {
117      return "test file";
118   }
119   if Path::new(filename)
120      .extension()
121      .is_some_and(|e| e.eq_ignore_ascii_case("md"))
122   {
123      return "documentation";
124   }
125   let ext = Path::new(filename).extension();
126   if filename_lower.contains("config")
127      || ext.is_some_and(|e| e.eq_ignore_ascii_case("toml"))
128      || ext.is_some_and(|e| e.eq_ignore_ascii_case("yaml"))
129      || ext.is_some_and(|e| e.eq_ignore_ascii_case("yml"))
130   {
131      return "configuration";
132   }
133   if filename_lower.contains("error") {
134      return "error definitions";
135   }
136   if filename_lower.contains("type") {
137      return "type definitions";
138   }
139   if filename_lower.ends_with("mod.rs") || filename_lower.ends_with("lib.rs") {
140      return "module exports";
141   }
142   if filename_lower.ends_with("main.rs")
143      || filename_lower.ends_with("main.go")
144      || filename_lower.ends_with("main.py")
145   {
146      return "entry point";
147   }
148
149   // Check content patterns
150   if content.contains("impl ") || content.contains("fn ") {
151      return "implementation";
152   }
153   if content.contains("struct ") || content.contains("enum ") {
154      return "type definitions";
155   }
156   if content.contains("async ") || content.contains("await") {
157      return "async code";
158   }
159
160   "source code"
161}
162
163/// Map phase: analyze each file individually and extract observations
164async fn map_phase(
165   files: &[FileDiff],
166   model_name: &str,
167   config: &CommitConfig,
168   counter: &TokenCounter,
169) -> Result<Vec<FileObservation>> {
170   // Process files concurrently using futures stream
171   let observations: Vec<Result<FileObservation>> = stream::iter(files.iter())
172      .map(|file| async {
173         if file.is_binary {
174            return Ok(FileObservation {
175               file:         file.filename.clone(),
176               observations: vec!["Binary file changed.".to_string()],
177               additions:    0,
178               deletions:    0,
179            });
180         }
181
182         let context_header = generate_context_header(files, &file.filename);
183         // Truncate large files to fit API limits
184         let mut file_clone = file.clone();
185         let file_tokens = file_clone.token_estimate(counter);
186         if file_tokens > MAX_FILE_TOKENS {
187            let target_size = MAX_FILE_TOKENS * 4; // Convert tokens to chars
188            file_clone.truncate(target_size);
189            eprintln!(
190               "  {} truncated {} ({} \u{2192} {} tokens)",
191               crate::style::icons::WARNING,
192               file.filename,
193               file_tokens,
194               file_clone.token_estimate(counter)
195            );
196         }
197
198         let file_diff = reconstruct_diff(&[file_clone]);
199
200         map_single_file(&file.filename, &file_diff, &context_header, model_name, config).await
201      })
202      .buffer_unordered(8)
203      .collect()
204      .await;
205
206   // Collect results, failing fast on first error
207   observations.into_iter().collect()
208}
209
210pub async fn observe_diff_files(
211   diff: &str,
212   model_name: &str,
213   config: &CommitConfig,
214   counter: &TokenCounter,
215) -> Result<Vec<FileObservation>> {
216   let mut files = parse_diff(diff);
217
218   files.retain(|file| {
219      !config
220         .excluded_files
221         .iter()
222         .any(|excluded| file.filename.ends_with(excluded))
223   });
224
225   if files.is_empty() {
226      return Err(CommitGenError::Other(
227         "No relevant files to summarize after filtering".to_string(),
228      ));
229   }
230
231   map_phase(&files, model_name, config, counter).await
232}
233
234/// Analyze a single file and extract observations
235async fn map_single_file(
236   filename: &str,
237   file_diff: &str,
238   context_header: &str,
239   model_name: &str,
240   config: &CommitConfig,
241) -> Result<FileObservation> {
242   let parts = templates::render_map_prompt("default", filename, file_diff, context_header)?;
243   let observation_schema = build_observation_schema();
244
245   let response = run_oneshot::<FileObservationResponse>(config, &OneShotSpec {
246      operation:        "map-reduce/map",
247      model:            model_name,
248      max_tokens:       1500,
249      temperature:      config.temperature,
250      prompt_family:    "map",
251      prompt_variant:   "default",
252      system_prompt:    &parts.system,
253      user_prompt:      &parts.user,
254      tool_name:        "create_file_observation",
255      tool_description: "Extract observations from a single file's changes",
256      schema:           &observation_schema,
257      debug:            None,
258      cacheable:        true,
259   })
260   .await?;
261
262   let mut observations = response.output.observations;
263   if observations.is_empty() {
264      let text_observations = response
265         .text_content
266         .as_deref()
267         .map(parse_observations_from_text)
268         .unwrap_or_default();
269
270      if !text_observations.is_empty() {
271         observations = text_observations;
272      } else if response.stop_reason.as_deref() == Some("max_tokens") {
273         crate::style::warn(
274            "Anthropic stopped at max_tokens with empty observations; using fallback observation.",
275         );
276         let fallback_target = Path::new(filename)
277            .file_name()
278            .and_then(|name| name.to_str())
279            .unwrap_or(filename);
280         observations = vec![format!("Updated {fallback_target}.")];
281      } else {
282         crate::style::warn("Model returned empty observations; continuing with no observations.");
283      }
284   }
285
286   Ok(FileObservation { file: filename.to_string(), observations, additions: 0, deletions: 0 })
287}
288
289/// Reduce phase: synthesize all observations into final analysis
290pub async fn reduce_phase(
291   observations: &[FileObservation],
292   stat: &str,
293   scope_candidates: &str,
294   model_name: &str,
295   config: &CommitConfig,
296) -> Result<ConventionalAnalysis> {
297   let type_enum: Vec<&str> = config.types.keys().map(|s| s.as_str()).collect();
298   let observations_json =
299      serde_json::to_string_pretty(observations).unwrap_or_else(|_| "[]".to_string());
300
301   let types_description = crate::api::format_types_description(config);
302   let parts = templates::render_reduce_prompt(
303      "default",
304      &observations_json,
305      stat,
306      scope_candidates,
307      Some(&types_description),
308   )?;
309
310   let analysis_schema = build_analysis_schema(&type_enum);
311   let response = run_oneshot::<ConventionalAnalysis>(config, &OneShotSpec {
312      operation:        "map-reduce/reduce",
313      model:            model_name,
314      max_tokens:       1500,
315      temperature:      config.temperature,
316      prompt_family:    "reduce",
317      prompt_variant:   "default",
318      system_prompt:    &parts.system,
319      user_prompt:      &parts.user,
320      tool_name:        "create_conventional_analysis",
321      tool_description: "Analyze changes and classify as conventional commit with type, scope, \
322                         details, and metadata",
323      schema:           &analysis_schema,
324      debug:            None,
325      cacheable:        true,
326   })
327   .await?;
328
329   Ok(response.output)
330}
331
332/// Run full map-reduce pipeline for large diffs
333pub async fn run_map_reduce(
334   diff: &str,
335   stat: &str,
336   scope_candidates: &str,
337   model_name: &str,
338   config: &CommitConfig,
339   counter: &TokenCounter,
340) -> Result<ConventionalAnalysis> {
341   let observations = observe_diff_files(diff, model_name, config, counter).await?;
342
343   let file_count = observations.len();
344   crate::style::print_info(&format!("Running map-reduce on {file_count} files..."));
345
346   // Reduce phase
347   reduce_phase(&observations, stat, scope_candidates, model_name, config).await
348}
349
350fn parse_observations_from_text(text: &str) -> Vec<String> {
351   let trimmed = text.trim();
352   if trimmed.is_empty() {
353      return Vec::new();
354   }
355
356   if let Ok(obs) = serde_json::from_str::<FileObservationResponse>(trimmed) {
357      return obs.observations;
358   }
359
360   trimmed
361      .lines()
362      .map(str::trim)
363      .filter(|line| !line.is_empty())
364      .map(|line| {
365         line
366            .strip_prefix("- ")
367            .or_else(|| line.strip_prefix("* "))
368            .unwrap_or(line)
369            .trim()
370      })
371      .filter(|line| !line.is_empty())
372      .map(str::to_string)
373      .collect()
374}
375
376#[derive(Debug, Deserialize, Serialize)]
377struct FileObservationResponse {
378   #[serde(deserialize_with = "deserialize_observations")]
379   observations: Vec<String>,
380}
381
382/// Deserialize observations flexibly: handles array, stringified array, or
383/// bullet string
384fn deserialize_observations<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
385where
386   D: serde::Deserializer<'de>,
387{
388   use std::fmt;
389
390   use serde::de::{self, Visitor};
391
392   struct ObservationsVisitor;
393
394   impl<'de> Visitor<'de> for ObservationsVisitor {
395      type Value = Vec<String>;
396
397      fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
398         formatter.write_str("an array of strings, a JSON array string, or a bullet-point string")
399      }
400
401      fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
402      where
403         A: de::SeqAccess<'de>,
404      {
405         let mut vec = Vec::new();
406         while let Some(item) = seq.next_element::<String>()? {
407            vec.push(item);
408         }
409         Ok(vec)
410      }
411
412      fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
413      where
414         E: de::Error,
415      {
416         Ok(parse_string_to_observations(s))
417      }
418   }
419
420   deserializer.deserialize_any(ObservationsVisitor)
421}
422
423/// Parse a string into observations: handles JSON array string or bullet-point
424/// string
425fn parse_string_to_observations(s: &str) -> Vec<String> {
426   let trimmed = s.trim();
427   if trimmed.is_empty() {
428      return Vec::new();
429   }
430
431   // Try parsing as JSON array first
432   if trimmed.starts_with('[')
433      && let Ok(arr) = serde_json::from_str::<Vec<String>>(trimmed)
434   {
435      return arr;
436   }
437
438   // Fall back to bullet-point parsing
439   trimmed
440      .lines()
441      .map(str::trim)
442      .filter(|line| !line.is_empty())
443      .map(|line| {
444         line
445            .strip_prefix("- ")
446            .or_else(|| line.strip_prefix("* "))
447            .or_else(|| line.strip_prefix("• "))
448            .unwrap_or(line)
449            .trim()
450            .to_string()
451      })
452      .filter(|line| !line.is_empty())
453      .collect()
454}
455
456fn build_observation_schema() -> serde_json::Value {
457   strict_json_schema(
458      serde_json::json!({
459         "observations": {
460            "type": "array",
461            "description": "List of factual observations about what changed in this file",
462            "items": {
463               "type": "string"
464            }
465         }
466      }),
467      &["observations"],
468   )
469}
470
471fn build_analysis_schema(type_enum: &[&str]) -> serde_json::Value {
472   strict_json_schema(
473      serde_json::json!({
474         "type": {
475            "type": "string",
476            "enum": type_enum,
477            "description": "Commit type based on combined changes"
478         },
479         "scope": {
480            "type": "string",
481            "description": "Optional scope (module/component). Omit if unclear or multi-component."
482         },
483         "details": {
484            "type": "array",
485            "description": "Array of 0-6 detail items with changelog metadata.",
486            "items": {
487               "type": "object",
488               "properties": {
489                  "text": {
490                     "type": "string",
491                     "description": "Detail about change, starting with past-tense verb, ending with period"
492                  },
493                  "changelog_category": {
494                     "type": "string",
495                     "enum": ["Added", "Changed", "Fixed", "Deprecated", "Removed", "Security"],
496                     "description": "Changelog category if user-visible. Omit for internal changes."
497                  },
498                  "user_visible": {
499                     "type": "boolean",
500                     "description": "True if this change affects users/API and should appear in changelog"
501                  }
502               },
503               "required": ["text", "user_visible"]
504            }
505         },
506         "issue_refs": {
507            "type": "array",
508            "description": "Issue numbers from context (e.g., ['#123', '#456']). Empty if none.",
509            "items": {
510               "type": "string"
511            }
512         }
513      }),
514      &["type", "details", "issue_refs"],
515   )
516}
517
518#[cfg(test)]
519mod tests {
520   use super::*;
521   use crate::tokens::TokenCounter;
522
523   fn test_counter() -> TokenCounter {
524      TokenCounter::new("http://localhost:4000", None, "claude-sonnet-4.5")
525   }
526
527   #[test]
528   fn test_should_use_map_reduce_disabled() {
529      let config = CommitConfig { map_reduce_enabled: false, ..Default::default() };
530      let counter = test_counter();
531      // Even with many files, disabled means no map-reduce
532      let diff = r"diff --git a/a.rs b/a.rs
533@@ -0,0 +1 @@
534+a
535diff --git a/b.rs b/b.rs
536@@ -0,0 +1 @@
537+b
538diff --git a/c.rs b/c.rs
539@@ -0,0 +1 @@
540+c
541diff --git a/d.rs b/d.rs
542@@ -0,0 +1 @@
543+d";
544      assert!(!should_use_map_reduce(diff, &config, &counter));
545   }
546
547   #[test]
548   fn test_should_use_map_reduce_few_files() {
549      let config = CommitConfig::default();
550      let counter = test_counter();
551      // Only 2 files - below threshold
552      let diff = r"diff --git a/a.rs b/a.rs
553@@ -0,0 +1 @@
554+a
555diff --git a/b.rs b/b.rs
556@@ -0,0 +1 @@
557+b";
558      assert!(!should_use_map_reduce(diff, &config, &counter));
559   }
560
561   #[test]
562   fn test_should_use_map_reduce_many_files() {
563      let config = CommitConfig::default();
564      let counter = test_counter();
565      // 5 files - above threshold
566      let diff = r"diff --git a/a.rs b/a.rs
567@@ -0,0 +1 @@
568+a
569diff --git a/b.rs b/b.rs
570@@ -0,0 +1 @@
571+b
572diff --git a/c.rs b/c.rs
573@@ -0,0 +1 @@
574+c
575diff --git a/d.rs d/d.rs
576@@ -0,0 +1 @@
577+d
578diff --git a/e.rs b/e.rs
579@@ -0,0 +1 @@
580+e";
581      assert!(should_use_map_reduce(diff, &config, &counter));
582   }
583
584   #[test]
585   fn test_generate_context_header_empty() {
586      let files = vec![FileDiff {
587         filename:  "only.rs".to_string(),
588         header:    String::new(),
589         content:   String::new(),
590         additions: 10,
591         deletions: 5,
592         is_binary: false,
593      }];
594      let header = generate_context_header(&files, "only.rs");
595      assert!(header.is_empty());
596   }
597
598   #[test]
599   fn test_generate_context_header_multiple() {
600      let files = vec![
601         FileDiff {
602            filename:  "src/main.rs".to_string(),
603            header:    String::new(),
604            content:   "fn main() {}".to_string(),
605            additions: 10,
606            deletions: 5,
607            is_binary: false,
608         },
609         FileDiff {
610            filename:  "src/lib.rs".to_string(),
611            header:    String::new(),
612            content:   "mod test;".to_string(),
613            additions: 3,
614            deletions: 1,
615            is_binary: false,
616         },
617         FileDiff {
618            filename:  "tests/test.rs".to_string(),
619            header:    String::new(),
620            content:   "#[test]".to_string(),
621            additions: 20,
622            deletions: 0,
623            is_binary: false,
624         },
625      ];
626
627      let header = generate_context_header(&files, "src/main.rs");
628      assert!(header.contains("OTHER FILES IN THIS CHANGE:"));
629      assert!(header.contains("src/lib.rs"));
630      assert!(header.contains("tests/test.rs"));
631      assert!(!header.contains("src/main.rs")); // Current file excluded
632   }
633
634   #[test]
635   fn test_infer_file_description() {
636      assert_eq!(infer_file_description("src/test_utils.rs", ""), "test file");
637      assert_eq!(infer_file_description("README.md", ""), "documentation");
638      assert_eq!(infer_file_description("config.toml", ""), "configuration");
639      assert_eq!(infer_file_description("src/error.rs", ""), "error definitions");
640      assert_eq!(infer_file_description("src/types.rs", ""), "type definitions");
641      assert_eq!(infer_file_description("src/mod.rs", ""), "module exports");
642      assert_eq!(infer_file_description("src/main.rs", ""), "entry point");
643      assert_eq!(infer_file_description("src/api.rs", "fn call()"), "implementation");
644      assert_eq!(infer_file_description("src/models.rs", "struct Foo"), "type definitions");
645      assert_eq!(infer_file_description("src/unknown.xyz", ""), "source code");
646   }
647
648   #[test]
649   fn test_parse_string_to_observations_json_array() {
650      let input = r#"["item one", "item two", "item three"]"#;
651      let result = parse_string_to_observations(input);
652      assert_eq!(result, vec!["item one", "item two", "item three"]);
653   }
654
655   #[test]
656   fn test_parse_string_to_observations_bullet_points() {
657      let input = "- added new function\n- fixed bug in parser\n- updated tests";
658      let result = parse_string_to_observations(input);
659      assert_eq!(result, vec!["added new function", "fixed bug in parser", "updated tests"]);
660   }
661
662   #[test]
663   fn test_parse_string_to_observations_asterisk_bullets() {
664      let input = "* first change\n* second change";
665      let result = parse_string_to_observations(input);
666      assert_eq!(result, vec!["first change", "second change"]);
667   }
668
669   #[test]
670   fn test_parse_string_to_observations_empty() {
671      assert!(parse_string_to_observations("").is_empty());
672      assert!(parse_string_to_observations("   ").is_empty());
673   }
674
675   #[test]
676   fn test_deserialize_observations_array() {
677      let json = r#"{"observations": ["a", "b", "c"]}"#;
678      let result: FileObservationResponse =
679         serde_json::from_str(json).expect("valid observation array JSON should deserialize");
680      assert_eq!(result.observations, vec!["a", "b", "c"]);
681   }
682
683   #[test]
684   fn test_deserialize_observations_stringified_array() {
685      let json = r#"{"observations": "[\"a\", \"b\", \"c\"]"}"#;
686      let result: FileObservationResponse = serde_json::from_str(json)
687         .expect("valid stringified observation array JSON should deserialize");
688      assert_eq!(result.observations, vec!["a", "b", "c"]);
689   }
690
691   #[test]
692   fn test_deserialize_observations_bullet_string() {
693      let json = r#"{"observations": "- updated function\n- fixed bug"}"#;
694      let result: FileObservationResponse =
695         serde_json::from_str(json).expect("valid bullet observation JSON should deserialize");
696      assert_eq!(result.observations, vec!["updated function", "fixed bug"]);
697   }
698}