Skip to main content

lean_ctx/core/neural/
context_reorder.rs

1use super::attention_learned::LearnedAttention;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum LineCategory {
5    ErrorHandling,
6    Import,
7    TypeDefinition,
8    FunctionSignature,
9    Logic,
10    ClosingBrace,
11    Empty,
12}
13
14pub struct CategorizedLine<'a> {
15    pub line: &'a str,
16    pub category: LineCategory,
17    pub original_index: usize,
18}
19
20pub fn categorize_line(line: &str) -> LineCategory {
21    let trimmed = line.trim();
22    if trimmed.is_empty() {
23        return LineCategory::Empty;
24    }
25
26    if is_error_handling(trimmed) {
27        return LineCategory::ErrorHandling;
28    }
29
30    if is_import(trimmed) {
31        return LineCategory::Import;
32    }
33
34    if is_type_def(trimmed) {
35        return LineCategory::TypeDefinition;
36    }
37
38    if is_fn_signature(trimmed) {
39        return LineCategory::FunctionSignature;
40    }
41
42    if is_closing(trimmed) {
43        return LineCategory::ClosingBrace;
44    }
45
46    LineCategory::Logic
47}
48
49pub fn reorder_for_lcurve(content: &str, task_keywords: &[String]) -> String {
50    let lines: Vec<&str> = content.lines().collect();
51    if lines.len() <= 5 {
52        return content.to_string();
53    }
54
55    // Try semantic chunk-based reordering for larger content
56    if lines.len() >= 15 {
57        let chunks = crate::core::semantic_chunks::detect_chunks(content);
58        if chunks.len() >= 3 {
59            let ordered = crate::core::semantic_chunks::order_for_attention(chunks, task_keywords);
60            return crate::core::semantic_chunks::render_with_bridges(&ordered);
61        }
62    }
63
64    // Fall back to line-level reordering for small content
65    let categorized: Vec<CategorizedLine> = lines
66        .iter()
67        .enumerate()
68        .map(|(i, line)| CategorizedLine {
69            line,
70            category: categorize_line(line),
71            original_index: i,
72        })
73        .collect();
74
75    let attention = LearnedAttention::with_defaults();
76    let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
77
78    let mut scored: Vec<(&CategorizedLine, f64)> = categorized
79        .iter()
80        .map(|cl| {
81            let base = category_priority(cl.category);
82            let kw_boost: f64 = if kw_lower.is_empty() {
83                0.0
84            } else {
85                let line_lower = cl.line.to_lowercase();
86                kw_lower
87                    .iter()
88                    .filter(|kw| line_lower.contains(kw.as_str()))
89                    .count() as f64
90                    * 0.5
91            };
92            let n = lines.len().max(1) as f64;
93            let orig_pos = cl.original_index as f64 / n;
94            let orig_attention = attention.weight(orig_pos);
95            let score = base + kw_boost + orig_attention * 0.1;
96            (cl, score)
97        })
98        .collect();
99
100    scored.sort_by(|a, b| {
101        b.1.partial_cmp(&a.1)
102            .unwrap_or(std::cmp::Ordering::Equal)
103            .then_with(|| a.0.original_index.cmp(&b.0.original_index))
104    });
105
106    scored
107        .iter()
108        .filter(|(cl, _)| cl.category != LineCategory::Empty || cl.original_index == 0)
109        .map(|(cl, _)| cl.line)
110        .collect::<Vec<_>>()
111        .join("\n")
112}
113
114fn category_priority(cat: LineCategory) -> f64 {
115    match cat {
116        LineCategory::ErrorHandling => 5.0,
117        LineCategory::Import => 4.0,
118        LineCategory::TypeDefinition => 3.5,
119        LineCategory::FunctionSignature => 3.0,
120        LineCategory::Logic => 1.0,
121        LineCategory::ClosingBrace => 0.2,
122        LineCategory::Empty => 0.1,
123    }
124}
125
126fn is_error_handling(line: &str) -> bool {
127    line.starts_with("return Err(")
128        || line.starts_with("Err(")
129        || line.starts_with("bail!(")
130        || line.contains(".map_err(")
131        || line.starts_with("raise ")
132        || line.starts_with("throw ")
133        || line.starts_with("catch ")
134        || line.starts_with("except ")
135        || line.starts_with("panic!(")
136        || line.contains("Error::")
137}
138
139fn is_import(line: &str) -> bool {
140    line.starts_with("use ")
141        || line.starts_with("import ")
142        || line.starts_with("from ")
143        || line.starts_with("#include")
144        || line.starts_with("require(")
145        || line.starts_with("const ") && line.contains("require(")
146}
147
148fn is_type_def(line: &str) -> bool {
149    let starters = [
150        "struct ",
151        "pub struct ",
152        "enum ",
153        "pub enum ",
154        "trait ",
155        "pub trait ",
156        "type ",
157        "pub type ",
158        "interface ",
159        "export interface ",
160        "class ",
161        "export class ",
162        "typedef ",
163        "data ",
164    ];
165    starters.iter().any(|s| line.starts_with(s))
166}
167
168fn is_fn_signature(line: &str) -> bool {
169    let starters = [
170        "fn ",
171        "pub fn ",
172        "async fn ",
173        "pub async fn ",
174        "function ",
175        "export function ",
176        "async function ",
177        "def ",
178        "async def ",
179        "func ",
180        "pub(crate) fn ",
181        "pub(super) fn ",
182    ];
183    starters.iter().any(|s| line.starts_with(s))
184}
185
186fn is_closing(line: &str) -> bool {
187    matches!(line, "}" | "};" | ");" | "});" | ")" | "})")
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn categorize_lines_correctly() {
196        assert_eq!(categorize_line("use std::io;"), LineCategory::Import);
197        assert_eq!(
198            categorize_line("pub struct Foo {"),
199            LineCategory::TypeDefinition
200        );
201        assert_eq!(
202            categorize_line("fn main() {"),
203            LineCategory::FunctionSignature
204        );
205        assert_eq!(
206            categorize_line("return Err(e);"),
207            LineCategory::ErrorHandling
208        );
209        assert_eq!(categorize_line("}"), LineCategory::ClosingBrace);
210        assert_eq!(categorize_line("let x = 1;"), LineCategory::Logic);
211        assert_eq!(categorize_line(""), LineCategory::Empty);
212    }
213
214    #[test]
215    fn reorder_puts_errors_and_imports_first() {
216        let content = "let x = 1;\nuse std::io;\n}\nreturn Err(e);\npub struct Foo {\nfn main() {";
217        let result = reorder_for_lcurve(content, &[]);
218        let lines: Vec<&str> = result.lines().collect();
219        assert!(
220            lines[0].contains("Err") || lines[0].contains("use "),
221            "first line should be error handling or import, got: {}",
222            lines[0]
223        );
224    }
225
226    #[test]
227    fn task_keywords_boost_relevant_lines() {
228        let content = "fn unrelated() {\nlet x = 1;\n}\nfn validate_token() {\nlet y = 2;\n}";
229        let result = reorder_for_lcurve(content, &["validate".to_string()]);
230        let lines: Vec<&str> = result.lines().collect();
231        let validate_pos = lines.iter().position(|l| l.contains("validate"));
232        let unrelated_pos = lines.iter().position(|l| l.contains("unrelated"));
233        if let (Some(v), Some(u)) = (validate_pos, unrelated_pos) {
234            assert!(v < u, "validate should appear before unrelated");
235        }
236    }
237}