Skip to main content

lean_ctx/core/
bounce_tracker.rs

1use std::collections::HashMap;
2use std::sync::{Mutex, OnceLock};
3
4const BOUNCE_WINDOW: u64 = 5;
5const BOUNCE_RATE_THRESHOLD: f64 = 0.30;
6
7#[derive(Debug, Clone)]
8struct ReadEvent {
9    _mode: String,
10    tokens_sent: usize,
11    _original_tokens: usize,
12    seq: u64,
13    was_compressed: bool,
14}
15
16#[derive(Debug, Default)]
17struct BounceStats {
18    total_reads: u64,
19    bounces: u64,
20    wasted_tokens: usize,
21}
22
23#[derive(Debug, Default)]
24pub struct BounceTracker {
25    recent_reads: HashMap<String, Vec<ReadEvent>>,
26    per_extension: HashMap<String, BounceStats>,
27    recently_edited: HashMap<String, u64>,
28    seq_counter: u64,
29    total_bounces: u64,
30    total_wasted_tokens: usize,
31    /// When true (only for the process-global tracker), detected bounces are appended to
32    /// the persistent savings ledger so a fresh `gain` process sees historical bounce.
33    /// Local trackers in unit tests leave this `false` to avoid touching the real ledger.
34    persist: bool,
35}
36
37fn is_compressed_mode(mode: &str) -> bool {
38    !matches!(mode, "full" | "diff")
39}
40
41fn extension_of(path: &str) -> String {
42    path.rsplit('.')
43        .next()
44        .map(|e| format!(".{}", e.to_ascii_lowercase()))
45        .unwrap_or_default()
46}
47
48impl BounceTracker {
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    pub fn next_seq(&mut self) -> u64 {
54        self.seq_counter += 1;
55        self.seq_counter
56    }
57
58    pub fn set_seq(&mut self, seq: u64) {
59        self.seq_counter = seq;
60    }
61
62    pub fn record_read(
63        &mut self,
64        path: &str,
65        mode: &str,
66        tokens_sent: usize,
67        original_tokens: usize,
68    ) {
69        let norm = crate::core::pathutil::normalize_tool_path(path);
70        let seq = self.seq_counter;
71        let compressed = is_compressed_mode(mode);
72
73        if !compressed {
74            self.detect_bounce(&norm, seq);
75        }
76
77        let events = self.recent_reads.entry(norm).or_default();
78        events.push(ReadEvent {
79            _mode: mode.to_string(),
80            tokens_sent,
81            _original_tokens: original_tokens,
82            seq,
83            was_compressed: compressed,
84        });
85
86        if events.len() > 10 {
87            events.drain(..events.len() - 10);
88        }
89
90        let ext = extension_of(path);
91        if !ext.is_empty() {
92            let stats = self.per_extension.entry(ext).or_default();
93            stats.total_reads += 1;
94        }
95    }
96
97    fn detect_bounce(&mut self, norm_path: &str, full_seq: u64) {
98        let Some(events) = self.recent_reads.get(norm_path) else {
99            return;
100        };
101
102        if let Some(ev) = events.iter().next_back() {
103            if ev.was_compressed && full_seq.saturating_sub(ev.seq) <= BOUNCE_WINDOW {
104                let wasted = ev.tokens_sent;
105                self.total_bounces += 1;
106                self.total_wasted_tokens += wasted;
107
108                let ext = extension_of(norm_path);
109                if !ext.is_empty() {
110                    let stats = self.per_extension.entry(ext).or_default();
111                    stats.bounces += 1;
112                    stats.wasted_tokens += wasted;
113                }
114
115                if self.persist {
116                    crate::core::savings_ledger::record_bounce_event(wasted);
117                }
118            }
119        }
120    }
121
122    pub fn record_shell_file_access(&mut self, path: &str) {
123        let norm = crate::core::pathutil::normalize_tool_path(path);
124        let seq = self.seq_counter;
125        self.detect_bounce(&norm, seq);
126    }
127
128    pub fn record_edit(&mut self, path: &str) {
129        let norm = crate::core::pathutil::normalize_tool_path(path);
130        self.recently_edited.insert(norm, self.seq_counter);
131    }
132
133    pub fn should_force_full(&self, path: &str) -> bool {
134        let norm = crate::core::pathutil::normalize_tool_path(path);
135
136        if let Some(&edit_seq) = self.recently_edited.get(&norm) {
137            if self.seq_counter.saturating_sub(edit_seq) <= 10 {
138                return true;
139            }
140        }
141
142        let ext = extension_of(path);
143        if !ext.is_empty() {
144            if let Some(stats) = self.per_extension.get(&ext) {
145                if stats.total_reads >= 3 {
146                    let rate = stats.bounces as f64 / stats.total_reads as f64;
147                    if rate >= BOUNCE_RATE_THRESHOLD {
148                        return true;
149                    }
150                }
151            }
152        }
153
154        false
155    }
156
157    pub fn bounce_rate_for_extension(&self, path: &str) -> Option<f64> {
158        let ext = extension_of(path);
159        self.per_extension.get(&ext).and_then(|s| {
160            if s.total_reads >= 3 {
161                Some(s.bounces as f64 / s.total_reads as f64)
162            } else {
163                None
164            }
165        })
166    }
167
168    pub fn total_bounces(&self) -> u64 {
169        self.total_bounces
170    }
171
172    pub fn total_wasted_tokens(&self) -> usize {
173        self.total_wasted_tokens
174    }
175
176    pub fn adjusted_savings(&self, raw_savings: usize) -> isize {
177        raw_savings as isize - self.total_wasted_tokens as isize
178    }
179
180    pub fn per_extension_json(&self) -> Vec<serde_json::Value> {
181        let mut exts: Vec<_> = self
182            .per_extension
183            .iter()
184            .filter(|(_, s)| s.total_reads > 0)
185            .collect();
186        exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
187        exts.iter()
188            .take(10)
189            .map(|(ext, stats)| {
190                let rate = if stats.total_reads > 0 {
191                    stats.bounces as f64 / stats.total_reads as f64
192                } else {
193                    0.0
194                };
195                serde_json::json!({
196                    "ext": ext,
197                    "reads": stats.total_reads,
198                    "bounces": stats.bounces,
199                    "wasted_tokens": stats.wasted_tokens,
200                    "rate": (rate * 1000.0).round() / 1000.0,
201                })
202            })
203            .collect()
204    }
205
206    pub fn format_summary(&self) -> String {
207        if self.total_bounces == 0 {
208            return "Bounces: 0".to_string();
209        }
210        let mut lines = vec![format!(
211            "Bounces: {} ({} wasted tokens)",
212            self.total_bounces, self.total_wasted_tokens
213        )];
214        let mut exts: Vec<_> = self
215            .per_extension
216            .iter()
217            .filter(|(_, s)| s.bounces > 0)
218            .collect();
219        exts.sort_by_key(|a| std::cmp::Reverse(a.1.bounces));
220        for (ext, stats) in exts.iter().take(5) {
221            let rate = if stats.total_reads > 0 {
222                stats.bounces as f64 / stats.total_reads as f64 * 100.0
223            } else {
224                0.0
225            };
226            lines.push(format!(
227                "  {ext}: {}/{} reads bounced ({rate:.0}%), {} tok wasted",
228                stats.bounces, stats.total_reads, stats.wasted_tokens,
229            ));
230        }
231        lines.join("\n")
232    }
233}
234
235static GLOBAL_TRACKER: OnceLock<Mutex<BounceTracker>> = OnceLock::new();
236
237pub fn global() -> &'static Mutex<BounceTracker> {
238    GLOBAL_TRACKER.get_or_init(|| {
239        // Seed from the persistent ledger so every process (including a fresh `gain`)
240        // accounts for historical bounce, then mark this tracker as the persisting one.
241        let summary = crate::core::savings_ledger::summary();
242        let mut bt = BounceTracker::new();
243        bt.total_wasted_tokens = summary.bounce_tokens as usize;
244        bt.total_bounces = summary.bounce_events as u64;
245        bt.persist = true;
246        Mutex::new(bt)
247    })
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn no_bounce_when_first_read_is_full() {
256        let mut bt = BounceTracker::new();
257        bt.seq_counter = 1;
258        bt.record_read("src/main.rs", "full", 500, 500);
259        assert_eq!(bt.total_bounces(), 0);
260        assert_eq!(bt.total_wasted_tokens(), 0);
261    }
262
263    #[test]
264    fn bounce_detected_on_compressed_then_full() {
265        let mut bt = BounceTracker::new();
266        bt.seq_counter = 1;
267        bt.record_read("src/main.rs", "map", 50, 500);
268        bt.seq_counter = 2;
269        bt.record_read("src/main.rs", "full", 500, 500);
270        assert_eq!(bt.total_bounces(), 1);
271        assert_eq!(bt.total_wasted_tokens(), 50);
272    }
273
274    #[test]
275    fn no_bounce_outside_window() {
276        let mut bt = BounceTracker::new();
277        bt.seq_counter = 1;
278        bt.record_read("src/main.rs", "map", 50, 500);
279        bt.seq_counter = 10;
280        bt.record_read("src/main.rs", "full", 500, 500);
281        assert_eq!(bt.total_bounces(), 0);
282    }
283
284    #[test]
285    fn shell_access_triggers_bounce() {
286        let mut bt = BounceTracker::new();
287        bt.seq_counter = 1;
288        bt.record_read("config.yml", "signatures", 30, 400);
289        bt.seq_counter = 3;
290        bt.record_shell_file_access("config.yml");
291        assert_eq!(bt.total_bounces(), 1);
292        assert_eq!(bt.total_wasted_tokens(), 30);
293    }
294
295    #[test]
296    fn should_force_full_after_edit() {
297        let mut bt = BounceTracker::new();
298        bt.seq_counter = 5;
299        bt.record_edit("src/lib.rs");
300        bt.seq_counter = 8;
301        assert!(bt.should_force_full("src/lib.rs"));
302        bt.seq_counter = 20;
303        assert!(!bt.should_force_full("src/lib.rs"));
304    }
305
306    #[test]
307    fn should_force_full_by_extension_bounce_rate() {
308        let mut bt = BounceTracker::new();
309        for i in 1..=6 {
310            bt.seq_counter = i * 2 - 1;
311            bt.record_read(&format!("f{i}.yml"), "map", 30, 400);
312            bt.seq_counter = i * 2;
313            bt.record_read(&format!("f{i}.yml"), "full", 400, 400);
314        }
315        assert!(bt.should_force_full("new.yml"));
316    }
317
318    #[test]
319    fn adjusted_savings_subtracts_waste() {
320        let mut bt = BounceTracker::new();
321        bt.seq_counter = 1;
322        bt.record_read("a.rs", "map", 50, 500);
323        bt.seq_counter = 2;
324        bt.record_read("a.rs", "full", 500, 500);
325        assert_eq!(bt.adjusted_savings(1000), 950);
326    }
327
328    #[test]
329    fn bounce_rate_for_extension_below_minimum() {
330        let bt = BounceTracker::new();
331        assert!(bt.bounce_rate_for_extension("test.rs").is_none());
332    }
333
334    #[test]
335    fn format_summary_empty() {
336        let bt = BounceTracker::new();
337        assert_eq!(bt.format_summary(), "Bounces: 0");
338    }
339}