Skip to main content

lean_ctx/core/
bounce_tracker.rs

1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3
4const BOUNCE_WINDOW: u64 = 5;
5const BOUNCE_RATE_THRESHOLD: f64 = 0.30;
6
7#[derive(Debug, Clone)]
8struct ReadEvent {
9    _mode: String,
10    tokens_sent: usize,
11    _original_tokens: usize,
12    seq: u64,
13    was_compressed: bool,
14}
15
16#[derive(Debug, Default)]
17struct BounceStats {
18    total_reads: u64,
19    bounces: u64,
20    wasted_tokens: usize,
21}
22
23#[derive(Debug, Default)]
24pub struct BounceTracker {
25    recent_reads: HashMap<String, Vec<ReadEvent>>,
26    per_extension: HashMap<String, BounceStats>,
27    recently_edited: HashMap<String, u64>,
28    seq_counter: u64,
29    total_bounces: u64,
30    total_wasted_tokens: usize,
31}
32
33fn is_compressed_mode(mode: &str) -> bool {
34    !matches!(mode, "full" | "diff")
35}
36
37fn extension_of(path: &str) -> String {
38    path.rsplit('.')
39        .next()
40        .map(|e| format!(".{}", e.to_ascii_lowercase()))
41        .unwrap_or_default()
42}
43
44impl BounceTracker {
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    pub fn next_seq(&mut self) -> u64 {
50        self.seq_counter += 1;
51        self.seq_counter
52    }
53
54    pub fn set_seq(&mut self, seq: u64) {
55        self.seq_counter = seq;
56    }
57
58    pub fn record_read(
59        &mut self,
60        path: &str,
61        mode: &str,
62        tokens_sent: usize,
63        original_tokens: usize,
64    ) {
65        let norm = crate::core::pathutil::normalize_tool_path(path);
66        let seq = self.seq_counter;
67        let compressed = is_compressed_mode(mode);
68
69        if !compressed {
70            self.detect_bounce(&norm, seq);
71        }
72
73        let events = self.recent_reads.entry(norm).or_default();
74        events.push(ReadEvent {
75            _mode: mode.to_string(),
76            tokens_sent,
77            _original_tokens: original_tokens,
78            seq,
79            was_compressed: compressed,
80        });
81
82        if events.len() > 10 {
83            events.drain(..events.len() - 10);
84        }
85
86        let ext = extension_of(path);
87        if !ext.is_empty() {
88            let stats = self.per_extension.entry(ext).or_default();
89            stats.total_reads += 1;
90        }
91    }
92
93    fn detect_bounce(&mut self, norm_path: &str, full_seq: u64) {
94        let Some(events) = self.recent_reads.get(norm_path) else {
95            return;
96        };
97
98        if let Some(ev) = events.iter().next_back() {
99            if ev.was_compressed && full_seq.saturating_sub(ev.seq) <= BOUNCE_WINDOW {
100                self.total_bounces += 1;
101                self.total_wasted_tokens += ev.tokens_sent;
102
103                let ext = extension_of(norm_path);
104                if !ext.is_empty() {
105                    let stats = self.per_extension.entry(ext).or_default();
106                    stats.bounces += 1;
107                    stats.wasted_tokens += ev.tokens_sent;
108                }
109            }
110        }
111    }
112
113    pub fn record_shell_file_access(&mut self, path: &str) {
114        let norm = crate::core::pathutil::normalize_tool_path(path);
115        let seq = self.seq_counter;
116        self.detect_bounce(&norm, seq);
117    }
118
119    pub fn record_edit(&mut self, path: &str) {
120        let norm = crate::core::pathutil::normalize_tool_path(path);
121        self.recently_edited.insert(norm, self.seq_counter);
122    }
123
124    pub fn should_force_full(&self, path: &str) -> bool {
125        let norm = crate::core::pathutil::normalize_tool_path(path);
126
127        if let Some(&edit_seq) = self.recently_edited.get(&norm) {
128            if self.seq_counter.saturating_sub(edit_seq) <= 10 {
129                return true;
130            }
131        }
132
133        let ext = extension_of(path);
134        if !ext.is_empty() {
135            if let Some(stats) = self.per_extension.get(&ext) {
136                if stats.total_reads >= 3 {
137                    let rate = stats.bounces as f64 / stats.total_reads as f64;
138                    if rate >= BOUNCE_RATE_THRESHOLD {
139                        return true;
140                    }
141                }
142            }
143        }
144
145        false
146    }
147
148    pub fn bounce_rate_for_extension(&self, path: &str) -> Option<f64> {
149        let ext = extension_of(path);
150        self.per_extension.get(&ext).and_then(|s| {
151            if s.total_reads >= 3 {
152                Some(s.bounces as f64 / s.total_reads as f64)
153            } else {
154                None
155            }
156        })
157    }
158
159    pub fn total_bounces(&self) -> u64 {
160        self.total_bounces
161    }
162
163    pub fn total_wasted_tokens(&self) -> usize {
164        self.total_wasted_tokens
165    }
166
167    pub fn adjusted_savings(&self, raw_savings: usize) -> isize {
168        raw_savings as isize - self.total_wasted_tokens as isize
169    }
170
171    pub fn per_extension_json(&self) -> Vec<serde_json::Value> {
172        let mut exts: Vec<_> = self
173            .per_extension
174            .iter()
175            .filter(|(_, s)| s.total_reads > 0)
176            .collect();
177        exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
178        exts.iter()
179            .take(10)
180            .map(|(ext, stats)| {
181                let rate = if stats.total_reads > 0 {
182                    stats.bounces as f64 / stats.total_reads as f64
183                } else {
184                    0.0
185                };
186                serde_json::json!({
187                    "ext": ext,
188                    "reads": stats.total_reads,
189                    "bounces": stats.bounces,
190                    "wasted_tokens": stats.wasted_tokens,
191                    "rate": (rate * 1000.0).round() / 1000.0,
192                })
193            })
194            .collect()
195    }
196
197    pub fn format_summary(&self) -> String {
198        if self.total_bounces == 0 {
199            return "Bounces: 0".to_string();
200        }
201        let mut lines = vec![format!(
202            "Bounces: {} ({} wasted tokens)",
203            self.total_bounces, self.total_wasted_tokens
204        )];
205        let mut exts: Vec<_> = self
206            .per_extension
207            .iter()
208            .filter(|(_, s)| s.bounces > 0)
209            .collect();
210        exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
211        for (ext, stats) in exts.iter().take(5) {
212            let rate = if stats.total_reads > 0 {
213                stats.bounces as f64 / stats.total_reads as f64 * 100.0
214            } else {
215                0.0
216            };
217            lines.push(format!(
218                "  {ext}: {}/{} reads bounced ({rate:.0}%), {} tok wasted",
219                stats.bounces, stats.total_reads, stats.wasted_tokens,
220            ));
221        }
222        lines.join("\n")
223    }
224}
225
226static GLOBAL_TRACKER: OnceLock<Mutex<BounceTracker>> = OnceLock::new();
227
228pub fn global() -> &'static Mutex<BounceTracker> {
229    GLOBAL_TRACKER.get_or_init(|| Mutex::new(BounceTracker::new()))
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn no_bounce_when_first_read_is_full() {
238        let mut bt = BounceTracker::new();
239        bt.seq_counter = 1;
240        bt.record_read("src/main.rs", "full", 500, 500);
241        assert_eq!(bt.total_bounces(), 0);
242        assert_eq!(bt.total_wasted_tokens(), 0);
243    }
244
245    #[test]
246    fn bounce_detected_on_compressed_then_full() {
247        let mut bt = BounceTracker::new();
248        bt.seq_counter = 1;
249        bt.record_read("src/main.rs", "map", 50, 500);
250        bt.seq_counter = 2;
251        bt.record_read("src/main.rs", "full", 500, 500);
252        assert_eq!(bt.total_bounces(), 1);
253        assert_eq!(bt.total_wasted_tokens(), 50);
254    }
255
256    #[test]
257    fn no_bounce_outside_window() {
258        let mut bt = BounceTracker::new();
259        bt.seq_counter = 1;
260        bt.record_read("src/main.rs", "map", 50, 500);
261        bt.seq_counter = 10;
262        bt.record_read("src/main.rs", "full", 500, 500);
263        assert_eq!(bt.total_bounces(), 0);
264    }
265
266    #[test]
267    fn shell_access_triggers_bounce() {
268        let mut bt = BounceTracker::new();
269        bt.seq_counter = 1;
270        bt.record_read("config.yml", "signatures", 30, 400);
271        bt.seq_counter = 3;
272        bt.record_shell_file_access("config.yml");
273        assert_eq!(bt.total_bounces(), 1);
274        assert_eq!(bt.total_wasted_tokens(), 30);
275    }
276
277    #[test]
278    fn should_force_full_after_edit() {
279        let mut bt = BounceTracker::new();
280        bt.seq_counter = 5;
281        bt.record_edit("src/lib.rs");
282        bt.seq_counter = 8;
283        assert!(bt.should_force_full("src/lib.rs"));
284        bt.seq_counter = 20;
285        assert!(!bt.should_force_full("src/lib.rs"));
286    }
287
288    #[test]
289    fn should_force_full_by_extension_bounce_rate() {
290        let mut bt = BounceTracker::new();
291        for i in 1..=6 {
292            bt.seq_counter = i * 2 - 1;
293            bt.record_read(&format!("f{i}.yml"), "map", 30, 400);
294            bt.seq_counter = i * 2;
295            bt.record_read(&format!("f{i}.yml"), "full", 400, 400);
296        }
297        assert!(bt.should_force_full("new.yml"));
298    }
299
300    #[test]
301    fn adjusted_savings_subtracts_waste() {
302        let mut bt = BounceTracker::new();
303        bt.seq_counter = 1;
304        bt.record_read("a.rs", "map", 50, 500);
305        bt.seq_counter = 2;
306        bt.record_read("a.rs", "full", 500, 500);
307        assert_eq!(bt.adjusted_savings(1000), 950);
308    }
309
310    #[test]
311    fn bounce_rate_for_extension_below_minimum() {
312        let bt = BounceTracker::new();
313        assert!(bt.bounce_rate_for_extension("test.rs").is_none());
314    }
315
316    #[test]
317    fn format_summary_empty() {
318        let bt = BounceTracker::new();
319        assert_eq!(bt.format_summary(), "Bounces: 0");
320    }
321}