lean_ctx/core/neural/
context_reorder.rs1use super::attention_learned::LearnedAttention;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum LineCategory {
5 ErrorHandling,
6 Import,
7 TypeDefinition,
8 FunctionSignature,
9 Logic,
10 ClosingBrace,
11 Empty,
12}
13
14pub struct CategorizedLine<'a> {
15 pub line: &'a str,
16 pub category: LineCategory,
17 pub original_index: usize,
18}
19
20pub fn categorize_line(line: &str) -> LineCategory {
21 let trimmed = line.trim();
22 if trimmed.is_empty() {
23 return LineCategory::Empty;
24 }
25
26 if is_error_handling(trimmed) {
27 return LineCategory::ErrorHandling;
28 }
29
30 if is_import(trimmed) {
31 return LineCategory::Import;
32 }
33
34 if is_type_def(trimmed) {
35 return LineCategory::TypeDefinition;
36 }
37
38 if is_fn_signature(trimmed) {
39 return LineCategory::FunctionSignature;
40 }
41
42 if is_closing(trimmed) {
43 return LineCategory::ClosingBrace;
44 }
45
46 LineCategory::Logic
47}
48
49pub fn reorder_for_lcurve(content: &str, task_keywords: &[String]) -> String {
50 let lines: Vec<&str> = content.lines().collect();
51 if lines.len() <= 5 {
52 return content.to_string();
53 }
54
55 if lines.len() >= 15 {
57 let chunks = crate::core::semantic_chunks::detect_chunks(content);
58 if chunks.len() >= 3 {
59 let ordered = crate::core::semantic_chunks::order_for_attention(chunks, task_keywords);
60 return crate::core::semantic_chunks::render_with_bridges(&ordered);
61 }
62 }
63
64 let categorized: Vec<CategorizedLine> = lines
66 .iter()
67 .enumerate()
68 .map(|(i, line)| CategorizedLine {
69 line,
70 category: categorize_line(line),
71 original_index: i,
72 })
73 .collect();
74
75 let attention = LearnedAttention::with_defaults();
76 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
77
78 let mut scored: Vec<(&CategorizedLine, f64)> = categorized
79 .iter()
80 .map(|cl| {
81 let base = category_priority(cl.category);
82 let kw_boost: f64 = if kw_lower.is_empty() {
83 0.0
84 } else {
85 let line_lower = cl.line.to_lowercase();
86 kw_lower
87 .iter()
88 .filter(|kw| line_lower.contains(kw.as_str()))
89 .count() as f64
90 * 0.5
91 };
92 let n = lines.len().max(1) as f64;
93 let orig_pos = cl.original_index as f64 / n;
94 let orig_attention = attention.weight(orig_pos);
95 let score = base + kw_boost + orig_attention * 0.1;
96 (cl, score)
97 })
98 .collect();
99
100 scored.sort_by(|a, b| {
101 b.1.partial_cmp(&a.1)
102 .unwrap_or(std::cmp::Ordering::Equal)
103 .then_with(|| a.0.original_index.cmp(&b.0.original_index))
104 });
105
106 scored
107 .iter()
108 .filter(|(cl, _)| cl.category != LineCategory::Empty || cl.original_index == 0)
109 .map(|(cl, _)| cl.line)
110 .collect::<Vec<_>>()
111 .join("\n")
112}
113
114fn category_priority(cat: LineCategory) -> f64 {
115 match cat {
116 LineCategory::ErrorHandling => 5.0,
117 LineCategory::Import => 4.0,
118 LineCategory::TypeDefinition => 3.5,
119 LineCategory::FunctionSignature => 3.0,
120 LineCategory::Logic => 1.0,
121 LineCategory::ClosingBrace => 0.2,
122 LineCategory::Empty => 0.1,
123 }
124}
125
126fn is_error_handling(line: &str) -> bool {
127 line.starts_with("return Err(")
128 || line.starts_with("Err(")
129 || line.starts_with("bail!(")
130 || line.contains(".map_err(")
131 || line.starts_with("raise ")
132 || line.starts_with("throw ")
133 || line.starts_with("catch ")
134 || line.starts_with("except ")
135 || line.starts_with("panic!(")
136 || line.contains("Error::")
137}
138
139fn is_import(line: &str) -> bool {
140 line.starts_with("use ")
141 || line.starts_with("import ")
142 || line.starts_with("from ")
143 || line.starts_with("#include")
144 || line.starts_with("require(")
145 || line.starts_with("const ") && line.contains("require(")
146}
147
148fn is_type_def(line: &str) -> bool {
149 let starters = [
150 "struct ",
151 "pub struct ",
152 "enum ",
153 "pub enum ",
154 "trait ",
155 "pub trait ",
156 "type ",
157 "pub type ",
158 "interface ",
159 "export interface ",
160 "class ",
161 "export class ",
162 "typedef ",
163 "data ",
164 ];
165 starters.iter().any(|s| line.starts_with(s))
166}
167
168fn is_fn_signature(line: &str) -> bool {
169 let starters = [
170 "fn ",
171 "pub fn ",
172 "async fn ",
173 "pub async fn ",
174 "function ",
175 "export function ",
176 "async function ",
177 "def ",
178 "async def ",
179 "func ",
180 "pub(crate) fn ",
181 "pub(super) fn ",
182 ];
183 starters.iter().any(|s| line.starts_with(s))
184}
185
186fn is_closing(line: &str) -> bool {
187 matches!(line, "}" | "};" | ");" | "});" | ")" | "})")
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn categorize_lines_correctly() {
196 assert_eq!(categorize_line("use std::io;"), LineCategory::Import);
197 assert_eq!(
198 categorize_line("pub struct Foo {"),
199 LineCategory::TypeDefinition
200 );
201 assert_eq!(
202 categorize_line("fn main() {"),
203 LineCategory::FunctionSignature
204 );
205 assert_eq!(
206 categorize_line("return Err(e);"),
207 LineCategory::ErrorHandling
208 );
209 assert_eq!(categorize_line("}"), LineCategory::ClosingBrace);
210 assert_eq!(categorize_line("let x = 1;"), LineCategory::Logic);
211 assert_eq!(categorize_line(""), LineCategory::Empty);
212 }
213
214 #[test]
215 fn reorder_puts_errors_and_imports_first() {
216 let content = "let x = 1;\nuse std::io;\n}\nreturn Err(e);\npub struct Foo {\nfn main() {";
217 let result = reorder_for_lcurve(content, &[]);
218 let lines: Vec<&str> = result.lines().collect();
219 assert!(
220 lines[0].contains("Err") || lines[0].contains("use "),
221 "first line should be error handling or import, got: {}",
222 lines[0]
223 );
224 }
225
226 #[test]
227 fn task_keywords_boost_relevant_lines() {
228 let content = "fn unrelated() {\nlet x = 1;\n}\nfn validate_token() {\nlet y = 2;\n}";
229 let result = reorder_for_lcurve(content, &["validate".to_string()]);
230 let lines: Vec<&str> = result.lines().collect();
231 let validate_pos = lines.iter().position(|l| l.contains("validate"));
232 let unrelated_pos = lines.iter().position(|l| l.contains("unrelated"));
233 if let (Some(v), Some(u)) = (validate_pos, unrelated_pos) {
234 assert!(v < u, "validate should appear before unrelated");
235 }
236 }
237}