Skip to main content

graphify_benchmark/
lib.rs

1//! Token efficiency benchmarking for graphify.
2//!
3//! Measures graph quality, compression ratio, and query performance to
4//! validate that the graph representation is efficient for LLM consumption.
5//! Port of Python `benchmark.py`.
6
7use std::collections::HashSet;
8use std::path::Path;
9
10use graphify_core::graph::KnowledgeGraph;
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13use tracing::info;
14
15/// Errors from the benchmark runner.
16#[derive(Debug, Error)]
17pub enum BenchmarkError {
18    #[error("IO error: {0}")]
19    Io(#[from] std::io::Error),
20
21    #[error("graph load error: {0}")]
22    GraphLoad(String),
23
24    #[error("serialization error: {0}")]
25    Serialization(#[from] serde_json::Error),
26}
27
28/// Benchmark result metrics.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct BenchmarkResult {
31    pub graph_nodes: usize,
32    pub graph_edges: usize,
33    pub graph_tokens: usize,
34    pub corpus_words: Option<usize>,
35    pub corpus_tokens: Option<usize>,
36    pub compression_ratio: Option<f64>,
37    pub community_count: usize,
38    pub sample_queries: Vec<QuerySample>,
39}
40
41/// A single sample query benchmark.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct QuerySample {
44    pub question: String,
45    pub context_tokens: usize,
46    pub full_corpus_tokens: usize,
47    pub reduction: f64,
48}
49
50/// Sample questions used for benchmarking query efficiency.
51const SAMPLE_QUESTIONS: &[&str] = &[
52    "What are the main components?",
53    "How does authentication work?",
54    "What are the key abstractions?",
55    "How do components communicate?",
56    "What are the entry points?",
57];
58
59/// Estimate the number of tokens in a string.
60///
61/// Uses the approximation: 1 token ≈ 4 characters.
62fn estimate_tokens(text: &str) -> usize {
63    text.len().div_ceil(4)
64}
65
66/// Estimate tokens from word count.
67fn tokens_from_words(words: usize) -> usize {
68    ((words as f64) * 1.3).ceil() as usize
69}
70
71/// Simulate a query against the graph and estimate context tokens needed.
72///
73/// For each query, we find matching nodes and gather their neighborhood,
74/// then measure how many tokens the resulting context would consume.
75fn simulate_query(graph: &KnowledgeGraph, question: &str) -> usize {
76    let terms: Vec<String> = question
77        .to_lowercase()
78        .split_whitespace()
79        .filter(|w| w.len() > 3) // skip short words
80        .map(String::from)
81        .collect();
82
83    // Score nodes by term overlap
84    let mut matched_nodes: Vec<(f64, String)> = Vec::new();
85    for node_id in graph.node_ids() {
86        if let Some(node) = graph.get_node(&node_id) {
87            let label_lower = node.label.to_lowercase();
88            let score: f64 = terms
89                .iter()
90                .map(|t| {
91                    if label_lower.contains(t.as_str()) {
92                        1.0
93                    } else {
94                        0.0
95                    }
96                })
97                .sum();
98            if score > 0.0 {
99                matched_nodes.push((score, node_id));
100            }
101        }
102    }
103
104    matched_nodes.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
105
106    // Take top-5 matches and their 1-hop neighbors
107    let top_nodes: Vec<String> = matched_nodes
108        .into_iter()
109        .take(5)
110        .map(|(_, id)| id)
111        .collect();
112
113    let mut context_parts: Vec<String> = Vec::new();
114    let mut seen = HashSet::new();
115
116    for node_id in &top_nodes {
117        if seen.insert(node_id.clone())
118            && let Some(node) = graph.get_node(node_id)
119        {
120            context_parts.push(format!(
121                "{} [{}] (type: {:?}, file: {})",
122                node.label, node.id, node.node_type, node.source_file
123            ));
124        }
125
126        // Add 1-hop neighbors
127        for neighbor in graph.neighbor_ids(node_id) {
128            if seen.insert(neighbor.clone())
129                && let Some(node) = graph.get_node(&neighbor)
130            {
131                context_parts.push(format!(
132                    "  -> {} [{}] (type: {:?})",
133                    node.label, node.id, node.node_type
134                ));
135            }
136        }
137    }
138
139    // If no matches, estimate a minimum context
140    if context_parts.is_empty() {
141        let json = graph.to_node_link_json();
142        let total = estimate_tokens(&json.to_string());
143        return total / 5; // ~20% of graph
144    }
145
146    let context_text = context_parts.join("\n");
147    estimate_tokens(&context_text)
148}
149
150/// Run the benchmark suite on the graph at `graph_path`.
151///
152/// # Arguments
153/// * `graph_path` - Path to the graph JSON file.
154/// * `corpus_words` - Optional word count of the original corpus for compression ratio.
155pub fn run_benchmark(
156    graph_path: &Path,
157    corpus_words: Option<usize>,
158) -> Result<BenchmarkResult, BenchmarkError> {
159    // Load graph JSON
160    let content = std::fs::read_to_string(graph_path)?;
161    let value: serde_json::Value = serde_json::from_str(&content)?;
162    let graph = KnowledgeGraph::from_node_link_json(&value)
163        .map_err(|e| BenchmarkError::GraphLoad(e.to_string()))?;
164
165    let graph_tokens = estimate_tokens(&content);
166    let corpus_tokens = corpus_words.map(tokens_from_words);
167
168    let compression_ratio = corpus_tokens.map(|ct| {
169        if graph_tokens > 0 {
170            ct as f64 / graph_tokens as f64
171        } else {
172            0.0
173        }
174    });
175
176    // Run sample queries
177    let full_corpus_tokens = corpus_tokens.unwrap_or(graph_tokens);
178    let sample_queries: Vec<QuerySample> = SAMPLE_QUESTIONS
179        .iter()
180        .map(|q| {
181            let context_tokens = simulate_query(&graph, q);
182            let reduction = if context_tokens > 0 {
183                full_corpus_tokens as f64 / context_tokens as f64
184            } else {
185                0.0
186            };
187            QuerySample {
188                question: q.to_string(),
189                context_tokens,
190                full_corpus_tokens,
191                reduction,
192            }
193        })
194        .collect();
195
196    let result = BenchmarkResult {
197        graph_nodes: graph.node_count(),
198        graph_edges: graph.edge_count(),
199        graph_tokens,
200        corpus_words,
201        corpus_tokens,
202        compression_ratio,
203        community_count: graph.communities.len(),
204        sample_queries,
205    };
206
207    info!(
208        "Benchmark complete: {} nodes, {} edges, {} tokens",
209        result.graph_nodes, result.graph_edges, result.graph_tokens
210    );
211
212    Ok(result)
213}
214
215/// Print a human-readable benchmark report.
216pub fn print_benchmark(result: &BenchmarkResult) {
217    println!("=== graphify Benchmark ===");
218    println!();
219    println!(
220        "Graph: {} nodes, {} edges, {} communities",
221        result.graph_nodes, result.graph_edges, result.community_count
222    );
223    println!("Graph tokens: {}", result.graph_tokens);
224
225    if let Some(words) = result.corpus_words {
226        println!("Corpus words: {}", words);
227    }
228    if let Some(tokens) = result.corpus_tokens {
229        println!("Corpus tokens (est.): {}", tokens);
230    }
231    if let Some(ratio) = result.compression_ratio {
232        println!("Compression: {:.1}x", ratio);
233    }
234
235    println!();
236    println!("Sample queries:");
237    for q in &result.sample_queries {
238        println!("  Q: {}", q.question);
239        println!(
240            "    Context: {} tokens (vs {} full) = {:.1}x reduction",
241            q.context_tokens, q.full_corpus_tokens, q.reduction
242        );
243    }
244}
245
246// ---------------------------------------------------------------------------
247// Tests
248// ---------------------------------------------------------------------------
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use graphify_core::confidence::Confidence;
254    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
255    use std::collections::HashMap;
256
257    fn make_node(id: &str, label: &str) -> GraphNode {
258        GraphNode {
259            id: id.into(),
260            label: label.into(),
261            source_file: "test.rs".into(),
262            source_location: None,
263            node_type: NodeType::Class,
264            community: None,
265            extra: HashMap::new(),
266        }
267    }
268
269    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
270        GraphEdge {
271            source: src.into(),
272            target: tgt.into(),
273            relation: "calls".into(),
274            confidence: Confidence::Extracted,
275            confidence_score: 1.0,
276            source_file: "test.rs".into(),
277            source_location: None,
278            weight: 1.0,
279            extra: HashMap::new(),
280        }
281    }
282
283    #[test]
284    fn test_estimate_tokens() {
285        assert_eq!(estimate_tokens(""), 0);
286        assert_eq!(estimate_tokens("hello world"), 3); // (11+3)/4 = 3
287        assert!(estimate_tokens(&"a".repeat(100)) >= 25);
288    }
289
290    #[test]
291    fn test_tokens_from_words() {
292        assert_eq!(tokens_from_words(100), 130);
293        assert_eq!(tokens_from_words(0), 0);
294        assert_eq!(tokens_from_words(1), 2); // ceil(1.3)
295    }
296
297    #[test]
298    fn test_simulate_query() {
299        let mut g = KnowledgeGraph::new();
300        g.add_node(make_node("auth", "AuthService")).unwrap();
301        g.add_node(make_node("user", "UserManager")).unwrap();
302        g.add_node(make_node("db", "Database")).unwrap();
303        g.add_edge(make_edge("auth", "user")).unwrap();
304        g.add_edge(make_edge("auth", "db")).unwrap();
305
306        let tokens = simulate_query(&g, "How does authentication work?");
307        assert!(tokens > 0, "Query should produce some context tokens");
308    }
309
310    #[test]
311    fn test_simulate_query_no_match() {
312        let mut g = KnowledgeGraph::new();
313        g.add_node(make_node("auth", "AuthService")).unwrap();
314
315        let tokens = simulate_query(&g, "zzzzz qqqqq");
316        assert!(
317            tokens > 0,
318            "Even with no matches, should return fallback tokens"
319        );
320    }
321
322    #[test]
323    fn test_run_benchmark_from_file() {
324        let mut g = KnowledgeGraph::new();
325        g.add_node(make_node("auth", "AuthService")).unwrap();
326        g.add_node(make_node("user", "UserManager")).unwrap();
327        g.add_node(make_node("db", "Database")).unwrap();
328        g.add_edge(make_edge("auth", "user")).unwrap();
329        g.add_edge(make_edge("user", "db")).unwrap();
330
331        let json = g.to_node_link_json();
332        let tmp = tempfile::NamedTempFile::new().unwrap();
333        std::fs::write(tmp.path(), serde_json::to_string_pretty(&json).unwrap()).unwrap();
334
335        let result = run_benchmark(tmp.path(), Some(10000)).unwrap();
336        assert_eq!(result.graph_nodes, 3);
337        assert_eq!(result.graph_edges, 2);
338        assert!(result.graph_tokens > 0);
339        assert_eq!(result.corpus_words, Some(10000));
340        assert_eq!(result.corpus_tokens, Some(13000));
341        assert!(result.compression_ratio.unwrap() > 0.0);
342        assert_eq!(result.sample_queries.len(), SAMPLE_QUESTIONS.len());
343    }
344
345    #[test]
346    fn test_run_benchmark_no_corpus() {
347        let mut g = KnowledgeGraph::new();
348        g.add_node(make_node("a", "Alpha")).unwrap();
349        let json = g.to_node_link_json();
350        let tmp = tempfile::NamedTempFile::new().unwrap();
351        std::fs::write(tmp.path(), serde_json::to_string(&json).unwrap()).unwrap();
352
353        let result = run_benchmark(tmp.path(), None).unwrap();
354        assert!(result.compression_ratio.is_none());
355        assert!(result.corpus_words.is_none());
356    }
357
358    #[test]
359    fn test_print_benchmark_no_panic() {
360        let result = BenchmarkResult {
361            graph_nodes: 10,
362            graph_edges: 15,
363            graph_tokens: 500,
364            corpus_words: Some(5000),
365            corpus_tokens: Some(6500),
366            compression_ratio: Some(13.0),
367            community_count: 3,
368            sample_queries: vec![QuerySample {
369                question: "Test?".to_string(),
370                context_tokens: 50,
371                full_corpus_tokens: 6500,
372                reduction: 130.0,
373            }],
374        };
375        // Should not panic
376        print_benchmark(&result);
377    }
378}