Skip to main content

graphify_benchmark/
lib.rs

1//! Token efficiency benchmarking for graphify.
2//!
3//! Measures graph quality, compression ratio, and query performance to
4//! validate that the graph representation is efficient for LLM consumption.
5//! Port of Python `benchmark.py`.
6
7use std::collections::HashSet;
8use std::path::Path;
9
10use graphify_core::graph::KnowledgeGraph;
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13use tracing::info;
14
15/// Errors from the benchmark runner.
16#[derive(Debug, Error)]
17pub enum BenchmarkError {
18    #[error("IO error: {0}")]
19    Io(#[from] std::io::Error),
20
21    #[error("graph load error: {0}")]
22    GraphLoad(String),
23
24    #[error("serialization error: {0}")]
25    Serialization(#[from] serde_json::Error),
26}
27
28/// Benchmark result metrics.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct BenchmarkResult {
31    pub graph_nodes: usize,
32    pub graph_edges: usize,
33    pub graph_tokens: usize,
34    pub corpus_words: Option<usize>,
35    pub corpus_tokens: Option<usize>,
36    pub compression_ratio: Option<f64>,
37    pub community_count: usize,
38    pub sample_queries: Vec<QuerySample>,
39}
40
41/// A single sample query benchmark.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct QuerySample {
44    pub question: String,
45    pub context_tokens: usize,
46    pub full_corpus_tokens: usize,
47    pub reduction: f64,
48}
49
50/// Sample questions used for benchmarking query efficiency.
51const SAMPLE_QUESTIONS: &[&str] = &[
52    "What are the main components?",
53    "How does authentication work?",
54    "What are the key abstractions?",
55    "How do components communicate?",
56    "What are the entry points?",
57];
58
59/// Estimate the number of tokens in a string.
60///
61/// Uses the approximation: 1 token ≈ 4 characters.
62fn estimate_tokens(text: &str) -> usize {
63    text.len().div_ceil(4)
64}
65
66/// Estimate tokens from word count.
67fn tokens_from_words(words: usize) -> usize {
68    ((words as f64) * 1.3).ceil() as usize
69}
70
71/// Simulate a query against the graph and estimate context tokens needed.
72///
73/// For each query, we find matching nodes and gather their neighborhood,
74/// then measure how many tokens the resulting context would consume.
75fn simulate_query(graph: &KnowledgeGraph, question: &str) -> usize {
76    let terms: Vec<String> = question
77        .to_lowercase()
78        .split_whitespace()
79        .filter(|w| w.len() > 3) // skip short words
80        .map(String::from)
81        .collect();
82
83    let mut matched_nodes: Vec<(f64, String)> = Vec::new();
84    for node_id in graph.node_ids() {
85        if let Some(node) = graph.get_node(&node_id) {
86            let label_lower = node.label.to_lowercase();
87            let score: f64 = terms
88                .iter()
89                .map(|t| {
90                    if label_lower.contains(t.as_str()) {
91                        1.0
92                    } else {
93                        0.0
94                    }
95                })
96                .sum();
97            if score > 0.0 {
98                matched_nodes.push((score, node_id));
99            }
100        }
101    }
102
103    matched_nodes.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
104
105    let top_nodes: Vec<String> = matched_nodes
106        .into_iter()
107        .take(5)
108        .map(|(_, id)| id)
109        .collect();
110
111    let mut context_parts: Vec<String> = Vec::new();
112    let mut seen = HashSet::new();
113
114    for node_id in &top_nodes {
115        if seen.insert(node_id.clone())
116            && let Some(node) = graph.get_node(node_id)
117        {
118            context_parts.push(format!(
119                "{} [{}] (type: {:?}, file: {})",
120                node.label, node.id, node.node_type, node.source_file
121            ));
122        }
123
124        for neighbor in graph.neighbor_ids(node_id) {
125            if seen.insert(neighbor.clone())
126                && let Some(node) = graph.get_node(&neighbor)
127            {
128                context_parts.push(format!(
129                    "  -> {} [{}] (type: {:?})",
130                    node.label, node.id, node.node_type
131                ));
132            }
133        }
134    }
135
136    if context_parts.is_empty() {
137        let json = graph.to_node_link_json();
138        let total = estimate_tokens(&json.to_string());
139        return total / 5; // ~20% of graph
140    }
141
142    let context_text = context_parts.join("\n");
143    estimate_tokens(&context_text)
144}
145
146/// Run the benchmark suite on the graph at `graph_path`.
147///
148/// # Arguments
149/// * `graph_path` - Path to the graph JSON file.
150/// * `corpus_words` - Optional word count of the original corpus for compression ratio.
151pub fn run_benchmark(
152    graph_path: &Path,
153    corpus_words: Option<usize>,
154) -> Result<BenchmarkResult, BenchmarkError> {
155    let content = std::fs::read_to_string(graph_path)?;
156    let value: serde_json::Value = serde_json::from_str(&content)?;
157    let graph = KnowledgeGraph::from_node_link_json(&value)
158        .map_err(|e| BenchmarkError::GraphLoad(e.to_string()))?;
159
160    let graph_tokens = estimate_tokens(&content);
161    let corpus_tokens = corpus_words.map(tokens_from_words);
162
163    let compression_ratio = corpus_tokens.map(|ct| {
164        if graph_tokens > 0 {
165            ct as f64 / graph_tokens as f64
166        } else {
167            0.0
168        }
169    });
170
171    let full_corpus_tokens = corpus_tokens.unwrap_or(graph_tokens);
172    let sample_queries: Vec<QuerySample> = SAMPLE_QUESTIONS
173        .iter()
174        .map(|q| {
175            let context_tokens = simulate_query(&graph, q);
176            let reduction = if context_tokens > 0 {
177                full_corpus_tokens as f64 / context_tokens as f64
178            } else {
179                0.0
180            };
181            QuerySample {
182                question: q.to_string(),
183                context_tokens,
184                full_corpus_tokens,
185                reduction,
186            }
187        })
188        .collect();
189
190    let result = BenchmarkResult {
191        graph_nodes: graph.node_count(),
192        graph_edges: graph.edge_count(),
193        graph_tokens,
194        corpus_words,
195        corpus_tokens,
196        compression_ratio,
197        community_count: graph.communities.len(),
198        sample_queries,
199    };
200
201    info!(
202        "Benchmark complete: {} nodes, {} edges, {} tokens",
203        result.graph_nodes, result.graph_edges, result.graph_tokens
204    );
205
206    Ok(result)
207}
208
209/// Print a human-readable benchmark report.
210pub fn print_benchmark(result: &BenchmarkResult) {
211    println!("=== graphify-rs Benchmark ===");
212    println!();
213    println!(
214        "Graph: {} nodes, {} edges, {} communities",
215        result.graph_nodes, result.graph_edges, result.community_count
216    );
217    println!("Graph tokens: {}", result.graph_tokens);
218
219    if let Some(words) = result.corpus_words {
220        println!("Corpus words: {words}");
221    }
222    if let Some(tokens) = result.corpus_tokens {
223        println!("Corpus tokens (est.): {tokens}");
224    }
225    if let Some(ratio) = result.compression_ratio {
226        println!("Compression: {ratio:.1}x");
227    }
228
229    println!();
230    println!("Sample queries:");
231    for q in &result.sample_queries {
232        println!("  Q: {}", q.question);
233        println!(
234            "    Context: {} tokens (vs {} full) = {:.1}x reduction",
235            q.context_tokens, q.full_corpus_tokens, q.reduction
236        );
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use graphify_core::confidence::Confidence;
244    use graphify_core::model::{GraphEdge, GraphNode, NodeType};
245    use std::collections::HashMap;
246
247    fn make_node(id: &str, label: &str) -> GraphNode {
248        GraphNode {
249            id: id.into(),
250            label: label.into(),
251            source_file: "test.rs".into(),
252            source_location: None,
253            node_type: NodeType::Class,
254            community: None,
255            extra: HashMap::new(),
256        }
257    }
258
259    fn make_edge(src: &str, tgt: &str) -> GraphEdge {
260        GraphEdge {
261            source: src.into(),
262            target: tgt.into(),
263            relation: "calls".into(),
264            confidence: Confidence::Extracted,
265            confidence_score: 1.0,
266            source_file: "test.rs".into(),
267            source_location: None,
268            weight: 1.0,
269            extra: HashMap::new(),
270        }
271    }
272
273    #[test]
274    fn test_estimate_tokens() {
275        assert_eq!(estimate_tokens(""), 0);
276        assert_eq!(estimate_tokens("hello world"), 3); // (11+3)/4 = 3
277        assert!(estimate_tokens(&"a".repeat(100)) >= 25);
278    }
279
280    #[test]
281    fn test_tokens_from_words() {
282        assert_eq!(tokens_from_words(100), 130);
283        assert_eq!(tokens_from_words(0), 0);
284        assert_eq!(tokens_from_words(1), 2); // ceil(1.3)
285    }
286
287    #[test]
288    fn test_simulate_query() {
289        let mut g = KnowledgeGraph::new();
290        g.add_node(make_node("auth", "AuthService")).unwrap();
291        g.add_node(make_node("user", "UserManager")).unwrap();
292        g.add_node(make_node("db", "Database")).unwrap();
293        g.add_edge(make_edge("auth", "user")).unwrap();
294        g.add_edge(make_edge("auth", "db")).unwrap();
295
296        let tokens = simulate_query(&g, "How does authentication work?");
297        assert!(tokens > 0, "Query should produce some context tokens");
298    }
299
300    #[test]
301    fn test_simulate_query_no_match() {
302        let mut g = KnowledgeGraph::new();
303        g.add_node(make_node("auth", "AuthService")).unwrap();
304
305        let tokens = simulate_query(&g, "zzzzz qqqqq");
306        assert!(
307            tokens > 0,
308            "Even with no matches, should return fallback tokens"
309        );
310    }
311
312    #[test]
313    fn test_run_benchmark_from_file() {
314        let mut g = KnowledgeGraph::new();
315        g.add_node(make_node("auth", "AuthService")).unwrap();
316        g.add_node(make_node("user", "UserManager")).unwrap();
317        g.add_node(make_node("db", "Database")).unwrap();
318        g.add_edge(make_edge("auth", "user")).unwrap();
319        g.add_edge(make_edge("user", "db")).unwrap();
320
321        let json = g.to_node_link_json();
322        let tmp = tempfile::NamedTempFile::new().unwrap();
323        std::fs::write(tmp.path(), serde_json::to_string_pretty(&json).unwrap()).unwrap();
324
325        let result = run_benchmark(tmp.path(), Some(10000)).unwrap();
326        assert_eq!(result.graph_nodes, 3);
327        assert_eq!(result.graph_edges, 2);
328        assert!(result.graph_tokens > 0);
329        assert_eq!(result.corpus_words, Some(10000));
330        assert_eq!(result.corpus_tokens, Some(13000));
331        assert!(result.compression_ratio.unwrap() > 0.0);
332        assert_eq!(result.sample_queries.len(), SAMPLE_QUESTIONS.len());
333    }
334
335    #[test]
336    fn test_run_benchmark_no_corpus() {
337        let mut g = KnowledgeGraph::new();
338        g.add_node(make_node("a", "Alpha")).unwrap();
339        let json = g.to_node_link_json();
340        let tmp = tempfile::NamedTempFile::new().unwrap();
341        std::fs::write(tmp.path(), serde_json::to_string(&json).unwrap()).unwrap();
342
343        let result = run_benchmark(tmp.path(), None).unwrap();
344        assert!(result.compression_ratio.is_none());
345        assert!(result.corpus_words.is_none());
346    }
347
348    #[test]
349    fn test_print_benchmark_no_panic() {
350        let result = BenchmarkResult {
351            graph_nodes: 10,
352            graph_edges: 15,
353            graph_tokens: 500,
354            corpus_words: Some(5000),
355            corpus_tokens: Some(6500),
356            compression_ratio: Some(13.0),
357            community_count: 3,
358            sample_queries: vec![QuerySample {
359                question: "Test?".to_string(),
360                context_tokens: 50,
361                full_corpus_tokens: 6500,
362                reduction: 130.0,
363            }],
364        };
365        print_benchmark(&result);
366    }
367}