1use std::path::Path;
10
11#[cfg(feature = "neural")]
12use rten::Model;
13
14pub struct NeuralLineScorer {
15 #[cfg(feature = "neural")]
16 model: Model,
17 #[cfg(not(feature = "neural"))]
18 _phantom: (),
19}
20
21#[derive(Debug, Clone)]
22pub struct LineFeatures {
23 pub line_length: f64,
24 pub indentation_level: f64,
25 pub token_diversity: f64,
26 pub is_definition: f64,
27 pub is_import: f64,
28 pub is_comment: f64,
29 pub is_closing: f64,
30 pub keyword_density: f64,
31 pub position_normalized: f64,
32 pub has_type_annotation: f64,
33 pub nesting_depth: f64,
34 pub prev_line_type: f64,
35 pub next_line_type: f64,
36}
37
38impl LineFeatures {
39 pub fn from_line(line: &str, position: f64, context: &LineContext) -> Self {
40 let trimmed = line.trim();
41 let leading = (line.len() - line.trim_start().len()) as f64;
42
43 Self {
44 line_length: trimmed.len() as f64,
45 indentation_level: leading / 4.0,
46 token_diversity: Self::compute_token_diversity(trimmed),
47 is_definition: if Self::check_definition(trimmed) {
48 1.0
49 } else {
50 0.0
51 },
52 is_import: if Self::check_import(trimmed) {
53 1.0
54 } else {
55 0.0
56 },
57 is_comment: if Self::check_comment(trimmed) {
58 1.0
59 } else {
60 0.0
61 },
62 is_closing: if Self::check_closing(trimmed) {
63 1.0
64 } else {
65 0.0
66 },
67 keyword_density: Self::compute_keyword_density(trimmed),
68 position_normalized: position,
69 has_type_annotation: if Self::check_type_annotation(trimmed) {
70 1.0
71 } else {
72 0.0
73 },
74 nesting_depth: context.nesting_depth as f64,
75 prev_line_type: context.prev_line_type as f64,
76 next_line_type: context.next_line_type as f64,
77 }
78 }
79
80 pub fn to_array(&self) -> [f64; 13] {
81 [
82 self.line_length,
83 self.indentation_level,
84 self.token_diversity,
85 self.is_definition,
86 self.is_import,
87 self.is_comment,
88 self.is_closing,
89 self.keyword_density,
90 self.position_normalized,
91 self.has_type_annotation,
92 self.nesting_depth,
93 self.prev_line_type,
94 self.next_line_type,
95 ]
96 }
97
98 fn compute_token_diversity(line: &str) -> f64 {
99 let tokens: Vec<&str> = line.split_whitespace().collect();
100 if tokens.is_empty() {
101 return 0.0;
102 }
103 let unique: std::collections::HashSet<&str> = tokens.iter().copied().collect();
104 unique.len() as f64 / tokens.len() as f64
105 }
106
107 fn check_definition(line: &str) -> bool {
108 const STARTERS: &[&str] = &[
109 "fn ",
110 "pub fn ",
111 "async fn ",
112 "pub async fn ",
113 "def ",
114 "async def ",
115 "function ",
116 "export function ",
117 "async function ",
118 "class ",
119 "export class ",
120 "struct ",
121 "pub struct ",
122 "enum ",
123 "pub enum ",
124 "trait ",
125 "pub trait ",
126 "impl ",
127 "type ",
128 "pub type ",
129 "interface ",
130 "export interface ",
131 ];
132 STARTERS.iter().any(|s| line.starts_with(s))
133 }
134
135 fn check_import(line: &str) -> bool {
136 line.starts_with("import ")
137 || line.starts_with("use ")
138 || line.starts_with("from ")
139 || line.starts_with("#include")
140 || line.starts_with("require(")
141 }
142
143 fn check_comment(line: &str) -> bool {
144 line.starts_with("//")
145 || line.starts_with('#')
146 || line.starts_with("/*")
147 || line.starts_with('*')
148 || line.starts_with("///")
149 }
150
151 fn check_closing(line: &str) -> bool {
152 matches!(line, "}" | "};" | "})" | "]" | ");" | "end")
153 }
154
155 fn check_type_annotation(line: &str) -> bool {
156 line.contains("->")
157 || line.contains("=>")
158 || line.contains(": ")
159 || line.contains("Result<")
160 || line.contains("Option<")
161 }
162
163 fn compute_keyword_density(line: &str) -> f64 {
164 const KEYWORDS: &[&str] = &[
165 "fn",
166 "let",
167 "mut",
168 "pub",
169 "use",
170 "impl",
171 "struct",
172 "enum",
173 "match",
174 "if",
175 "else",
176 "for",
177 "while",
178 "return",
179 "async",
180 "await",
181 "trait",
182 "where",
183 "def",
184 "class",
185 "import",
186 "from",
187 "function",
188 "export",
189 "const",
190 "var",
191 "type",
192 "interface",
193 "try",
194 "catch",
195 "throw",
196 "yield",
197 "raise",
198 ];
199 let tokens: Vec<&str> = line.split_whitespace().collect();
200 if tokens.is_empty() {
201 return 0.0;
202 }
203 let hits = tokens
204 .iter()
205 .filter(|t| {
206 let clean = t.trim_end_matches(|c: char| !c.is_alphanumeric());
207 KEYWORDS.contains(&clean)
208 })
209 .count();
210 hits as f64 / tokens.len() as f64
211 }
212}
213
214#[derive(Debug, Clone, Default)]
215pub struct LineContext {
216 pub nesting_depth: usize,
217 pub prev_line_type: u8,
218 pub next_line_type: u8,
219}
220
221impl NeuralLineScorer {
222 #[cfg(feature = "neural")]
223 pub fn load(model_path: &Path) -> anyhow::Result<Self> {
224 let model = Model::load_file(model_path)?;
225 Ok(Self { model })
226 }
227
228 #[cfg(not(feature = "neural"))]
229 pub fn load(_model_path: &Path) -> anyhow::Result<Self> {
230 anyhow::bail!("Neural feature not enabled. Compile with --features neural")
231 }
232
233 pub fn score_line(&self, line: &str, position: f64, task_keywords: &[String]) -> f64 {
234 let context = LineContext::default();
235 let features = LineFeatures::from_line(line, position, &context);
236 self.score_from_features(&features, task_keywords)
237 }
238
239 pub fn score_from_features(&self, features: &LineFeatures, _task_keywords: &[String]) -> f64 {
240 #[cfg(feature = "neural")]
241 {
242 self.neural_score(features)
243 }
244 #[cfg(not(feature = "neural"))]
245 {
246 self.decision_tree_score(features)
247 }
248 }
249
250 #[cfg(feature = "neural")]
251 fn neural_score(&self, features: &LineFeatures) -> f64 {
252 use rten_tensor::{AsView, NdTensor};
253
254 let input_data = features.to_array();
255 let float_data: Vec<f32> = input_data.iter().map(|&x| x as f32).collect();
256 let input = NdTensor::from_data([1, 13], float_data);
257
258 match self.model.run_one(input.into(), None) {
259 Ok(output) => {
260 let tensor: Vec<f32> = output
261 .into_tensor::<f32>()
262 .map(|t| t.to_vec())
263 .unwrap_or_default();
264 tensor.first().copied().unwrap_or(0.5) as f64
265 }
266 Err(_) => 0.5,
267 }
268 }
269
270 #[cfg(not(feature = "neural"))]
271 #[allow(clippy::unused_self)]
272 fn decision_tree_score(&self, features: &LineFeatures) -> f64 {
273 let f = features.to_array();
274
275 let mut score = 0.5;
276
277 if f[3] > 0.5 {
278 score += 0.3; }
280 if f[5] > 0.5 {
281 score -= 0.2; }
283 if f[6] > 0.5 {
284 score -= 0.3; }
286 if f[4] > 0.5 {
287 score -= 0.1; }
289 if f[9] > 0.5 {
290 score += 0.15; }
292
293 let pos = f[8];
294 let u_curve = if pos <= 0.5 {
295 1.0 - 0.6 * (2.0 * pos).powi(2)
296 } else {
297 1.0 - 0.6 * (2.0 * (1.0 - pos)).powi(2)
298 };
299 score *= u_curve;
300
301 score.clamp(0.0, 1.0)
302 }
303}
304
305pub fn score_all_lines(
306 lines: &[&str],
307 scorer: &NeuralLineScorer,
308 task_keywords: &[String],
309) -> Vec<f64> {
310 let n = lines.len();
311 let mut nesting_depth: usize = 0;
312
313 lines
314 .iter()
315 .enumerate()
316 .map(|(i, line)| {
317 let trimmed = line.trim();
318 nesting_depth = nesting_depth
319 .saturating_add(trimmed.matches('{').count())
320 .saturating_sub(trimmed.matches('}').count());
321
322 let prev_type = if i > 0 {
323 classify_type(lines[i - 1].trim())
324 } else {
325 0
326 };
327 let next_type = if i + 1 < n {
328 classify_type(lines[i + 1].trim())
329 } else {
330 0
331 };
332 let position = i as f64 / (n.max(1) - 1).max(1) as f64;
333
334 let context = LineContext {
335 nesting_depth,
336 prev_line_type: prev_type,
337 next_line_type: next_type,
338 };
339 let features = LineFeatures::from_line(line, position, &context);
340 scorer.score_from_features(&features, task_keywords)
341 })
342 .collect()
343}
344
345fn classify_type(line: &str) -> u8 {
346 if line.is_empty() {
347 return 0;
348 }
349 if LineFeatures::check_definition(line) {
350 return 1;
351 }
352 if LineFeatures::check_import(line) {
353 return 2;
354 }
355 if LineFeatures::check_comment(line) {
356 return 3;
357 }
358 if LineFeatures::check_closing(line) {
359 return 5;
360 }
361 4 }