Skip to main content

lean_ctx/core/neural/
line_scorer.rs

1//! Neural line importance scorer using ONNX inference via rten.
2//!
3//! Replaces the heuristic IB-Filter with a trained model that predicts
4//! per-line importance based on structural features.
5//!
6//! When no ONNX model is available, falls back to the decision-tree
7//! implementation (static rules generated by distill.py).
8
9use std::path::Path;
10
11#[cfg(feature = "neural")]
12use rten::Model;
13
14pub struct NeuralLineScorer {
15    #[cfg(feature = "neural")]
16    model: Model,
17    #[cfg(not(feature = "neural"))]
18    _phantom: (),
19}
20
21#[derive(Debug, Clone)]
22pub struct LineFeatures {
23    pub line_length: f64,
24    pub indentation_level: f64,
25    pub token_diversity: f64,
26    pub is_definition: f64,
27    pub is_import: f64,
28    pub is_comment: f64,
29    pub is_closing: f64,
30    pub keyword_density: f64,
31    pub position_normalized: f64,
32    pub has_type_annotation: f64,
33    pub nesting_depth: f64,
34    pub prev_line_type: f64,
35    pub next_line_type: f64,
36}
37
38impl LineFeatures {
39    pub fn from_line(line: &str, position: f64, context: &LineContext) -> Self {
40        let trimmed = line.trim();
41        let leading = (line.len() - line.trim_start().len()) as f64;
42
43        Self {
44            line_length: trimmed.len() as f64,
45            indentation_level: leading / 4.0,
46            token_diversity: Self::compute_token_diversity(trimmed),
47            is_definition: if Self::check_definition(trimmed) {
48                1.0
49            } else {
50                0.0
51            },
52            is_import: if Self::check_import(trimmed) {
53                1.0
54            } else {
55                0.0
56            },
57            is_comment: if Self::check_comment(trimmed) {
58                1.0
59            } else {
60                0.0
61            },
62            is_closing: if Self::check_closing(trimmed) {
63                1.0
64            } else {
65                0.0
66            },
67            keyword_density: Self::compute_keyword_density(trimmed),
68            position_normalized: position,
69            has_type_annotation: if Self::check_type_annotation(trimmed) {
70                1.0
71            } else {
72                0.0
73            },
74            nesting_depth: context.nesting_depth as f64,
75            prev_line_type: context.prev_line_type as f64,
76            next_line_type: context.next_line_type as f64,
77        }
78    }
79
80    pub fn to_array(&self) -> [f64; 13] {
81        [
82            self.line_length,
83            self.indentation_level,
84            self.token_diversity,
85            self.is_definition,
86            self.is_import,
87            self.is_comment,
88            self.is_closing,
89            self.keyword_density,
90            self.position_normalized,
91            self.has_type_annotation,
92            self.nesting_depth,
93            self.prev_line_type,
94            self.next_line_type,
95        ]
96    }
97
98    fn compute_token_diversity(line: &str) -> f64 {
99        let tokens: Vec<&str> = line.split_whitespace().collect();
100        if tokens.is_empty() {
101            return 0.0;
102        }
103        let unique: std::collections::HashSet<&str> = tokens.iter().copied().collect();
104        unique.len() as f64 / tokens.len() as f64
105    }
106
107    fn check_definition(line: &str) -> bool {
108        const STARTERS: &[&str] = &[
109            "fn ",
110            "pub fn ",
111            "async fn ",
112            "pub async fn ",
113            "def ",
114            "async def ",
115            "function ",
116            "export function ",
117            "async function ",
118            "class ",
119            "export class ",
120            "struct ",
121            "pub struct ",
122            "enum ",
123            "pub enum ",
124            "trait ",
125            "pub trait ",
126            "impl ",
127            "type ",
128            "pub type ",
129            "interface ",
130            "export interface ",
131        ];
132        STARTERS.iter().any(|s| line.starts_with(s))
133    }
134
135    fn check_import(line: &str) -> bool {
136        line.starts_with("import ")
137            || line.starts_with("use ")
138            || line.starts_with("from ")
139            || line.starts_with("#include")
140            || line.starts_with("require(")
141    }
142
143    fn check_comment(line: &str) -> bool {
144        line.starts_with("//")
145            || line.starts_with('#')
146            || line.starts_with("/*")
147            || line.starts_with('*')
148            || line.starts_with("///")
149    }
150
151    fn check_closing(line: &str) -> bool {
152        matches!(line, "}" | "};" | "})" | "]" | ");" | "end")
153    }
154
155    fn check_type_annotation(line: &str) -> bool {
156        line.contains("->")
157            || line.contains("=>")
158            || line.contains(": ")
159            || line.contains("Result<")
160            || line.contains("Option<")
161    }
162
163    fn compute_keyword_density(line: &str) -> f64 {
164        const KEYWORDS: &[&str] = &[
165            "fn",
166            "let",
167            "mut",
168            "pub",
169            "use",
170            "impl",
171            "struct",
172            "enum",
173            "match",
174            "if",
175            "else",
176            "for",
177            "while",
178            "return",
179            "async",
180            "await",
181            "trait",
182            "where",
183            "def",
184            "class",
185            "import",
186            "from",
187            "function",
188            "export",
189            "const",
190            "var",
191            "type",
192            "interface",
193            "try",
194            "catch",
195            "throw",
196            "yield",
197            "raise",
198        ];
199        let tokens: Vec<&str> = line.split_whitespace().collect();
200        if tokens.is_empty() {
201            return 0.0;
202        }
203        let hits = tokens
204            .iter()
205            .filter(|t| {
206                let clean = t.trim_end_matches(|c: char| !c.is_alphanumeric());
207                KEYWORDS.contains(&clean)
208            })
209            .count();
210        hits as f64 / tokens.len() as f64
211    }
212}
213
214#[derive(Debug, Clone, Default)]
215pub struct LineContext {
216    pub nesting_depth: usize,
217    pub prev_line_type: u8,
218    pub next_line_type: u8,
219}
220
221impl NeuralLineScorer {
222    #[cfg(feature = "neural")]
223    pub fn load(model_path: &Path) -> anyhow::Result<Self> {
224        let model = Model::load_file(model_path)?;
225        Ok(Self { model })
226    }
227
228    #[cfg(not(feature = "neural"))]
229    pub fn load(_model_path: &Path) -> anyhow::Result<Self> {
230        anyhow::bail!("Neural feature not enabled. Compile with --features neural")
231    }
232
233    pub fn score_line(&self, line: &str, position: f64, task_keywords: &[String]) -> f64 {
234        let context = LineContext::default();
235        let features = LineFeatures::from_line(line, position, &context);
236        self.score_from_features(&features, task_keywords)
237    }
238
239    pub fn score_from_features(&self, features: &LineFeatures, _task_keywords: &[String]) -> f64 {
240        #[cfg(feature = "neural")]
241        {
242            self.neural_score(features)
243        }
244        #[cfg(not(feature = "neural"))]
245        {
246            self.decision_tree_score(features)
247        }
248    }
249
250    #[cfg(feature = "neural")]
251    fn neural_score(&self, features: &LineFeatures) -> f64 {
252        use rten_tensor::{AsView, NdTensor};
253
254        let input_data = features.to_array();
255        let float_data: Vec<f32> = input_data.iter().map(|&x| x as f32).collect();
256        let input = NdTensor::from_data([1, 13], float_data);
257
258        match self.model.run_one(input.into(), None) {
259            Ok(output) => {
260                let tensor: Vec<f32> = output
261                    .into_tensor::<f32>()
262                    .map(|t| t.to_vec())
263                    .unwrap_or_default();
264                tensor.first().copied().unwrap_or(0.5) as f64
265            }
266            Err(_) => 0.5,
267        }
268    }
269
270    #[cfg(not(feature = "neural"))]
271    fn decision_tree_score(&self, features: &LineFeatures) -> f64 {
272        let f = features.to_array();
273
274        let mut score = 0.5;
275
276        if f[3] > 0.5 {
277            score += 0.3; // is_definition
278        }
279        if f[5] > 0.5 {
280            score -= 0.2; // is_comment
281        }
282        if f[6] > 0.5 {
283            score -= 0.3; // is_closing
284        }
285        if f[4] > 0.5 {
286            score -= 0.1; // is_import
287        }
288        if f[9] > 0.5 {
289            score += 0.15; // has_type_annotation
290        }
291
292        let pos = f[8];
293        let u_curve = if pos <= 0.5 {
294            1.0 - 0.6 * (2.0 * pos).powi(2)
295        } else {
296            1.0 - 0.6 * (2.0 * (1.0 - pos)).powi(2)
297        };
298        score *= u_curve;
299
300        score.clamp(0.0, 1.0)
301    }
302}
303
304pub fn score_all_lines(
305    lines: &[&str],
306    scorer: &NeuralLineScorer,
307    task_keywords: &[String],
308) -> Vec<f64> {
309    let n = lines.len();
310    let mut nesting_depth: usize = 0;
311
312    lines
313        .iter()
314        .enumerate()
315        .map(|(i, line)| {
316            let trimmed = line.trim();
317            nesting_depth = nesting_depth
318                .saturating_add(trimmed.matches('{').count())
319                .saturating_sub(trimmed.matches('}').count());
320
321            let prev_type = if i > 0 {
322                classify_type(lines[i - 1].trim())
323            } else {
324                0
325            };
326            let next_type = if i + 1 < n {
327                classify_type(lines[i + 1].trim())
328            } else {
329                0
330            };
331            let position = i as f64 / (n.max(1) - 1).max(1) as f64;
332
333            let context = LineContext {
334                nesting_depth,
335                prev_line_type: prev_type,
336                next_line_type: next_type,
337            };
338            let features = LineFeatures::from_line(line, position, &context);
339            scorer.score_from_features(&features, task_keywords)
340        })
341        .collect()
342}
343
344fn classify_type(line: &str) -> u8 {
345    if line.is_empty() {
346        return 0;
347    }
348    if LineFeatures::check_definition(line) {
349        return 1;
350    }
351    if LineFeatures::check_import(line) {
352        return 2;
353    }
354    if LineFeatures::check_comment(line) {
355        return 3;
356    }
357    if LineFeatures::check_closing(line) {
358        return 5;
359    }
360    4 // logic
361}