lean_ctx/core/neural/
context_reorder.rs1use super::attention_learned::LearnedAttention;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum LineCategory {
5 ErrorHandling,
6 Import,
7 TypeDefinition,
8 FunctionSignature,
9 Logic,
10 ClosingBrace,
11 Empty,
12}
13
14pub struct CategorizedLine<'a> {
15 pub line: &'a str,
16 pub category: LineCategory,
17 pub original_index: usize,
18}
19
20pub fn categorize_line(line: &str) -> LineCategory {
21 let trimmed = line.trim();
22 if trimmed.is_empty() {
23 return LineCategory::Empty;
24 }
25
26 if is_error_handling(trimmed) {
27 return LineCategory::ErrorHandling;
28 }
29
30 if is_import(trimmed) {
31 return LineCategory::Import;
32 }
33
34 if is_type_def(trimmed) {
35 return LineCategory::TypeDefinition;
36 }
37
38 if is_fn_signature(trimmed) {
39 return LineCategory::FunctionSignature;
40 }
41
42 if is_closing(trimmed) {
43 return LineCategory::ClosingBrace;
44 }
45
46 LineCategory::Logic
47}
48
49pub fn reorder_for_lcurve(content: &str, task_keywords: &[String]) -> String {
50 let lines: Vec<&str> = content.lines().collect();
51 if lines.len() <= 5 {
52 return content.to_string();
53 }
54
55 let categorized: Vec<CategorizedLine> = lines
56 .iter()
57 .enumerate()
58 .map(|(i, line)| CategorizedLine {
59 line,
60 category: categorize_line(line),
61 original_index: i,
62 })
63 .collect();
64
65 let attention = LearnedAttention::with_defaults();
66 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
67
68 let mut scored: Vec<(&CategorizedLine, f64)> = categorized
69 .iter()
70 .map(|cl| {
71 let base = category_priority(cl.category);
72 let kw_boost: f64 = if !kw_lower.is_empty() {
73 let line_lower = cl.line.to_lowercase();
74 kw_lower
75 .iter()
76 .filter(|kw| line_lower.contains(kw.as_str()))
77 .count() as f64
78 * 0.5
79 } else {
80 0.0
81 };
82 let n = lines.len().max(1) as f64;
83 let orig_pos = cl.original_index as f64 / n;
84 let orig_attention = attention.weight(orig_pos);
85 let score = base + kw_boost + orig_attention * 0.1;
86 (cl, score)
87 })
88 .collect();
89
90 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
91
92 scored
93 .iter()
94 .filter(|(cl, _)| cl.category != LineCategory::Empty || cl.original_index == 0)
95 .map(|(cl, _)| cl.line)
96 .collect::<Vec<_>>()
97 .join("\n")
98}
99
100fn category_priority(cat: LineCategory) -> f64 {
101 match cat {
102 LineCategory::ErrorHandling => 5.0,
103 LineCategory::Import => 4.0,
104 LineCategory::TypeDefinition => 3.5,
105 LineCategory::FunctionSignature => 3.0,
106 LineCategory::Logic => 1.0,
107 LineCategory::ClosingBrace => 0.2,
108 LineCategory::Empty => 0.1,
109 }
110}
111
112fn is_error_handling(line: &str) -> bool {
113 line.starts_with("return Err(")
114 || line.starts_with("Err(")
115 || line.starts_with("bail!(")
116 || line.contains(".map_err(")
117 || line.starts_with("raise ")
118 || line.starts_with("throw ")
119 || line.starts_with("catch ")
120 || line.starts_with("except ")
121 || line.starts_with("panic!(")
122 || line.contains("Error::")
123}
124
125fn is_import(line: &str) -> bool {
126 line.starts_with("use ")
127 || line.starts_with("import ")
128 || line.starts_with("from ")
129 || line.starts_with("#include")
130 || line.starts_with("require(")
131 || line.starts_with("const ") && line.contains("require(")
132}
133
134fn is_type_def(line: &str) -> bool {
135 let starters = [
136 "struct ",
137 "pub struct ",
138 "enum ",
139 "pub enum ",
140 "trait ",
141 "pub trait ",
142 "type ",
143 "pub type ",
144 "interface ",
145 "export interface ",
146 "class ",
147 "export class ",
148 "typedef ",
149 "data ",
150 ];
151 starters.iter().any(|s| line.starts_with(s))
152}
153
154fn is_fn_signature(line: &str) -> bool {
155 let starters = [
156 "fn ",
157 "pub fn ",
158 "async fn ",
159 "pub async fn ",
160 "function ",
161 "export function ",
162 "async function ",
163 "def ",
164 "async def ",
165 "func ",
166 "pub(crate) fn ",
167 "pub(super) fn ",
168 ];
169 starters.iter().any(|s| line.starts_with(s))
170}
171
172fn is_closing(line: &str) -> bool {
173 matches!(line, "}" | "};" | ");" | "});" | ")" | "})")
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn categorize_lines_correctly() {
182 assert_eq!(categorize_line("use std::io;"), LineCategory::Import);
183 assert_eq!(
184 categorize_line("pub struct Foo {"),
185 LineCategory::TypeDefinition
186 );
187 assert_eq!(
188 categorize_line("fn main() {"),
189 LineCategory::FunctionSignature
190 );
191 assert_eq!(
192 categorize_line("return Err(e);"),
193 LineCategory::ErrorHandling
194 );
195 assert_eq!(categorize_line("}"), LineCategory::ClosingBrace);
196 assert_eq!(categorize_line("let x = 1;"), LineCategory::Logic);
197 assert_eq!(categorize_line(""), LineCategory::Empty);
198 }
199
200 #[test]
201 fn reorder_puts_errors_and_imports_first() {
202 let content = "let x = 1;\nuse std::io;\n}\nreturn Err(e);\npub struct Foo {\nfn main() {";
203 let result = reorder_for_lcurve(content, &[]);
204 let lines: Vec<&str> = result.lines().collect();
205 assert!(
206 lines[0].contains("Err") || lines[0].contains("use "),
207 "first line should be error handling or import, got: {}",
208 lines[0]
209 );
210 }
211
212 #[test]
213 fn task_keywords_boost_relevant_lines() {
214 let content = "fn unrelated() {\nlet x = 1;\n}\nfn validate_token() {\nlet y = 2;\n}";
215 let result = reorder_for_lcurve(content, &["validate".to_string()]);
216 let lines: Vec<&str> = result.lines().collect();
217 let validate_pos = lines.iter().position(|l| l.contains("validate"));
218 let unrelated_pos = lines.iter().position(|l| l.contains("unrelated"));
219 if let (Some(v), Some(u)) = (validate_pos, unrelated_pos) {
220 assert!(v < u, "validate should appear before unrelated");
221 }
222 }
223}