1use crate::error::AstError;
4use crate::error::AstResult;
5use crate::types::ParsedAst;
6use std::collections::HashSet;
8use tree_sitter::Node;
9use tree_sitter::TreeCursor;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum CompressionLevel {
14 Light,
16 Standard,
18 Medium,
20 Maximum,
22 Hard,
24}
25
26impl CompressionLevel {
27 pub const fn target_ratio(&self) -> f64 {
29 match self {
30 Self::Light => 0.70,
31 Self::Standard | Self::Medium => 0.85,
32 Self::Maximum | Self::Hard => 0.95,
33 }
34 }
35
36 pub const fn preserve_comments(&self) -> bool {
38 matches!(self, Self::Light)
39 }
40
41 pub const fn preserve_implementation(&self) -> bool {
43 matches!(self, Self::Light)
44 }
45
46 pub const fn preserve_private(&self) -> bool {
48 !matches!(self, Self::Maximum | Self::Hard)
49 }
50}
51
52#[derive(Debug)]
54pub struct AstCompactor {
55 compression_level: CompressionLevel,
56 preserved_node_types: HashSet<String>,
57 removed_node_types: HashSet<String>,
58}
59
60impl AstCompactor {
61 pub fn new(compression_level: CompressionLevel) -> Self {
63 let mut compactor = Self {
64 compression_level,
65 preserved_node_types: HashSet::new(),
66 removed_node_types: HashSet::new(),
67 };
68
69 compactor.configure_for_level();
71 compactor
72 }
73
74 fn configure_for_level(&mut self) {
76 self.preserved_node_types
78 .insert("function_declaration".to_string());
79 self.preserved_node_types
80 .insert("function_item".to_string());
81 self.preserved_node_types
82 .insert("class_declaration".to_string());
83 self.preserved_node_types.insert("struct_item".to_string());
84 self.preserved_node_types.insert("impl_item".to_string());
85 self.preserved_node_types.insert("enum_item".to_string());
86 self.preserved_node_types
87 .insert("interface_declaration".to_string());
88 self.preserved_node_types.insert("type_alias".to_string());
89 self.preserved_node_types
90 .insert("import_statement".to_string());
91 self.preserved_node_types
92 .insert("use_declaration".to_string());
93
94 match self.compression_level {
95 CompressionLevel::Light => {
96 self.preserved_node_types
98 .insert("method_declaration".to_string());
99 self.preserved_node_types
100 .insert("function_definition".to_string());
101 self.preserved_node_types.insert("comment".to_string());
102 self.preserved_node_types.insert("doc_comment".to_string());
103 }
104 CompressionLevel::Standard | CompressionLevel::Medium => {
105 self.removed_node_types.insert("block".to_string());
107 self.removed_node_types
108 .insert("compound_statement".to_string());
109 self.removed_node_types.insert("comment".to_string());
110 self.removed_node_types.insert("line_comment".to_string());
111 }
112 CompressionLevel::Maximum | CompressionLevel::Hard => {
113 self.removed_node_types.insert("block".to_string());
115 self.removed_node_types
116 .insert("compound_statement".to_string());
117 self.removed_node_types.insert("comment".to_string());
118 self.removed_node_types.insert("doc_comment".to_string());
119 self.removed_node_types.insert("private".to_string());
120 self.removed_node_types.insert("protected".to_string());
121 self.removed_node_types.insert("internal".to_string());
122 }
123 }
124 }
125
126 pub fn compact(&self, ast: &ParsedAst) -> AstResult<String> {
128 let mut output = String::new();
129 let source = ast.source.as_bytes();
130
131 output.push_str(&format!("// Language: {}\n", ast.language.name()));
133 output.push_str("// Compacted representation\n\n");
134
135 let mut cursor = ast.tree.root_node().walk();
137 self.compact_node(&mut cursor, source, &mut output, 0)?;
138
139 let original_size = ast.source.len();
141 let compressed_size = output.len();
142 let ratio = 1.0 - (compressed_size as f64 / original_size as f64);
143
144 output.push_str(&format!(
146 "\n// Compression: {:.1}% ({}→{} bytes)",
147 ratio * 100.0,
148 original_size,
149 compressed_size
150 ));
151
152 Ok(output)
153 }
154
155 fn compact_node(
157 &self,
158 cursor: &mut TreeCursor,
159 source: &[u8],
160 output: &mut String,
161 depth: usize,
162 ) -> AstResult<()> {
163 let node = cursor.node();
164 let node_type = node.kind();
165
166 if self.removed_node_types.contains(node_type) {
168 return Ok(());
169 }
170
171 if self.should_preserve_node(&node) {
173 let indent = " ".repeat(depth);
174
175 match node_type {
177 "function_declaration" | "function_definition" | "function_item" => {
178 self.extract_function_signature(&node, source, output, &indent)?;
179 }
180 "class_declaration" | "struct_item" | "impl_item" => {
181 self.extract_class_structure(&node, source, output, &indent)?;
182 }
183 "import_statement" | "use_declaration" => {
184 let text = std::str::from_utf8(&source[node.byte_range()])
185 .map_err(|e| AstError::ParserError(e.to_string()))?;
186 output.push_str(&indent);
187 output.push_str(text.trim());
188 output.push('\n');
189 }
190 _ => {
191 if node.child_count() == 0 {
193 if self.is_significant_leaf(&node) {
195 let text = std::str::from_utf8(&source[node.byte_range()])
196 .map_err(|e| AstError::ParserError(e.to_string()))?;
197 output.push_str(&indent);
198 output.push_str(text.trim());
199 output.push('\n');
200 }
201 }
202 }
203 }
204 }
205
206 if cursor.goto_first_child() {
208 loop {
209 self.compact_node(cursor, source, output, depth + 1)?;
210 if !cursor.goto_next_sibling() {
211 break;
212 }
213 }
214 cursor.goto_parent();
215 }
216
217 Ok(())
218 }
219
220 fn should_preserve_node(&self, node: &Node) -> bool {
222 let node_type = node.kind();
223
224 if self.preserved_node_types.contains(node_type) {
226 return true;
227 }
228
229 if matches!(
231 self.compression_level,
232 CompressionLevel::Maximum | CompressionLevel::Hard
233 ) {
234 if let Some(parent) = node.parent() {
236 let parent_text = parent.kind();
237 if parent_text.contains("private") || parent_text.contains("protected") {
238 return false;
239 }
240 }
241 }
242
243 matches!(
245 node_type,
246 "module"
247 | "namespace"
248 | "package_declaration"
249 | "trait_item"
250 | "interface_declaration"
251 | "protocol_declaration"
252 )
253 }
254
255 fn extract_function_signature(
257 &self,
258 node: &Node,
259 source: &[u8],
260 output: &mut String,
261 indent: &str,
262 ) -> AstResult<()> {
263 let mut sig_end = node.start_byte();
265
266 for i in 0..node.child_count() {
267 if let Some(child) = node.child(i) {
268 let child_type = child.kind();
269 if child_type == "block"
270 || child_type == "compound_statement"
271 || child_type == "function_body"
272 {
273 sig_end = child.start_byte();
274 break;
275 }
276 }
277 }
278
279 if sig_end == node.start_byte() {
280 sig_end = node.end_byte();
281 }
282
283 let signature = std::str::from_utf8(&source[node.start_byte()..sig_end])
284 .map_err(|e| AstError::ParserError(e.to_string()))?;
285
286 output.push_str(indent);
287 output.push_str(signature.trim());
288
289 if !self.compression_level.preserve_implementation() && !signature.trim().ends_with(';') {
291 output.push(';');
292 }
293 output.push('\n');
294
295 Ok(())
296 }
297
298 fn extract_class_structure(
300 &self,
301 node: &Node,
302 source: &[u8],
303 output: &mut String,
304 indent: &str,
305 ) -> AstResult<()> {
306 let mut header_end = node.start_byte();
308 let mut found_body = false;
309
310 for i in 0..node.child_count() {
311 if let Some(child) = node.child(i) {
312 let child_type = child.kind();
313 if child_type == "field_declaration_list"
314 || child_type == "declaration_list"
315 || child_type == "class_body"
316 || child_type == "{"
317 {
318 header_end = child.start_byte();
319 found_body = true;
320 break;
321 }
322 }
323 }
324
325 if !found_body {
326 let text = std::str::from_utf8(&source[node.byte_range()])
328 .map_err(|e| AstError::ParserError(e.to_string()))?;
329 output.push_str(indent);
330 output.push_str(text.trim());
331 output.push('\n');
332 return Ok(());
333 }
334
335 let header = std::str::from_utf8(&source[node.start_byte()..header_end])
336 .map_err(|e| AstError::ParserError(e.to_string()))?;
337
338 output.push_str(indent);
339 output.push_str(header.trim());
340 output.push_str(" {\n");
341
342 if !matches!(
344 self.compression_level,
345 CompressionLevel::Maximum | CompressionLevel::Hard
346 ) {
347 self.extract_class_members(node, source, output, &format!("{} ", indent))?;
348 }
349
350 output.push_str(indent);
351 output.push_str("}\n");
352
353 Ok(())
354 }
355
356 fn extract_class_members(
358 &self,
359 node: &Node,
360 source: &[u8],
361 output: &mut String,
362 indent: &str,
363 ) -> AstResult<()> {
364 let mut cursor = node.walk();
365
366 if cursor.goto_first_child() {
367 loop {
368 let child = cursor.node();
369 let child_type = child.kind();
370
371 if matches!(
373 child_type,
374 "field_declaration"
375 | "method_declaration"
376 | "function_declaration"
377 | "property_declaration"
378 | "field"
379 | "method"
380 ) {
381 if matches!(
383 self.compression_level,
384 CompressionLevel::Maximum | CompressionLevel::Hard
385 ) {
386 let child_text =
388 std::str::from_utf8(&source[child.byte_range()]).unwrap_or("");
389 if child_text.contains("private") || child_text.contains("protected") {
390 continue;
391 }
392 }
393
394 self.extract_member_signature(&child, source, output, indent)?;
396 }
397
398 if !cursor.goto_next_sibling() {
399 break;
400 }
401 }
402 }
403
404 Ok(())
405 }
406
407 fn extract_member_signature(
409 &self,
410 node: &Node,
411 source: &[u8],
412 output: &mut String,
413 indent: &str,
414 ) -> AstResult<()> {
415 let mut sig_end = node.end_byte();
417
418 for i in 0..node.child_count() {
419 if let Some(child) = node.child(i)
420 && (child.kind() == "block" || child.kind() == "compound_statement")
421 {
422 sig_end = child.start_byte();
423 break;
424 }
425 }
426
427 let signature = std::str::from_utf8(&source[node.start_byte()..sig_end])
428 .map_err(|e| AstError::ParserError(e.to_string()))?;
429
430 output.push_str(indent);
431 output.push_str(signature.trim());
432 if !signature.trim().ends_with(';') {
433 output.push(';');
434 }
435 output.push('\n');
436
437 Ok(())
438 }
439
440 fn is_significant_leaf(&self, node: &Node) -> bool {
442 let node_type = node.kind();
443 matches!(
444 node_type,
445 "identifier"
446 | "type_identifier"
447 | "string_literal"
448 | "number_literal"
449 | "boolean_literal"
450 ) && node.parent().is_some_and(|p| self.should_preserve_node(&p))
451 }
452
453 pub fn calculate_stats(&self, original: &str, compressed: &str) -> CompressionStats {
455 let original_size = original.len();
456 let compressed_size = compressed.len();
457 let original_lines = original.lines().count();
458 let compressed_lines = compressed.lines().count();
459
460 CompressionStats {
461 original_bytes: original_size,
462 compressed_bytes: compressed_size,
463 compression_ratio: 1.0 - (compressed_size as f64 / original_size as f64),
464 original_lines,
465 compressed_lines,
466 line_reduction: 1.0 - (compressed_lines as f64 / original_lines as f64),
467 }
468 }
469}
470
471#[derive(Debug, Clone)]
473pub struct CompressionStats {
474 pub original_bytes: usize,
475 pub compressed_bytes: usize,
476 pub compression_ratio: f64,
477 pub original_lines: usize,
478 pub compressed_lines: usize,
479 pub line_reduction: f64,
480}
481
482impl CompressionStats {
483 pub fn meets_target(&self, level: CompressionLevel) -> bool {
485 self.compression_ratio >= level.target_ratio()
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::language_registry::LanguageRegistry;
493
494 #[test]
495 fn test_compression_levels() {
496 let light = CompressionLevel::Light;
497 let standard = CompressionLevel::Standard;
498 let maximum = CompressionLevel::Maximum;
499
500 assert_eq!(light.target_ratio(), 0.70);
501 assert_eq!(standard.target_ratio(), 0.85);
502 assert_eq!(maximum.target_ratio(), 0.95);
503
504 assert!(light.preserve_comments());
505 assert!(!standard.preserve_comments());
506 assert!(!maximum.preserve_comments());
507 }
508
509 #[test]
510 fn test_compaction() {
511 let registry = LanguageRegistry::new();
512 let compactor = AstCompactor::new(CompressionLevel::Standard);
513
514 let code = r#"
515// This is a comment
516fn calculate_fibonacci(n: u32) -> u32 {
517 // Implementation details
518 if n <= 1 {
519 return n;
520 }
521 calculate_fibonacci(n - 1) + calculate_fibonacci(n - 2)
522}
523
524pub struct Calculator {
525 value: i32,
526}
527
528impl Calculator {
529 pub fn new() -> Self {
530 Self { value: 0 }
531 }
532
533 pub fn add(&mut self, x: i32) {
534 self.value += x;
535 }
536
537 private fn reset(&mut self) {
538 self.value = 0;
539 }
540}
541"#;
542
543 let ast = registry.parse(&crate::Language::Rust, code).unwrap();
544 let compressed = compactor.compact(&ast).unwrap();
545
546 assert!(compressed.len() < code.len());
548
549 assert!(compressed.contains("fn calculate_fibonacci"));
551 assert!(compressed.contains("pub struct Calculator"));
552 assert!(compressed.contains("pub fn new"));
553 assert!(compressed.contains("pub fn add"));
554
555 assert!(!compressed.contains("if n <= 1"));
557 assert!(!compressed.contains("self.value += x"));
558 }
559}