1use std::path::Path;
7
8use futures::stream::{self, StreamExt};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 api::{OneShotSpec, run_oneshot, strict_json_schema},
13 config::CommitConfig,
14 diff::{FileDiff, parse_diff, reconstruct_diff},
15 error::{CommitGenError, Result},
16 templates,
17 tokens::TokenCounter,
18 types::ConventionalAnalysis,
19};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FileObservation {
24 pub file: String,
25 pub observations: Vec<String>,
26 pub additions: usize,
27 pub deletions: usize,
28}
29
30const MIN_FILES_FOR_MAP_REDUCE: usize = 4;
32
33const MAX_FILE_TOKENS: usize = 50_000;
36
37pub fn should_use_map_reduce(diff: &str, config: &CommitConfig, counter: &TokenCounter) -> bool {
42 if !config.map_reduce_enabled {
43 return false;
44 }
45
46 let files = parse_diff(diff);
47 let file_count = files
48 .iter()
49 .filter(|f| {
50 !config
51 .excluded_files
52 .iter()
53 .any(|ex| f.filename.ends_with(ex))
54 })
55 .count();
56
57 file_count >= MIN_FILES_FOR_MAP_REDUCE
59 || files
60 .iter()
61 .any(|f| f.token_estimate(counter) > MAX_FILE_TOKENS)
62}
63
64const MAX_CONTEXT_FILES: usize = 20;
66
67fn generate_context_header(files: &[FileDiff], current_file: &str) -> String {
69 if files.len() > 100 {
71 return format!("(Large commit with {} total files)", files.len());
72 }
73
74 let mut lines = vec!["OTHER FILES IN THIS CHANGE:".to_string()];
75
76 let other_files: Vec<_> = files
77 .iter()
78 .filter(|f| f.filename != current_file)
79 .collect();
80
81 let total_other = other_files.len();
82
83 let to_show: Vec<&FileDiff> = if total_other > MAX_CONTEXT_FILES {
85 let mut sorted = other_files;
86 sorted.sort_by_key(|f| std::cmp::Reverse(f.additions + f.deletions));
87 sorted.truncate(MAX_CONTEXT_FILES);
88 sorted
89 } else {
90 other_files
91 };
92
93 for file in &to_show {
94 let line_count = file.additions + file.deletions;
95 let description = infer_file_description(&file.filename, &file.content);
96 lines.push(format!("- {} ({} lines): {}", file.filename, line_count, description));
97 }
98
99 if to_show.len() < total_other {
100 lines.push(format!("... and {} more files", total_other - to_show.len()));
101 }
102
103 if lines.len() == 1 {
104 return String::new(); }
106
107 lines.join("\n")
108}
109
110fn infer_file_description(filename: &str, content: &str) -> &'static str {
113 let filename_lower = filename.to_lowercase();
114
115 if filename_lower.contains("test") {
117 return "test file";
118 }
119 if Path::new(filename)
120 .extension()
121 .is_some_and(|e| e.eq_ignore_ascii_case("md"))
122 {
123 return "documentation";
124 }
125 let ext = Path::new(filename).extension();
126 if filename_lower.contains("config")
127 || ext.is_some_and(|e| e.eq_ignore_ascii_case("toml"))
128 || ext.is_some_and(|e| e.eq_ignore_ascii_case("yaml"))
129 || ext.is_some_and(|e| e.eq_ignore_ascii_case("yml"))
130 {
131 return "configuration";
132 }
133 if filename_lower.contains("error") {
134 return "error definitions";
135 }
136 if filename_lower.contains("type") {
137 return "type definitions";
138 }
139 if filename_lower.ends_with("mod.rs") || filename_lower.ends_with("lib.rs") {
140 return "module exports";
141 }
142 if filename_lower.ends_with("main.rs")
143 || filename_lower.ends_with("main.go")
144 || filename_lower.ends_with("main.py")
145 {
146 return "entry point";
147 }
148
149 if content.contains("impl ") || content.contains("fn ") {
151 return "implementation";
152 }
153 if content.contains("struct ") || content.contains("enum ") {
154 return "type definitions";
155 }
156 if content.contains("async ") || content.contains("await") {
157 return "async code";
158 }
159
160 "source code"
161}
162
163async fn map_phase(
165 files: &[FileDiff],
166 model_name: &str,
167 config: &CommitConfig,
168 counter: &TokenCounter,
169) -> Result<Vec<FileObservation>> {
170 let observations: Vec<Result<FileObservation>> = stream::iter(files.iter())
172 .map(|file| async {
173 if file.is_binary {
174 return Ok(FileObservation {
175 file: file.filename.clone(),
176 observations: vec!["Binary file changed.".to_string()],
177 additions: 0,
178 deletions: 0,
179 });
180 }
181
182 let context_header = generate_context_header(files, &file.filename);
183 let mut file_clone = file.clone();
185 let file_tokens = file_clone.token_estimate(counter);
186 if file_tokens > MAX_FILE_TOKENS {
187 let target_size = MAX_FILE_TOKENS * 4; file_clone.truncate(target_size);
189 eprintln!(
190 " {} truncated {} ({} \u{2192} {} tokens)",
191 crate::style::icons::WARNING,
192 file.filename,
193 file_tokens,
194 file_clone.token_estimate(counter)
195 );
196 }
197
198 let file_diff = reconstruct_diff(&[file_clone]);
199
200 map_single_file(&file.filename, &file_diff, &context_header, model_name, config).await
201 })
202 .buffer_unordered(8)
203 .collect()
204 .await;
205
206 observations.into_iter().collect()
208}
209
210pub async fn observe_diff_files(
211 diff: &str,
212 model_name: &str,
213 config: &CommitConfig,
214 counter: &TokenCounter,
215) -> Result<Vec<FileObservation>> {
216 let mut files = parse_diff(diff);
217
218 files.retain(|file| {
219 !config
220 .excluded_files
221 .iter()
222 .any(|excluded| file.filename.ends_with(excluded))
223 });
224
225 if files.is_empty() {
226 return Err(CommitGenError::Other(
227 "No relevant files to summarize after filtering".to_string(),
228 ));
229 }
230
231 map_phase(&files, model_name, config, counter).await
232}
233
234async fn map_single_file(
236 filename: &str,
237 file_diff: &str,
238 context_header: &str,
239 model_name: &str,
240 config: &CommitConfig,
241) -> Result<FileObservation> {
242 let parts = templates::render_map_prompt("default", filename, file_diff, context_header)?;
243 let observation_schema = build_observation_schema();
244
245 let response = run_oneshot::<FileObservationResponse>(config, &OneShotSpec {
246 operation: "map-reduce/map",
247 model: model_name,
248 max_tokens: 1500,
249 temperature: config.temperature,
250 prompt_family: "map",
251 prompt_variant: "default",
252 system_prompt: &parts.system,
253 user_prompt: &parts.user,
254 tool_name: "create_file_observation",
255 tool_description: "Extract observations from a single file's changes",
256 schema: &observation_schema,
257 debug: None,
258 cacheable: true,
259 })
260 .await?;
261
262 let mut observations = response.output.observations;
263 if observations.is_empty() {
264 let text_observations = response
265 .text_content
266 .as_deref()
267 .map(parse_observations_from_text)
268 .unwrap_or_default();
269
270 if !text_observations.is_empty() {
271 observations = text_observations;
272 } else if response.stop_reason.as_deref() == Some("max_tokens") {
273 crate::style::warn(
274 "Anthropic stopped at max_tokens with empty observations; using fallback observation.",
275 );
276 let fallback_target = Path::new(filename)
277 .file_name()
278 .and_then(|name| name.to_str())
279 .unwrap_or(filename);
280 observations = vec![format!("Updated {fallback_target}.")];
281 } else {
282 crate::style::warn("Model returned empty observations; continuing with no observations.");
283 }
284 }
285
286 Ok(FileObservation { file: filename.to_string(), observations, additions: 0, deletions: 0 })
287}
288
289pub async fn reduce_phase(
291 observations: &[FileObservation],
292 stat: &str,
293 scope_candidates: &str,
294 model_name: &str,
295 config: &CommitConfig,
296) -> Result<ConventionalAnalysis> {
297 let type_enum: Vec<&str> = config.types.keys().map(|s| s.as_str()).collect();
298 let observations_json =
299 serde_json::to_string_pretty(observations).unwrap_or_else(|_| "[]".to_string());
300
301 let types_description = crate::api::format_types_description(config);
302 let parts = templates::render_reduce_prompt(
303 "default",
304 &observations_json,
305 stat,
306 scope_candidates,
307 Some(&types_description),
308 )?;
309
310 let analysis_schema = build_analysis_schema(&type_enum);
311 let response = run_oneshot::<ConventionalAnalysis>(config, &OneShotSpec {
312 operation: "map-reduce/reduce",
313 model: model_name,
314 max_tokens: 1500,
315 temperature: config.temperature,
316 prompt_family: "reduce",
317 prompt_variant: "default",
318 system_prompt: &parts.system,
319 user_prompt: &parts.user,
320 tool_name: "create_conventional_analysis",
321 tool_description: "Analyze changes and classify as conventional commit with type, scope, \
322 details, and metadata",
323 schema: &analysis_schema,
324 debug: None,
325 cacheable: true,
326 })
327 .await?;
328
329 Ok(response.output)
330}
331
332pub async fn run_map_reduce(
334 diff: &str,
335 stat: &str,
336 scope_candidates: &str,
337 model_name: &str,
338 config: &CommitConfig,
339 counter: &TokenCounter,
340) -> Result<ConventionalAnalysis> {
341 let observations = observe_diff_files(diff, model_name, config, counter).await?;
342
343 let file_count = observations.len();
344 crate::style::print_info(&format!("Running map-reduce on {file_count} files..."));
345
346 reduce_phase(&observations, stat, scope_candidates, model_name, config).await
348}
349
350fn parse_observations_from_text(text: &str) -> Vec<String> {
351 let trimmed = text.trim();
352 if trimmed.is_empty() {
353 return Vec::new();
354 }
355
356 if let Ok(obs) = serde_json::from_str::<FileObservationResponse>(trimmed) {
357 return obs.observations;
358 }
359
360 trimmed
361 .lines()
362 .map(str::trim)
363 .filter(|line| !line.is_empty())
364 .map(|line| {
365 line
366 .strip_prefix("- ")
367 .or_else(|| line.strip_prefix("* "))
368 .unwrap_or(line)
369 .trim()
370 })
371 .filter(|line| !line.is_empty())
372 .map(str::to_string)
373 .collect()
374}
375
376#[derive(Debug, Deserialize, Serialize)]
377struct FileObservationResponse {
378 #[serde(deserialize_with = "deserialize_observations")]
379 observations: Vec<String>,
380}
381
382fn deserialize_observations<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
385where
386 D: serde::Deserializer<'de>,
387{
388 use std::fmt;
389
390 use serde::de::{self, Visitor};
391
392 struct ObservationsVisitor;
393
394 impl<'de> Visitor<'de> for ObservationsVisitor {
395 type Value = Vec<String>;
396
397 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
398 formatter.write_str("an array of strings, a JSON array string, or a bullet-point string")
399 }
400
401 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
402 where
403 A: de::SeqAccess<'de>,
404 {
405 let mut vec = Vec::new();
406 while let Some(item) = seq.next_element::<String>()? {
407 vec.push(item);
408 }
409 Ok(vec)
410 }
411
412 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
413 where
414 E: de::Error,
415 {
416 Ok(parse_string_to_observations(s))
417 }
418 }
419
420 deserializer.deserialize_any(ObservationsVisitor)
421}
422
423fn parse_string_to_observations(s: &str) -> Vec<String> {
426 let trimmed = s.trim();
427 if trimmed.is_empty() {
428 return Vec::new();
429 }
430
431 if trimmed.starts_with('[')
433 && let Ok(arr) = serde_json::from_str::<Vec<String>>(trimmed)
434 {
435 return arr;
436 }
437
438 trimmed
440 .lines()
441 .map(str::trim)
442 .filter(|line| !line.is_empty())
443 .map(|line| {
444 line
445 .strip_prefix("- ")
446 .or_else(|| line.strip_prefix("* "))
447 .or_else(|| line.strip_prefix("• "))
448 .unwrap_or(line)
449 .trim()
450 .to_string()
451 })
452 .filter(|line| !line.is_empty())
453 .collect()
454}
455
456fn build_observation_schema() -> serde_json::Value {
457 strict_json_schema(
458 serde_json::json!({
459 "observations": {
460 "type": "array",
461 "description": "List of factual observations about what changed in this file",
462 "items": {
463 "type": "string"
464 }
465 }
466 }),
467 &["observations"],
468 )
469}
470
471fn build_analysis_schema(type_enum: &[&str]) -> serde_json::Value {
472 strict_json_schema(
473 serde_json::json!({
474 "type": {
475 "type": "string",
476 "enum": type_enum,
477 "description": "Commit type based on combined changes"
478 },
479 "scope": {
480 "type": "string",
481 "description": "Optional scope (module/component). Omit if unclear or multi-component."
482 },
483 "details": {
484 "type": "array",
485 "description": "Array of 0-6 detail items with changelog metadata.",
486 "items": {
487 "type": "object",
488 "properties": {
489 "text": {
490 "type": "string",
491 "description": "Detail about change, starting with past-tense verb, ending with period"
492 },
493 "changelog_category": {
494 "type": "string",
495 "enum": ["Added", "Changed", "Fixed", "Deprecated", "Removed", "Security"],
496 "description": "Changelog category if user-visible. Omit for internal changes."
497 },
498 "user_visible": {
499 "type": "boolean",
500 "description": "True if this change affects users/API and should appear in changelog"
501 }
502 },
503 "required": ["text", "user_visible"]
504 }
505 },
506 "issue_refs": {
507 "type": "array",
508 "description": "Issue numbers from context (e.g., ['#123', '#456']). Empty if none.",
509 "items": {
510 "type": "string"
511 }
512 }
513 }),
514 &["type", "details", "issue_refs"],
515 )
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use crate::tokens::TokenCounter;
522
523 fn test_counter() -> TokenCounter {
524 TokenCounter::new("http://localhost:4000", None, "claude-sonnet-4.5")
525 }
526
527 #[test]
528 fn test_should_use_map_reduce_disabled() {
529 let config = CommitConfig { map_reduce_enabled: false, ..Default::default() };
530 let counter = test_counter();
531 let diff = r"diff --git a/a.rs b/a.rs
533@@ -0,0 +1 @@
534+a
535diff --git a/b.rs b/b.rs
536@@ -0,0 +1 @@
537+b
538diff --git a/c.rs b/c.rs
539@@ -0,0 +1 @@
540+c
541diff --git a/d.rs b/d.rs
542@@ -0,0 +1 @@
543+d";
544 assert!(!should_use_map_reduce(diff, &config, &counter));
545 }
546
547 #[test]
548 fn test_should_use_map_reduce_few_files() {
549 let config = CommitConfig::default();
550 let counter = test_counter();
551 let diff = r"diff --git a/a.rs b/a.rs
553@@ -0,0 +1 @@
554+a
555diff --git a/b.rs b/b.rs
556@@ -0,0 +1 @@
557+b";
558 assert!(!should_use_map_reduce(diff, &config, &counter));
559 }
560
561 #[test]
562 fn test_should_use_map_reduce_many_files() {
563 let config = CommitConfig::default();
564 let counter = test_counter();
565 let diff = r"diff --git a/a.rs b/a.rs
567@@ -0,0 +1 @@
568+a
569diff --git a/b.rs b/b.rs
570@@ -0,0 +1 @@
571+b
572diff --git a/c.rs b/c.rs
573@@ -0,0 +1 @@
574+c
575diff --git a/d.rs d/d.rs
576@@ -0,0 +1 @@
577+d
578diff --git a/e.rs b/e.rs
579@@ -0,0 +1 @@
580+e";
581 assert!(should_use_map_reduce(diff, &config, &counter));
582 }
583
584 #[test]
585 fn test_generate_context_header_empty() {
586 let files = vec![FileDiff {
587 filename: "only.rs".to_string(),
588 header: String::new(),
589 content: String::new(),
590 additions: 10,
591 deletions: 5,
592 is_binary: false,
593 }];
594 let header = generate_context_header(&files, "only.rs");
595 assert!(header.is_empty());
596 }
597
598 #[test]
599 fn test_generate_context_header_multiple() {
600 let files = vec![
601 FileDiff {
602 filename: "src/main.rs".to_string(),
603 header: String::new(),
604 content: "fn main() {}".to_string(),
605 additions: 10,
606 deletions: 5,
607 is_binary: false,
608 },
609 FileDiff {
610 filename: "src/lib.rs".to_string(),
611 header: String::new(),
612 content: "mod test;".to_string(),
613 additions: 3,
614 deletions: 1,
615 is_binary: false,
616 },
617 FileDiff {
618 filename: "tests/test.rs".to_string(),
619 header: String::new(),
620 content: "#[test]".to_string(),
621 additions: 20,
622 deletions: 0,
623 is_binary: false,
624 },
625 ];
626
627 let header = generate_context_header(&files, "src/main.rs");
628 assert!(header.contains("OTHER FILES IN THIS CHANGE:"));
629 assert!(header.contains("src/lib.rs"));
630 assert!(header.contains("tests/test.rs"));
631 assert!(!header.contains("src/main.rs")); }
633
634 #[test]
635 fn test_infer_file_description() {
636 assert_eq!(infer_file_description("src/test_utils.rs", ""), "test file");
637 assert_eq!(infer_file_description("README.md", ""), "documentation");
638 assert_eq!(infer_file_description("config.toml", ""), "configuration");
639 assert_eq!(infer_file_description("src/error.rs", ""), "error definitions");
640 assert_eq!(infer_file_description("src/types.rs", ""), "type definitions");
641 assert_eq!(infer_file_description("src/mod.rs", ""), "module exports");
642 assert_eq!(infer_file_description("src/main.rs", ""), "entry point");
643 assert_eq!(infer_file_description("src/api.rs", "fn call()"), "implementation");
644 assert_eq!(infer_file_description("src/models.rs", "struct Foo"), "type definitions");
645 assert_eq!(infer_file_description("src/unknown.xyz", ""), "source code");
646 }
647
648 #[test]
649 fn test_parse_string_to_observations_json_array() {
650 let input = r#"["item one", "item two", "item three"]"#;
651 let result = parse_string_to_observations(input);
652 assert_eq!(result, vec!["item one", "item two", "item three"]);
653 }
654
655 #[test]
656 fn test_parse_string_to_observations_bullet_points() {
657 let input = "- added new function\n- fixed bug in parser\n- updated tests";
658 let result = parse_string_to_observations(input);
659 assert_eq!(result, vec!["added new function", "fixed bug in parser", "updated tests"]);
660 }
661
662 #[test]
663 fn test_parse_string_to_observations_asterisk_bullets() {
664 let input = "* first change\n* second change";
665 let result = parse_string_to_observations(input);
666 assert_eq!(result, vec!["first change", "second change"]);
667 }
668
669 #[test]
670 fn test_parse_string_to_observations_empty() {
671 assert!(parse_string_to_observations("").is_empty());
672 assert!(parse_string_to_observations(" ").is_empty());
673 }
674
675 #[test]
676 fn test_deserialize_observations_array() {
677 let json = r#"{"observations": ["a", "b", "c"]}"#;
678 let result: FileObservationResponse =
679 serde_json::from_str(json).expect("valid observation array JSON should deserialize");
680 assert_eq!(result.observations, vec!["a", "b", "c"]);
681 }
682
683 #[test]
684 fn test_deserialize_observations_stringified_array() {
685 let json = r#"{"observations": "[\"a\", \"b\", \"c\"]"}"#;
686 let result: FileObservationResponse = serde_json::from_str(json)
687 .expect("valid stringified observation array JSON should deserialize");
688 assert_eq!(result.observations, vec!["a", "b", "c"]);
689 }
690
691 #[test]
692 fn test_deserialize_observations_bullet_string() {
693 let json = r#"{"observations": "- updated function\n- fixed bug"}"#;
694 let result: FileObservationResponse =
695 serde_json::from_str(json).expect("valid bullet observation JSON should deserialize");
696 assert_eq!(result.observations, vec!["updated function", "fixed bug"]);
697 }
698}