Skip to main content

lean_ctx/core/neural/
context_reorder.rs

1use super::attention_learned::LearnedAttention;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
4pub enum LineCategory {
5    ErrorHandling,
6    Import,
7    TypeDefinition,
8    FunctionSignature,
9    Logic,
10    ClosingBrace,
11    Empty,
12}
13
14pub struct CategorizedLine<'a> {
15    pub line: &'a str,
16    pub category: LineCategory,
17    pub original_index: usize,
18}
19
20pub fn categorize_line(line: &str) -> LineCategory {
21    let trimmed = line.trim();
22    if trimmed.is_empty() {
23        return LineCategory::Empty;
24    }
25
26    if is_error_handling(trimmed) {
27        return LineCategory::ErrorHandling;
28    }
29
30    if is_import(trimmed) {
31        return LineCategory::Import;
32    }
33
34    if is_type_def(trimmed) {
35        return LineCategory::TypeDefinition;
36    }
37
38    if is_fn_signature(trimmed) {
39        return LineCategory::FunctionSignature;
40    }
41
42    if is_closing(trimmed) {
43        return LineCategory::ClosingBrace;
44    }
45
46    LineCategory::Logic
47}
48
49pub fn reorder_for_lcurve(content: &str, task_keywords: &[String]) -> String {
50    let lines: Vec<&str> = content.lines().collect();
51    if lines.len() <= 5 {
52        return content.to_string();
53    }
54
55    let categorized: Vec<CategorizedLine> = lines
56        .iter()
57        .enumerate()
58        .map(|(i, line)| CategorizedLine {
59            line,
60            category: categorize_line(line),
61            original_index: i,
62        })
63        .collect();
64
65    let attention = LearnedAttention::with_defaults();
66    let kw_lower: Vec<String> = task_keywords.iter().map(|k| k.to_lowercase()).collect();
67
68    let mut scored: Vec<(&CategorizedLine, f64)> = categorized
69        .iter()
70        .map(|cl| {
71            let base = category_priority(cl.category);
72            let kw_boost: f64 = if !kw_lower.is_empty() {
73                let line_lower = cl.line.to_lowercase();
74                kw_lower
75                    .iter()
76                    .filter(|kw| line_lower.contains(kw.as_str()))
77                    .count() as f64
78                    * 0.5
79            } else {
80                0.0
81            };
82            let n = lines.len().max(1) as f64;
83            let orig_pos = cl.original_index as f64 / n;
84            let orig_attention = attention.weight(orig_pos);
85            let score = base + kw_boost + orig_attention * 0.1;
86            (cl, score)
87        })
88        .collect();
89
90    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
91
92    scored
93        .iter()
94        .filter(|(cl, _)| cl.category != LineCategory::Empty || cl.original_index == 0)
95        .map(|(cl, _)| cl.line)
96        .collect::<Vec<_>>()
97        .join("\n")
98}
99
100fn category_priority(cat: LineCategory) -> f64 {
101    match cat {
102        LineCategory::ErrorHandling => 5.0,
103        LineCategory::Import => 4.0,
104        LineCategory::TypeDefinition => 3.5,
105        LineCategory::FunctionSignature => 3.0,
106        LineCategory::Logic => 1.0,
107        LineCategory::ClosingBrace => 0.2,
108        LineCategory::Empty => 0.1,
109    }
110}
111
112fn is_error_handling(line: &str) -> bool {
113    line.starts_with("return Err(")
114        || line.starts_with("Err(")
115        || line.starts_with("bail!(")
116        || line.contains(".map_err(")
117        || line.starts_with("raise ")
118        || line.starts_with("throw ")
119        || line.starts_with("catch ")
120        || line.starts_with("except ")
121        || line.starts_with("panic!(")
122        || line.contains("Error::")
123}
124
125fn is_import(line: &str) -> bool {
126    line.starts_with("use ")
127        || line.starts_with("import ")
128        || line.starts_with("from ")
129        || line.starts_with("#include")
130        || line.starts_with("require(")
131        || line.starts_with("const ") && line.contains("require(")
132}
133
134fn is_type_def(line: &str) -> bool {
135    let starters = [
136        "struct ",
137        "pub struct ",
138        "enum ",
139        "pub enum ",
140        "trait ",
141        "pub trait ",
142        "type ",
143        "pub type ",
144        "interface ",
145        "export interface ",
146        "class ",
147        "export class ",
148        "typedef ",
149        "data ",
150    ];
151    starters.iter().any(|s| line.starts_with(s))
152}
153
154fn is_fn_signature(line: &str) -> bool {
155    let starters = [
156        "fn ",
157        "pub fn ",
158        "async fn ",
159        "pub async fn ",
160        "function ",
161        "export function ",
162        "async function ",
163        "def ",
164        "async def ",
165        "func ",
166        "pub(crate) fn ",
167        "pub(super) fn ",
168    ];
169    starters.iter().any(|s| line.starts_with(s))
170}
171
172fn is_closing(line: &str) -> bool {
173    matches!(line, "}" | "};" | ");" | "});" | ")" | "})")
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn categorize_lines_correctly() {
182        assert_eq!(categorize_line("use std::io;"), LineCategory::Import);
183        assert_eq!(
184            categorize_line("pub struct Foo {"),
185            LineCategory::TypeDefinition
186        );
187        assert_eq!(
188            categorize_line("fn main() {"),
189            LineCategory::FunctionSignature
190        );
191        assert_eq!(
192            categorize_line("return Err(e);"),
193            LineCategory::ErrorHandling
194        );
195        assert_eq!(categorize_line("}"), LineCategory::ClosingBrace);
196        assert_eq!(categorize_line("let x = 1;"), LineCategory::Logic);
197        assert_eq!(categorize_line(""), LineCategory::Empty);
198    }
199
200    #[test]
201    fn reorder_puts_errors_and_imports_first() {
202        let content = "let x = 1;\nuse std::io;\n}\nreturn Err(e);\npub struct Foo {\nfn main() {";
203        let result = reorder_for_lcurve(content, &[]);
204        let lines: Vec<&str> = result.lines().collect();
205        assert!(
206            lines[0].contains("Err") || lines[0].contains("use "),
207            "first line should be error handling or import, got: {}",
208            lines[0]
209        );
210    }
211
212    #[test]
213    fn task_keywords_boost_relevant_lines() {
214        let content = "fn unrelated() {\nlet x = 1;\n}\nfn validate_token() {\nlet y = 2;\n}";
215        let result = reorder_for_lcurve(content, &["validate".to_string()]);
216        let lines: Vec<&str> = result.lines().collect();
217        let validate_pos = lines.iter().position(|l| l.contains("validate"));
218        let unrelated_pos = lines.iter().position(|l| l.contains("unrelated"));
219        if let (Some(v), Some(u)) = (validate_pos, unrelated_pos) {
220            assert!(v < u, "validate should appear before unrelated");
221        }
222    }
223}