lean_ctx/core/neural/
mod.rs1pub mod attention_learned;
8pub mod cache_alignment;
9pub mod context_reorder;
10pub mod line_scorer;
11pub mod token_optimizer;
12
13use std::path::PathBuf;
14
15use attention_learned::LearnedAttention;
16use line_scorer::NeuralLineScorer;
17use token_optimizer::TokenOptimizer;
18
19pub struct NeuralEngine {
20 line_scorer: Option<NeuralLineScorer>,
21 token_optimizer: TokenOptimizer,
22 attention: LearnedAttention,
23}
24
25impl NeuralEngine {
26 pub fn load() -> Self {
27 let model_dir = Self::model_directory();
28
29 let line_scorer = if model_dir.join("line_importance.onnx").exists() {
30 match NeuralLineScorer::load(&model_dir.join("line_importance.onnx")) {
31 Ok(scorer) => {
32 tracing::info!("Neural line scorer loaded from {:?}", model_dir);
33 Some(scorer)
34 }
35 Err(e) => {
36 tracing::warn!(
37 "Failed to load neural line scorer: {e}. Using heuristic fallback."
38 );
39 None
40 }
41 }
42 } else {
43 tracing::debug!("No ONNX model found, using heuristic line scoring");
44 None
45 };
46
47 let token_optimizer = TokenOptimizer::load_or_default(&model_dir);
48 let attention = LearnedAttention::load_or_default(&model_dir);
49
50 Self {
51 line_scorer,
52 token_optimizer,
53 attention,
54 }
55 }
56
57 pub fn score_line(&self, line: &str, position: f64, task_keywords: &[String]) -> f64 {
58 if let Some(ref scorer) = self.line_scorer {
59 scorer.score_line(line, position, task_keywords)
60 } else {
61 self.heuristic_score(line, position)
62 }
63 }
64
65 pub fn optimize_line(&self, line: &str) -> String {
66 self.token_optimizer.optimize_line(line)
67 }
68
69 pub fn attention_weight(&self, position: f64) -> f64 {
70 self.attention.weight(position)
71 }
72
73 pub fn has_neural_model(&self) -> bool {
74 self.line_scorer.is_some()
75 }
76
77 fn heuristic_score(&self, line: &str, position: f64) -> f64 {
78 let structural = super::attention_model::structural_importance(line);
79 let positional = self.attention.weight(position);
80 (structural * positional).sqrt()
81 }
82
83 fn model_directory() -> PathBuf {
84 if let Ok(dir) = std::env::var("LEAN_CTX_MODELS_DIR") {
85 return PathBuf::from(dir);
86 }
87
88 if let Some(data_dir) = dirs::data_dir() {
89 return data_dir.join("lean-ctx").join("models");
90 }
91
92 PathBuf::from("models")
93 }
94}