1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3
4const BOUNCE_WINDOW: u64 = 5;
5const BOUNCE_RATE_THRESHOLD: f64 = 0.30;
6
7#[derive(Debug, Clone)]
8struct ReadEvent {
9 #[allow(dead_code)]
10 mode: String,
11 tokens_sent: usize,
12 #[allow(dead_code)]
13 original_tokens: usize,
14 seq: u64,
15 was_compressed: bool,
16}
17
18#[derive(Debug, Default)]
19struct BounceStats {
20 total_reads: u64,
21 bounces: u64,
22 wasted_tokens: usize,
23}
24
25#[derive(Debug, Default)]
26pub struct BounceTracker {
27 recent_reads: HashMap<String, Vec<ReadEvent>>,
28 per_extension: HashMap<String, BounceStats>,
29 recently_edited: HashMap<String, u64>,
30 seq_counter: u64,
31 total_bounces: u64,
32 total_wasted_tokens: usize,
33}
34
35fn is_compressed_mode(mode: &str) -> bool {
36 !matches!(mode, "full" | "diff")
37}
38
39fn extension_of(path: &str) -> String {
40 path.rsplit('.')
41 .next()
42 .map(|e| format!(".{}", e.to_ascii_lowercase()))
43 .unwrap_or_default()
44}
45
46impl BounceTracker {
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn next_seq(&mut self) -> u64 {
52 self.seq_counter += 1;
53 self.seq_counter
54 }
55
56 pub fn set_seq(&mut self, seq: u64) {
57 self.seq_counter = seq;
58 }
59
60 pub fn record_read(
61 &mut self,
62 path: &str,
63 mode: &str,
64 tokens_sent: usize,
65 original_tokens: usize,
66 ) {
67 let norm = crate::core::pathutil::normalize_tool_path(path);
68 let seq = self.seq_counter;
69 let compressed = is_compressed_mode(mode);
70
71 if !compressed {
72 self.detect_bounce(&norm, seq);
73 }
74
75 let events = self.recent_reads.entry(norm).or_default();
76 events.push(ReadEvent {
77 mode: mode.to_string(),
78 tokens_sent,
79 original_tokens,
80 seq,
81 was_compressed: compressed,
82 });
83
84 if events.len() > 10 {
85 events.drain(..events.len() - 10);
86 }
87
88 let ext = extension_of(path);
89 if !ext.is_empty() {
90 let stats = self.per_extension.entry(ext).or_default();
91 stats.total_reads += 1;
92 }
93 }
94
95 fn detect_bounce(&mut self, norm_path: &str, full_seq: u64) {
96 let Some(events) = self.recent_reads.get(norm_path) else {
97 return;
98 };
99
100 if let Some(ev) = events.iter().next_back() {
101 if ev.was_compressed && full_seq.saturating_sub(ev.seq) <= BOUNCE_WINDOW {
102 self.total_bounces += 1;
103 self.total_wasted_tokens += ev.tokens_sent;
104
105 let ext = extension_of(norm_path);
106 if !ext.is_empty() {
107 let stats = self.per_extension.entry(ext).or_default();
108 stats.bounces += 1;
109 stats.wasted_tokens += ev.tokens_sent;
110 }
111 }
112 }
113 }
114
115 pub fn record_shell_file_access(&mut self, path: &str) {
116 let norm = crate::core::pathutil::normalize_tool_path(path);
117 let seq = self.seq_counter;
118 self.detect_bounce(&norm, seq);
119 }
120
121 pub fn record_edit(&mut self, path: &str) {
122 let norm = crate::core::pathutil::normalize_tool_path(path);
123 self.recently_edited.insert(norm, self.seq_counter);
124 }
125
126 pub fn should_force_full(&self, path: &str) -> bool {
127 let norm = crate::core::pathutil::normalize_tool_path(path);
128
129 if let Some(&edit_seq) = self.recently_edited.get(&norm) {
130 if self.seq_counter.saturating_sub(edit_seq) <= 10 {
131 return true;
132 }
133 }
134
135 let ext = extension_of(path);
136 if !ext.is_empty() {
137 if let Some(stats) = self.per_extension.get(&ext) {
138 if stats.total_reads >= 3 {
139 let rate = stats.bounces as f64 / stats.total_reads as f64;
140 if rate >= BOUNCE_RATE_THRESHOLD {
141 return true;
142 }
143 }
144 }
145 }
146
147 false
148 }
149
150 pub fn bounce_rate_for_extension(&self, path: &str) -> Option<f64> {
151 let ext = extension_of(path);
152 self.per_extension.get(&ext).and_then(|s| {
153 if s.total_reads >= 3 {
154 Some(s.bounces as f64 / s.total_reads as f64)
155 } else {
156 None
157 }
158 })
159 }
160
161 pub fn total_bounces(&self) -> u64 {
162 self.total_bounces
163 }
164
165 pub fn total_wasted_tokens(&self) -> usize {
166 self.total_wasted_tokens
167 }
168
169 pub fn adjusted_savings(&self, raw_savings: usize) -> isize {
170 raw_savings as isize - self.total_wasted_tokens as isize
171 }
172
173 pub fn format_summary(&self) -> String {
174 if self.total_bounces == 0 {
175 return "Bounces: 0".to_string();
176 }
177 let mut lines = vec![format!(
178 "Bounces: {} ({} wasted tokens)",
179 self.total_bounces, self.total_wasted_tokens
180 )];
181 let mut exts: Vec<_> = self
182 .per_extension
183 .iter()
184 .filter(|(_, s)| s.bounces > 0)
185 .collect();
186 exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
187 for (ext, stats) in exts.iter().take(5) {
188 let rate = if stats.total_reads > 0 {
189 stats.bounces as f64 / stats.total_reads as f64 * 100.0
190 } else {
191 0.0
192 };
193 lines.push(format!(
194 " {ext}: {}/{} reads bounced ({rate:.0}%), {} tok wasted",
195 stats.bounces, stats.total_reads, stats.wasted_tokens,
196 ));
197 }
198 lines.join("\n")
199 }
200}
201
202static GLOBAL_TRACKER: OnceLock<Mutex<BounceTracker>> = OnceLock::new();
203
204pub fn global() -> &'static Mutex<BounceTracker> {
205 GLOBAL_TRACKER.get_or_init(|| Mutex::new(BounceTracker::new()))
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn no_bounce_when_first_read_is_full() {
214 let mut bt = BounceTracker::new();
215 bt.seq_counter = 1;
216 bt.record_read("src/main.rs", "full", 500, 500);
217 assert_eq!(bt.total_bounces(), 0);
218 assert_eq!(bt.total_wasted_tokens(), 0);
219 }
220
221 #[test]
222 fn bounce_detected_on_compressed_then_full() {
223 let mut bt = BounceTracker::new();
224 bt.seq_counter = 1;
225 bt.record_read("src/main.rs", "map", 50, 500);
226 bt.seq_counter = 2;
227 bt.record_read("src/main.rs", "full", 500, 500);
228 assert_eq!(bt.total_bounces(), 1);
229 assert_eq!(bt.total_wasted_tokens(), 50);
230 }
231
232 #[test]
233 fn no_bounce_outside_window() {
234 let mut bt = BounceTracker::new();
235 bt.seq_counter = 1;
236 bt.record_read("src/main.rs", "map", 50, 500);
237 bt.seq_counter = 10;
238 bt.record_read("src/main.rs", "full", 500, 500);
239 assert_eq!(bt.total_bounces(), 0);
240 }
241
242 #[test]
243 fn shell_access_triggers_bounce() {
244 let mut bt = BounceTracker::new();
245 bt.seq_counter = 1;
246 bt.record_read("config.yml", "signatures", 30, 400);
247 bt.seq_counter = 3;
248 bt.record_shell_file_access("config.yml");
249 assert_eq!(bt.total_bounces(), 1);
250 assert_eq!(bt.total_wasted_tokens(), 30);
251 }
252
253 #[test]
254 fn should_force_full_after_edit() {
255 let mut bt = BounceTracker::new();
256 bt.seq_counter = 5;
257 bt.record_edit("src/lib.rs");
258 bt.seq_counter = 8;
259 assert!(bt.should_force_full("src/lib.rs"));
260 bt.seq_counter = 20;
261 assert!(!bt.should_force_full("src/lib.rs"));
262 }
263
264 #[test]
265 fn should_force_full_by_extension_bounce_rate() {
266 let mut bt = BounceTracker::new();
267 for i in 1..=6 {
268 bt.seq_counter = i * 2 - 1;
269 bt.record_read(&format!("f{i}.yml"), "map", 30, 400);
270 bt.seq_counter = i * 2;
271 bt.record_read(&format!("f{i}.yml"), "full", 400, 400);
272 }
273 assert!(bt.should_force_full("new.yml"));
274 }
275
276 #[test]
277 fn adjusted_savings_subtracts_waste() {
278 let mut bt = BounceTracker::new();
279 bt.seq_counter = 1;
280 bt.record_read("a.rs", "map", 50, 500);
281 bt.seq_counter = 2;
282 bt.record_read("a.rs", "full", 500, 500);
283 assert_eq!(bt.adjusted_savings(1000), 950);
284 }
285
286 #[test]
287 fn bounce_rate_for_extension_below_minimum() {
288 let bt = BounceTracker::new();
289 assert!(bt.bounce_rate_for_extension("test.rs").is_none());
290 }
291
292 #[test]
293 fn format_summary_empty() {
294 let bt = BounceTracker::new();
295 assert_eq!(bt.format_summary(), "Bounces: 0");
296 }
297}