Skip to main content

lean_ctx/core/neural/
mod.rs

1//! Neural context compression — trained models replacing heuristic filters.
2//!
3//! Feature-gated under `#[cfg(feature = "neural")]`.
4//! When an ONNX model is present, switches from heuristic to neural scoring.
5//! Falls back gracefully to heuristic mode when no model is available.
6
7pub mod attention_learned;
8pub mod cache_alignment;
9pub mod context_reorder;
10pub mod line_scorer;
11pub mod token_optimizer;
12
13use std::path::PathBuf;
14
15use attention_learned::LearnedAttention;
16use line_scorer::NeuralLineScorer;
17use token_optimizer::TokenOptimizer;
18
19pub struct NeuralEngine {
20    line_scorer: Option<NeuralLineScorer>,
21    token_optimizer: TokenOptimizer,
22    attention: LearnedAttention,
23}
24
25impl NeuralEngine {
26    pub fn load() -> Self {
27        let model_dir = Self::model_directory();
28
29        let line_scorer = if model_dir.join("line_importance.onnx").exists() {
30            match NeuralLineScorer::load(&model_dir.join("line_importance.onnx")) {
31                Ok(scorer) => {
32                    tracing::info!("Neural line scorer loaded from {:?}", model_dir);
33                    Some(scorer)
34                }
35                Err(e) => {
36                    tracing::warn!(
37                        "Failed to load neural line scorer: {e}. Using heuristic fallback."
38                    );
39                    None
40                }
41            }
42        } else {
43            tracing::debug!("No ONNX model found, using heuristic line scoring");
44            None
45        };
46
47        let token_optimizer = TokenOptimizer::load_or_default(&model_dir);
48        let attention = LearnedAttention::load_or_default(&model_dir);
49
50        Self {
51            line_scorer,
52            token_optimizer,
53            attention,
54        }
55    }
56
57    pub fn score_line(&self, line: &str, position: f64, task_keywords: &[String]) -> f64 {
58        if let Some(ref scorer) = self.line_scorer {
59            scorer.score_line(line, position, task_keywords)
60        } else {
61            self.heuristic_score(line, position)
62        }
63    }
64
65    pub fn optimize_line(&self, line: &str) -> String {
66        self.token_optimizer.optimize_line(line)
67    }
68
69    pub fn attention_weight(&self, position: f64) -> f64 {
70        self.attention.weight(position)
71    }
72
73    pub fn has_neural_model(&self) -> bool {
74        self.line_scorer.is_some()
75    }
76
77    fn heuristic_score(&self, line: &str, position: f64) -> f64 {
78        let structural = super::attention_model::structural_importance(line);
79        let positional = self.attention.weight(position);
80        (structural * positional).sqrt()
81    }
82
83    fn model_directory() -> PathBuf {
84        if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
85            return PathBuf::from(dir);
86        }
87
88        if let Some(data_dir) = dirs::data_dir() {
89            return data_dir.join("lean-ctx").join("models");
90        }
91
92        PathBuf::from("models")
93    }
94}