1use std::{borrow::Cow, cmp::Reverse, fmt::Write as _, path::Path};
7
8use futures::stream::{self, StreamExt};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 api::{OneShotSpec, run_oneshot, strict_json_schema},
13 config::CommitConfig,
14 diff::{FileDiff, parse_diff, reconstruct_diff},
15 error::{CommitGenError, Result},
16 templates,
17 tokens::TokenCounter,
18 types::ConventionalAnalysis,
19};
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct FileObservation {
24 pub file: String,
25 pub observations: Vec<String>,
26 pub additions: usize,
27 pub deletions: usize,
28}
29
30const MAX_FILE_TOKENS: usize = 50_000;
33const MAP_PHASE_CONCURRENCY: usize = 16;
34
35const fn map_phase_model(config: &CommitConfig) -> &str {
36 config.summary_model.as_str()
37}
38
39fn build_file_batches(
40 files: &[FileDiff],
41 counter: &TokenCounter,
42 budget: usize,
43) -> Vec<Vec<usize>> {
44 build_file_batches_for_indices(files, 0..files.len(), counter, budget)
45}
46
47fn build_llm_file_batches(
48 files: &[FileDiff],
49 counter: &TokenCounter,
50 budget: usize,
51) -> Vec<Vec<usize>> {
52 if files.iter().all(|file| !file.is_binary) {
53 return build_file_batches(files, counter, budget);
54 }
55
56 build_file_batches_for_indices(
57 files,
58 files
59 .iter()
60 .enumerate()
61 .filter_map(|(idx, file)| (!file.is_binary).then_some(idx)),
62 counter,
63 budget,
64 )
65}
66
67fn build_file_batches_for_indices<I>(
68 files: &[FileDiff],
69 indices: I,
70 counter: &TokenCounter,
71 budget: usize,
72) -> Vec<Vec<usize>>
73where
74 I: IntoIterator<Item = usize>,
75{
76 let budget = budget.max(1);
77 let mut batches = Vec::new();
78 let mut current_batch = Vec::new();
79 let mut current_tokens = 0usize;
80
81 for idx in indices {
82 let file_tokens = files[idx].token_estimate(counter);
83 if file_tokens > budget {
84 if !current_batch.is_empty() {
85 batches.push(std::mem::take(&mut current_batch));
86 current_tokens = 0;
87 }
88 batches.push(vec![idx]);
89 continue;
90 }
91
92 if !current_batch.is_empty() && current_tokens.saturating_add(file_tokens) > budget {
93 batches.push(std::mem::take(&mut current_batch));
94 current_tokens = 0;
95 }
96
97 current_batch.push(idx);
98 current_tokens = current_tokens.saturating_add(file_tokens);
99 }
100
101 if !current_batch.is_empty() {
102 batches.push(current_batch);
103 }
104
105 batches
106}
107
108#[tracing::instrument(target = "lgit", name = "map_reduce.should_use", skip_all, fields(diff_bytes = diff.len(), threshold = config.map_reduce_threshold))]
114pub fn should_use_map_reduce(diff: &str, config: &CommitConfig, counter: &TokenCounter) -> bool {
115 if !config.map_reduce_enabled {
116 return false;
117 }
118
119 let files = parse_diff(diff);
120 let mut has_included_file = false;
121 let mut total_tokens = 0usize;
122
123 for file in files.iter().filter(|file| {
124 !config
125 .excluded_files
126 .iter()
127 .any(|excluded| file.filename.ends_with(excluded))
128 }) {
129 has_included_file = true;
130
131 let file_tokens = file.token_estimate(counter);
132 if file_tokens > MAX_FILE_TOKENS {
133 return true;
134 }
135
136 total_tokens = total_tokens.saturating_add(file_tokens);
137 if total_tokens >= config.map_reduce_threshold {
138 return true;
139 }
140 }
141
142 has_included_file && total_tokens >= config.map_reduce_threshold
143}
144
145const MAX_CONTEXT_FILES: usize = 20;
147
148struct ContextFile<'a> {
150 filename: &'a str,
151 summary_line: String,
152 change_size: usize,
153}
154
155struct ContextHeaders<'a> {
156 files: Vec<ContextFile<'a>>,
157 ranked_indices: Vec<usize>,
158 large_commit_header: Option<String>,
159}
160
161impl<'a> ContextHeaders<'a> {
162 fn new(files: &'a [FileDiff]) -> Self {
163 if files.len() > 100 {
165 return Self {
166 files: Vec::new(),
167 ranked_indices: Vec::new(),
168 large_commit_header: Some(format!("(Large commit with {} total files)", files.len())),
169 };
170 }
171
172 let files: Vec<_> = files
173 .iter()
174 .map(|file| {
175 let change_size = file.additions + file.deletions;
176 let description = infer_file_description(&file.filename, &file.content);
177 ContextFile {
178 filename: &file.filename,
179 summary_line: format!(
180 "- {} ({} lines): {}",
181 file.filename, change_size, description
182 ),
183 change_size,
184 }
185 })
186 .collect();
187
188 let mut ranked_indices = Vec::new();
189 if files.len() > MAX_CONTEXT_FILES {
190 ranked_indices = (0..files.len()).collect();
191 ranked_indices.sort_by_key(|&idx| Reverse(files[idx].change_size));
192 }
193
194 Self { files, ranked_indices, large_commit_header: None }
195 }
196
197 fn header_for_files(&self, current_files: &[&str]) -> Cow<'_, str> {
198 if let Some(header) = &self.large_commit_header {
199 return Cow::Borrowed(header.as_str());
200 }
201
202 let current_count = self
203 .files
204 .iter()
205 .filter(|file| is_current_context_file(file.filename, current_files))
206 .count();
207 let total_other = self.files.len().saturating_sub(current_count);
208
209 if total_other == 0 {
210 return Cow::Borrowed("");
211 }
212
213 let mut header = String::with_capacity(32 + total_other.min(MAX_CONTEXT_FILES) * 80);
214 header.push_str("OTHER FILES IN THIS CHANGE:");
215
216 let mut shown = 0usize;
217 if total_other > MAX_CONTEXT_FILES {
218 for &idx in &self.ranked_indices {
219 let file = &self.files[idx];
220 if is_current_context_file(file.filename, current_files) {
221 continue;
222 }
223
224 header.push('\n');
225 header.push_str(&file.summary_line);
226 shown += 1;
227
228 if shown == MAX_CONTEXT_FILES {
229 break;
230 }
231 }
232 } else {
233 for file in &self.files {
234 if is_current_context_file(file.filename, current_files) {
235 continue;
236 }
237
238 header.push('\n');
239 header.push_str(&file.summary_line);
240 shown += 1;
241 }
242 }
243
244 if shown < total_other {
245 write!(&mut header, "\n... and {} more files", total_other - shown)
246 .expect("writing to a string cannot fail");
247 }
248
249 Cow::Owned(header)
250 }
251}
252
253fn is_current_context_file(filename: &str, current_files: &[&str]) -> bool {
254 current_files.contains(&filename)
255}
256
257fn infer_file_description(filename: &str, content: &str) -> &'static str {
260 let filename_lower = filename.to_lowercase();
261
262 if filename_lower.contains("test") {
264 return "test file";
265 }
266 if filename_lower.contains("prompt") || filename_lower.contains("system") {
267 return "prompt template";
268 }
269 if Path::new(filename)
270 .extension()
271 .is_some_and(|e| e.eq_ignore_ascii_case("md"))
272 {
273 return "documentation";
274 }
275 let ext = Path::new(filename).extension();
276 if filename_lower.contains("config")
277 || ext.is_some_and(|e| e.eq_ignore_ascii_case("toml"))
278 || ext.is_some_and(|e| e.eq_ignore_ascii_case("yaml"))
279 || ext.is_some_and(|e| e.eq_ignore_ascii_case("yml"))
280 {
281 return "configuration";
282 }
283 if filename_lower.contains("error") {
284 return "error definitions";
285 }
286 if filename_lower.contains("type") {
287 return "type definitions";
288 }
289 if filename_lower.ends_with("mod.rs") || filename_lower.ends_with("lib.rs") {
290 return "module exports";
291 }
292 if filename_lower.ends_with("main.rs")
293 || filename_lower.ends_with("main.go")
294 || filename_lower.ends_with("main.py")
295 {
296 return "entry point";
297 }
298
299 if content.contains("impl ") || content.contains("fn ") {
301 return "implementation";
302 }
303 if content.contains("struct ") || content.contains("enum ") {
304 return "type definitions";
305 }
306 if content.contains("async ") || content.contains("await") {
307 return "async code";
308 }
309
310 "source code"
311}
312
313#[tracing::instrument(target = "lgit", name = "map_reduce.map_phase", skip_all, fields(file_count = files.len(), model = map_model_name))]
315async fn map_phase(
316 files: &[FileDiff],
317 map_model_name: &str,
318 config: &CommitConfig,
319 counter: &TokenCounter,
320) -> Result<Vec<FileObservation>> {
321 let context_headers = ContextHeaders::new(files);
322 let llm_batches = build_llm_file_batches(files, counter, config.map_batch_token_budget);
323 let total_batches = llm_batches.len();
324
325 let mut observations_by_index = vec![None; files.len()];
326 for (idx, file) in files.iter().enumerate().filter(|(_, file)| file.is_binary) {
327 observations_by_index[idx] = Some(FileObservation {
328 file: file.filename.clone(),
329 observations: vec!["Binary file changed.".to_string()],
330 additions: 0,
331 deletions: 0,
332 });
333 }
334
335 let batch_results: Vec<Result<Vec<(usize, FileObservation)>>> =
336 stream::iter(llm_batches.into_iter().enumerate())
337 .map(|(batch_idx, batch_indices)| {
338 let context_headers = &context_headers;
339 async move {
340 let batch_files: Vec<&FileDiff> =
341 batch_indices.iter().map(|&idx| &files[idx]).collect();
342 let current_paths: Vec<&str> = batch_files
343 .iter()
344 .map(|file| file.filename.as_str())
345 .collect();
346 let context_header = context_headers.header_for_files(¤t_paths);
347 let progress_label = format!(
348 "map batch {}/{} ({} files)",
349 batch_idx + 1,
350 total_batches,
351 batch_files.len()
352 );
353 let observations = map_file_batch(
354 &batch_files,
355 &context_header,
356 map_model_name,
357 config,
358 counter,
359 &progress_label,
360 )
361 .await?;
362
363 Ok(batch_indices.into_iter().zip(observations).collect())
364 }
365 })
366 .buffer_unordered(MAP_PHASE_CONCURRENCY)
367 .collect()
368 .await;
369
370 for result in batch_results {
371 for (idx, observation) in result? {
372 observations_by_index[idx] = Some(observation);
373 }
374 }
375
376 let mut observations = Vec::with_capacity(files.len());
377 for (idx, observation) in observations_by_index.into_iter().enumerate() {
378 let observation = observation.ok_or_else(|| {
379 CommitGenError::Other(format!("Missing map observation for {}", files[idx].filename))
380 })?;
381 observations.push(observation);
382 }
383
384 Ok(observations)
385}
386
387#[tracing::instrument(target = "lgit", name = "map_reduce.observe_diff_files", skip_all, fields(diff_bytes = diff.len(), model = map_model_name))]
388pub async fn observe_diff_files(
389 diff: &str,
390 map_model_name: &str,
391 config: &CommitConfig,
392 counter: &TokenCounter,
393) -> Result<Vec<FileObservation>> {
394 let mut files = parse_diff(diff);
395
396 files.retain(|file| {
397 !config
398 .excluded_files
399 .iter()
400 .any(|excluded| file.filename.ends_with(excluded))
401 });
402
403 if files.is_empty() {
404 return Err(CommitGenError::Other(
405 "No relevant files to summarize after filtering".to_string(),
406 ));
407 }
408
409 let llm_file_count = files.iter().filter(|file| !file.is_binary).count();
410 let batch_count = build_llm_file_batches(&files, counter, config.map_batch_token_budget).len();
411 crate::api::print_llm_progress(|| {
412 format!(
413 "Map-reduce map phase: {batch_count} batch LLM queries for {llm_file_count} files queued \
414 on {map_model_name} (max {MAP_PHASE_CONCURRENCY} parallel)"
415 )
416 });
417
418 map_phase(&files, map_model_name, config, counter).await
419}
420
421#[tracing::instrument(target = "lgit", name = "map_reduce.map_file_batch", skip_all, fields(file_count = files.len(), model = model_name))]
423async fn map_file_batch(
424 files: &[&FileDiff],
425 context_header: &str,
426 model_name: &str,
427 config: &CommitConfig,
428 counter: &TokenCounter,
429 progress_label: &str,
430) -> Result<Vec<FileObservation>> {
431 let rendered_diffs: Vec<String> = files
432 .iter()
433 .map(|file| render_file_diff_for_batch(file, counter))
434 .collect();
435 let prompt_files: Vec<templates::MapFile<'_>> = files
436 .iter()
437 .zip(&rendered_diffs)
438 .map(|(file, diff)| templates::MapFile { path: file.filename.as_str(), diff })
439 .collect();
440 let parts = templates::render_map_prompt("default", &prompt_files, context_header)?;
441 let observation_schema = build_batch_observation_schema();
442 let max_tokens = u32::try_from((files.len() * 250).clamp(1500, 4000))
443 .expect("batch output token cap fits in u32");
444
445 let response = run_oneshot::<BatchObservationResponse>(config, &OneShotSpec {
446 operation: "map-reduce/map",
447 model: model_name,
448 max_tokens,
449 temperature: config.temperature,
450 prompt_family: "map",
451 prompt_variant: "default",
452 system_prompt: &parts.system,
453 user_prompt: &parts.user,
454 tool_name: "create_file_observations",
455 tool_description: "Extract observations from a batch of file changes",
456 schema: &observation_schema,
457 progress_label: Some(progress_label),
458 debug: None,
459 cacheable: true,
460 })
461 .await?;
462
463 Ok(map_batch_response_to_observations(
464 files,
465 &response.output,
466 response.text_content.as_deref(),
467 response.stop_reason.as_deref(),
468 ))
469}
470
471fn render_file_diff_for_batch(file: &FileDiff, counter: &TokenCounter) -> String {
472 let file_tokens = file.token_estimate(counter);
473 if file_tokens > MAX_FILE_TOKENS {
474 let mut file_clone = file.clone();
475 let target_size = MAX_FILE_TOKENS * 4; file_clone.truncate(target_size);
477 eprintln!(
478 " {} truncated {} ({} \u{2192} {} tokens)",
479 crate::style::icons::WARNING,
480 file.filename,
481 file_tokens,
482 file_clone.token_estimate(counter)
483 );
484 return reconstruct_diff(&[file_clone]);
485 }
486
487 reconstruct_single_file_diff(file)
488}
489
490fn reconstruct_single_file_diff(file: &FileDiff) -> String {
491 let mut diff = String::with_capacity(file.size());
492 diff.push_str(&file.header);
493 if !file.content.is_empty() {
494 diff.push('\n');
495 diff.push_str(&file.content);
496 }
497 diff
498}
499
500fn map_batch_response_to_observations(
501 files: &[&FileDiff],
502 response: &BatchObservationResponse,
503 text_content: Option<&str>,
504 stop_reason: Option<&str>,
505) -> Vec<FileObservation> {
506 if response.files.is_empty() && text_content.is_some_and(|text| !text.trim().is_empty()) {
507 crate::style::warn(
508 "Model returned batch observations as text; using fallback observations for every file.",
509 );
510 return files
511 .iter()
512 .map(|file| fallback_file_observation(file))
513 .collect();
514 }
515
516 let stopped_at_max_tokens = stop_reason == Some("max_tokens");
517 let mut used_entries = vec![false; response.files.len()];
518 files
519 .iter()
520 .map(|file| {
521 let Some(entry_idx) =
522 find_observation_entry(file.filename.as_str(), &response.files, &used_entries, files)
523 else {
524 return fallback_file_observation(file);
525 };
526
527 used_entries[entry_idx] = true;
528 let entry = &response.files[entry_idx];
529 let observations = if entry.observations.is_empty() && stopped_at_max_tokens {
530 vec![fallback_observation_text(&file.filename)]
531 } else {
532 entry.observations.clone()
533 };
534
535 FileObservation { file: file.filename.clone(), observations, additions: 0, deletions: 0 }
536 })
537 .collect()
538}
539
540fn find_observation_entry(
541 filename: &str,
542 entries: &[FileObservationEntry],
543 used_entries: &[bool],
544 batch_files: &[&FileDiff],
545) -> Option<usize> {
546 find_entry_by(entries, used_entries, |entry| entry.path == filename)
547 .or_else(|| {
548 let filename_basename = path_basename(filename);
549 let basename_is_unique = batch_files
550 .iter()
551 .filter(|file| path_basename(file.filename.as_str()) == filename_basename)
552 .count()
553 == 1;
554 basename_is_unique
555 .then(|| {
556 find_entry_by(entries, used_entries, |entry| {
557 path_basename(&entry.path) == filename_basename
558 })
559 })
560 .flatten()
561 })
562 .or_else(|| {
563 find_entry_by(entries, used_entries, |entry| path_suffix_matches(&entry.path, filename))
564 })
565}
566
567fn find_entry_by<F>(
568 entries: &[FileObservationEntry],
569 used_entries: &[bool],
570 mut matches: F,
571) -> Option<usize>
572where
573 F: FnMut(&FileObservationEntry) -> bool,
574{
575 entries
576 .iter()
577 .enumerate()
578 .find_map(|(idx, entry)| (!used_entries[idx] && matches(entry)).then_some(idx))
579}
580
581fn path_basename(path: &str) -> &str {
582 Path::new(path)
583 .file_name()
584 .and_then(|name| name.to_str())
585 .unwrap_or(path)
586}
587
588fn path_suffix_matches(left: &str, right: &str) -> bool {
589 path_has_suffix(left, right) || path_has_suffix(right, left)
590}
591
592fn path_has_suffix(path: &str, suffix: &str) -> bool {
593 if path == suffix {
594 return true;
595 }
596
597 path
598 .strip_suffix(suffix)
599 .is_some_and(|prefix| prefix.ends_with('/') || prefix.ends_with('\\'))
600}
601
602fn fallback_file_observation(file: &FileDiff) -> FileObservation {
603 FileObservation {
604 file: file.filename.clone(),
605 observations: vec![fallback_observation_text(&file.filename)],
606 additions: 0,
607 deletions: 0,
608 }
609}
610
611fn fallback_observation_text(filename: &str) -> String {
612 let fallback_target = path_basename(filename);
613 format!("Updated {fallback_target}.")
614}
615
616#[tracing::instrument(target = "lgit", name = "map_reduce.reduce_phase", skip_all, fields(observation_count = observations.len(), model = model_name))]
618pub async fn reduce_phase(
619 observations: &[FileObservation],
620 stat: &str,
621 scope_candidates: &str,
622 model_name: &str,
623 config: &CommitConfig,
624) -> Result<ConventionalAnalysis> {
625 let type_enum: Vec<&str> = config.types.keys().map(|s| s.as_str()).collect();
626 let observations_json =
627 serde_json::to_string_pretty(observations).unwrap_or_else(|_| "[]".to_string());
628
629 let types_description = crate::api::format_types_description(config);
630 let parts = templates::render_reduce_prompt(
631 "default",
632 &observations_json,
633 stat,
634 scope_candidates,
635 Some(&types_description),
636 )?;
637
638 let analysis_schema = build_analysis_schema(&type_enum, config);
639 let response = run_oneshot::<ConventionalAnalysis>(config, &OneShotSpec {
640 operation: "map-reduce/reduce",
641 model: model_name,
642 max_tokens: 1500,
643 temperature: config.temperature,
644 prompt_family: "reduce",
645 prompt_variant: "default",
646 system_prompt: &parts.system,
647 user_prompt: &parts.user,
648 tool_name: "create_conventional_analysis",
649 tool_description: "Analyze changes and classify as conventional commit with type, scope, \
650 summary, details, and metadata",
651 schema: &analysis_schema,
652 progress_label: Some("reduce file observations"),
653 debug: None,
654 cacheable: true,
655 })
656 .await?;
657
658 Ok(response.output)
659}
660
661#[tracing::instrument(target = "lgit", name = "map_reduce.run", skip_all, fields(diff_bytes = diff.len(), model = model_name))]
663pub async fn run_map_reduce(
664 diff: &str,
665 stat: &str,
666 scope_candidates: &str,
667 model_name: &str,
668 config: &CommitConfig,
669 counter: &TokenCounter,
670) -> Result<ConventionalAnalysis> {
671 let map_model_name = map_phase_model(config);
672 let observations = observe_diff_files(diff, map_model_name, config, counter).await?;
673 let file_count = observations.len();
674 crate::api::print_llm_progress(|| {
675 format!("Map-reduce reduce phase: synthesizing {file_count} file observations")
676 });
677
678 reduce_phase(&observations, stat, scope_candidates, model_name, config).await
679}
680
681#[derive(Debug, Deserialize, Serialize)]
682struct BatchObservationResponse {
683 #[serde(default)]
684 files: Vec<FileObservationEntry>,
685}
686
687#[derive(Debug, Deserialize, Serialize)]
688struct FileObservationEntry {
689 path: String,
690 #[serde(default, deserialize_with = "deserialize_observations")]
691 observations: Vec<String>,
692}
693
694fn deserialize_observations<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
697where
698 D: serde::Deserializer<'de>,
699{
700 use std::fmt;
701
702 use serde::de::{self, Visitor};
703
704 struct ObservationsVisitor;
705
706 impl<'de> Visitor<'de> for ObservationsVisitor {
707 type Value = Vec<String>;
708
709 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
710 formatter.write_str("an array of strings, a JSON array string, or a bullet-point string")
711 }
712
713 fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
714 where
715 A: de::SeqAccess<'de>,
716 {
717 let mut vec = Vec::new();
718 while let Some(item) = seq.next_element::<String>()? {
719 vec.push(item);
720 }
721 Ok(vec)
722 }
723
724 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
725 where
726 E: de::Error,
727 {
728 Ok(parse_string_to_observations(s))
729 }
730 }
731
732 deserializer.deserialize_any(ObservationsVisitor)
733}
734
735fn parse_string_to_observations(s: &str) -> Vec<String> {
738 let trimmed = s.trim();
739 if trimmed.is_empty() {
740 return Vec::new();
741 }
742
743 if trimmed.starts_with('[')
745 && let Ok(arr) = serde_json::from_str::<Vec<String>>(trimmed)
746 {
747 return arr;
748 }
749
750 trimmed
752 .lines()
753 .map(str::trim)
754 .filter(|line| !line.is_empty())
755 .map(|line| {
756 line
757 .strip_prefix("- ")
758 .or_else(|| line.strip_prefix("* "))
759 .or_else(|| line.strip_prefix("• "))
760 .unwrap_or(line)
761 .trim()
762 .to_string()
763 })
764 .filter(|line| !line.is_empty())
765 .collect()
766}
767
768fn build_batch_observation_schema() -> serde_json::Value {
769 strict_json_schema(
770 serde_json::json!({
771 "files": {
772 "type": "array",
773 "description": "Per-file observations for every file in the map batch.",
774 "items": {
775 "type": "object",
776 "properties": {
777 "path": {
778 "type": "string",
779 "description": "The exact input file path this observation set describes."
780 },
781 "observations": {
782 "type": "array",
783 "description": "Factual observations about what changed in this file.",
784 "items": {
785 "type": "string"
786 }
787 }
788 },
789 "required": ["path", "observations"],
790 "additionalProperties": false
791 }
792 }
793 }),
794 &["files"],
795 )
796}
797
798fn build_analysis_schema(type_enum: &[&str], config: &CommitConfig) -> serde_json::Value {
799 strict_json_schema(
800 serde_json::json!({
801 "type": {
802 "type": "string",
803 "enum": type_enum,
804 "description": "Commit type based on combined changes"
805 },
806 "scope": {
807 "type": "string",
808 "description": "Optional scope (module/component). Omit if unclear or multi-component."
809 },
810 "summary": {
811 "type": "string",
812 "description": format!(
813 "Concise past-tense commit summary without type/scope prefix or trailing period; target {} chars, hard limit {}.",
814 config.summary_guideline,
815 config.summary_hard_limit
816 ),
817 "maxLength": config.summary_hard_limit
818 },
819 "details": {
820 "type": "array",
821 "description": "Array of 0-6 detail items with changelog metadata.",
822 "items": {
823 "type": "object",
824 "properties": {
825 "text": {
826 "type": "string",
827 "description": "Detail about change, starting with past-tense verb, ending with period"
828 },
829 "changelog_category": {
830 "type": "string",
831 "enum": ["Added", "Changed", "Fixed", "Deprecated", "Removed", "Security"],
832 "description": "Changelog category if user-visible. Omit for internal changes."
833 },
834 "user_visible": {
835 "type": "boolean",
836 "description": "True if this change affects users/API and should appear in changelog"
837 }
838 },
839 "required": ["text", "user_visible"]
840 }
841 },
842 "issue_refs": {
843 "type": "array",
844 "description": "Issue numbers from context (e.g., ['#123', '#456']). Empty if none.",
845 "items": {
846 "type": "string"
847 }
848 }
849 }),
850 &["type", "details", "issue_refs"],
851 )
852}
853
854#[cfg(test)]
855mod tests {
856 use super::*;
857 use crate::tokens::TokenCounter;
858
859 fn test_counter() -> TokenCounter {
860 TokenCounter::new("http://localhost:4000", None, "claude-sonnet-4.5")
861 }
862
863 fn file_with_tokens(filename: &str, token_estimate: usize) -> FileDiff {
864 FileDiff {
865 filename: filename.to_string(),
866 header: String::new(),
867 content: "x".repeat(token_estimate * 4),
868 additions: 0,
869 deletions: 0,
870 is_binary: false,
871 }
872 }
873
874 #[test]
875 fn test_map_phase_model_uses_summary_model() {
876 let config = CommitConfig {
877 summary_model: "claude-haiku-4-5".to_string(),
878 analysis_model: "claude-opus-4.1".to_string(),
879 ..Default::default()
880 };
881
882 assert_eq!(map_phase_model(&config), "claude-haiku-4-5");
883 assert_eq!(MAP_PHASE_CONCURRENCY, 16);
884 }
885
886 #[test]
887 fn test_build_file_batches_single_batch_when_under_budget() {
888 let counter = test_counter();
889 let files = vec![
890 file_with_tokens("a.rs", 4),
891 file_with_tokens("b.rs", 4),
892 file_with_tokens("c.rs", 1),
893 ];
894
895 assert_eq!(build_file_batches(&files, &counter, 10), vec![vec![0, 1, 2]]);
896 }
897
898 #[test]
899 fn test_build_file_batches_splits_when_budget_exceeded() {
900 let counter = test_counter();
901 let files = vec![
902 file_with_tokens("a.rs", 4),
903 file_with_tokens("b.rs", 4),
904 file_with_tokens("c.rs", 4),
905 ];
906
907 assert_eq!(build_file_batches(&files, &counter, 10), vec![vec![0, 1], vec![2]]);
908 }
909
910 #[test]
911 fn test_build_file_batches_preserves_order_and_every_file_once() {
912 let counter = test_counter();
913 let files = vec![
914 file_with_tokens("a.rs", 3),
915 file_with_tokens("b.rs", 8),
916 file_with_tokens("c.rs", 2),
917 file_with_tokens("d.rs", 9),
918 file_with_tokens("e.rs", 1),
919 ];
920
921 let batches = build_file_batches(&files, &counter, 10);
922 let flattened: Vec<usize> = batches.into_iter().flatten().collect();
923 assert_eq!(flattened, vec![0, 1, 2, 3, 4]);
924 }
925
926 #[test]
927 fn test_build_file_batches_isolates_oversized_files() {
928 let counter = test_counter();
929 let files = vec![
930 file_with_tokens("a.rs", 2),
931 file_with_tokens("b.rs", 2),
932 file_with_tokens("huge.rs", 12),
933 file_with_tokens("c.rs", 2),
934 ];
935
936 assert_eq!(build_file_batches(&files, &counter, 10), vec![vec![0, 1], vec![2], vec![3]]);
937 }
938
939 #[test]
940 fn test_batch_response_mapping_matches_paths_and_falls_back_for_omissions() {
941 let exact = file_with_tokens("src/lib.rs", 1);
942 let basename = file_with_tokens("src/main.rs", 1);
943 let omitted = file_with_tokens("crates/core/mod.rs", 1);
944 let files = [&exact, &basename, &omitted];
945 let response = BatchObservationResponse {
946 files: vec![
947 FileObservationEntry {
948 path: "src/lib.rs".to_string(),
949 observations: vec!["updated library entrypoint".to_string()],
950 },
951 FileObservationEntry {
952 path: "main.rs".to_string(),
953 observations: vec!["changed CLI wiring".to_string()],
954 },
955 ],
956 };
957
958 let result = map_batch_response_to_observations(&files, &response, None, None);
959
960 assert_eq!(result[0].file, "src/lib.rs");
961 assert_eq!(result[0].observations, vec!["updated library entrypoint".to_string()]);
962 assert_eq!(result[1].file, "src/main.rs");
963 assert_eq!(result[1].observations, vec!["changed CLI wiring".to_string()]);
964 assert_eq!(result[2].file, "crates/core/mod.rs");
965 assert_eq!(result[2].observations, vec!["Updated mod.rs.".to_string()]);
966 }
967
968 #[test]
969 fn test_batch_response_mapping_falls_back_for_text_only_response() {
970 let first = file_with_tokens("src/lib.rs", 1);
971 let second = file_with_tokens("src/main.rs", 1);
972 let files = [&first, &second];
973 let response = BatchObservationResponse { files: Vec::new() };
974
975 let result = map_batch_response_to_observations(
976 &files,
977 &response,
978 Some("- unstructured observation"),
979 None,
980 );
981
982 assert_eq!(result[0].observations, vec!["Updated lib.rs.".to_string()]);
983 assert_eq!(result[1].observations, vec!["Updated main.rs.".to_string()]);
984 }
985
986 #[test]
987 fn test_should_use_map_reduce_disabled() {
988 let config = CommitConfig { map_reduce_enabled: false, ..Default::default() };
989 let counter = test_counter();
990 let diff = r"diff --git a/a.rs b/a.rs
992@@ -0,0 +1 @@
993+a
994diff --git a/b.rs b/b.rs
995@@ -0,0 +1 @@
996+b
997diff --git a/c.rs b/c.rs
998@@ -0,0 +1 @@
999+c
1000diff --git a/d.rs b/d.rs
1001@@ -0,0 +1 @@
1002+d";
1003 assert!(!should_use_map_reduce(diff, &config, &counter));
1004 }
1005
1006 #[test]
1007 fn test_should_use_map_reduce_few_files() {
1008 let config = CommitConfig::default();
1009 let counter = test_counter();
1010 let diff = r"diff --git a/a.rs b/a.rs
1012@@ -0,0 +1 @@
1013+a
1014diff --git a/b.rs b/b.rs
1015@@ -0,0 +1 @@
1016+b";
1017 assert!(!should_use_map_reduce(diff, &config, &counter));
1018 }
1019
1020 #[test]
1021 fn test_should_use_map_reduce_many_tiny_files_below_threshold() {
1022 let config = CommitConfig { map_reduce_threshold: 1_000, ..Default::default() };
1023 let counter = test_counter();
1024 let diff = r"diff --git a/a.rs b/a.rs
1025@@ -0,0 +1 @@
1026+a
1027diff --git a/b.rs b/b.rs
1028@@ -0,0 +1 @@
1029+b
1030diff --git a/c.rs b/c.rs
1031@@ -0,0 +1 @@
1032+c
1033diff --git a/d.rs b/d.rs
1034@@ -0,0 +1 @@
1035+d
1036diff --git a/e.rs b/e.rs
1037@@ -0,0 +1 @@
1038+e";
1039 assert!(!should_use_map_reduce(diff, &config, &counter));
1040 }
1041
1042 #[test]
1043 fn test_should_use_map_reduce_large_total_diff() {
1044 let config = CommitConfig { map_reduce_threshold: 20, ..Default::default() };
1045 let counter = test_counter();
1046 let payload = "a".repeat(200);
1047 let diff = format!("diff --git a/a.rs b/a.rs\n@@ -0,0 +1 @@\n+{payload}");
1048
1049 assert!(should_use_map_reduce(&diff, &config, &counter));
1050 }
1051
1052 #[test]
1053 fn test_should_use_map_reduce_single_oversized_file() {
1054 let config = CommitConfig { map_reduce_threshold: usize::MAX, ..Default::default() };
1055 let counter = test_counter();
1056 let payload = "a".repeat((MAX_FILE_TOKENS + 1) * 4);
1057 let diff = format!("diff --git a/a.rs b/a.rs\n@@ -0,0 +1 @@\n+{payload}");
1058
1059 assert!(should_use_map_reduce(&diff, &config, &counter));
1060 }
1061
1062 #[test]
1063 fn test_generate_context_header_empty() {
1064 let files = vec![FileDiff {
1065 filename: "only.rs".to_string(),
1066 header: String::new(),
1067 content: String::new(),
1068 additions: 10,
1069 deletions: 5,
1070 is_binary: false,
1071 }];
1072 let context_headers = ContextHeaders::new(&files);
1073 let header = context_headers.header_for_files(&["only.rs"]);
1074 assert!(header.is_empty());
1075 }
1076
1077 #[test]
1078 fn test_generate_context_header_multiple() {
1079 let files = vec![
1080 FileDiff {
1081 filename: "src/main.rs".to_string(),
1082 header: String::new(),
1083 content: "fn main() {}".to_string(),
1084 additions: 10,
1085 deletions: 5,
1086 is_binary: false,
1087 },
1088 FileDiff {
1089 filename: "src/lib.rs".to_string(),
1090 header: String::new(),
1091 content: "mod test;".to_string(),
1092 additions: 3,
1093 deletions: 1,
1094 is_binary: false,
1095 },
1096 FileDiff {
1097 filename: "tests/test.rs".to_string(),
1098 header: String::new(),
1099 content: "#[test]".to_string(),
1100 additions: 20,
1101 deletions: 0,
1102 is_binary: false,
1103 },
1104 ];
1105
1106 let context_headers = ContextHeaders::new(&files);
1107 let header = context_headers.header_for_files(&["src/main.rs"]);
1108 assert!(header.contains("OTHER FILES IN THIS CHANGE:"));
1109 assert!(header.contains("src/lib.rs"));
1110 assert!(header.contains("tests/test.rs"));
1111 assert!(!header.contains("src/main.rs")); }
1113
1114 #[test]
1115 fn test_infer_file_description() {
1116 assert_eq!(infer_file_description("src/test_utils.rs", ""), "test file");
1117 assert_eq!(infer_file_description("README.md", ""), "documentation");
1118 assert_eq!(infer_file_description("prompts/analysis/default.md", ""), "prompt template");
1119 assert_eq!(infer_file_description("system/analysis/default.md", ""), "prompt template");
1120 assert_eq!(infer_file_description("config.toml", ""), "configuration");
1121 assert_eq!(infer_file_description("src/error.rs", ""), "error definitions");
1122 assert_eq!(infer_file_description("src/types.rs", ""), "type definitions");
1123 assert_eq!(infer_file_description("src/mod.rs", ""), "module exports");
1124 assert_eq!(infer_file_description("src/main.rs", ""), "entry point");
1125 assert_eq!(infer_file_description("src/api.rs", "fn call()"), "implementation");
1126 assert_eq!(infer_file_description("src/models.rs", "struct Foo"), "type definitions");
1127 assert_eq!(infer_file_description("src/unknown.xyz", ""), "source code");
1128 }
1129
1130 #[test]
1131 fn test_parse_string_to_observations_json_array() {
1132 let input = r#"["item one", "item two", "item three"]"#;
1133 let result = parse_string_to_observations(input);
1134 assert_eq!(result, vec!["item one", "item two", "item three"]);
1135 }
1136
1137 #[test]
1138 fn test_parse_string_to_observations_bullet_points() {
1139 let input = "- added new function\n- fixed bug in parser\n- updated tests";
1140 let result = parse_string_to_observations(input);
1141 assert_eq!(result, vec!["added new function", "fixed bug in parser", "updated tests"]);
1142 }
1143
1144 #[test]
1145 fn test_parse_string_to_observations_asterisk_bullets() {
1146 let input = "* first change\n* second change";
1147 let result = parse_string_to_observations(input);
1148 assert_eq!(result, vec!["first change", "second change"]);
1149 }
1150
1151 #[test]
1152 fn test_parse_string_to_observations_empty() {
1153 assert!(parse_string_to_observations("").is_empty());
1154 assert!(parse_string_to_observations(" ").is_empty());
1155 }
1156
1157 #[test]
1158 fn test_deserialize_observations_array() {
1159 let json = r#"{"path": "src/lib.rs", "observations": ["a", "b", "c"]}"#;
1160 let result: FileObservationEntry =
1161 serde_json::from_str(json).expect("valid observation array JSON should deserialize");
1162 assert_eq!(result.path, "src/lib.rs");
1163 assert_eq!(result.observations, vec!["a", "b", "c"]);
1164 }
1165
1166 #[test]
1167 fn test_deserialize_observations_stringified_array() {
1168 let json = r#"{"path": "src/lib.rs", "observations": "[\"a\", \"b\", \"c\"]"}"#;
1169 let result: FileObservationEntry = serde_json::from_str(json)
1170 .expect("valid stringified observation array JSON should deserialize");
1171 assert_eq!(result.path, "src/lib.rs");
1172 assert_eq!(result.observations, vec!["a", "b", "c"]);
1173 }
1174
1175 #[test]
1176 fn test_deserialize_observations_bullet_string() {
1177 let json = r#"{"path": "src/lib.rs", "observations": "- updated function\n- fixed bug"}"#;
1178 let result: FileObservationEntry =
1179 serde_json::from_str(json).expect("valid bullet observation JSON should deserialize");
1180 assert_eq!(result.path, "src/lib.rs");
1181 assert_eq!(result.observations, vec!["updated function", "fixed bug"]);
1182 }
1183}