1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3
4const BOUNCE_WINDOW: u64 = 5;
5const BOUNCE_RATE_THRESHOLD: f64 = 0.30;
6
7#[derive(Debug, Clone)]
8struct ReadEvent {
9 #[allow(dead_code)]
10 mode: String,
11 tokens_sent: usize,
12 #[allow(dead_code)]
13 original_tokens: usize,
14 seq: u64,
15 was_compressed: bool,
16}
17
18#[derive(Debug, Default)]
19struct BounceStats {
20 total_reads: u64,
21 bounces: u64,
22 wasted_tokens: usize,
23}
24
25#[derive(Debug, Default)]
26pub struct BounceTracker {
27 recent_reads: HashMap<String, Vec<ReadEvent>>,
28 per_extension: HashMap<String, BounceStats>,
29 recently_edited: HashMap<String, u64>,
30 seq_counter: u64,
31 total_bounces: u64,
32 total_wasted_tokens: usize,
33}
34
35fn is_compressed_mode(mode: &str) -> bool {
36 !matches!(mode, "full" | "diff")
37}
38
39fn extension_of(path: &str) -> String {
40 path.rsplit('.')
41 .next()
42 .map(|e| format!(".{}", e.to_ascii_lowercase()))
43 .unwrap_or_default()
44}
45
46impl BounceTracker {
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn next_seq(&mut self) -> u64 {
52 self.seq_counter += 1;
53 self.seq_counter
54 }
55
56 pub fn set_seq(&mut self, seq: u64) {
57 self.seq_counter = seq;
58 }
59
60 pub fn record_read(
61 &mut self,
62 path: &str,
63 mode: &str,
64 tokens_sent: usize,
65 original_tokens: usize,
66 ) {
67 let norm = crate::core::pathutil::normalize_tool_path(path);
68 let seq = self.seq_counter;
69 let compressed = is_compressed_mode(mode);
70
71 if !compressed {
72 self.detect_bounce(&norm, seq);
73 }
74
75 let events = self.recent_reads.entry(norm).or_default();
76 events.push(ReadEvent {
77 mode: mode.to_string(),
78 tokens_sent,
79 original_tokens,
80 seq,
81 was_compressed: compressed,
82 });
83
84 if events.len() > 10 {
85 events.drain(..events.len() - 10);
86 }
87
88 let ext = extension_of(path);
89 if !ext.is_empty() {
90 let stats = self.per_extension.entry(ext).or_default();
91 stats.total_reads += 1;
92 }
93 }
94
95 fn detect_bounce(&mut self, norm_path: &str, full_seq: u64) {
96 let Some(events) = self.recent_reads.get(norm_path) else {
97 return;
98 };
99
100 if let Some(ev) = events.iter().next_back() {
101 if ev.was_compressed && full_seq.saturating_sub(ev.seq) <= BOUNCE_WINDOW {
102 self.total_bounces += 1;
103 self.total_wasted_tokens += ev.tokens_sent;
104
105 let ext = extension_of(norm_path);
106 if !ext.is_empty() {
107 let stats = self.per_extension.entry(ext).or_default();
108 stats.bounces += 1;
109 stats.wasted_tokens += ev.tokens_sent;
110 }
111 }
112 }
113 }
114
115 pub fn record_shell_file_access(&mut self, path: &str) {
116 let norm = crate::core::pathutil::normalize_tool_path(path);
117 let seq = self.seq_counter;
118 self.detect_bounce(&norm, seq);
119 }
120
121 pub fn record_edit(&mut self, path: &str) {
122 let norm = crate::core::pathutil::normalize_tool_path(path);
123 self.recently_edited.insert(norm, self.seq_counter);
124 }
125
126 pub fn should_force_full(&self, path: &str) -> bool {
127 let norm = crate::core::pathutil::normalize_tool_path(path);
128
129 if let Some(&edit_seq) = self.recently_edited.get(&norm) {
130 if self.seq_counter.saturating_sub(edit_seq) <= 10 {
131 return true;
132 }
133 }
134
135 let ext = extension_of(path);
136 if !ext.is_empty() {
137 if let Some(stats) = self.per_extension.get(&ext) {
138 if stats.total_reads >= 3 {
139 let rate = stats.bounces as f64 / stats.total_reads as f64;
140 if rate >= BOUNCE_RATE_THRESHOLD {
141 return true;
142 }
143 }
144 }
145 }
146
147 false
148 }
149
150 pub fn bounce_rate_for_extension(&self, path: &str) -> Option<f64> {
151 let ext = extension_of(path);
152 self.per_extension.get(&ext).and_then(|s| {
153 if s.total_reads >= 3 {
154 Some(s.bounces as f64 / s.total_reads as f64)
155 } else {
156 None
157 }
158 })
159 }
160
161 pub fn total_bounces(&self) -> u64 {
162 self.total_bounces
163 }
164
165 pub fn total_wasted_tokens(&self) -> usize {
166 self.total_wasted_tokens
167 }
168
169 pub fn adjusted_savings(&self, raw_savings: usize) -> isize {
170 raw_savings as isize - self.total_wasted_tokens as isize
171 }
172
173 pub fn per_extension_json(&self) -> Vec<serde_json::Value> {
174 let mut exts: Vec<_> = self
175 .per_extension
176 .iter()
177 .filter(|(_, s)| s.total_reads > 0)
178 .collect();
179 exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
180 exts.iter()
181 .take(10)
182 .map(|(ext, stats)| {
183 let rate = if stats.total_reads > 0 {
184 stats.bounces as f64 / stats.total_reads as f64
185 } else {
186 0.0
187 };
188 serde_json::json!({
189 "ext": ext,
190 "reads": stats.total_reads,
191 "bounces": stats.bounces,
192 "wasted_tokens": stats.wasted_tokens,
193 "rate": (rate * 1000.0).round() / 1000.0,
194 })
195 })
196 .collect()
197 }
198
199 pub fn format_summary(&self) -> String {
200 if self.total_bounces == 0 {
201 return "Bounces: 0".to_string();
202 }
203 let mut lines = vec![format!(
204 "Bounces: {} ({} wasted tokens)",
205 self.total_bounces, self.total_wasted_tokens
206 )];
207 let mut exts: Vec<_> = self
208 .per_extension
209 .iter()
210 .filter(|(_, s)| s.bounces > 0)
211 .collect();
212 exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
213 for (ext, stats) in exts.iter().take(5) {
214 let rate = if stats.total_reads > 0 {
215 stats.bounces as f64 / stats.total_reads as f64 * 100.0
216 } else {
217 0.0
218 };
219 lines.push(format!(
220 " {ext}: {}/{} reads bounced ({rate:.0}%), {} tok wasted",
221 stats.bounces, stats.total_reads, stats.wasted_tokens,
222 ));
223 }
224 lines.join("\n")
225 }
226}
227
228static GLOBAL_TRACKER: OnceLock<Mutex<BounceTracker>> = OnceLock::new();
229
230pub fn global() -> &'static Mutex<BounceTracker> {
231 GLOBAL_TRACKER.get_or_init(|| Mutex::new(BounceTracker::new()))
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn no_bounce_when_first_read_is_full() {
240 let mut bt = BounceTracker::new();
241 bt.seq_counter = 1;
242 bt.record_read("src/main.rs", "full", 500, 500);
243 assert_eq!(bt.total_bounces(), 0);
244 assert_eq!(bt.total_wasted_tokens(), 0);
245 }
246
247 #[test]
248 fn bounce_detected_on_compressed_then_full() {
249 let mut bt = BounceTracker::new();
250 bt.seq_counter = 1;
251 bt.record_read("src/main.rs", "map", 50, 500);
252 bt.seq_counter = 2;
253 bt.record_read("src/main.rs", "full", 500, 500);
254 assert_eq!(bt.total_bounces(), 1);
255 assert_eq!(bt.total_wasted_tokens(), 50);
256 }
257
258 #[test]
259 fn no_bounce_outside_window() {
260 let mut bt = BounceTracker::new();
261 bt.seq_counter = 1;
262 bt.record_read("src/main.rs", "map", 50, 500);
263 bt.seq_counter = 10;
264 bt.record_read("src/main.rs", "full", 500, 500);
265 assert_eq!(bt.total_bounces(), 0);
266 }
267
268 #[test]
269 fn shell_access_triggers_bounce() {
270 let mut bt = BounceTracker::new();
271 bt.seq_counter = 1;
272 bt.record_read("config.yml", "signatures", 30, 400);
273 bt.seq_counter = 3;
274 bt.record_shell_file_access("config.yml");
275 assert_eq!(bt.total_bounces(), 1);
276 assert_eq!(bt.total_wasted_tokens(), 30);
277 }
278
279 #[test]
280 fn should_force_full_after_edit() {
281 let mut bt = BounceTracker::new();
282 bt.seq_counter = 5;
283 bt.record_edit("src/lib.rs");
284 bt.seq_counter = 8;
285 assert!(bt.should_force_full("src/lib.rs"));
286 bt.seq_counter = 20;
287 assert!(!bt.should_force_full("src/lib.rs"));
288 }
289
290 #[test]
291 fn should_force_full_by_extension_bounce_rate() {
292 let mut bt = BounceTracker::new();
293 for i in 1..=6 {
294 bt.seq_counter = i * 2 - 1;
295 bt.record_read(&format!("f{i}.yml"), "map", 30, 400);
296 bt.seq_counter = i * 2;
297 bt.record_read(&format!("f{i}.yml"), "full", 400, 400);
298 }
299 assert!(bt.should_force_full("new.yml"));
300 }
301
302 #[test]
303 fn adjusted_savings_subtracts_waste() {
304 let mut bt = BounceTracker::new();
305 bt.seq_counter = 1;
306 bt.record_read("a.rs", "map", 50, 500);
307 bt.seq_counter = 2;
308 bt.record_read("a.rs", "full", 500, 500);
309 assert_eq!(bt.adjusted_savings(1000), 950);
310 }
311
312 #[test]
313 fn bounce_rate_for_extension_below_minimum() {
314 let bt = BounceTracker::new();
315 assert!(bt.bounce_rate_for_extension("test.rs").is_none());
316 }
317
318 #[test]
319 fn format_summary_empty() {
320 let bt = BounceTracker::new();
321 assert_eq!(bt.format_summary(), "Bounces: 0");
322 }
323}