Skip to main content

ralph_workflow/phases/commit/
diff_truncation.rs

1/// Maximum safe prompt size in bytes before pre-truncation.
2pub const MAX_SAFE_PROMPT_SIZE: u64 = 200_000;
3
4use itertools::Itertools;
5
6/// Maximum prompt size for GLM-like agents (GLM, Zhipu, Qwen, `DeepSeek`).
7pub const GLM_MAX_PROMPT_SIZE: u64 = 100_000;
8
9/// Maximum prompt size for Claude-based agents.
10pub const CLAUDE_MAX_PROMPT_SIZE: u64 = 300_000;
11
12/// Get the maximum safe prompt size for a specific agent.
13#[must_use]
14pub fn model_budget_bytes_for_agent_name(commit_agent: &str) -> u64 {
15    let agent_lower = commit_agent.to_lowercase();
16
17    if agent_lower.contains("glm")
18        || agent_lower.contains("zhipuai")
19        || agent_lower.contains("zai")
20        || agent_lower.contains("qwen")
21        || agent_lower.contains("deepseek")
22    {
23        GLM_MAX_PROMPT_SIZE
24    } else if agent_lower.contains("claude")
25        || agent_lower.contains("ccs")
26        || agent_lower.contains("anthropic")
27    {
28        CLAUDE_MAX_PROMPT_SIZE
29    } else {
30        MAX_SAFE_PROMPT_SIZE
31    }
32}
33
34#[must_use]
35pub fn effective_model_budget_bytes(agent_names: &[String]) -> u64 {
36    agent_names
37        .iter()
38        .map(|name| model_budget_bytes_for_agent_name(name))
39        .min()
40        .unwrap_or(MAX_SAFE_PROMPT_SIZE)
41}
42
43/// Truncate diff if it's too large for agents with small context windows.
44pub fn truncate_diff_if_large(diff: &str, max_size: usize) -> String {
45    if diff.len() <= max_size {
46        return diff.to_string();
47    }
48
49    #[derive(Default)]
50    struct Accumulator {
51        files: Vec<DiffFile>,
52        current_file: Option<DiffFile>,
53    }
54
55    impl Accumulator {
56        fn process_line(self, line: &str) -> Self {
57            if line.starts_with("diff --git ") {
58                let priority = line
59                    .split_once(" b/")
60                    .map(|(_, after)| prioritize_file_path(after))
61                    .unwrap_or_default();
62                let new_file = DiffFile {
63                    priority,
64                    lines: vec![line.to_string()],
65                };
66                // Finalize the previous file and start a new one.
67                let finalized_files: Vec<DiffFile> = self
68                    .files
69                    .into_iter()
70                    .chain(
71                        self.current_file
72                            .into_iter()
73                            .filter(|f| !f.lines.is_empty()),
74                    )
75                    .collect();
76                Self {
77                    files: finalized_files,
78                    current_file: Some(new_file),
79                }
80            } else {
81                match self.current_file {
82                    Some(current) => Self {
83                        files: self.files,
84                        current_file: Some(DiffFile {
85                            priority: current.priority,
86                            lines: current
87                                .lines
88                                .into_iter()
89                                .chain(std::iter::once(line.to_string()))
90                                .collect(),
91                        }),
92                    },
93                    None => self,
94                }
95            }
96        }
97    }
98
99    let final_acc = diff
100        .lines()
101        .fold(Accumulator::default(), |acc, line| acc.process_line(line));
102
103    let sorted_files: Vec<_> = {
104        let Accumulator {
105            files,
106            current_file,
107        } = final_acc;
108        files
109            .into_iter()
110            .chain(current_file.into_iter().filter(|f| !f.lines.is_empty()))
111            .sorted_by_key(|f| -f.priority)
112            .collect()
113    };
114
115    let total_files = sorted_files.len();
116
117    let (result, files_included) = {
118        let file_data: Vec<_> = sorted_files
119            .iter()
120            .map(|file| {
121                let lines_text: String = file.lines.iter().map(|l| format!("{l}\n")).collect();
122                (lines_text.clone(), lines_text.len())
123            })
124            .collect();
125
126        let total_chunks_len: usize = file_data.iter().map(|(_, len)| *len).sum();
127        if total_chunks_len <= max_size {
128            let result = file_data
129                .iter()
130                .map(|(t, _)| t.clone())
131                .collect::<Vec<_>>()
132                .join("");
133            return result;
134        }
135
136        let cumulative_sizes: Vec<_> = file_data
137            .iter()
138            .scan(0usize, |acc, (_, len)| {
139                *acc += len;
140                Some(*acc)
141            })
142            .collect();
143
144        let last_fitting_index = cumulative_sizes
145            .iter()
146            .position(|&size| size > max_size)
147            .unwrap_or(file_data.len());
148
149        if last_fitting_index == 0 {
150            let truncated = truncate_lines_to_fit(&sorted_files[0].lines, max_size);
151            let truncated_text: String = truncated.iter().map(|l| format!("{l}\n")).collect();
152            (truncated_text, 1)
153        } else {
154            let included: Vec<_> = file_data
155                .iter()
156                .take(last_fitting_index)
157                .map(|(t, _)| t.clone())
158                .collect();
159            (included.join(""), included.len())
160        }
161    };
162
163    let was_truncated = result.len() < diff.len();
164
165    if files_included < total_files || was_truncated {
166        let summary = format!("\n[Truncated: {files_included} of {total_files} files shown]\n");
167        if summary.len() <= max_size {
168            if result.len() + summary.len() > max_size {
169                let target_bytes = max_size.saturating_sub(summary.len());
170                if target_bytes < result.len() {
171                    let cut = result
172                        .char_indices()
173                        .take_while(|(idx, _)| *idx <= target_bytes)
174                        .last()
175                        .map(|(idx, _)| idx)
176                        .unwrap_or(0);
177                    return format!("{}{}", &result[..cut], summary);
178                }
179            }
180            return format!("{result}{summary}");
181        }
182    }
183
184    result
185}
186
187#[must_use]
188pub fn truncate_diff_to_model_budget(diff: &str, max_size_bytes: u64) -> (String, bool) {
189    let max_size = usize::try_from(max_size_bytes).unwrap_or(usize::MAX);
190    if diff.len() <= max_size {
191        (diff.to_string(), false)
192    } else {
193        (truncate_diff_if_large(diff, max_size), true)
194    }
195}
196
197#[derive(Default)]
198struct DiffFile {
199    priority: i32,
200    lines: Vec<String>,
201}
202
203fn prioritize_file_path(path: &str) -> i32 {
204    let normalized = path.replace('\\', "/");
205    let parts: Vec<&str> = normalized.split('/').filter(|p| !p.is_empty()).collect();
206
207    if parts.contains(&"src") {
208        100
209    } else if parts.contains(&"tests") {
210        50
211    } else if std::path::Path::new(&normalized)
212        .extension()
213        .is_some_and(|ext| ext.eq_ignore_ascii_case("md") || ext.eq_ignore_ascii_case("txt"))
214    {
215        10
216    } else {
217        0
218    }
219}
220
221fn truncate_to_utf8_boundary(s: &str, max_bytes: usize) -> String {
222    if s.len() <= max_bytes {
223        return s.to_string();
224    }
225    let cut = s
226        .char_indices()
227        .take_while(|(idx, _)| *idx <= max_bytes)
228        .last()
229        .map(|(idx, _)| idx)
230        .unwrap_or(0);
231    s[..cut].to_string()
232}
233
234pub fn truncate_lines_to_fit(lines: &[String], max_size: usize) -> Vec<String> {
235    let suffix = " [truncated...]";
236    let suffix_len = suffix.len();
237
238    if lines.is_empty() {
239        return Vec::new();
240    }
241
242    let line_sizes: Vec<usize> = lines.iter().map(|l| l.len() + 1).collect();
243    let total_size: usize = line_sizes.iter().sum();
244
245    if total_size <= max_size {
246        return lines.to_vec();
247    }
248
249    let available_for_lines = max_size.saturating_sub(suffix_len);
250
251    let result: Vec<_> = lines
252        .iter()
253        .scan(0usize, |size, line| {
254            let line_size = line.len() + 1;
255            if *size + line_size <= available_for_lines {
256                *size += line_size;
257                Some(line.clone())
258            } else {
259                None
260            }
261        })
262        .collect();
263
264    if result.is_empty() {
265        return result;
266    }
267
268    let current_size: usize = result.iter().map(|l| l.len() + 1).sum();
269
270    let adjusted: Vec<String> = if current_size + suffix_len > max_size {
271        let target_bytes = max_size.saturating_sub(suffix_len);
272        // Scan in reverse to collect lines that fit within target_bytes.
273        // The scan state is (accumulated_so_far, still_accepting).
274        // When a line doesn't fit and nothing has been accumulated yet,
275        // we truncate that line to the UTF-8 boundary and stop.
276        result
277            .iter()
278            .rev()
279            .scan((0usize, true), |(accumulated, accepting), line| {
280                if !*accepting {
281                    return None;
282                }
283                let line_size = line.len() + 1;
284                if *accumulated + line_size <= target_bytes {
285                    *accumulated += line_size;
286                    Some(Some(line.clone()))
287                } else if *accumulated == 0 {
288                    // First line already too big: truncate it.
289                    *accepting = false;
290                    let max_for_line = target_bytes.saturating_sub(1);
291                    let new_line = truncate_to_utf8_boundary(line, max_for_line);
292                    if new_line.is_empty() {
293                        Some(None)
294                    } else {
295                        Some(Some(new_line))
296                    }
297                } else {
298                    *accepting = false;
299                    None
300                }
301            })
302            .flatten()
303            .collect::<Vec<_>>()
304            .into_iter()
305            .rev()
306            .collect()
307    } else {
308        result
309    };
310
311    if adjusted.is_empty() {
312        adjusted
313    } else {
314        let last_idx = adjusted.len() - 1;
315        let (init, last) = adjusted.split_at(last_idx);
316        let last_formatted = format!("{}{}", last[0], suffix);
317        init.iter()
318            .map(String::clone)
319            .chain(std::iter::once(last_formatted))
320            .collect()
321    }
322}
323
324#[cfg(test)]
325mod diff_truncation_tests {
326    use super::*;
327
328    #[test]
329    fn prioritize_file_path_handles_crate_prefixed_paths() {
330        assert_eq!(prioritize_file_path("ralph-workflow/src/lib.rs"), 100);
331        assert_eq!(
332            prioritize_file_path("ralph-workflow/tests/integration.rs"),
333            50
334        );
335        assert_eq!(prioritize_file_path("README.md"), 10);
336    }
337
338    #[test]
339    fn truncate_diff_to_model_budget_never_exceeds_max_size() {
340        let files_included = 1;
341        let total_files = 2;
342        let summary = format!("\n[Truncated: {files_included} of {total_files} files shown]\n");
343
344        let max_size = 1_000usize;
345
346        let file1_header = "diff --git a/src/a.rs b/src/a.rs";
347        let desired_file1_size = max_size - summary.len() + 1;
348        let filler_line_len = desired_file1_size.saturating_sub(file1_header.len() + 2);
349        let file1 = format!(
350            "{file1_header}\n+{}\n",
351            "x".repeat(filler_line_len.saturating_sub(1))
352        );
353
354        let file2 = "diff --git a/tests/b.rs b/tests/b.rs\n+small\n";
355        let diff = format!("{file1}{file2}");
356
357        let (truncated, was_truncated) = truncate_diff_to_model_budget(&diff, max_size as u64);
358        assert!(
359            was_truncated,
360            "expected truncation when diff exceeds max size"
361        );
362        assert!(
363            truncated.len() <= max_size,
364            "truncated diff must not exceed max_size (got {} > {})",
365            truncated.len(),
366            max_size
367        );
368    }
369
370    #[test]
371    fn truncate_lines_to_fit_reserves_space_for_truncation_suffix() {
372        let max_size = 20usize;
373        let lines = vec!["x".repeat(max_size - 1)];
374
375        let truncated = truncate_lines_to_fit(&lines, max_size);
376
377        let total_size: usize = truncated.iter().map(|l| l.len() + 1).sum();
378        assert!(
379            total_size <= max_size,
380            "truncate_lines_to_fit must not exceed max_size after adding suffix (got {total_size} > {max_size})"
381        );
382    }
383
384    #[test]
385    fn truncate_diff_invariant_never_exceeds_max_size_edge_cases() {
386        let summary_len = "\n[Truncated: 1 of 2 files shown]\n".len();
387
388        [
389            10,
390            summary_len - 1,
391            summary_len,
392            summary_len + 1,
393            summary_len + 10,
394            100,
395            1000,
396        ]
397        .into_iter()
398        .for_each(|max_size| {
399            let file1 = format!(
400                "diff --git a/src/a.rs b/src/a.rs\n+{}\n",
401                "x".repeat(max_size)
402            );
403            let file2 = "diff --git a/tests/b.rs b/tests/b.rs\n+extra\n";
404            let diff = format!("{file1}{file2}");
405
406            let (truncated, _) = truncate_diff_to_model_budget(&diff, max_size as u64);
407            assert!(
408                truncated.len() <= max_size,
409                "truncated diff exceeded max_size {} (got {}): {:?}",
410                max_size,
411                truncated.len(),
412                &truncated[..truncated.len().min(100)]
413            );
414        });
415    }
416
417    #[test]
418    fn truncate_diff_boundary_content_sizes() {
419        [50usize, 100, 200, 500].into_iter().for_each(|max_size| {
420            let header = "diff --git a/a b/a\n+";
421            let exact_diff = format!(
422                "{}{}",
423                header,
424                "x".repeat(max_size.saturating_sub(header.len()))
425            );
426            if exact_diff.len() == max_size {
427                let (result, was_truncated) =
428                    truncate_diff_to_model_budget(&exact_diff, max_size as u64);
429                assert!(!was_truncated, "exact size should not trigger truncation");
430                assert_eq!(result.len(), max_size);
431            }
432
433            let over_diff = format!("{}{}", header, "x".repeat(max_size + 1 - header.len()));
434            let (result, was_truncated) =
435                truncate_diff_to_model_budget(&over_diff, max_size as u64);
436            assert!(was_truncated, "over size should trigger truncation");
437            assert!(
438                result.len() <= max_size,
439                "truncated result {} should not exceed max_size {}",
440                result.len(),
441                max_size
442            );
443        });
444    }
445
446    #[test]
447    fn truncate_single_large_file_stays_within_budget() {
448        let max_size = 100usize;
449
450        let large_file = format!(
451            "diff --git a/src/big.rs b/src/big.rs\n+{}\n",
452            "x".repeat(max_size * 3)
453        );
454
455        let (truncated, was_truncated) =
456            truncate_diff_to_model_budget(&large_file, max_size as u64);
457        assert!(was_truncated, "large file should be truncated");
458        assert!(
459            truncated.len() <= max_size,
460            "single large file truncation {} exceeded max_size {}",
461            truncated.len(),
462            max_size
463        );
464    }
465
466    #[test]
467    fn truncate_diff_handles_unicode_boundaries() {
468        let max_size = 50usize;
469
470        let emoji_line = "🎉".repeat(20);
471        let diff = format!("diff --git a/a b/a\n+{emoji_line}\n");
472
473        let (truncated, was_truncated) = truncate_diff_to_model_budget(&diff, max_size as u64);
474        assert!(was_truncated, "unicode diff should be truncated");
475        assert!(
476            truncated.len() <= max_size,
477            "unicode truncation {} exceeded max_size {}",
478            truncated.len(),
479            max_size
480        );
481        assert!(
482            std::str::from_utf8(truncated.as_bytes()).is_ok(),
483            "truncated output should be valid UTF-8"
484        );
485    }
486
487    #[test]
488    fn truncate_diff_preserves_header_only_file() {
489        let diff = "diff --git a/src/existing.rs b/src/existing.rs\n+line\n".to_string();
490        let header_only = "diff --git a/src/header_only.rs b/src/header_only.rs\n";
491        let combined = format!("{diff}{header_only}");
492
493        let (result, was_truncated) = truncate_diff_to_model_budget(&combined, 512);
494        assert!(!was_truncated, "diff should fit within budget");
495        assert!(
496            result.contains(header_only),
497            "header-only file should still be present"
498        );
499    }
500
501    #[test]
502    fn truncate_empty_diff() {
503        let (result, was_truncated) = truncate_diff_to_model_budget("", 100);
504        assert!(!was_truncated, "empty diff should not be truncated");
505        assert_eq!(result, "");
506    }
507
508    #[test]
509    fn truncate_multiple_small_files_prefers_high_priority() {
510        let max_size = 200usize;
511
512        let src_file = "diff --git a/src/main.rs b/src/main.rs\n+high priority\n";
513        let test_file = "diff --git a/tests/test.rs b/tests/test.rs\n+medium priority\n";
514        let doc_file = "diff --git a/README.md b/README.md\n+low priority docs\n";
515        let extra = "diff --git a/extra.rs b/extra.rs\n+extra content that exceeds budget\n";
516
517        let diff = format!("{doc_file}{test_file}{src_file}{extra}");
518
519        let (truncated, was_truncated) = truncate_diff_to_model_budget(&diff, max_size as u64);
520        assert!(was_truncated, "should truncate when files exceed budget");
521        assert!(
522            truncated.len() <= max_size,
523            "truncated {} exceeded max_size {}",
524            truncated.len(),
525            max_size
526        );
527        if truncated.contains("priority") {
528            assert!(
529                truncated.contains("high priority") || truncated.contains("medium priority"),
530                "should prioritize src/tests over docs"
531            );
532        }
533    }
534}