lean_ctx/core/neural/
context_reorder.rs1use super::attention_learned::LearnedAttention;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum LineCategory {
5 ErrorHandling,
6 Import,
7 TypeDefinition,
8 FunctionSignature,
9 Logic,
10 ClosingBrace,
11 Empty,
12}
13
14pub struct CategorizedLine<'a> {
15 pub line: &'a str,
16 pub category: LineCategory,
17 pub original_index: usize,
18}
19
20pub fn categorize_line(line: &str) -> LineCategory {
21 let trimmed = line.trim();
22 if trimmed.is_empty() {
23 return LineCategory::Empty;
24 }
25
26 if is_error_handling(trimmed) {
27 return LineCategory::ErrorHandling;
28 }
29
30 if is_import(trimmed) {
31 return LineCategory::Import;
32 }
33
34 if is_type_def(trimmed) {
35 return LineCategory::TypeDefinition;
36 }
37
38 if is_fn_signature(trimmed) {
39 return LineCategory::FunctionSignature;
40 }
41
42 if is_closing(trimmed) {
43 return LineCategory::ClosingBrace;
44 }
45
46 LineCategory::Logic
47}
48
49pub fn reorder_for_lcurve(content: &str, task_keywords: &[String]) -> String {
50 let lines: Vec<&str> = content.lines().collect();
51 if lines.len() <= 5 {
52 return content.to_string();
53 }
54
55 if lines.len() >= 15 {
57 let chunks = crate::core::semantic_chunks::detect_chunks(content);
58 if chunks.len() >= 3 {
59 let ordered = crate::core::semantic_chunks::order_for_attention(chunks, task_keywords);
60 return crate::core::semantic_chunks::render_with_bridges(&ordered);
61 }
62 }
63
64 let categorized: Vec<CategorizedLine> = lines
66 .iter()
67 .enumerate()
68 .map(|(i, line)| CategorizedLine {
69 line,
70 category: categorize_line(line),
71 original_index: i,
72 })
73 .collect();
74
75 let attention = LearnedAttention::with_defaults();
76 let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
77
78 let mut scored: Vec<(&CategorizedLine, f64)> = categorized
79 .iter()
80 .map(|cl| {
81 let base = category_priority(cl.category);
82 let kw_boost: f64 = if !kw_lower.is_empty() {
83 let line_lower = cl.line.to_lowercase();
84 kw_lower
85 .iter()
86 .filter(|kw| line_lower.contains(kw.as_str()))
87 .count() as f64
88 * 0.5
89 } else {
90 0.0
91 };
92 let n = lines.len().max(1) as f64;
93 let orig_pos = cl.original_index as f64 / n;
94 let orig_attention = attention.weight(orig_pos);
95 let score = base + kw_boost + orig_attention * 0.1;
96 (cl, score)
97 })
98 .collect();
99
100 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
101
102 scored
103 .iter()
104 .filter(|(cl, _)| cl.category != LineCategory::Empty || cl.original_index == 0)
105 .map(|(cl, _)| cl.line)
106 .collect::<Vec<_>>()
107 .join("\n")
108}
109
110fn category_priority(cat: LineCategory) -> f64 {
111 match cat {
112 LineCategory::ErrorHandling => 5.0,
113 LineCategory::Import => 4.0,
114 LineCategory::TypeDefinition => 3.5,
115 LineCategory::FunctionSignature => 3.0,
116 LineCategory::Logic => 1.0,
117 LineCategory::ClosingBrace => 0.2,
118 LineCategory::Empty => 0.1,
119 }
120}
121
122fn is_error_handling(line: &str) -> bool {
123 line.starts_with("return Err(")
124 || line.starts_with("Err(")
125 || line.starts_with("bail!(")
126 || line.contains(".map_err(")
127 || line.starts_with("raise ")
128 || line.starts_with("throw ")
129 || line.starts_with("catch ")
130 || line.starts_with("except ")
131 || line.starts_with("panic!(")
132 || line.contains("Error::")
133}
134
135fn is_import(line: &str) -> bool {
136 line.starts_with("use ")
137 || line.starts_with("import ")
138 || line.starts_with("from ")
139 || line.starts_with("#include")
140 || line.starts_with("require(")
141 || line.starts_with("const ") && line.contains("require(")
142}
143
144fn is_type_def(line: &str) -> bool {
145 let starters = [
146 "struct ",
147 "pub struct ",
148 "enum ",
149 "pub enum ",
150 "trait ",
151 "pub trait ",
152 "type ",
153 "pub type ",
154 "interface ",
155 "export interface ",
156 "class ",
157 "export class ",
158 "typedef ",
159 "data ",
160 ];
161 starters.iter().any(|s| line.starts_with(s))
162}
163
164fn is_fn_signature(line: &str) -> bool {
165 let starters = [
166 "fn ",
167 "pub fn ",
168 "async fn ",
169 "pub async fn ",
170 "function ",
171 "export function ",
172 "async function ",
173 "def ",
174 "async def ",
175 "func ",
176 "pub(crate) fn ",
177 "pub(super) fn ",
178 ];
179 starters.iter().any(|s| line.starts_with(s))
180}
181
182fn is_closing(line: &str) -> bool {
183 matches!(line, "}" | "};" | ");" | "});" | ")" | "})")
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn categorize_lines_correctly() {
192 assert_eq!(categorize_line("use std::io;"), LineCategory::Import);
193 assert_eq!(
194 categorize_line("pub struct Foo {"),
195 LineCategory::TypeDefinition
196 );
197 assert_eq!(
198 categorize_line("fn main() {"),
199 LineCategory::FunctionSignature
200 );
201 assert_eq!(
202 categorize_line("return Err(e);"),
203 LineCategory::ErrorHandling
204 );
205 assert_eq!(categorize_line("}"), LineCategory::ClosingBrace);
206 assert_eq!(categorize_line("let x = 1;"), LineCategory::Logic);
207 assert_eq!(categorize_line(""), LineCategory::Empty);
208 }
209
210 #[test]
211 fn reorder_puts_errors_and_imports_first() {
212 let content = "let x = 1;\nuse std::io;\n}\nreturn Err(e);\npub struct Foo {\nfn main() {";
213 let result = reorder_for_lcurve(content, &[]);
214 let lines: Vec<&str> = result.lines().collect();
215 assert!(
216 lines[0].contains("Err") || lines[0].contains("use "),
217 "first line should be error handling or import, got: {}",
218 lines[0]
219 );
220 }
221
222 #[test]
223 fn task_keywords_boost_relevant_lines() {
224 let content = "fn unrelated() {\nlet x = 1;\n}\nfn validate_token() {\nlet y = 2;\n}";
225 let result = reorder_for_lcurve(content, &["validate".to_string()]);
226 let lines: Vec<&str> = result.lines().collect();
227 let validate_pos = lines.iter().position(|l| l.contains("validate"));
228 let unrelated_pos = lines.iter().position(|l| l.contains("unrelated"));
229 if let (Some(v), Some(u)) = (validate_pos, unrelated_pos) {
230 assert!(v < u, "validate should appear before unrelated");
231 }
232 }
233}