Skip to main content

context_footprint/
cli.rs

1use crate::adapters::policy::academic::AcademicBaseline;
2use crate::adapters::test_detector::UniversalTestDetector;
3use crate::domain::graph::ContextGraph;
4use crate::domain::node::Node;
5use crate::domain::ports::SourceReader;
6use crate::domain::solver::CfSolver;
7use anyhow::Result;
8use petgraph::graph::NodeIndex;
9
10pub fn compute_cf_for_symbol(graph: &ContextGraph, symbol: &str) -> Result<()> {
11    println!("Computing CF for symbol: {}", symbol);
12
13    let node_idx = graph
14        .get_node_by_symbol(symbol)
15        .ok_or_else(|| anyhow::anyhow!("Symbol not found: {}", symbol))?;
16
17    let policy = AcademicBaseline::default();
18    let solver = CfSolver::new();
19    let result = solver.compute_cf(graph, node_idx, &policy, None);
20
21    println!("CF Result:");
22    println!("  Total context size: {} tokens", result.total_context_size);
23    println!("  Reachable nodes: {}", result.reachable_set.len());
24
25    Ok(())
26}
27
28pub fn display_top_cf_nodes(
29    graph: &ContextGraph,
30    limit: usize,
31    node_type: &str,
32    include_tests: bool,
33) -> Result<()> {
34    let policy = AcademicBaseline::default();
35    let solver = CfSolver::new();
36    let test_detector = UniversalTestDetector::new();
37
38    println!("Computing CF for all nodes...");
39    let mut cf_results: Vec<(String, u32, &str)> = Vec::new();
40
41    for (symbol, &node_idx) in &graph.symbol_to_node {
42        let node = graph.node(node_idx);
43
44        let type_str = match node {
45            Node::Function(_) => "function",
46            Node::Type(_) => "type",
47            Node::Variable(_) => "variable",
48        };
49
50        // Filter by node type if specified
51        if node_type != "all" && node_type != type_str {
52            continue;
53        }
54
55        // Filter out test code if requested (default is to exclude)
56        if !include_tests && test_detector.is_test_code(symbol, &node.core().file_path) {
57            continue;
58        }
59
60        let result = solver.compute_cf(graph, node_idx, &policy, None);
61        cf_results.push((symbol.clone(), result.total_context_size, type_str));
62    }
63
64    // Sort by CF (descending)
65    cf_results.sort_by(|a, b| b.1.cmp(&a.1));
66
67    let filter_msg = if !include_tests {
68        " (excluding tests)"
69    } else {
70        ""
71    };
72    println!("\nTop {} nodes by Context Footprint{}:", limit, filter_msg);
73    println!("{}", "=".repeat(80));
74
75    for (i, (symbol, cf, node_type)) in cf_results.iter().take(limit).enumerate() {
76        println!("{}. [{}] {} tokens", i + 1, node_type, cf);
77        println!("   {}", symbol);
78        println!();
79    }
80
81    Ok(())
82}
83
84pub fn search_symbols(
85    graph: &ContextGraph,
86    pattern: &str,
87    with_cf: bool,
88    limit: Option<usize>,
89    include_tests: bool,
90) -> Result<()> {
91    let policy = AcademicBaseline::default();
92    let solver = CfSolver::new();
93    let test_detector = UniversalTestDetector::new();
94
95    println!("Searching for symbols matching: \"{}\"", pattern);
96    println!("{}", "=".repeat(80));
97
98    let pattern_lower = pattern.to_lowercase();
99    let mut matches: Vec<(String, &str, u32)> = Vec::new();
100
101    for (symbol, &node_idx) in &graph.symbol_to_node {
102        let node = graph.node(node_idx);
103
104        let type_str = match node {
105            Node::Function(_) => "function",
106            Node::Type(_) => "type",
107            Node::Variable(_) => "variable",
108        };
109
110        // Simple substring match (case-insensitive)
111        if symbol.to_lowercase().contains(&pattern_lower) {
112            // Filter out test code if requested (default is to exclude)
113            if !include_tests && test_detector.is_test_code(symbol, &node.core().file_path) {
114                continue;
115            }
116
117            // Always compute CF for sorting, even if not displaying
118            let result = solver.compute_cf(graph, node_idx, &policy, None);
119            matches.push((symbol.clone(), type_str, result.total_context_size));
120        }
121    }
122
123    // Sort by CF (descending)
124    matches.sort_by(|a, b| b.2.cmp(&a.2));
125
126    // Apply limit if specified
127    let display_count = limit.unwrap_or(matches.len());
128    let matches_to_show = &matches[..matches.len().min(display_count)];
129
130    let filter_msg = if !include_tests {
131        " (excluding tests)"
132    } else {
133        ""
134    };
135    println!(
136        "Found {} matching symbol(s){}:\n",
137        matches.len(),
138        filter_msg
139    );
140
141    if let Some(lim) = limit.filter(|&lim| matches.len() > lim) {
142        println!("Showing top {} by CF:\n", lim);
143    }
144
145    for (i, (symbol, node_type, cf)) in matches_to_show.iter().enumerate() {
146        print!("{}. [{}] ", i + 1, node_type);
147        if with_cf || limit.is_some() {
148            print!("CF: {} tokens", cf);
149        }
150        println!("\n   {}", symbol);
151        println!();
152    }
153
154    Ok(())
155}
156
157fn symbol_is_parameter(graph: &ContextGraph, node_idx: NodeIndex) -> bool {
158    let symbol = graph
159        .symbol_to_node
160        .iter()
161        .find(|&(_, &idx)| idx == node_idx)
162        .map(|(s, _)| s.as_str())
163        .unwrap_or("");
164
165    // Python parameter pattern: .../func().(param)
166    if symbol.contains("().(") && symbol.ends_with(')') {
167        return true;
168    }
169
170    false
171}
172
173pub fn display_context_code(
174    graph: &ContextGraph,
175    symbol: &str,
176    _show_boundaries: bool,
177    source_reader: &dyn SourceReader,
178    project_root: &str,
179    max_tokens: Option<u32>,
180) -> Result<()> {
181    println!("Computing context for symbol: {}", symbol);
182
183    let node_idx = graph
184        .get_node_by_symbol(symbol)
185        .ok_or_else(|| anyhow::anyhow!("Symbol not found: {}", symbol))?;
186
187    let policy = AcademicBaseline::default();
188    let solver = CfSolver::new();
189    let result = solver.compute_cf(graph, node_idx, &policy, max_tokens);
190
191    println!("\nContext Summary:");
192    println!("  Total size: {} tokens", result.total_context_size);
193    println!("  Reachable nodes: {}", result.reachable_set.len());
194    if let Some(limit) = max_tokens {
195        println!("  Max tokens: {}", limit);
196    }
197    println!("{}", "=".repeat(80));
198
199    for (depth, layer) in result.reachable_nodes_by_layer.iter().enumerate() {
200        if layer.is_empty() {
201            continue;
202        }
203
204        println!(
205            "\n\u{1F310} Layer {}: {}",
206            depth,
207            if depth == 0 {
208                "Observed Symbol"
209            } else {
210                "Direct Dependencies"
211            }
212        );
213        println!("{}", "=".repeat(40));
214
215        // Group nodes by file within this layer
216        let mut files_map: std::collections::HashMap<String, Vec<NodeIndex>> =
217            std::collections::HashMap::new();
218
219        for &node_id in layer {
220            // Find the NodeIndex for this node_id
221            let idx = graph
222                .graph
223                .node_indices()
224                .find(|&idx| graph.node(idx).core().id == node_id)
225                .unwrap();
226
227            let node = graph.node(idx);
228            let file_path = &node.core().file_path;
229
230            files_map
231                .entry(file_path.clone())
232                .or_insert_with(Vec::new)
233                .push(idx);
234        }
235
236        // Sort files for consistent output within layer
237        let mut file_list: Vec<_> = files_map.iter().collect();
238        file_list.sort_by_key(|(path, _)| *path);
239
240        for (file_path, nodes) in file_list {
241            let full_path = std::path::Path::new(project_root).join(file_path);
242            let full_path_str = full_path.to_string_lossy();
243
244            println!("\n  \u{1F4C4} File: {}", file_path);
245
246            // Sort nodes by start line within file
247            let mut sorted_nodes = nodes.clone();
248            sorted_nodes.sort_by_key(|&idx| graph.node(idx).core().span.start_line);
249
250            // Filter out nodes that are contained within another node (e.g. nested functions)
251            // unless we want to see everything.
252            let mut top_level_nodes = Vec::new();
253            for &idx in &sorted_nodes {
254                let core = graph.node(idx).core();
255
256                let is_sub_node = symbol_is_parameter(graph, idx);
257
258                let is_contained = top_level_nodes.iter().any(|&prev_idx| {
259                    let prev_core = graph.node(prev_idx).core();
260                    core.span.start_line >= prev_core.span.start_line
261                        && core.span.end_line <= prev_core.span.end_line
262                        && idx != prev_idx
263                });
264
265                if !is_contained && !is_sub_node {
266                    top_level_nodes.push(idx);
267                }
268            }
269
270            for node_idx in top_level_nodes {
271                let node = graph.node(node_idx);
272                let core = node.core();
273
274                // Get the symbol for this node
275                let node_symbol = graph
276                    .symbol_to_node
277                    .iter()
278                    .find(|&(_, &idx)| idx == node_idx)
279                    .map(|(s, _)| s.as_str())
280                    .unwrap_or(&core.name);
281
282                println!(
283                    "    Symbol: {} ({} tokens)",
284                    node_symbol.split('/').next_back().unwrap_or(node_symbol),
285                    core.context_size
286                );
287                println!(
288                    "    Lines: {}-{}",
289                    core.span.start_line + 1,
290                    core.span.end_line + 1
291                );
292
293                // Read and display the code
294                match source_reader.read_lines(
295                    &full_path_str,
296                    core.span.start_line as usize,
297                    core.span.end_line as usize,
298                ) {
299                    Ok(lines) => {
300                        println!("    Code:");
301                        for (i, line) in lines.iter().enumerate() {
302                            let line_num = core.span.start_line as usize + i + 1;
303                            println!("      {:4} | {}", line_num, line);
304                        }
305                    }
306                    Err(e) => {
307                        println!("      [Error reading code: {}]", e);
308                    }
309                }
310            }
311        }
312    }
313
314    Ok(())
315}
316
317pub fn compute_and_display_cf_stats(graph: &ContextGraph, include_tests: bool) -> Result<()> {
318    let policy = AcademicBaseline::default();
319    let solver = CfSolver::new();
320    let test_detector = UniversalTestDetector::new();
321    let node_count = graph.graph.node_count();
322
323    let mut function_cf: Vec<u32> = Vec::new();
324    let mut type_cf: Vec<u32> = Vec::new();
325
326    let filter_msg = if !include_tests {
327        " (excluding tests)"
328    } else {
329        ""
330    };
331    println!("Calculating CF for {} nodes{}...", node_count, filter_msg);
332
333    for (idx, node_idx) in graph.graph.node_indices().enumerate() {
334        let node = graph.node(node_idx);
335
336        if !include_tests {
337            let symbol = graph
338                .symbol_to_node
339                .iter()
340                .find(|&(_, &i)| i == node_idx)
341                .map(|(s, _)| s.as_str())
342                .unwrap_or("");
343
344            if test_detector.is_test_code(symbol, &node.core().file_path) {
345                continue;
346            }
347        }
348
349        let result = solver.compute_cf(graph, node_idx, &policy, None);
350        let cf = result.total_context_size;
351
352        match node {
353            Node::Function(_) => function_cf.push(cf),
354            Node::Type(_) => type_cf.push(cf),
355            Node::Variable(_) => {}
356        }
357
358        if (idx + 1) % 1000 == 0 {
359            println!("  Processed {}/{} nodes...", idx + 1, node_count);
360        }
361    }
362
363    println!("\n{}", "=".repeat(60));
364    print_cf_distribution(&format!("Functions{}", filter_msg), &mut function_cf);
365    println!("{}", "=".repeat(60));
366    print_cf_distribution(&format!("Types{}", filter_msg), &mut type_cf);
367    println!("{}", "=".repeat(60));
368
369    Ok(())
370}
371
372fn print_cf_distribution(name: &str, sizes: &mut [u32]) {
373    if sizes.is_empty() {
374        println!("\n{}: No nodes found", name);
375        return;
376    }
377
378    sizes.sort_unstable();
379
380    println!("\n{} - Context Footprint Distribution:", name);
381    println!("  Total count: {}", sizes.len());
382
383    println!("\n  Percentiles:");
384    for i in (5..=100).step_by(5) {
385        let index = ((i * (sizes.len() - 1)) / 100).min(sizes.len() - 1);
386        println!("    {:>3}%: {:>8} tokens", i, sizes[index]);
387    }
388
389    let sum: u64 = sizes.iter().map(|&s| s as u64).sum();
390    let avg = sum / sizes.len() as u64;
391    let median_idx = sizes.len() / 2;
392    let median = sizes[median_idx];
393
394    println!("\n  Summary:");
395    println!("    Average: {:>8} tokens", avg);
396    println!("    Median:  {:>8} tokens", median);
397    println!("    Min:     {:>8} tokens", sizes[0]);
398    println!("    Max:     {:>8} tokens", sizes[sizes.len() - 1]);
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::domain::graph::ContextGraph;
405    use crate::domain::node::{FunctionNode, NodeCore, SourceSpan, TypeNode, Visibility};
406    use crate::domain::ports::SourceReader;
407    use std::path::Path;
408
409    struct MockReader;
410    impl SourceReader for MockReader {
411        fn read(&self, _path: &Path) -> Result<String> {
412            Ok("line1\nline2\nline3\nline4\n".into())
413        }
414        fn read_lines(&self, _path: &str, start: usize, end: usize) -> Result<Vec<String>> {
415            let lines = vec![
416                "line1".to_string(),
417                "line2".to_string(),
418                "line3".to_string(),
419                "line4".to_string(),
420            ];
421            Ok(lines[start..=end.min(lines.len() - 1)].to_vec())
422        }
423    }
424
425    fn create_test_graph() -> ContextGraph {
426        let mut graph = ContextGraph::new();
427        let core1 = NodeCore::new(
428            0,
429            "func1".into(),
430            None,
431            10,
432            SourceSpan {
433                start_line: 0,
434                start_column: 0,
435                end_line: 1,
436                end_column: 0,
437            },
438            1.0,
439            false,
440            "file1.py".into(),
441        );
442        let node1 = Node::Function(FunctionNode {
443            core: core1,
444            param_count: 0,
445            typed_param_count: 0,
446            has_return_type: false,
447            is_async: false,
448            is_generator: false,
449            visibility: Visibility::Public,
450        });
451        graph.add_node("sym/func1().".into(), node1);
452
453        let core2 = NodeCore::new(
454            1,
455            "Type1".into(),
456            None,
457            20,
458            SourceSpan {
459                start_line: 2,
460                start_column: 0,
461                end_line: 3,
462                end_column: 0,
463            },
464            0.8,
465            false,
466            "file1.py".into(),
467        );
468        let node2 = Node::Type(TypeNode {
469            core: core2,
470            type_kind: crate::domain::node::TypeKind::Class,
471            is_abstract: false,
472            type_param_count: 0,
473        });
474        graph.add_node("sym/Type1#".into(), node2);
475
476        let core3 = NodeCore::new(
477            2,
478            "test_func".into(),
479            None,
480            5,
481            SourceSpan {
482                start_line: 0,
483                start_column: 0,
484                end_line: 1,
485                end_column: 0,
486            },
487            1.0,
488            false,
489            "tests/test_file.py".into(),
490        );
491        let node3 = Node::Function(FunctionNode {
492            core: core3,
493            param_count: 0,
494            typed_param_count: 0,
495            has_return_type: false,
496            is_async: false,
497            is_generator: false,
498            visibility: Visibility::Public,
499        });
500        graph.add_node("sym/test_func().".into(), node3);
501
502        graph
503    }
504
505    #[test]
506    fn test_compute_cf_for_symbol_basic() {
507        let graph = create_test_graph();
508        assert!(compute_cf_for_symbol(&graph, "sym/func1().").is_ok());
509        assert!(compute_cf_for_symbol(&graph, "nonexistent").is_err());
510    }
511
512    #[test]
513    fn test_display_top_cf_nodes_filtering() {
514        let graph = create_test_graph();
515
516        // All nodes
517        assert!(display_top_cf_nodes(&graph, 10, "all", true).is_ok());
518
519        // Only functions
520        assert!(display_top_cf_nodes(&graph, 10, "function", true).is_ok());
521
522        // Only types
523        assert!(display_top_cf_nodes(&graph, 10, "type", true).is_ok());
524
525        // Exclude tests
526        assert!(display_top_cf_nodes(&graph, 10, "all", false).is_ok());
527    }
528
529    #[test]
530    fn test_search_symbols_variations() {
531        let graph = create_test_graph();
532
533        // Search with CF
534        assert!(search_symbols(&graph, "func", true, None, true).is_ok());
535
536        // Search with limit
537        assert!(search_symbols(&graph, "func", false, Some(1), true).is_ok());
538
539        // Search excluding tests
540        assert!(search_symbols(&graph, "test", false, None, false).is_ok());
541
542        // Search with no results
543        assert!(search_symbols(&graph, "zzz", false, None, true).is_ok());
544    }
545
546    #[test]
547    fn test_display_context_code_basic() {
548        let graph = create_test_graph();
549        let reader = MockReader;
550
551        assert!(
552            display_context_code(&graph, "sym/func1().", false, &reader, "/root", None).is_ok()
553        );
554        assert!(
555            display_context_code(&graph, "nonexistent", false, &reader, "/root", None).is_err()
556        );
557    }
558
559    #[test]
560    fn test_symbol_is_parameter_logic() {
561        let mut graph = ContextGraph::new();
562        let core = NodeCore::new(
563            0,
564            "param".into(),
565            None,
566            1,
567            SourceSpan {
568                start_line: 0,
569                start_column: 10,
570                end_line: 0,
571                end_column: 15,
572            },
573            0.0,
574            false,
575            "file.py".into(),
576        );
577        let node = Node::Variable(crate::domain::node::VariableNode {
578            core,
579            has_type_annotation: false,
580            mutability: crate::domain::node::Mutability::Immutable,
581            variable_kind: crate::domain::node::VariableKind::Global,
582        });
583
584        graph.add_node("sym/func().(param)".into(), node);
585        let idx = graph.get_node_by_symbol("sym/func().(param)").unwrap();
586
587        assert!(symbol_is_parameter(&graph, idx));
588
589        // Non-parameter
590        let core2 = NodeCore::new(
591            1,
592            "func".into(),
593            None,
594            10,
595            SourceSpan {
596                start_line: 0,
597                start_column: 0,
598                end_line: 1,
599                end_column: 0,
600            },
601            1.0,
602            false,
603            "file.py".into(),
604        );
605        let node2 = Node::Function(FunctionNode {
606            core: core2,
607            param_count: 0,
608            typed_param_count: 0,
609            has_return_type: false,
610            is_async: false,
611            is_generator: false,
612            visibility: Visibility::Public,
613        });
614        graph.add_node("sym/func().".into(), node2);
615        let idx2 = graph.get_node_by_symbol("sym/func().").unwrap();
616
617        assert!(!symbol_is_parameter(&graph, idx2));
618    }
619
620    #[test]
621    fn test_compute_and_display_cf_stats_variations() {
622        let graph = create_test_graph();
623        assert!(compute_and_display_cf_stats(&graph, true).is_ok());
624        assert!(compute_and_display_cf_stats(&graph, false).is_ok());
625    }
626
627    #[test]
628    fn test_print_cf_distribution_logic() {
629        let mut empty: Vec<u32> = vec![];
630        print_cf_distribution("Empty", &mut empty);
631
632        let mut data = vec![10, 20, 30, 40, 50, 100];
633        print_cf_distribution("Test", &mut data);
634    }
635}