1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3
4const BOUNCE_WINDOW: u64 = 5;
5const BOUNCE_RATE_THRESHOLD: f64 = 0.30;
6
7#[derive(Debug, Clone)]
8struct ReadEvent {
9 _mode: String,
10 tokens_sent: usize,
11 _original_tokens: usize,
12 seq: u64,
13 was_compressed: bool,
14}
15
16#[derive(Debug, Default)]
17struct BounceStats {
18 total_reads: u64,
19 bounces: u64,
20 wasted_tokens: usize,
21}
22
23#[derive(Debug, Default)]
24pub struct BounceTracker {
25 recent_reads: HashMap<String, Vec<ReadEvent>>,
26 per_extension: HashMap<String, BounceStats>,
27 recently_edited: HashMap<String, u64>,
28 seq_counter: u64,
29 total_bounces: u64,
30 total_wasted_tokens: usize,
31}
32
33fn is_compressed_mode(mode: &str) -> bool {
34 !matches!(mode, "full" | "diff")
35}
36
37fn extension_of(path: &str) -> String {
38 path.rsplit('.')
39 .next()
40 .map(|e| format!(".{}", e.to_ascii_lowercase()))
41 .unwrap_or_default()
42}
43
44impl BounceTracker {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn next_seq(&mut self) -> u64 {
50 self.seq_counter += 1;
51 self.seq_counter
52 }
53
54 pub fn set_seq(&mut self, seq: u64) {
55 self.seq_counter = seq;
56 }
57
58 pub fn record_read(
59 &mut self,
60 path: &str,
61 mode: &str,
62 tokens_sent: usize,
63 original_tokens: usize,
64 ) {
65 let norm = crate::core::pathutil::normalize_tool_path(path);
66 let seq = self.seq_counter;
67 let compressed = is_compressed_mode(mode);
68
69 if !compressed {
70 self.detect_bounce(&norm, seq);
71 }
72
73 let events = self.recent_reads.entry(norm).or_default();
74 events.push(ReadEvent {
75 _mode: mode.to_string(),
76 tokens_sent,
77 _original_tokens: original_tokens,
78 seq,
79 was_compressed: compressed,
80 });
81
82 if events.len() > 10 {
83 events.drain(..events.len() - 10);
84 }
85
86 let ext = extension_of(path);
87 if !ext.is_empty() {
88 let stats = self.per_extension.entry(ext).or_default();
89 stats.total_reads += 1;
90 }
91 }
92
93 fn detect_bounce(&mut self, norm_path: &str, full_seq: u64) {
94 let Some(events) = self.recent_reads.get(norm_path) else {
95 return;
96 };
97
98 if let Some(ev) = events.iter().next_back() {
99 if ev.was_compressed && full_seq.saturating_sub(ev.seq) <= BOUNCE_WINDOW {
100 self.total_bounces += 1;
101 self.total_wasted_tokens += ev.tokens_sent;
102
103 let ext = extension_of(norm_path);
104 if !ext.is_empty() {
105 let stats = self.per_extension.entry(ext).or_default();
106 stats.bounces += 1;
107 stats.wasted_tokens += ev.tokens_sent;
108 }
109 }
110 }
111 }
112
113 pub fn record_shell_file_access(&mut self, path: &str) {
114 let norm = crate::core::pathutil::normalize_tool_path(path);
115 let seq = self.seq_counter;
116 self.detect_bounce(&norm, seq);
117 }
118
119 pub fn record_edit(&mut self, path: &str) {
120 let norm = crate::core::pathutil::normalize_tool_path(path);
121 self.recently_edited.insert(norm, self.seq_counter);
122 }
123
124 pub fn should_force_full(&self, path: &str) -> bool {
125 let norm = crate::core::pathutil::normalize_tool_path(path);
126
127 if let Some(&edit_seq) = self.recently_edited.get(&norm) {
128 if self.seq_counter.saturating_sub(edit_seq) <= 10 {
129 return true;
130 }
131 }
132
133 let ext = extension_of(path);
134 if !ext.is_empty() {
135 if let Some(stats) = self.per_extension.get(&ext) {
136 if stats.total_reads >= 3 {
137 let rate = stats.bounces as f64 / stats.total_reads as f64;
138 if rate >= BOUNCE_RATE_THRESHOLD {
139 return true;
140 }
141 }
142 }
143 }
144
145 false
146 }
147
148 pub fn bounce_rate_for_extension(&self, path: &str) -> Option<f64> {
149 let ext = extension_of(path);
150 self.per_extension.get(&ext).and_then(|s| {
151 if s.total_reads >= 3 {
152 Some(s.bounces as f64 / s.total_reads as f64)
153 } else {
154 None
155 }
156 })
157 }
158
159 pub fn total_bounces(&self) -> u64 {
160 self.total_bounces
161 }
162
163 pub fn total_wasted_tokens(&self) -> usize {
164 self.total_wasted_tokens
165 }
166
167 pub fn adjusted_savings(&self, raw_savings: usize) -> isize {
168 raw_savings as isize - self.total_wasted_tokens as isize
169 }
170
171 pub fn per_extension_json(&self) -> Vec<serde_json::Value> {
172 let mut exts: Vec<_> = self
173 .per_extension
174 .iter()
175 .filter(|(_, s)| s.total_reads > 0)
176 .collect();
177 exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
178 exts.iter()
179 .take(10)
180 .map(|(ext, stats)| {
181 let rate = if stats.total_reads > 0 {
182 stats.bounces as f64 / stats.total_reads as f64
183 } else {
184 0.0
185 };
186 serde_json::json!({
187 "ext": ext,
188 "reads": stats.total_reads,
189 "bounces": stats.bounces,
190 "wasted_tokens": stats.wasted_tokens,
191 "rate": (rate * 1000.0).round() / 1000.0,
192 })
193 })
194 .collect()
195 }
196
197 pub fn format_summary(&self) -> String {
198 if self.total_bounces == 0 {
199 return "Bounces: 0".to_string();
200 }
201 let mut lines = vec![format!(
202 "Bounces: {} ({} wasted tokens)",
203 self.total_bounces, self.total_wasted_tokens
204 )];
205 let mut exts: Vec<_> = self
206 .per_extension
207 .iter()
208 .filter(|(_, s)| s.bounces > 0)
209 .collect();
210 exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
211 for (ext, stats) in exts.iter().take(5) {
212 let rate = if stats.total_reads > 0 {
213 stats.bounces as f64 / stats.total_reads as f64 * 100.0
214 } else {
215 0.0
216 };
217 lines.push(format!(
218 " {ext}: {}/{} reads bounced ({rate:.0}%), {} tok wasted",
219 stats.bounces, stats.total_reads, stats.wasted_tokens,
220 ));
221 }
222 lines.join("\n")
223 }
224}
225
226static GLOBAL_TRACKER: OnceLock<Mutex<BounceTracker>> = OnceLock::new();
227
228pub fn global() -> &'static Mutex<BounceTracker> {
229 GLOBAL_TRACKER.get_or_init(|| Mutex::new(BounceTracker::new()))
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn no_bounce_when_first_read_is_full() {
238 let mut bt = BounceTracker::new();
239 bt.seq_counter = 1;
240 bt.record_read("src/main.rs", "full", 500, 500);
241 assert_eq!(bt.total_bounces(), 0);
242 assert_eq!(bt.total_wasted_tokens(), 0);
243 }
244
245 #[test]
246 fn bounce_detected_on_compressed_then_full() {
247 let mut bt = BounceTracker::new();
248 bt.seq_counter = 1;
249 bt.record_read("src/main.rs", "map", 50, 500);
250 bt.seq_counter = 2;
251 bt.record_read("src/main.rs", "full", 500, 500);
252 assert_eq!(bt.total_bounces(), 1);
253 assert_eq!(bt.total_wasted_tokens(), 50);
254 }
255
256 #[test]
257 fn no_bounce_outside_window() {
258 let mut bt = BounceTracker::new();
259 bt.seq_counter = 1;
260 bt.record_read("src/main.rs", "map", 50, 500);
261 bt.seq_counter = 10;
262 bt.record_read("src/main.rs", "full", 500, 500);
263 assert_eq!(bt.total_bounces(), 0);
264 }
265
266 #[test]
267 fn shell_access_triggers_bounce() {
268 let mut bt = BounceTracker::new();
269 bt.seq_counter = 1;
270 bt.record_read("config.yml", "signatures", 30, 400);
271 bt.seq_counter = 3;
272 bt.record_shell_file_access("config.yml");
273 assert_eq!(bt.total_bounces(), 1);
274 assert_eq!(bt.total_wasted_tokens(), 30);
275 }
276
277 #[test]
278 fn should_force_full_after_edit() {
279 let mut bt = BounceTracker::new();
280 bt.seq_counter = 5;
281 bt.record_edit("src/lib.rs");
282 bt.seq_counter = 8;
283 assert!(bt.should_force_full("src/lib.rs"));
284 bt.seq_counter = 20;
285 assert!(!bt.should_force_full("src/lib.rs"));
286 }
287
288 #[test]
289 fn should_force_full_by_extension_bounce_rate() {
290 let mut bt = BounceTracker::new();
291 for i in 1..=6 {
292 bt.seq_counter = i * 2 - 1;
293 bt.record_read(&format!("f{i}.yml"), "map", 30, 400);
294 bt.seq_counter = i * 2;
295 bt.record_read(&format!("f{i}.yml"), "full", 400, 400);
296 }
297 assert!(bt.should_force_full("new.yml"));
298 }
299
300 #[test]
301 fn adjusted_savings_subtracts_waste() {
302 let mut bt = BounceTracker::new();
303 bt.seq_counter = 1;
304 bt.record_read("a.rs", "map", 50, 500);
305 bt.seq_counter = 2;
306 bt.record_read("a.rs", "full", 500, 500);
307 assert_eq!(bt.adjusted_savings(1000), 950);
308 }
309
310 #[test]
311 fn bounce_rate_for_extension_below_minimum() {
312 let bt = BounceTracker::new();
313 assert!(bt.bounce_rate_for_extension("test.rs").is_none());
314 }
315
316 #[test]
317 fn format_summary_empty() {
318 let bt = BounceTracker::new();
319 assert_eq!(bt.format_summary(), "Bounces: 0");
320 }
321}