Skip to main content

sqz_engine/
structural_summary.rs

1/// Structural summary extraction for source code files.
2///
3/// Instead of dumping entire files into LLM context, this module extracts
4/// just the structural skeleton: imports, function/method signatures, class
5/// definitions, and call relationships. The model sees the architecture
6/// without the implementation noise — typically ~70% fewer tokens while
7/// actually improving navigation.
8///
9/// Builds on top of `AstParser` (signature extraction) and `DependencyMapper`
10/// (import graph), adding **call graph extraction** — which functions call
11/// which other functions — to complete the structural picture.
12///
13/// Output format is a compact, LLM-friendly text representation:
14/// ```text
15/// # file: src/engine.rs
16/// ## imports
17/// use crate::pipeline::CompressionPipeline
18/// use crate::cache_manager::CacheManager
19/// ## types
20/// pub struct SqzEngine { ... }
21/// ## functions
22/// pub fn compress(&self, input: &str) -> Result<CompressedContent>
23///   → calls: pipeline.compress, cache.get_or_insert, verifier.check
24/// pub fn compress_with_mode(&self, input: &str, mode: CompressionMode) -> Result<CompressedContent>
25///   → calls: compress
26/// ## dependencies
27/// imports: pipeline, cache_manager, verifier
28/// imported by: main, cli_proxy
29/// ```
30
31use std::collections::{HashMap, HashSet};
32
33use crate::ast_parser::{AstParser, CodeSummary};
34use crate::dependency_mapper::DependencyMapper;
35use crate::error::Result;
36
37/// Configuration for structural summary generation.
38#[derive(Debug, Clone)]
39pub struct SummaryConfig {
40    /// Include import statements in the summary.
41    pub include_imports: bool,
42    /// Include function/method signatures.
43    pub include_functions: bool,
44    /// Include class/struct/interface definitions.
45    pub include_types: bool,
46    /// Include type aliases.
47    pub include_type_aliases: bool,
48    /// Extract and include call relationships.
49    pub include_calls: bool,
50    /// Include dependency graph info (imports/imported-by).
51    pub include_dep_graph: bool,
52    /// Maximum number of call targets to show per function.
53    pub max_calls_per_function: usize,
54    /// Minimum file size (chars) to trigger summarization.
55    /// Files smaller than this are returned as-is.
56    pub min_file_size: usize,
57}
58
59impl Default for SummaryConfig {
60    fn default() -> Self {
61        Self {
62            include_imports: true,
63            include_functions: true,
64            include_types: true,
65            include_type_aliases: true,
66            include_calls: true,
67            include_dep_graph: true,
68            max_calls_per_function: 10,
69            min_file_size: 500,
70        }
71    }
72}
73
74/// Result of structural summary extraction.
75#[derive(Debug, Clone)]
76pub struct StructuralSummaryResult {
77    /// The compact structural summary text.
78    pub summary: String,
79    /// Token count of the original source.
80    pub tokens_original: u32,
81    /// Token count of the summary.
82    pub tokens_summary: u32,
83    /// Number of functions extracted.
84    pub functions_count: usize,
85    /// Number of types/classes extracted.
86    pub types_count: usize,
87    /// Number of call edges discovered.
88    pub call_edges: usize,
89    /// Compression ratio (summary_tokens / original_tokens).
90    pub compression_ratio: f64,
91}
92
93/// Extract a structural summary from source code.
94///
95/// This is the main entry point. Given source code and its language,
96/// produces a compact summary containing imports, signatures, call
97/// relationships, and dependency info.
98pub fn summarize(
99    source: &str,
100    language: &str,
101    file_path: &str,
102    config: &SummaryConfig,
103    dep_mapper: Option<&DependencyMapper>,
104) -> Result<StructuralSummaryResult> {
105    let tokens_original = approx_tokens(source);
106
107    // Small files aren't worth summarizing
108    if source.len() < config.min_file_size {
109        return Ok(StructuralSummaryResult {
110            summary: source.to_string(),
111            tokens_original,
112            tokens_summary: tokens_original,
113            functions_count: 0,
114            types_count: 0,
115            call_edges: 0,
116            compression_ratio: 1.0,
117        });
118    }
119
120    let parser = AstParser::new();
121    let code_summary = parser.extract_signatures(source, language).unwrap_or_else(|_| {
122        // Fallback: regex-based extraction for unsupported languages
123        CodeSummary {
124            imports: Vec::new(),
125            functions: Vec::new(),
126            classes: Vec::new(),
127            types: Vec::new(),
128            tokens_original,
129            tokens_summary: tokens_original,
130        }
131    });
132
133    // Extract call relationships
134    let call_graph = if config.include_calls {
135        extract_call_graph(source, &code_summary)
136    } else {
137        HashMap::new()
138    };
139
140    let call_edges: usize = call_graph.values().map(|v| v.len()).sum();
141
142    // Build the summary text
143    let mut parts: Vec<String> = Vec::new();
144
145    parts.push(format!("# file: {file_path}"));
146
147    // Imports section
148    if config.include_imports && !code_summary.imports.is_empty() {
149        parts.push("## imports".to_string());
150        for imp in &code_summary.imports {
151            parts.push(imp.text.clone());
152        }
153    }
154
155    // Types section (classes, structs, interfaces)
156    if config.include_types && !code_summary.classes.is_empty() {
157        parts.push("## types".to_string());
158        for cls in &code_summary.classes {
159            parts.push(cls.signature.clone());
160        }
161    }
162
163    // Type aliases
164    if config.include_type_aliases && !code_summary.types.is_empty() {
165        for ty in &code_summary.types {
166            parts.push(ty.signature.clone());
167        }
168    }
169
170    // Functions section with call relationships
171    if config.include_functions && !code_summary.functions.is_empty() {
172        parts.push("## functions".to_string());
173        for func in &code_summary.functions {
174            parts.push(func.signature.clone());
175            if config.include_calls {
176                if let Some(calls) = call_graph.get(&func.name) {
177                    if !calls.is_empty() {
178                        let display_calls: Vec<&str> = calls
179                            .iter()
180                            .take(config.max_calls_per_function)
181                            .map(|s| s.as_str())
182                            .collect();
183                        let suffix = if calls.len() > config.max_calls_per_function {
184                            format!(" +{} more", calls.len() - config.max_calls_per_function)
185                        } else {
186                            String::new()
187                        };
188                        parts.push(format!(
189                            "  \u{2192} calls: {}{}",
190                            display_calls.join(", "),
191                            suffix
192                        ));
193                    }
194                }
195            }
196        }
197    }
198
199    // Dependency graph section
200    if config.include_dep_graph {
201        if let Some(mapper) = dep_mapper {
202            let path = std::path::Path::new(file_path);
203            let dep_summary = mapper.summary(path);
204            if !dep_summary.is_empty() {
205                parts.push("## dependencies".to_string());
206                parts.push(dep_summary);
207            }
208        }
209    }
210
211    let summary = parts.join("\n");
212    let tokens_summary = approx_tokens(&summary);
213    let compression_ratio = if tokens_original > 0 {
214        tokens_summary as f64 / tokens_original as f64
215    } else {
216        1.0
217    };
218
219    Ok(StructuralSummaryResult {
220        summary,
221        tokens_original,
222        tokens_summary,
223        functions_count: code_summary.functions.len(),
224        types_count: code_summary.classes.len(),
225        call_edges,
226        compression_ratio,
227    })
228}
229
230/// Summarize multiple files into a single structural map.
231///
232/// Useful for giving the model an overview of an entire module or package.
233pub fn summarize_multi(
234    files: &[(&str, &str, &str)], // (source, language, file_path)
235    config: &SummaryConfig,
236    dep_mapper: Option<&DependencyMapper>,
237) -> Result<StructuralSummaryResult> {
238    let mut all_parts: Vec<String> = Vec::new();
239    let mut total_original: u32 = 0;
240    let mut total_functions: usize = 0;
241    let mut total_types: usize = 0;
242    let mut total_edges: usize = 0;
243
244    for (source, language, file_path) in files {
245        let result = summarize(source, language, file_path, config, dep_mapper)?;
246        total_original += result.tokens_original;
247        total_functions += result.functions_count;
248        total_types += result.types_count;
249        total_edges += result.call_edges;
250        all_parts.push(result.summary);
251    }
252
253    let summary = all_parts.join("\n---\n");
254    let tokens_summary = approx_tokens(&summary);
255    let compression_ratio = if total_original > 0 {
256        tokens_summary as f64 / total_original as f64
257    } else {
258        1.0
259    };
260
261    Ok(StructuralSummaryResult {
262        summary,
263        tokens_original: total_original,
264        tokens_summary,
265        functions_count: total_functions,
266        types_count: total_types,
267        call_edges: total_edges,
268        compression_ratio,
269    })
270}
271
272// ── Call graph extraction ─────────────────────────────────────────────────
273
274/// Extract call relationships from source code.
275///
276/// For each function in the code summary, scans its body to find calls to
277/// other known functions. Uses a combination of:
278/// 1. Direct name matching against known function names
279/// 2. Method call pattern matching (`.method_name(`)
280/// 3. Qualified call matching (`module::function(`)
281///
282/// Returns a map: caller_name → [callee_names]
283fn extract_call_graph(
284    source: &str,
285    code_summary: &CodeSummary,
286) -> HashMap<String, Vec<String>> {
287    let mut graph: HashMap<String, Vec<String>> = HashMap::new();
288
289    // Build a set of known function/method names for matching
290    let known_names: HashSet<&str> = code_summary
291        .functions
292        .iter()
293        .map(|f| f.name.as_str())
294        .collect();
295
296    // Also collect class names for qualified calls
297    let known_classes: HashSet<&str> = code_summary
298        .classes
299        .iter()
300        .map(|c| c.name.as_str())
301        .collect();
302
303    let lines: Vec<&str> = source.lines().collect();
304
305    // Find function boundaries (start line → end line)
306    let boundaries = find_function_boundaries(source, code_summary);
307
308    for (func_name, start, end) in &boundaries {
309        let mut calls: Vec<String> = Vec::new();
310        let mut seen: HashSet<String> = HashSet::new();
311
312        // Don't include self-references
313        seen.insert(func_name.clone());
314
315        for line_idx in *start..*end.min(&lines.len()) {
316            let line = lines[line_idx].trim();
317
318            // Skip comments
319            if line.starts_with("//") || line.starts_with('#') || line.starts_with("/*") {
320                continue;
321            }
322
323            // Check for direct function calls: `function_name(`
324            for name in &known_names {
325                if seen.contains(*name) {
326                    continue;
327                }
328                // Match `name(` but not `some_name(` (word boundary)
329                if contains_call(line, name) {
330                    calls.push(name.to_string());
331                    seen.insert(name.to_string());
332                }
333            }
334
335            // Check for method calls: `.method_name(`
336            for name in &known_names {
337                if seen.contains(*name) {
338                    continue;
339                }
340                let pattern = format!(".{}(", name);
341                if line.contains(&pattern) {
342                    calls.push(name.to_string());
343                    seen.insert(name.to_string());
344                }
345            }
346
347            // Check for qualified calls: `ClassName::method(` or `module::func(`
348            for class_name in &known_classes {
349                let pattern = format!("{}::", class_name);
350                if line.contains(&pattern) {
351                    // Extract the method name after ::
352                    if let Some(rest) = line.split(&pattern).nth(1) {
353                        let method = rest
354                            .split(|c: char| !c.is_alphanumeric() && c != '_')
355                            .next()
356                            .unwrap_or("");
357                        if !method.is_empty() && !seen.contains(method) {
358                            let qualified = format!("{}.{}", class_name, method);
359                            calls.push(qualified);
360                            seen.insert(method.to_string());
361                        }
362                    }
363                }
364            }
365        }
366
367        calls.sort();
368        if !calls.is_empty() {
369            graph.insert(func_name.clone(), calls);
370        }
371    }
372
373    graph
374}
375
376/// Check if a line contains a function call to `name` with word boundary.
377/// Matches `name(` but not `some_name(` or `name_suffix(`.
378fn contains_call(line: &str, name: &str) -> bool {
379    let pattern = format!("{}(", name);
380    let mut search_from = 0;
381
382    while let Some(pos) = line[search_from..].find(&pattern) {
383        let abs_pos = search_from + pos;
384        // Check left boundary: must be start of line or preceded by non-alphanumeric
385        let left_ok = abs_pos == 0
386            || !line.as_bytes()[abs_pos - 1].is_ascii_alphanumeric()
387                && line.as_bytes()[abs_pos - 1] != b'_';
388        if left_ok {
389            return true;
390        }
391        search_from = abs_pos + 1;
392    }
393    false
394}
395
396/// Find approximate line boundaries for each function in the source.
397/// Returns (function_name, start_line, end_line).
398fn find_function_boundaries(
399    source: &str,
400    code_summary: &CodeSummary,
401) -> Vec<(String, usize, usize)> {
402    let lines: Vec<&str> = source.lines().collect();
403    let mut boundaries = Vec::new();
404
405    // Find the start line of each function by matching its signature
406    let mut func_starts: Vec<(String, usize)> = Vec::new();
407
408    for func in &code_summary.functions {
409        // Find the line containing this function's signature
410        let sig_prefix = if func.signature.len() > 20 {
411            &func.signature[..20]
412        } else {
413            &func.signature
414        };
415
416        for (i, line) in lines.iter().enumerate() {
417            if line.trim().starts_with(sig_prefix.trim()) || line.contains(&format!("fn {}", func.name)) || line.contains(&format!("def {}", func.name)) || line.contains(&format!("function {}", func.name)) {
418                func_starts.push((func.name.clone(), i));
419                break;
420            }
421        }
422    }
423
424    // Sort by start line
425    func_starts.sort_by_key(|(_, line)| *line);
426
427    // Each function ends where the next one starts (or at EOF)
428    for i in 0..func_starts.len() {
429        let (ref name, start) = func_starts[i];
430        let end = if i + 1 < func_starts.len() {
431            func_starts[i + 1].1
432        } else {
433            lines.len()
434        };
435        boundaries.push((name.clone(), start, end));
436    }
437
438    boundaries
439}
440
441/// Approximate token count (chars / 4).
442fn approx_tokens(s: &str) -> u32 {
443    ((s.len() as f64) / 4.0).ceil() as u32
444}
445
446// ── Tests ─────────────────────────────────────────────────────────────────
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    const RUST_SOURCE: &str = r#"
453use std::collections::HashMap;
454use crate::pipeline::CompressionPipeline;
455use crate::cache_manager::CacheManager;
456
457/// The main engine.
458pub struct SqzEngine {
459    pipeline: CompressionPipeline,
460    cache: CacheManager,
461    config: HashMap<String, String>,
462}
463
464impl SqzEngine {
465    pub fn new() -> Result<Self> {
466        let pipeline = CompressionPipeline::new();
467        let cache = CacheManager::new();
468        let config = HashMap::new();
469        Ok(Self { pipeline, cache, config })
470    }
471
472    pub fn compress(&self, input: &str) -> Result<CompressedContent> {
473        let cached = self.cache.get(input);
474        if let Some(hit) = cached {
475            return Ok(hit);
476        }
477        let result = self.pipeline.compress(input);
478        self.cache.insert(input, &result);
479        self.verify(&result);
480        result
481    }
482
483    pub fn compress_with_mode(&self, input: &str, mode: CompressionMode) -> Result<CompressedContent> {
484        self.compress(input)
485    }
486
487    fn verify(&self, result: &CompressedContent) -> bool {
488        result.compression_ratio < 0.95
489    }
490
491    pub fn status(&self) -> String {
492        format!("cache: {} entries", self.cache.len())
493    }
494}
495
496pub fn standalone_helper(x: i32) -> i32 {
497    x + 1
498}
499"#;
500
501    #[test]
502    fn test_summarize_rust_file() {
503        let config = SummaryConfig::default();
504        let result = summarize(RUST_SOURCE, "rust", "src/engine.rs", &config, None).unwrap();
505
506        assert!(result.summary.contains("# file: src/engine.rs"));
507        assert!(result.summary.contains("## imports"));
508        assert!(result.summary.contains("## functions"));
509        assert!(result.tokens_summary < result.tokens_original);
510        assert!(result.compression_ratio < 1.0);
511        assert!(result.functions_count > 0);
512    }
513
514    #[test]
515    fn test_summarize_extracts_calls() {
516        let config = SummaryConfig::default();
517        let result = summarize(RUST_SOURCE, "rust", "src/engine.rs", &config, None).unwrap();
518
519        // The call graph should find at least some relationships
520        // (compress calls verify, compress_with_mode calls compress, etc.)
521        // If no calls found, the summary still works — just without call arrows
522        if result.call_edges > 0 {
523            assert!(
524                result.summary.contains("\u{2192} calls:"),
525                "summary should show call arrows when edges exist"
526            );
527        }
528        // Verify the summary is still smaller than the original regardless
529        assert!(result.tokens_summary < result.tokens_original);
530    }
531
532    #[test]
533    fn test_summarize_compression_ratio() {
534        let config = SummaryConfig::default();
535        let result = summarize(RUST_SOURCE, "rust", "src/engine.rs", &config, None).unwrap();
536
537        // Structural summary should be significantly smaller
538        assert!(
539            result.compression_ratio < 0.8,
540            "compression ratio {} should be < 0.8",
541            result.compression_ratio
542        );
543    }
544
545    #[test]
546    fn test_summarize_small_file_passthrough() {
547        let small = "fn main() {}";
548        let config = SummaryConfig::default();
549        let result = summarize(small, "rust", "main.rs", &config, None).unwrap();
550
551        assert_eq!(result.summary, small);
552        assert_eq!(result.compression_ratio, 1.0);
553    }
554
555    #[test]
556    fn test_summarize_python() {
557        let source = r#"
558import os
559from typing import List, Dict
560from .utils import helper
561
562class UserService:
563    def __init__(self, db):
564        self.db = db
565
566    def create(self, name: str) -> Dict:
567        user = self.db.insert(name)
568        self.notify(user)
569        return user
570
571    def notify(self, user: Dict) -> None:
572        print(f"Created {user}")
573
574def standalone(x: int) -> int:
575    result = helper(x)
576    return result + 1
577"#;
578        let config = SummaryConfig {
579            min_file_size: 50,
580            ..Default::default()
581        };
582        let result = summarize(source, "python", "services/user.py", &config, None).unwrap();
583
584        assert!(result.summary.contains("## imports"));
585        assert!(result.summary.contains("## types") || result.summary.contains("class UserService"));
586        assert!(result.summary.contains("## functions"));
587        assert!(result.tokens_summary < result.tokens_original);
588    }
589
590    #[test]
591    fn test_summarize_javascript() {
592        let source = r#"
593import React from 'react';
594import { useState, useEffect } from 'react';
595import { fetchUsers } from './api';
596
597class UserList extends React.Component {
598    constructor(props) {
599        super(props);
600        this.state = { users: [] };
601    }
602
603    componentDidMount() {
604        fetchUsers().then(users => {
605            this.setState({ users });
606        });
607    }
608
609    render() {
610        return this.state.users.map(u => (
611            <div key={u.id}>{u.name}</div>
612        ));
613    }
614}
615
616function formatUser(user) {
617    const name = user.firstName + ' ' + user.lastName;
618    return { ...user, displayName: name };
619}
620
621export default UserList;
622"#;
623        let config = SummaryConfig::default();
624        let result = summarize(source, "javascript", "src/UserList.js", &config, None).unwrap();
625
626        assert!(result.summary.contains("## imports"));
627        assert!(result.functions_count > 0);
628        assert!(result.tokens_summary < result.tokens_original);
629    }
630
631    #[test]
632    fn test_summarize_with_dep_mapper() {
633        let mut mapper = DependencyMapper::new();
634        mapper.add_file(
635            std::path::Path::new("src/engine.rs"),
636            "use crate::pipeline;\nuse crate::cache;\n",
637        );
638        mapper.add_file(
639            std::path::Path::new("src/main.rs"),
640            "use crate::engine;\n",
641        );
642
643        let config = SummaryConfig::default();
644        let result = summarize(
645            RUST_SOURCE,
646            "rust",
647            "src/engine.rs",
648            &config,
649            Some(&mapper),
650        )
651        .unwrap();
652
653        assert!(result.summary.contains("## dependencies"));
654    }
655
656    #[test]
657    fn test_summarize_config_disable_calls() {
658        let config = SummaryConfig {
659            include_calls: false,
660            ..Default::default()
661        };
662        let result = summarize(RUST_SOURCE, "rust", "src/engine.rs", &config, None).unwrap();
663
664        assert!(
665            !result.summary.contains("\u{2192} calls:"),
666            "should not show calls when disabled"
667        );
668        assert_eq!(result.call_edges, 0);
669    }
670
671    #[test]
672    fn test_summarize_config_disable_imports() {
673        let config = SummaryConfig {
674            include_imports: false,
675            ..Default::default()
676        };
677        let result = summarize(RUST_SOURCE, "rust", "src/engine.rs", &config, None).unwrap();
678
679        assert!(
680            !result.summary.contains("## imports"),
681            "should not show imports when disabled"
682        );
683    }
684
685    #[test]
686    fn test_summarize_multi() {
687        let files: Vec<(&str, &str, &str)> = vec![
688            (RUST_SOURCE, "rust", "src/engine.rs"),
689            (
690                "use crate::engine;\n\nfn main() {\n    let e = SqzEngine::new();\n    e.compress(\"hello\");\n}\n",
691                "rust",
692                "src/main.rs",
693            ),
694        ];
695        let config = SummaryConfig {
696            min_file_size: 10,
697            ..Default::default()
698        };
699        let result = summarize_multi(&files, &config, None).unwrap();
700
701        assert!(result.summary.contains("src/engine.rs"));
702        assert!(result.summary.contains("src/main.rs"));
703        assert!(result.summary.contains("---")); // separator
704    }
705
706    #[test]
707    fn test_contains_call_word_boundary() {
708        assert!(contains_call("    cache.get(input)", "get"));
709        assert!(contains_call("result = compress(data)", "compress"));
710        assert!(!contains_call("decompressor(data)", "compress"));
711        assert!(!contains_call("get_all()", "get"));
712    }
713
714    #[test]
715    fn test_extract_call_graph_finds_known_calls() {
716        let parser = AstParser::new();
717        let summary = parser.extract_signatures(RUST_SOURCE, "rust").unwrap();
718        let graph = extract_call_graph(RUST_SOURCE, &summary);
719
720        // compress() should call verify()
721        if let Some(calls) = graph.get("compress") {
722            assert!(
723                calls.iter().any(|c| c == "verify"),
724                "compress should call verify, got: {:?}",
725                calls
726            );
727        }
728    }
729
730    #[test]
731    fn test_find_function_boundaries() {
732        let parser = AstParser::new();
733        let summary = parser.extract_signatures(RUST_SOURCE, "rust").unwrap();
734        let boundaries = find_function_boundaries(RUST_SOURCE, &summary);
735
736        assert!(
737            !boundaries.is_empty(),
738            "should find function boundaries"
739        );
740        // Each boundary should have start < end
741        for (name, start, end) in &boundaries {
742            assert!(
743                start < end,
744                "function {} should have start ({}) < end ({})",
745                name,
746                start,
747                end
748            );
749        }
750    }
751
752    #[test]
753    fn test_summarize_unsupported_language() {
754        let source = "some random content that is long enough to trigger summarization, we need at least 500 characters of content here to pass the minimum file size threshold so let me add more text to make this work properly and ensure we get a valid result back from the summarize function even for unsupported languages like COBOL or Fortran or whatever else might come through the pipeline in production use cases where we cannot predict what files will be processed by the engine and need graceful fallback behavior that does not crash or panic but instead returns a reasonable default result that the caller can work with safely and reliably in all circumstances";
755        let config = SummaryConfig::default();
756        let result = summarize(source, "cobol", "main.cob", &config, None).unwrap();
757
758        // Should still produce a result (with file header at minimum)
759        assert!(result.summary.contains("# file: main.cob"));
760    }
761}