1use std::path::Path;
10
11#[cfg(feature = "neural")]
12use rten::Model;
13
14pub struct NeuralLineScorer {
15 #[cfg(feature = "neural")]
16 model: Model,
17 #[cfg(not(feature = "neural"))]
18 _phantom: (),
19}
20
21#[derive(Debug, Clone)]
22pub struct LineFeatures {
23 pub line_length: f64,
24 pub indentation_level: f64,
25 pub token_diversity: f64,
26 pub is_definition: f64,
27 pub is_import: f64,
28 pub is_comment: f64,
29 pub is_closing: f64,
30 pub keyword_density: f64,
31 pub position_normalized: f64,
32 pub has_type_annotation: f64,
33 pub nesting_depth: f64,
34 pub prev_line_type: f64,
35 pub next_line_type: f64,
36}
37
38impl LineFeatures {
39 pub fn from_line(line: &str, position: f64, context: &LineContext) -> Self {
40 let trimmed = line.trim();
41 let leading = (line.len() - line.trim_start().len()) as f64;
42
43 Self {
44 line_length: trimmed.len() as f64,
45 indentation_level: leading / 4.0,
46 token_diversity: Self::compute_token_diversity(trimmed),
47 is_definition: if Self::check_definition(trimmed) {
48 1.0
49 } else {
50 0.0
51 },
52 is_import: if Self::check_import(trimmed) {
53 1.0
54 } else {
55 0.0
56 },
57 is_comment: if Self::check_comment(trimmed) {
58 1.0
59 } else {
60 0.0
61 },
62 is_closing: if Self::check_closing(trimmed) {
63 1.0
64 } else {
65 0.0
66 },
67 keyword_density: Self::compute_keyword_density(trimmed),
68 position_normalized: position,
69 has_type_annotation: if Self::check_type_annotation(trimmed) {
70 1.0
71 } else {
72 0.0
73 },
74 nesting_depth: context.nesting_depth as f64,
75 prev_line_type: context.prev_line_type as f64,
76 next_line_type: context.next_line_type as f64,
77 }
78 }
79
80 pub fn to_array(&self) -> [f64; 13] {
81 [
82 self.line_length,
83 self.indentation_level,
84 self.token_diversity,
85 self.is_definition,
86 self.is_import,
87 self.is_comment,
88 self.is_closing,
89 self.keyword_density,
90 self.position_normalized,
91 self.has_type_annotation,
92 self.nesting_depth,
93 self.prev_line_type,
94 self.next_line_type,
95 ]
96 }
97
98 fn compute_token_diversity(line: &str) -> f64 {
99 let tokens: Vec<&str> = line.split_whitespace().collect();
100 if tokens.is_empty() {
101 return 0.0;
102 }
103 let unique: std::collections::HashSet<&str> = tokens.iter().copied().collect();
104 unique.len() as f64 / tokens.len() as f64
105 }
106
107 fn check_definition(line: &str) -> bool {
108 const STARTERS: &[&str] = &[
109 "fn ",
110 "pub fn ",
111 "async fn ",
112 "pub async fn ",
113 "def ",
114 "async def ",
115 "function ",
116 "export function ",
117 "async function ",
118 "class ",
119 "export class ",
120 "struct ",
121 "pub struct ",
122 "enum ",
123 "pub enum ",
124 "trait ",
125 "pub trait ",
126 "impl ",
127 "type ",
128 "pub type ",
129 "interface ",
130 "export interface ",
131 ];
132 STARTERS.iter().any(|s| line.starts_with(s))
133 }
134
135 fn check_import(line: &str) -> bool {
136 line.starts_with("import ")
137 || line.starts_with("use ")
138 || line.starts_with("from ")
139 || line.starts_with("#include")
140 || line.starts_with("require(")
141 }
142
143 fn check_comment(line: &str) -> bool {
144 line.starts_with("//")
145 || line.starts_with('#')
146 || line.starts_with("/*")
147 || line.starts_with('*')
148 || line.starts_with("///")
149 }
150
151 fn check_closing(line: &str) -> bool {
152 matches!(line, "}" | "};" | "})" | "]" | ");" | "end")
153 }
154
155 fn check_type_annotation(line: &str) -> bool {
156 line.contains("->")
157 || line.contains("=>")
158 || line.contains(": ")
159 || line.contains("Result<")
160 || line.contains("Option<")
161 }
162
163 fn compute_keyword_density(line: &str) -> f64 {
164 const KEYWORDS: &[&str] = &[
165 "fn",
166 "let",
167 "mut",
168 "pub",
169 "use",
170 "impl",
171 "struct",
172 "enum",
173 "match",
174 "if",
175 "else",
176 "for",
177 "while",
178 "return",
179 "async",
180 "await",
181 "trait",
182 "where",
183 "def",
184 "class",
185 "import",
186 "from",
187 "function",
188 "export",
189 "const",
190 "var",
191 "type",
192 "interface",
193 "try",
194 "catch",
195 "throw",
196 "yield",
197 "raise",
198 ];
199 let tokens: Vec<&str> = line.split_whitespace().collect();
200 if tokens.is_empty() {
201 return 0.0;
202 }
203 let hits = tokens
204 .iter()
205 .filter(|t| {
206 let clean = t.trim_end_matches(|c: char| !c.is_alphanumeric());
207 KEYWORDS.contains(&clean)
208 })
209 .count();
210 hits as f64 / tokens.len() as f64
211 }
212}
213
214#[derive(Debug, Clone, Default)]
215pub struct LineContext {
216 pub nesting_depth: usize,
217 pub prev_line_type: u8,
218 pub next_line_type: u8,
219}
220
221impl NeuralLineScorer {
222 #[cfg(feature = "neural")]
223 pub fn load(model_path: &Path) -> anyhow::Result<Self> {
224 let model = Model::load_file(model_path)?;
225 Ok(Self { model })
226 }
227
228 #[cfg(not(feature = "neural"))]
229 pub fn load(_model_path: &Path) -> anyhow::Result<Self> {
230 anyhow::bail!("Neural feature not enabled. Compile with --features neural")
231 }
232
233 pub fn score_line(&self, line: &str, position: f64, task_keywords: &[String]) -> f64 {
234 let context = LineContext::default();
235 let features = LineFeatures::from_line(line, position, &context);
236 self.score_from_features(&features, task_keywords)
237 }
238
239 pub fn score_from_features(&self, features: &LineFeatures, _task_keywords: &[String]) -> f64 {
240 #[cfg(feature = "neural")]
241 {
242 self.neural_score(features)
243 }
244 #[cfg(not(feature = "neural"))]
245 {
246 self.decision_tree_score(features)
247 }
248 }
249
250 #[cfg(feature = "neural")]
251 fn neural_score(&self, features: &LineFeatures) -> f64 {
252 use rten_tensor::{AsView, NdTensor};
253
254 let input_data = features.to_array();
255 let float_data: Vec<f32> = input_data.iter().map(|&x| x as f32).collect();
256 let input = NdTensor::from_data([1, 13], float_data);
257
258 match self.model.run_one(input.into(), None) {
259 Ok(output) => {
260 let tensor: Vec<f32> = output
261 .into_tensor::<f32>()
262 .map(|t| t.to_vec())
263 .unwrap_or_default();
264 tensor.first().copied().unwrap_or(0.5) as f64
265 }
266 Err(_) => 0.5,
267 }
268 }
269
270 #[cfg(not(feature = "neural"))]
271 fn decision_tree_score(&self, features: &LineFeatures) -> f64 {
272 let f = features.to_array();
273
274 let mut score = 0.5;
275
276 if f[3] > 0.5 {
277 score += 0.3; }
279 if f[5] > 0.5 {
280 score -= 0.2; }
282 if f[6] > 0.5 {
283 score -= 0.3; }
285 if f[4] > 0.5 {
286 score -= 0.1; }
288 if f[9] > 0.5 {
289 score += 0.15; }
291
292 let pos = f[8];
293 let u_curve = if pos <= 0.5 {
294 1.0 - 0.6 * (2.0 * pos).powi(2)
295 } else {
296 1.0 - 0.6 * (2.0 * (1.0 - pos)).powi(2)
297 };
298 score *= u_curve;
299
300 score.clamp(0.0, 1.0)
301 }
302}
303
304pub fn score_all_lines(
305 lines: &[&str],
306 scorer: &NeuralLineScorer,
307 task_keywords: &[String],
308) -> Vec<f64> {
309 let n = lines.len();
310 let mut nesting_depth: usize = 0;
311
312 lines
313 .iter()
314 .enumerate()
315 .map(|(i, line)| {
316 let trimmed = line.trim();
317 nesting_depth = nesting_depth
318 .saturating_add(trimmed.matches('{').count())
319 .saturating_sub(trimmed.matches('}').count());
320
321 let prev_type = if i > 0 {
322 classify_type(lines[i - 1].trim())
323 } else {
324 0
325 };
326 let next_type = if i + 1 < n {
327 classify_type(lines[i + 1].trim())
328 } else {
329 0
330 };
331 let position = i as f64 / (n.max(1) - 1).max(1) as f64;
332
333 let context = LineContext {
334 nesting_depth,
335 prev_line_type: prev_type,
336 next_line_type: next_type,
337 };
338 let features = LineFeatures::from_line(line, position, &context);
339 scorer.score_from_features(&features, task_keywords)
340 })
341 .collect()
342}
343
344fn classify_type(line: &str) -> u8 {
345 if line.is_empty() {
346 return 0;
347 }
348 if LineFeatures::check_definition(line) {
349 return 1;
350 }
351 if LineFeatures::check_import(line) {
352 return 2;
353 }
354 if LineFeatures::check_comment(line) {
355 return 3;
356 }
357 if LineFeatures::check_closing(line) {
358 return 5;
359 }
360 4 }