Skip to main content

lean_ctx/core/
feedback.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5use serde::{Deserialize, Serialize};
6
7const FEEDBACK_FLUSH_SECS: u64 = 60;
8
9static FEEDBACK_BUFFER: Mutex<Option<(FeedbackStore, Instant)>> = Mutex::new(None);
10
11/// Feedback loop for learning optimal compression parameters.
12///
13/// Tracks compression outcomes per session and learns which
14/// threshold combinations lead to fewer turns and higher success rates.
15
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct CompressionOutcome {
18    pub session_id: String,
19    pub language: String,
20    pub entropy_threshold: f64,
21    pub jaccard_threshold: f64,
22    pub total_turns: u32,
23    pub tokens_saved: u64,
24    pub tokens_original: u64,
25    pub cache_hits: u32,
26    pub total_reads: u32,
27    pub task_completed: bool,
28    pub timestamp: String,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, Default)]
32pub struct FeedbackStore {
33    pub outcomes: Vec<CompressionOutcome>,
34    pub learned_thresholds: HashMap<String, LearnedThresholds>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct LearnedThresholds {
39    pub entropy: f64,
40    pub jaccard: f64,
41    pub sample_count: u32,
42    pub avg_efficiency: f64,
43}
44
45impl FeedbackStore {
46    pub fn load() -> Self {
47        let guard = FEEDBACK_BUFFER
48            .lock()
49            .unwrap_or_else(std::sync::PoisonError::into_inner);
50        if let Some((ref store, _)) = *guard {
51            return store.clone();
52        }
53        drop(guard);
54
55        let path = feedback_path();
56        if path.exists() {
57            if let Ok(content) = std::fs::read_to_string(&path) {
58                if let Ok(store) = serde_json::from_str::<FeedbackStore>(&content) {
59                    return store;
60                }
61            }
62        }
63        Self::default()
64    }
65
66    fn save_to_disk(&self) {
67        let path = feedback_path();
68        if let Some(parent) = path.parent() {
69            let _ = std::fs::create_dir_all(parent);
70        }
71        if let Ok(json) = serde_json::to_string_pretty(self) {
72            let _ = std::fs::write(path, json);
73        }
74    }
75
76    pub fn save(&self) {
77        self.save_to_disk();
78    }
79
80    pub fn flush() {
81        let guard = FEEDBACK_BUFFER
82            .lock()
83            .unwrap_or_else(std::sync::PoisonError::into_inner);
84        if let Some((ref store, _)) = *guard {
85            store.save_to_disk();
86        }
87    }
88
89    pub fn record_outcome(&mut self, outcome: CompressionOutcome) {
90        let lang = outcome.language.clone();
91        self.outcomes.push(outcome);
92
93        if self.outcomes.len() > 200 {
94            self.outcomes.drain(0..self.outcomes.len() - 200);
95        }
96
97        self.update_learned_thresholds(&lang);
98
99        let mut guard = FEEDBACK_BUFFER
100            .lock()
101            .unwrap_or_else(std::sync::PoisonError::into_inner);
102        let should_flush = match *guard {
103            Some((_, ref last)) => last.elapsed().as_secs() >= FEEDBACK_FLUSH_SECS,
104            None => true,
105        };
106        *guard = Some((
107            self.clone(),
108            guard.as_ref().map_or_else(Instant::now, |(_, t)| *t),
109        ));
110        if should_flush {
111            self.save_to_disk();
112            if let Some((_, ref mut t)) = *guard {
113                *t = Instant::now();
114            }
115        }
116    }
117
118    fn update_learned_thresholds(&mut self, language: &str) {
119        let relevant: Vec<&CompressionOutcome> = self
120            .outcomes
121            .iter()
122            .filter(|o| o.language == language && o.task_completed)
123            .collect();
124
125        if relevant.len() < 5 {
126            return; // not enough data to learn
127        }
128
129        // Find the threshold combination that maximizes efficiency
130        // Efficiency = tokens_saved / tokens_original * (1 / total_turns)
131        let mut best_entropy = 1.0;
132        let mut best_jaccard = 0.7;
133        let mut best_efficiency = 0.0;
134
135        for outcome in &relevant {
136            let compression_ratio = if outcome.tokens_original > 0 {
137                outcome.tokens_saved as f64 / outcome.tokens_original as f64
138            } else {
139                0.0
140            };
141            let turn_efficiency = 1.0 / (outcome.total_turns.max(1) as f64);
142            let efficiency = compression_ratio * 0.6 + turn_efficiency * 0.4;
143
144            if efficiency > best_efficiency {
145                best_efficiency = efficiency;
146                best_entropy = outcome.entropy_threshold;
147                best_jaccard = outcome.jaccard_threshold;
148            }
149        }
150
151        // Weighted average with current learned values for stability
152        let entry = self
153            .learned_thresholds
154            .entry(language.to_string())
155            .or_insert(LearnedThresholds {
156                entropy: best_entropy,
157                jaccard: best_jaccard,
158                sample_count: 0,
159                avg_efficiency: 0.0,
160            });
161
162        let momentum = 0.7;
163        let old_entropy = entry.entropy;
164        let old_jaccard = entry.jaccard;
165        entry.entropy = entry.entropy * momentum + best_entropy * (1.0 - momentum);
166        entry.jaccard = entry.jaccard * momentum + best_jaccard * (1.0 - momentum);
167        entry.sample_count = relevant.len() as u32;
168        entry.avg_efficiency = best_efficiency;
169
170        if (old_entropy - entry.entropy).abs() > 0.01 || (old_jaccard - entry.jaccard).abs() > 0.01
171        {
172            crate::core::events::emit(crate::core::events::EventKind::ThresholdShift {
173                language: language.to_string(),
174                old_entropy,
175                new_entropy: entry.entropy,
176                old_jaccard,
177                new_jaccard: entry.jaccard,
178            });
179        }
180    }
181
182    pub fn get_learned_entropy(&self, language: &str) -> Option<f64> {
183        self.learned_thresholds.get(language).map(|t| t.entropy)
184    }
185
186    pub fn get_learned_jaccard(&self, language: &str) -> Option<f64> {
187        self.learned_thresholds.get(language).map(|t| t.jaccard)
188    }
189
190    pub fn format_report(&self) -> String {
191        let mut lines = vec![String::from("Feedback Loop Report")];
192        lines.push(format!("Total outcomes tracked: {}", self.outcomes.len()));
193        lines.push(String::new());
194
195        if self.learned_thresholds.is_empty() {
196            lines.push(
197                "No learned thresholds yet (need 5+ completed sessions per language).".to_string(),
198            );
199        } else {
200            lines.push("Learned Thresholds:".to_string());
201            for (lang, t) in &self.learned_thresholds {
202                lines.push(format!(
203                    "  {lang}: entropy={:.2} jaccard={:.2} (n={}, eff={:.1}%)",
204                    t.entropy,
205                    t.jaccard,
206                    t.sample_count,
207                    t.avg_efficiency * 100.0
208                ));
209            }
210        }
211
212        lines.join("\n")
213    }
214}
215
216fn feedback_path() -> std::path::PathBuf {
217    crate::core::data_dir::lean_ctx_data_dir()
218        .unwrap_or_else(|_| std::path::PathBuf::from("."))
219        .join("feedback.json")
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn empty_store_loads() {
228        let store = FeedbackStore::default();
229        assert!(store.outcomes.is_empty());
230        assert!(store.learned_thresholds.is_empty());
231    }
232
233    #[test]
234    fn learned_thresholds_need_minimum_samples() {
235        let mut store = FeedbackStore::default();
236        for i in 0..3 {
237            store.record_outcome(CompressionOutcome {
238                session_id: format!("s{i}"),
239                language: "rs".to_string(),
240                entropy_threshold: 0.85,
241                jaccard_threshold: 0.72,
242                total_turns: 5,
243                tokens_saved: 1000,
244                tokens_original: 2000,
245                cache_hits: 3,
246                total_reads: 10,
247                task_completed: true,
248                timestamp: String::new(),
249            });
250        }
251        assert!(store.get_learned_entropy("rs").is_none()); // only 3, need 5
252    }
253}