Skip to main content

lean_ctx/core/neural/
context_reorder.rs

1use super::attention_learned::LearnedAttention;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum LineCategory {
5    ErrorHandling,
6    Import,
7    TypeDefinition,
8    FunctionSignature,
9    Logic,
10    ClosingBrace,
11    Empty,
12}
13
14pub struct CategorizedLine<'a> {
15    pub line: &'a str,
16    pub category: LineCategory,
17    pub original_index: usize,
18}
19
20pub fn categorize_line(line: &str) -> LineCategory {
21    let trimmed = line.trim();
22    if trimmed.is_empty() {
23        return LineCategory::Empty;
24    }
25
26    if is_error_handling(trimmed) {
27        return LineCategory::ErrorHandling;
28    }
29
30    if is_import(trimmed) {
31        return LineCategory::Import;
32    }
33
34    if is_type_def(trimmed) {
35        return LineCategory::TypeDefinition;
36    }
37
38    if is_fn_signature(trimmed) {
39        return LineCategory::FunctionSignature;
40    }
41
42    if is_closing(trimmed) {
43        return LineCategory::ClosingBrace;
44    }
45
46    LineCategory::Logic
47}
48
49pub fn reorder_for_lcurve(content: &str, task_keywords: &[String]) -> String {
50    let lines: Vec<&str> = content.lines().collect();
51    if lines.len() <= 5 {
52        return content.to_string();
53    }
54
55    // Try semantic chunk-based reordering for larger content
56    if lines.len() >= 15 {
57        let chunks = crate::core::semantic_chunks::detect_chunks(content);
58        if chunks.len() >= 3 {
59            let ordered = crate::core::semantic_chunks::order_for_attention(chunks, task_keywords);
60            return crate::core::semantic_chunks::render_with_bridges(&ordered);
61        }
62    }
63
64    // Fall back to line-level reordering for small content
65    let categorized: Vec<CategorizedLine> = lines
66        .iter()
67        .enumerate()
68        .map(|(i, line)| CategorizedLine {
69            line,
70            category: categorize_line(line),
71            original_index: i,
72        })
73        .collect();
74
75    let attention = LearnedAttention::with_defaults();
76    let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
77
78    let mut scored: Vec<(&CategorizedLine, f64)> = categorized
79        .iter()
80        .map(|cl| {
81            let base = category_priority(cl.category);
82            let kw_boost: f64 = if !kw_lower.is_empty() {
83                let line_lower = cl.line.to_lowercase();
84                kw_lower
85                    .iter()
86                    .filter(|kw| line_lower.contains(kw.as_str()))
87                    .count() as f64
88                    * 0.5
89            } else {
90                0.0
91            };
92            let n = lines.len().max(1) as f64;
93            let orig_pos = cl.original_index as f64 / n;
94            let orig_attention = attention.weight(orig_pos);
95            let score = base + kw_boost + orig_attention * 0.1;
96            (cl, score)
97        })
98        .collect();
99
100    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
101
102    scored
103        .iter()
104        .filter(|(cl, _)| cl.category != LineCategory::Empty || cl.original_index == 0)
105        .map(|(cl, _)| cl.line)
106        .collect::<Vec<_>>()
107        .join("\n")
108}
109
110fn category_priority(cat: LineCategory) -> f64 {
111    match cat {
112        LineCategory::ErrorHandling => 5.0,
113        LineCategory::Import => 4.0,
114        LineCategory::TypeDefinition => 3.5,
115        LineCategory::FunctionSignature => 3.0,
116        LineCategory::Logic => 1.0,
117        LineCategory::ClosingBrace => 0.2,
118        LineCategory::Empty => 0.1,
119    }
120}
121
122fn is_error_handling(line: &str) -> bool {
123    line.starts_with("return Err(")
124        || line.starts_with("Err(")
125        || line.starts_with("bail!(")
126        || line.contains(".map_err(")
127        || line.starts_with("raise ")
128        || line.starts_with("throw ")
129        || line.starts_with("catch ")
130        || line.starts_with("except ")
131        || line.starts_with("panic!(")
132        || line.contains("Error::")
133}
134
135fn is_import(line: &str) -> bool {
136    line.starts_with("use ")
137        || line.starts_with("import ")
138        || line.starts_with("from ")
139        || line.starts_with("#include")
140        || line.starts_with("require(")
141        || line.starts_with("const ") && line.contains("require(")
142}
143
144fn is_type_def(line: &str) -> bool {
145    let starters = [
146        "struct ",
147        "pub struct ",
148        "enum ",
149        "pub enum ",
150        "trait ",
151        "pub trait ",
152        "type ",
153        "pub type ",
154        "interface ",
155        "export interface ",
156        "class ",
157        "export class ",
158        "typedef ",
159        "data ",
160    ];
161    starters.iter().any(|s| line.starts_with(s))
162}
163
164fn is_fn_signature(line: &str) -> bool {
165    let starters = [
166        "fn ",
167        "pub fn ",
168        "async fn ",
169        "pub async fn ",
170        "function ",
171        "export function ",
172        "async function ",
173        "def ",
174        "async def ",
175        "func ",
176        "pub(crate) fn ",
177        "pub(super) fn ",
178    ];
179    starters.iter().any(|s| line.starts_with(s))
180}
181
182fn is_closing(line: &str) -> bool {
183    matches!(line, "}" | "};" | ");" | "});" | ")" | "})")
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn categorize_lines_correctly() {
192        assert_eq!(categorize_line("use std::io;"), LineCategory::Import);
193        assert_eq!(
194            categorize_line("pub struct Foo {"),
195            LineCategory::TypeDefinition
196        );
197        assert_eq!(
198            categorize_line("fn main() {"),
199            LineCategory::FunctionSignature
200        );
201        assert_eq!(
202            categorize_line("return Err(e);"),
203            LineCategory::ErrorHandling
204        );
205        assert_eq!(categorize_line("}"), LineCategory::ClosingBrace);
206        assert_eq!(categorize_line("let x = 1;"), LineCategory::Logic);
207        assert_eq!(categorize_line(""), LineCategory::Empty);
208    }
209
210    #[test]
211    fn reorder_puts_errors_and_imports_first() {
212        let content = "let x = 1;\nuse std::io;\n}\nreturn Err(e);\npub struct Foo {\nfn main() {";
213        let result = reorder_for_lcurve(content, &[]);
214        let lines: Vec<&str> = result.lines().collect();
215        assert!(
216            lines[0].contains("Err") || lines[0].contains("use "),
217            "first line should be error handling or import, got: {}",
218            lines[0]
219        );
220    }
221
222    #[test]
223    fn task_keywords_boost_relevant_lines() {
224        let content = "fn unrelated() {\nlet x = 1;\n}\nfn validate_token() {\nlet y = 2;\n}";
225        let result = reorder_for_lcurve(content, &["validate".to_string()]);
226        let lines: Vec<&str> = result.lines().collect();
227        let validate_pos = lines.iter().position(|l| l.contains("validate"));
228        let unrelated_pos = lines.iter().position(|l| l.contains("unrelated"));
229        if let (Some(v), Some(u)) = (validate_pos, unrelated_pos) {
230            assert!(v < u, "validate should appear before unrelated");
231        }
232    }
233}