Skip to main content

lean_ctx/core/
bounce_tracker.rs

1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3
4const BOUNCE_WINDOW: u64 = 5;
5const BOUNCE_RATE_THRESHOLD: f64 = 0.30;
6
7#[derive(Debug, Clone)]
8struct ReadEvent {
9    #[allow(dead_code)]
10    mode: String,
11    tokens_sent: usize,
12    #[allow(dead_code)]
13    original_tokens: usize,
14    seq: u64,
15    was_compressed: bool,
16}
17
18#[derive(Debug, Default)]
19struct BounceStats {
20    total_reads: u64,
21    bounces: u64,
22    wasted_tokens: usize,
23}
24
25#[derive(Debug, Default)]
26pub struct BounceTracker {
27    recent_reads: HashMap<String, Vec<ReadEvent>>,
28    per_extension: HashMap<String, BounceStats>,
29    recently_edited: HashMap<String, u64>,
30    seq_counter: u64,
31    total_bounces: u64,
32    total_wasted_tokens: usize,
33}
34
35fn is_compressed_mode(mode: &str) -> bool {
36    !matches!(mode, "full" | "diff")
37}
38
39fn extension_of(path: &str) -> String {
40    path.rsplit('.')
41        .next()
42        .map(|e| format!(".{}", e.to_ascii_lowercase()))
43        .unwrap_or_default()
44}
45
46impl BounceTracker {
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    pub fn next_seq(&mut self) -> u64 {
52        self.seq_counter += 1;
53        self.seq_counter
54    }
55
56    pub fn set_seq(&mut self, seq: u64) {
57        self.seq_counter = seq;
58    }
59
60    pub fn record_read(
61        &mut self,
62        path: &str,
63        mode: &str,
64        tokens_sent: usize,
65        original_tokens: usize,
66    ) {
67        let norm = crate::core::pathutil::normalize_tool_path(path);
68        let seq = self.seq_counter;
69        let compressed = is_compressed_mode(mode);
70
71        if !compressed {
72            self.detect_bounce(&norm, seq);
73        }
74
75        let events = self.recent_reads.entry(norm).or_default();
76        events.push(ReadEvent {
77            mode: mode.to_string(),
78            tokens_sent,
79            original_tokens,
80            seq,
81            was_compressed: compressed,
82        });
83
84        if events.len() > 10 {
85            events.drain(..events.len() - 10);
86        }
87
88        let ext = extension_of(path);
89        if !ext.is_empty() {
90            let stats = self.per_extension.entry(ext).or_default();
91            stats.total_reads += 1;
92        }
93    }
94
95    fn detect_bounce(&mut self, norm_path: &str, full_seq: u64) {
96        let Some(events) = self.recent_reads.get(norm_path) else {
97            return;
98        };
99
100        if let Some(ev) = events.iter().next_back() {
101            if ev.was_compressed && full_seq.saturating_sub(ev.seq) <= BOUNCE_WINDOW {
102                self.total_bounces += 1;
103                self.total_wasted_tokens += ev.tokens_sent;
104
105                let ext = extension_of(norm_path);
106                if !ext.is_empty() {
107                    let stats = self.per_extension.entry(ext).or_default();
108                    stats.bounces += 1;
109                    stats.wasted_tokens += ev.tokens_sent;
110                }
111            }
112        }
113    }
114
115    pub fn record_shell_file_access(&mut self, path: &str) {
116        let norm = crate::core::pathutil::normalize_tool_path(path);
117        let seq = self.seq_counter;
118        self.detect_bounce(&norm, seq);
119    }
120
121    pub fn record_edit(&mut self, path: &str) {
122        let norm = crate::core::pathutil::normalize_tool_path(path);
123        self.recently_edited.insert(norm, self.seq_counter);
124    }
125
126    pub fn should_force_full(&self, path: &str) -> bool {
127        let norm = crate::core::pathutil::normalize_tool_path(path);
128
129        if let Some(&edit_seq) = self.recently_edited.get(&norm) {
130            if self.seq_counter.saturating_sub(edit_seq) <= 10 {
131                return true;
132            }
133        }
134
135        let ext = extension_of(path);
136        if !ext.is_empty() {
137            if let Some(stats) = self.per_extension.get(&ext) {
138                if stats.total_reads >= 3 {
139                    let rate = stats.bounces as f64 / stats.total_reads as f64;
140                    if rate >= BOUNCE_RATE_THRESHOLD {
141                        return true;
142                    }
143                }
144            }
145        }
146
147        false
148    }
149
150    pub fn bounce_rate_for_extension(&self, path: &str) -> Option<f64> {
151        let ext = extension_of(path);
152        self.per_extension.get(&ext).and_then(|s| {
153            if s.total_reads >= 3 {
154                Some(s.bounces as f64 / s.total_reads as f64)
155            } else {
156                None
157            }
158        })
159    }
160
161    pub fn total_bounces(&self) -> u64 {
162        self.total_bounces
163    }
164
165    pub fn total_wasted_tokens(&self) -> usize {
166        self.total_wasted_tokens
167    }
168
169    pub fn adjusted_savings(&self, raw_savings: usize) -> isize {
170        raw_savings as isize - self.total_wasted_tokens as isize
171    }
172
173    pub fn per_extension_json(&self) -> Vec<serde_json::Value> {
174        let mut exts: Vec<_> = self
175            .per_extension
176            .iter()
177            .filter(|(_, s)| s.total_reads > 0)
178            .collect();
179        exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
180        exts.iter()
181            .take(10)
182            .map(|(ext, stats)| {
183                let rate = if stats.total_reads > 0 {
184                    stats.bounces as f64 / stats.total_reads as f64
185                } else {
186                    0.0
187                };
188                serde_json::json!({
189                    "ext": ext,
190                    "reads": stats.total_reads,
191                    "bounces": stats.bounces,
192                    "wasted_tokens": stats.wasted_tokens,
193                    "rate": (rate * 1000.0).round() / 1000.0,
194                })
195            })
196            .collect()
197    }
198
199    pub fn format_summary(&self) -> String {
200        if self.total_bounces == 0 {
201            return "Bounces: 0".to_string();
202        }
203        let mut lines = vec![format!(
204            "Bounces: {} ({} wasted tokens)",
205            self.total_bounces, self.total_wasted_tokens
206        )];
207        let mut exts: Vec<_> = self
208            .per_extension
209            .iter()
210            .filter(|(_, s)| s.bounces > 0)
211            .collect();
212        exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
213        for (ext, stats) in exts.iter().take(5) {
214            let rate = if stats.total_reads > 0 {
215                stats.bounces as f64 / stats.total_reads as f64 * 100.0
216            } else {
217                0.0
218            };
219            lines.push(format!(
220                "  {ext}: {}/{} reads bounced ({rate:.0}%), {} tok wasted",
221                stats.bounces, stats.total_reads, stats.wasted_tokens,
222            ));
223        }
224        lines.join("\n")
225    }
226}
227
228static GLOBAL_TRACKER: OnceLock<Mutex<BounceTracker>> = OnceLock::new();
229
230pub fn global() -> &'static Mutex<BounceTracker> {
231    GLOBAL_TRACKER.get_or_init(|| Mutex::new(BounceTracker::new()))
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn no_bounce_when_first_read_is_full() {
240        let mut bt = BounceTracker::new();
241        bt.seq_counter = 1;
242        bt.record_read("src/main.rs", "full", 500, 500);
243        assert_eq!(bt.total_bounces(), 0);
244        assert_eq!(bt.total_wasted_tokens(), 0);
245    }
246
247    #[test]
248    fn bounce_detected_on_compressed_then_full() {
249        let mut bt = BounceTracker::new();
250        bt.seq_counter = 1;
251        bt.record_read("src/main.rs", "map", 50, 500);
252        bt.seq_counter = 2;
253        bt.record_read("src/main.rs", "full", 500, 500);
254        assert_eq!(bt.total_bounces(), 1);
255        assert_eq!(bt.total_wasted_tokens(), 50);
256    }
257
258    #[test]
259    fn no_bounce_outside_window() {
260        let mut bt = BounceTracker::new();
261        bt.seq_counter = 1;
262        bt.record_read("src/main.rs", "map", 50, 500);
263        bt.seq_counter = 10;
264        bt.record_read("src/main.rs", "full", 500, 500);
265        assert_eq!(bt.total_bounces(), 0);
266    }
267
268    #[test]
269    fn shell_access_triggers_bounce() {
270        let mut bt = BounceTracker::new();
271        bt.seq_counter = 1;
272        bt.record_read("config.yml", "signatures", 30, 400);
273        bt.seq_counter = 3;
274        bt.record_shell_file_access("config.yml");
275        assert_eq!(bt.total_bounces(), 1);
276        assert_eq!(bt.total_wasted_tokens(), 30);
277    }
278
279    #[test]
280    fn should_force_full_after_edit() {
281        let mut bt = BounceTracker::new();
282        bt.seq_counter = 5;
283        bt.record_edit("src/lib.rs");
284        bt.seq_counter = 8;
285        assert!(bt.should_force_full("src/lib.rs"));
286        bt.seq_counter = 20;
287        assert!(!bt.should_force_full("src/lib.rs"));
288    }
289
290    #[test]
291    fn should_force_full_by_extension_bounce_rate() {
292        let mut bt = BounceTracker::new();
293        for i in 1..=6 {
294            bt.seq_counter = i * 2 - 1;
295            bt.record_read(&format!("f{i}.yml"), "map", 30, 400);
296            bt.seq_counter = i * 2;
297            bt.record_read(&format!("f{i}.yml"), "full", 400, 400);
298        }
299        assert!(bt.should_force_full("new.yml"));
300    }
301
302    #[test]
303    fn adjusted_savings_subtracts_waste() {
304        let mut bt = BounceTracker::new();
305        bt.seq_counter = 1;
306        bt.record_read("a.rs", "map", 50, 500);
307        bt.seq_counter = 2;
308        bt.record_read("a.rs", "full", 500, 500);
309        assert_eq!(bt.adjusted_savings(1000), 950);
310    }
311
312    #[test]
313    fn bounce_rate_for_extension_below_minimum() {
314        let bt = BounceTracker::new();
315        assert!(bt.bounce_rate_for_extension("test.rs").is_none());
316    }
317
318    #[test]
319    fn format_summary_empty() {
320        let bt = BounceTracker::new();
321        assert_eq!(bt.format_summary(), "Bounces: 0");
322    }
323}