agcodex_ast/
compactor.rs

1//! AI Distiller-style code compaction for 70-95% compression
2
3use crate::error::AstError;
4use crate::error::AstResult;
5use crate::types::ParsedAst;
6// use crate::language_registry::Language; // unused
7use std::collections::HashSet;
8use tree_sitter::Node;
9use tree_sitter::TreeCursor;
10
11/// Compression level for AST compaction
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum CompressionLevel {
14    /// Light: 70% compression - preserve more details
15    Light,
16    /// Standard: 85% compression - balanced
17    Standard,
18    /// Medium: alias for Standard
19    Medium,
20    /// Maximum: 95% compression - extreme compaction
21    Maximum,
22    /// Hard: alias for Maximum
23    Hard,
24}
25
26impl CompressionLevel {
27    /// Get target compression ratio
28    pub const fn target_ratio(&self) -> f64 {
29        match self {
30            Self::Light => 0.70,
31            Self::Standard | Self::Medium => 0.85,
32            Self::Maximum | Self::Hard => 0.95,
33        }
34    }
35
36    /// Should preserve comments
37    pub const fn preserve_comments(&self) -> bool {
38        matches!(self, Self::Light)
39    }
40
41    /// Should preserve implementation details
42    pub const fn preserve_implementation(&self) -> bool {
43        matches!(self, Self::Light)
44    }
45
46    /// Should preserve private members
47    pub const fn preserve_private(&self) -> bool {
48        !matches!(self, Self::Maximum | Self::Hard)
49    }
50}
51
52/// AST compactor for code compression
53#[derive(Debug)]
54pub struct AstCompactor {
55    compression_level: CompressionLevel,
56    preserved_node_types: HashSet<String>,
57    removed_node_types: HashSet<String>,
58}
59
60impl AstCompactor {
61    /// Create a new compactor with specified compression level
62    pub fn new(compression_level: CompressionLevel) -> Self {
63        let mut compactor = Self {
64            compression_level,
65            preserved_node_types: HashSet::new(),
66            removed_node_types: HashSet::new(),
67        };
68
69        // Configure based on compression level
70        compactor.configure_for_level();
71        compactor
72    }
73
74    /// Configure node types based on compression level
75    fn configure_for_level(&mut self) {
76        // Always preserve these structural elements
77        self.preserved_node_types
78            .insert("function_declaration".to_string());
79        self.preserved_node_types
80            .insert("function_item".to_string());
81        self.preserved_node_types
82            .insert("class_declaration".to_string());
83        self.preserved_node_types.insert("struct_item".to_string());
84        self.preserved_node_types.insert("impl_item".to_string());
85        self.preserved_node_types.insert("enum_item".to_string());
86        self.preserved_node_types
87            .insert("interface_declaration".to_string());
88        self.preserved_node_types.insert("type_alias".to_string());
89        self.preserved_node_types
90            .insert("import_statement".to_string());
91        self.preserved_node_types
92            .insert("use_declaration".to_string());
93
94        match self.compression_level {
95            CompressionLevel::Light => {
96                // Preserve more implementation details
97                self.preserved_node_types
98                    .insert("method_declaration".to_string());
99                self.preserved_node_types
100                    .insert("function_definition".to_string());
101                self.preserved_node_types.insert("comment".to_string());
102                self.preserved_node_types.insert("doc_comment".to_string());
103            }
104            CompressionLevel::Standard | CompressionLevel::Medium => {
105                // Remove most implementation but keep public API
106                self.removed_node_types.insert("block".to_string());
107                self.removed_node_types
108                    .insert("compound_statement".to_string());
109                self.removed_node_types.insert("comment".to_string());
110                self.removed_node_types.insert("line_comment".to_string());
111            }
112            CompressionLevel::Maximum | CompressionLevel::Hard => {
113                // Extreme compaction - only signatures
114                self.removed_node_types.insert("block".to_string());
115                self.removed_node_types
116                    .insert("compound_statement".to_string());
117                self.removed_node_types.insert("comment".to_string());
118                self.removed_node_types.insert("doc_comment".to_string());
119                self.removed_node_types.insert("private".to_string());
120                self.removed_node_types.insert("protected".to_string());
121                self.removed_node_types.insert("internal".to_string());
122            }
123        }
124    }
125
126    /// Compact an AST to compressed representation
127    pub fn compact(&self, ast: &ParsedAst) -> AstResult<String> {
128        let mut output = String::new();
129        let source = ast.source.as_bytes();
130
131        // Add file header with language
132        output.push_str(&format!("// Language: {}\n", ast.language.name()));
133        output.push_str("// Compacted representation\n\n");
134
135        // Walk the tree and extract relevant nodes
136        let mut cursor = ast.tree.root_node().walk();
137        self.compact_node(&mut cursor, source, &mut output, 0)?;
138
139        // Calculate compression ratio
140        let original_size = ast.source.len();
141        let compressed_size = output.len();
142        let ratio = 1.0 - (compressed_size as f64 / original_size as f64);
143
144        // Add compression stats
145        output.push_str(&format!(
146            "\n// Compression: {:.1}% ({}→{} bytes)",
147            ratio * 100.0,
148            original_size,
149            compressed_size
150        ));
151
152        Ok(output)
153    }
154
155    /// Recursively compact a node
156    fn compact_node(
157        &self,
158        cursor: &mut TreeCursor,
159        source: &[u8],
160        output: &mut String,
161        depth: usize,
162    ) -> AstResult<()> {
163        let node = cursor.node();
164        let node_type = node.kind();
165
166        // Skip removed node types
167        if self.removed_node_types.contains(node_type) {
168            return Ok(());
169        }
170
171        // Check if this is a significant node to preserve
172        if self.should_preserve_node(&node) {
173            let indent = "  ".repeat(depth);
174
175            // Extract node text based on type
176            match node_type {
177                "function_declaration" | "function_definition" | "function_item" => {
178                    self.extract_function_signature(&node, source, output, &indent)?;
179                }
180                "class_declaration" | "struct_item" | "impl_item" => {
181                    self.extract_class_structure(&node, source, output, &indent)?;
182                }
183                "import_statement" | "use_declaration" => {
184                    let text = std::str::from_utf8(&source[node.byte_range()])
185                        .map_err(|e| AstError::ParserError(e.to_string()))?;
186                    output.push_str(&indent);
187                    output.push_str(text.trim());
188                    output.push('\n');
189                }
190                _ => {
191                    // For other preserved nodes, extract simplified representation
192                    if node.child_count() == 0 {
193                        // Leaf node - include text if significant
194                        if self.is_significant_leaf(&node) {
195                            let text = std::str::from_utf8(&source[node.byte_range()])
196                                .map_err(|e| AstError::ParserError(e.to_string()))?;
197                            output.push_str(&indent);
198                            output.push_str(text.trim());
199                            output.push('\n');
200                        }
201                    }
202                }
203            }
204        }
205
206        // Process children
207        if cursor.goto_first_child() {
208            loop {
209                self.compact_node(cursor, source, output, depth + 1)?;
210                if !cursor.goto_next_sibling() {
211                    break;
212                }
213            }
214            cursor.goto_parent();
215        }
216
217        Ok(())
218    }
219
220    /// Check if a node should be preserved
221    fn should_preserve_node(&self, node: &Node) -> bool {
222        let node_type = node.kind();
223
224        // Check preserved types
225        if self.preserved_node_types.contains(node_type) {
226            return true;
227        }
228
229        // Check visibility for Maximum compression
230        if matches!(
231            self.compression_level,
232            CompressionLevel::Maximum | CompressionLevel::Hard
233        ) {
234            // Only preserve public members
235            if let Some(parent) = node.parent() {
236                let parent_text = parent.kind();
237                if parent_text.contains("private") || parent_text.contains("protected") {
238                    return false;
239                }
240            }
241        }
242
243        // Check for significant structural nodes
244        matches!(
245            node_type,
246            "module"
247                | "namespace"
248                | "package_declaration"
249                | "trait_item"
250                | "interface_declaration"
251                | "protocol_declaration"
252        )
253    }
254
255    /// Extract function signature without implementation
256    fn extract_function_signature(
257        &self,
258        node: &Node,
259        source: &[u8],
260        output: &mut String,
261        indent: &str,
262    ) -> AstResult<()> {
263        // Find the signature part (before the body)
264        let mut sig_end = node.start_byte();
265
266        for i in 0..node.child_count() {
267            if let Some(child) = node.child(i) {
268                let child_type = child.kind();
269                if child_type == "block"
270                    || child_type == "compound_statement"
271                    || child_type == "function_body"
272                {
273                    sig_end = child.start_byte();
274                    break;
275                }
276            }
277        }
278
279        if sig_end == node.start_byte() {
280            sig_end = node.end_byte();
281        }
282
283        let signature = std::str::from_utf8(&source[node.start_byte()..sig_end])
284            .map_err(|e| AstError::ParserError(e.to_string()))?;
285
286        output.push_str(indent);
287        output.push_str(signature.trim());
288
289        // Add semicolon if not present and not preserving implementation
290        if !self.compression_level.preserve_implementation() && !signature.trim().ends_with(';') {
291            output.push(';');
292        }
293        output.push('\n');
294
295        Ok(())
296    }
297
298    /// Extract class/struct structure
299    fn extract_class_structure(
300        &self,
301        node: &Node,
302        source: &[u8],
303        output: &mut String,
304        indent: &str,
305    ) -> AstResult<()> {
306        // Get class/struct header
307        let mut header_end = node.start_byte();
308        let mut found_body = false;
309
310        for i in 0..node.child_count() {
311            if let Some(child) = node.child(i) {
312                let child_type = child.kind();
313                if child_type == "field_declaration_list"
314                    || child_type == "declaration_list"
315                    || child_type == "class_body"
316                    || child_type == "{"
317                {
318                    header_end = child.start_byte();
319                    found_body = true;
320                    break;
321                }
322            }
323        }
324
325        if !found_body {
326            // No body found, include entire node
327            let text = std::str::from_utf8(&source[node.byte_range()])
328                .map_err(|e| AstError::ParserError(e.to_string()))?;
329            output.push_str(indent);
330            output.push_str(text.trim());
331            output.push('\n');
332            return Ok(());
333        }
334
335        let header = std::str::from_utf8(&source[node.start_byte()..header_end])
336            .map_err(|e| AstError::ParserError(e.to_string()))?;
337
338        output.push_str(indent);
339        output.push_str(header.trim());
340        output.push_str(" {\n");
341
342        // Extract members based on compression level
343        if !matches!(
344            self.compression_level,
345            CompressionLevel::Maximum | CompressionLevel::Hard
346        ) {
347            self.extract_class_members(node, source, output, &format!("{}  ", indent))?;
348        }
349
350        output.push_str(indent);
351        output.push_str("}\n");
352
353        Ok(())
354    }
355
356    /// Extract class members
357    fn extract_class_members(
358        &self,
359        node: &Node,
360        source: &[u8],
361        output: &mut String,
362        indent: &str,
363    ) -> AstResult<()> {
364        let mut cursor = node.walk();
365
366        if cursor.goto_first_child() {
367            loop {
368                let child = cursor.node();
369                let child_type = child.kind();
370
371                // Check for member nodes
372                if matches!(
373                    child_type,
374                    "field_declaration"
375                        | "method_declaration"
376                        | "function_declaration"
377                        | "property_declaration"
378                        | "field"
379                        | "method"
380                ) {
381                    // Check visibility
382                    if matches!(
383                        self.compression_level,
384                        CompressionLevel::Maximum | CompressionLevel::Hard
385                    ) {
386                        // Skip private/protected members
387                        let child_text =
388                            std::str::from_utf8(&source[child.byte_range()]).unwrap_or("");
389                        if child_text.contains("private") || child_text.contains("protected") {
390                            continue;
391                        }
392                    }
393
394                    // Extract member signature
395                    self.extract_member_signature(&child, source, output, indent)?;
396                }
397
398                if !cursor.goto_next_sibling() {
399                    break;
400                }
401            }
402        }
403
404        Ok(())
405    }
406
407    /// Extract member signature
408    fn extract_member_signature(
409        &self,
410        node: &Node,
411        source: &[u8],
412        output: &mut String,
413        indent: &str,
414    ) -> AstResult<()> {
415        // Find signature without body
416        let mut sig_end = node.end_byte();
417
418        for i in 0..node.child_count() {
419            if let Some(child) = node.child(i)
420                && (child.kind() == "block" || child.kind() == "compound_statement")
421            {
422                sig_end = child.start_byte();
423                break;
424            }
425        }
426
427        let signature = std::str::from_utf8(&source[node.start_byte()..sig_end])
428            .map_err(|e| AstError::ParserError(e.to_string()))?;
429
430        output.push_str(indent);
431        output.push_str(signature.trim());
432        if !signature.trim().ends_with(';') {
433            output.push(';');
434        }
435        output.push('\n');
436
437        Ok(())
438    }
439
440    /// Check if a leaf node is significant
441    fn is_significant_leaf(&self, node: &Node) -> bool {
442        let node_type = node.kind();
443        matches!(
444            node_type,
445            "identifier"
446                | "type_identifier"
447                | "string_literal"
448                | "number_literal"
449                | "boolean_literal"
450        ) && node.parent().is_some_and(|p| self.should_preserve_node(&p))
451    }
452
453    /// Calculate compression statistics
454    pub fn calculate_stats(&self, original: &str, compressed: &str) -> CompressionStats {
455        let original_size = original.len();
456        let compressed_size = compressed.len();
457        let original_lines = original.lines().count();
458        let compressed_lines = compressed.lines().count();
459
460        CompressionStats {
461            original_bytes: original_size,
462            compressed_bytes: compressed_size,
463            compression_ratio: 1.0 - (compressed_size as f64 / original_size as f64),
464            original_lines,
465            compressed_lines,
466            line_reduction: 1.0 - (compressed_lines as f64 / original_lines as f64),
467        }
468    }
469}
470
471/// Compression statistics
472#[derive(Debug, Clone)]
473pub struct CompressionStats {
474    pub original_bytes: usize,
475    pub compressed_bytes: usize,
476    pub compression_ratio: f64,
477    pub original_lines: usize,
478    pub compressed_lines: usize,
479    pub line_reduction: f64,
480}
481
482impl CompressionStats {
483    /// Check if target compression was achieved
484    pub fn meets_target(&self, level: CompressionLevel) -> bool {
485        self.compression_ratio >= level.target_ratio()
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use crate::language_registry::LanguageRegistry;
493
494    #[test]
495    fn test_compression_levels() {
496        let light = CompressionLevel::Light;
497        let standard = CompressionLevel::Standard;
498        let maximum = CompressionLevel::Maximum;
499
500        assert_eq!(light.target_ratio(), 0.70);
501        assert_eq!(standard.target_ratio(), 0.85);
502        assert_eq!(maximum.target_ratio(), 0.95);
503
504        assert!(light.preserve_comments());
505        assert!(!standard.preserve_comments());
506        assert!(!maximum.preserve_comments());
507    }
508
509    #[test]
510    fn test_compaction() {
511        let registry = LanguageRegistry::new();
512        let compactor = AstCompactor::new(CompressionLevel::Standard);
513
514        let code = r#"
515// This is a comment
516fn calculate_fibonacci(n: u32) -> u32 {
517    // Implementation details
518    if n <= 1 {
519        return n;
520    }
521    calculate_fibonacci(n - 1) + calculate_fibonacci(n - 2)
522}
523
524pub struct Calculator {
525    value: i32,
526}
527
528impl Calculator {
529    pub fn new() -> Self {
530        Self { value: 0 }
531    }
532    
533    pub fn add(&mut self, x: i32) {
534        self.value += x;
535    }
536    
537    private fn reset(&mut self) {
538        self.value = 0;
539    }
540}
541"#;
542
543        let ast = registry.parse(&crate::Language::Rust, code).unwrap();
544        let compressed = compactor.compact(&ast).unwrap();
545
546        // Should be significantly smaller
547        assert!(compressed.len() < code.len());
548
549        // Should preserve structure
550        assert!(compressed.contains("fn calculate_fibonacci"));
551        assert!(compressed.contains("pub struct Calculator"));
552        assert!(compressed.contains("pub fn new"));
553        assert!(compressed.contains("pub fn add"));
554
555        // Should not have implementation details in Standard mode
556        assert!(!compressed.contains("if n <= 1"));
557        assert!(!compressed.contains("self.value += x"));
558    }
559}