Skip to main content

lean_ctx/core/neural/
line_scorer.rs

1//! Neural line importance scorer using ONNX inference via rten.
2//!
3//! Replaces the heuristic IB-Filter with a trained model that predicts
4//! per-line importance based on structural features.
5//!
6//! When no ONNX model is available, falls back to the decision-tree
7//! implementation (static rules generated by distill.py).
8
9use std::path::Path;
10
11#[cfg(feature = "neural")]
12use rten::Model;
13
14pub struct NeuralLineScorer {
15    #[cfg(feature = "neural")]
16    model: Model,
17    #[cfg(not(feature = "neural"))]
18    _phantom: (),
19}
20
21#[derive(Debug, Clone)]
22pub struct LineFeatures {
23    pub line_length: f64,
24    pub indentation_level: f64,
25    pub token_diversity: f64,
26    pub is_definition: f64,
27    pub is_import: f64,
28    pub is_comment: f64,
29    pub is_closing: f64,
30    pub keyword_density: f64,
31    pub position_normalized: f64,
32    pub has_type_annotation: f64,
33    pub nesting_depth: f64,
34    pub prev_line_type: f64,
35    pub next_line_type: f64,
36}
37
38impl LineFeatures {
39    pub fn from_line(line: &str, position: f64, context: &LineContext) -> Self {
40        let trimmed = line.trim();
41        let leading = (line.len() - line.trim_start().len()) as f64;
42
43        Self {
44            line_length: trimmed.len() as f64,
45            indentation_level: leading / 4.0,
46            token_diversity: Self::compute_token_diversity(trimmed),
47            is_definition: if Self::check_definition(trimmed) {
48                1.0
49            } else {
50                0.0
51            },
52            is_import: if Self::check_import(trimmed) {
53                1.0
54            } else {
55                0.0
56            },
57            is_comment: if Self::check_comment(trimmed) {
58                1.0
59            } else {
60                0.0
61            },
62            is_closing: if Self::check_closing(trimmed) {
63                1.0
64            } else {
65                0.0
66            },
67            keyword_density: Self::compute_keyword_density(trimmed),
68            position_normalized: position,
69            has_type_annotation: if Self::check_type_annotation(trimmed) {
70                1.0
71            } else {
72                0.0
73            },
74            nesting_depth: context.nesting_depth as f64,
75            prev_line_type: context.prev_line_type as f64,
76            next_line_type: context.next_line_type as f64,
77        }
78    }
79
80    pub fn to_array(&self) -> [f64; 13] {
81        [
82            self.line_length,
83            self.indentation_level,
84            self.token_diversity,
85            self.is_definition,
86            self.is_import,
87            self.is_comment,
88            self.is_closing,
89            self.keyword_density,
90            self.position_normalized,
91            self.has_type_annotation,
92            self.nesting_depth,
93            self.prev_line_type,
94            self.next_line_type,
95        ]
96    }
97
98    fn compute_token_diversity(line: &str) -> f64 {
99        let tokens: Vec<&str> = line.split_whitespace().collect();
100        if tokens.is_empty() {
101            return 0.0;
102        }
103        let unique: std::collections::HashSet<&str> = tokens.iter().copied().collect();
104        unique.len() as f64 / tokens.len() as f64
105    }
106
107    fn check_definition(line: &str) -> bool {
108        const STARTERS: &[&str] = &[
109            "fn ",
110            "pub fn ",
111            "async fn ",
112            "pub async fn ",
113            "def ",
114            "async def ",
115            "function ",
116            "export function ",
117            "async function ",
118            "class ",
119            "export class ",
120            "struct ",
121            "pub struct ",
122            "enum ",
123            "pub enum ",
124            "trait ",
125            "pub trait ",
126            "impl ",
127            "type ",
128            "pub type ",
129            "interface ",
130            "export interface ",
131        ];
132        STARTERS.iter().any(|s| line.starts_with(s))
133    }
134
135    fn check_import(line: &str) -> bool {
136        line.starts_with("import ")
137            || line.starts_with("use ")
138            || line.starts_with("from ")
139            || line.starts_with("#include")
140            || line.starts_with("require(")
141    }
142
143    fn check_comment(line: &str) -> bool {
144        line.starts_with("//")
145            || line.starts_with('#')
146            || line.starts_with("/*")
147            || line.starts_with('*')
148            || line.starts_with("///")
149    }
150
151    fn check_closing(line: &str) -> bool {
152        matches!(line, "}" | "};" | "})" | "]" | ");" | "end")
153    }
154
155    fn check_type_annotation(line: &str) -> bool {
156        line.contains("->")
157            || line.contains("=>")
158            || line.contains(": ")
159            || line.contains("Result<")
160            || line.contains("Option<")
161    }
162
163    fn compute_keyword_density(line: &str) -> f64 {
164        const KEYWORDS: &[&str] = &[
165            "fn",
166            "let",
167            "mut",
168            "pub",
169            "use",
170            "impl",
171            "struct",
172            "enum",
173            "match",
174            "if",
175            "else",
176            "for",
177            "while",
178            "return",
179            "async",
180            "await",
181            "trait",
182            "where",
183            "def",
184            "class",
185            "import",
186            "from",
187            "function",
188            "export",
189            "const",
190            "var",
191            "type",
192            "interface",
193            "try",
194            "catch",
195            "throw",
196            "yield",
197            "raise",
198        ];
199        let tokens: Vec<&str> = line.split_whitespace().collect();
200        if tokens.is_empty() {
201            return 0.0;
202        }
203        let hits = tokens
204            .iter()
205            .filter(|t| {
206                let clean = t.trim_end_matches(|c: char| !c.is_alphanumeric());
207                KEYWORDS.contains(&clean)
208            })
209            .count();
210        hits as f64 / tokens.len() as f64
211    }
212}
213
214#[derive(Debug, Clone, Default)]
215pub struct LineContext {
216    pub nesting_depth: usize,
217    pub prev_line_type: u8,
218    pub next_line_type: u8,
219}
220
221impl NeuralLineScorer {
222    #[cfg(feature = "neural")]
223    pub fn load(model_path: &Path) -> anyhow::Result<Self> {
224        let model = Model::load_file(model_path)?;
225        Ok(Self { model })
226    }
227
228    #[cfg(not(feature = "neural"))]
229    pub fn load(_model_path: &Path) -> anyhow::Result<Self> {
230        anyhow::bail!("Neural feature not enabled. Compile with --features neural")
231    }
232
233    pub fn score_line(&self, line: &str, position: f64, task_keywords: &[String]) -> f64 {
234        let context = LineContext::default();
235        let features = LineFeatures::from_line(line, position, &context);
236        self.score_from_features(&features, task_keywords)
237    }
238
239    pub fn score_from_features(&self, features: &LineFeatures, _task_keywords: &[String]) -> f64 {
240        #[cfg(feature = "neural")]
241        {
242            self.neural_score(features)
243        }
244        #[cfg(not(feature = "neural"))]
245        {
246            self.decision_tree_score(features)
247        }
248    }
249
250    #[cfg(feature = "neural")]
251    fn neural_score(&self, features: &LineFeatures) -> f64 {
252        use rten_tensor::{AsView, NdTensor};
253
254        let input_data = features.to_array();
255        let float_data: Vec<f32> = input_data.iter().map(|&x| x as f32).collect();
256        let input = NdTensor::from_data([1, 13], float_data);
257
258        match self.model.run_one(input.into(), None) {
259            Ok(output) => {
260                let tensor: Vec<f32> = output
261                    .into_tensor::<f32>()
262                    .map(|t| t.to_vec())
263                    .unwrap_or_default();
264                tensor.first().copied().unwrap_or(0.5) as f64
265            }
266            Err(_) => 0.5,
267        }
268    }
269
270    #[cfg(not(feature = "neural"))]
271    #[allow(clippy::unused_self)]
272    fn decision_tree_score(&self, features: &LineFeatures) -> f64 {
273        let f = features.to_array();
274
275        let mut score = 0.5;
276
277        if f[3] > 0.5 {
278            score += 0.3; // is_definition
279        }
280        if f[5] > 0.5 {
281            score -= 0.2; // is_comment
282        }
283        if f[6] > 0.5 {
284            score -= 0.3; // is_closing
285        }
286        if f[4] > 0.5 {
287            score -= 0.1; // is_import
288        }
289        if f[9] > 0.5 {
290            score += 0.15; // has_type_annotation
291        }
292
293        let pos = f[8];
294        let u_curve = if pos <= 0.5 {
295            1.0 - 0.6 * (2.0 * pos).powi(2)
296        } else {
297            1.0 - 0.6 * (2.0 * (1.0 - pos)).powi(2)
298        };
299        score *= u_curve;
300
301        score.clamp(0.0, 1.0)
302    }
303}
304
305pub fn score_all_lines(
306    lines: &[&str],
307    scorer: &NeuralLineScorer,
308    task_keywords: &[String],
309) -> Vec<f64> {
310    let n = lines.len();
311    let mut nesting_depth: usize = 0;
312
313    lines
314        .iter()
315        .enumerate()
316        .map(|(i, line)| {
317            let trimmed = line.trim();
318            nesting_depth = nesting_depth
319                .saturating_add(trimmed.matches('{').count())
320                .saturating_sub(trimmed.matches('}').count());
321
322            let prev_type = if i > 0 {
323                classify_type(lines[i - 1].trim())
324            } else {
325                0
326            };
327            let next_type = if i + 1 < n {
328                classify_type(lines[i + 1].trim())
329            } else {
330                0
331            };
332            let position = i as f64 / (n.max(1) - 1).max(1) as f64;
333
334            let context = LineContext {
335                nesting_depth,
336                prev_line_type: prev_type,
337                next_line_type: next_type,
338            };
339            let features = LineFeatures::from_line(line, position, &context);
340            scorer.score_from_features(&features, task_keywords)
341        })
342        .collect()
343}
344
345fn classify_type(line: &str) -> u8 {
346    if line.is_empty() {
347        return 0;
348    }
349    if LineFeatures::check_definition(line) {
350        return 1;
351    }
352    if LineFeatures::check_import(line) {
353        return 2;
354    }
355    if LineFeatures::check_comment(line) {
356        return 3;
357    }
358    if LineFeatures::check_closing(line) {
359        return 5;
360    }
361    4 // logic
362}