Skip to main content

lean_ctx/core/
feedback.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::Instant;
4
5use serde::{Deserialize, Serialize};
6
7const FEEDBACK_FLUSH_SECS: u64 = 60;
8
9static FEEDBACK_BUFFER: Mutex<Option<(FeedbackStore, Instant)>> = Mutex::new(None);
10
11/// Feedback loop for learning optimal compression parameters.
12///
13/// Tracks compression outcomes per session and learns which
14/// threshold combinations lead to fewer turns and higher success rates.
15
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct CompressionOutcome {
18    pub session_id: String,
19    pub language: String,
20    pub entropy_threshold: f64,
21    pub jaccard_threshold: f64,
22    pub total_turns: u32,
23    pub tokens_saved: u64,
24    pub tokens_original: u64,
25    pub cache_hits: u32,
26    pub total_reads: u32,
27    pub task_completed: bool,
28    pub timestamp: String,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, Default)]
32pub struct FeedbackStore {
33    pub outcomes: Vec<CompressionOutcome>,
34    pub learned_thresholds: HashMap<String, LearnedThresholds>,
35    #[serde(skip)]
36    pub project_root: Option<String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct LearnedThresholds {
41    pub entropy: f64,
42    pub jaccard: f64,
43    pub sample_count: u32,
44    pub avg_efficiency: f64,
45}
46
47impl FeedbackStore {
48    pub fn load() -> Self {
49        let guard = FEEDBACK_BUFFER
50            .lock()
51            .unwrap_or_else(std::sync::PoisonError::into_inner);
52        if let Some((ref store, _)) = *guard {
53            let mut s = store.clone();
54            if s.project_root.is_none() {
55                s.project_root = std::env::current_dir()
56                    .ok()
57                    .map(|p| p.to_string_lossy().to_string());
58            }
59            return s;
60        }
61        drop(guard);
62
63        let path = feedback_path();
64        if path.exists() {
65            if let Ok(content) = std::fs::read_to_string(&path) {
66                if let Ok(mut store) = serde_json::from_str::<FeedbackStore>(&content) {
67                    store.project_root = std::env::current_dir()
68                        .ok()
69                        .map(|p| p.to_string_lossy().to_string());
70                    return store;
71                }
72            }
73        }
74        Self {
75            project_root: std::env::current_dir()
76                .ok()
77                .map(|p| p.to_string_lossy().to_string()),
78            ..Self::default()
79        }
80    }
81
82    fn save_to_disk(&self) {
83        let path = feedback_path();
84        if let Some(parent) = path.parent() {
85            let _ = std::fs::create_dir_all(parent);
86        }
87        if let Ok(json) = serde_json::to_string_pretty(self) {
88            let _ = std::fs::write(path, json);
89        }
90    }
91
92    pub fn save(&self) {
93        self.save_to_disk();
94    }
95
96    pub fn flush() {
97        let guard = FEEDBACK_BUFFER
98            .lock()
99            .unwrap_or_else(std::sync::PoisonError::into_inner);
100        if let Some((ref store, _)) = *guard {
101            store.save_to_disk();
102        }
103    }
104
105    pub fn record_outcome(&mut self, outcome: CompressionOutcome) {
106        let lang = outcome.language.clone();
107        self.update_bandit(&outcome);
108        self.outcomes.push(outcome);
109
110        if self.outcomes.len() > 200 {
111            self.outcomes.drain(0..self.outcomes.len() - 200);
112        }
113
114        self.update_learned_thresholds(&lang);
115
116        let mut guard = FEEDBACK_BUFFER
117            .lock()
118            .unwrap_or_else(std::sync::PoisonError::into_inner);
119        let should_flush = match *guard {
120            Some((_, ref last)) => last.elapsed().as_secs() >= FEEDBACK_FLUSH_SECS,
121            None => true,
122        };
123        *guard = Some((
124            self.clone(),
125            guard.as_ref().map_or_else(Instant::now, |(_, t)| *t),
126        ));
127        if should_flush {
128            self.save_to_disk();
129            if let Some((_, ref mut t)) = *guard {
130                *t = Instant::now();
131            }
132        }
133    }
134
135    fn update_bandit(&self, outcome: &CompressionOutcome) {
136        let key = format!("{}_feedback", outcome.language);
137        let project_root = self.project_root.as_deref().unwrap_or(".");
138        let mut store = crate::core::bandit::BanditStore::load(project_root);
139        let bandit = store.get_or_create(&key);
140        bandit.total_pulls = bandit.total_pulls.saturating_add(1);
141
142        let efficiency = if outcome.tokens_original > 0 {
143            outcome.tokens_saved as f64 / outcome.tokens_original as f64
144        } else {
145            0.0
146        };
147        let success = efficiency > 0.3 && outcome.task_completed;
148
149        let arm_name = if outcome.entropy_threshold >= 1.0 {
150            "conservative"
151        } else if outcome.entropy_threshold >= 0.7 {
152            "balanced"
153        } else {
154            "aggressive"
155        };
156
157        let old_mean = bandit
158            .arms
159            .iter()
160            .find(|a| a.name == arm_name)
161            .map_or(0.5, super::bandit::BanditArm::mean);
162
163        bandit.update(arm_name, success);
164
165        let new_mean = bandit
166            .arms
167            .iter()
168            .find(|a| a.name == arm_name)
169            .map_or(0.5, super::bandit::BanditArm::mean);
170
171        if (new_mean - old_mean).abs() > 0.05 {
172            crate::core::events::emit_threshold_adapted(
173                &outcome.language,
174                arm_name,
175                old_mean,
176                new_mean,
177            );
178        }
179
180        if bandit.total_pulls > 0 && bandit.total_pulls.is_multiple_of(50) {
181            bandit.decay_all(0.95);
182        }
183
184        let _ = store.save(project_root);
185    }
186
187    fn update_learned_thresholds(&mut self, language: &str) {
188        let relevant: Vec<&CompressionOutcome> = self
189            .outcomes
190            .iter()
191            .filter(|o| o.language == language && o.task_completed)
192            .collect();
193
194        if relevant.len() < 5 {
195            return; // not enough data to learn
196        }
197
198        // Find the threshold combination that maximizes efficiency
199        // Efficiency = tokens_saved / tokens_original * (1 / total_turns)
200        let mut best_entropy = 1.0;
201        let mut best_jaccard = 0.7;
202        let mut best_efficiency = 0.0;
203
204        for outcome in &relevant {
205            let compression_ratio = if outcome.tokens_original > 0 {
206                outcome.tokens_saved as f64 / outcome.tokens_original as f64
207            } else {
208                0.0
209            };
210            let turn_efficiency = 1.0 / (outcome.total_turns.max(1) as f64);
211            let efficiency = compression_ratio * 0.6 + turn_efficiency * 0.4;
212
213            if efficiency > best_efficiency {
214                best_efficiency = efficiency;
215                best_entropy = outcome.entropy_threshold;
216                best_jaccard = outcome.jaccard_threshold;
217            }
218        }
219
220        // Weighted average with current learned values for stability
221        let entry = self
222            .learned_thresholds
223            .entry(language.to_string())
224            .or_insert(LearnedThresholds {
225                entropy: best_entropy,
226                jaccard: best_jaccard,
227                sample_count: 0,
228                avg_efficiency: 0.0,
229            });
230
231        let momentum = 0.7;
232        let old_entropy = entry.entropy;
233        let old_jaccard = entry.jaccard;
234        entry.entropy = entry.entropy * momentum + best_entropy * (1.0 - momentum);
235        entry.jaccard = entry.jaccard * momentum + best_jaccard * (1.0 - momentum);
236        entry.sample_count = relevant.len() as u32;
237        entry.avg_efficiency = best_efficiency;
238
239        if (old_entropy - entry.entropy).abs() > 0.01 || (old_jaccard - entry.jaccard).abs() > 0.01
240        {
241            crate::core::events::emit(crate::core::events::EventKind::ThresholdShift {
242                language: language.to_string(),
243                old_entropy,
244                new_entropy: entry.entropy,
245                old_jaccard,
246                new_jaccard: entry.jaccard,
247            });
248        }
249    }
250
251    pub fn get_learned_entropy(&self, language: &str) -> Option<f64> {
252        self.learned_thresholds.get(language).map(|t| t.entropy)
253    }
254
255    pub fn get_learned_jaccard(&self, language: &str) -> Option<f64> {
256        self.learned_thresholds.get(language).map(|t| t.jaccard)
257    }
258
259    pub fn format_report(&self) -> String {
260        let mut lines = vec![String::from("Feedback Loop Report")];
261        lines.push(format!("Total outcomes tracked: {}", self.outcomes.len()));
262        lines.push(String::new());
263
264        if self.learned_thresholds.is_empty() {
265            lines.push(
266                "No learned thresholds yet (need 5+ completed sessions per language).".to_string(),
267            );
268        } else {
269            lines.push("Learned Thresholds:".to_string());
270            for (lang, t) in &self.learned_thresholds {
271                lines.push(format!(
272                    "  {lang}: entropy={:.2} jaccard={:.2} (n={}, eff={:.1}%)",
273                    t.entropy,
274                    t.jaccard,
275                    t.sample_count,
276                    t.avg_efficiency * 100.0
277                ));
278            }
279        }
280
281        lines.push(String::new());
282        let project_root = self.project_root.as_deref().unwrap_or(".");
283        let store = crate::core::bandit::BanditStore::load(project_root);
284        lines.push(store.format_report());
285
286        lines.join("\n")
287    }
288}
289
290fn feedback_path() -> std::path::PathBuf {
291    crate::core::data_dir::lean_ctx_data_dir()
292        .unwrap_or_else(|_| std::path::PathBuf::from("."))
293        .join("feedback.json")
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn empty_store_loads() {
302        let store = FeedbackStore::default();
303        assert!(store.outcomes.is_empty());
304        assert!(store.learned_thresholds.is_empty());
305    }
306
307    #[test]
308    fn learned_thresholds_need_minimum_samples() {
309        let mut store = FeedbackStore::default();
310        for i in 0..3 {
311            store.record_outcome(CompressionOutcome {
312                session_id: format!("s{i}"),
313                language: "rs".to_string(),
314                entropy_threshold: 0.85,
315                jaccard_threshold: 0.72,
316                total_turns: 5,
317                tokens_saved: 1000,
318                tokens_original: 2000,
319                cache_hits: 3,
320                total_reads: 10,
321                task_completed: true,
322                timestamp: String::new(),
323            });
324        }
325        assert!(store.get_learned_entropy("rs").is_none()); // only 3, need 5
326    }
327}