1use std::collections::HashSet;
8use std::path::Path;
9
10use graphify_core::graph::KnowledgeGraph;
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13use tracing::info;
14
15#[derive(Debug, Error)]
17pub enum BenchmarkError {
18 #[error("IO error: {0}")]
19 Io(#[from] std::io::Error),
20
21 #[error("graph load error: {0}")]
22 GraphLoad(String),
23
24 #[error("serialization error: {0}")]
25 Serialization(#[from] serde_json::Error),
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct BenchmarkResult {
31 pub graph_nodes: usize,
32 pub graph_edges: usize,
33 pub graph_tokens: usize,
34 pub corpus_words: Option<usize>,
35 pub corpus_tokens: Option<usize>,
36 pub compression_ratio: Option<f64>,
37 pub community_count: usize,
38 pub sample_queries: Vec<QuerySample>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct QuerySample {
44 pub question: String,
45 pub context_tokens: usize,
46 pub full_corpus_tokens: usize,
47 pub reduction: f64,
48}
49
50const SAMPLE_QUESTIONS: &[&str] = &[
52 "What are the main components?",
53 "How does authentication work?",
54 "What are the key abstractions?",
55 "How do components communicate?",
56 "What are the entry points?",
57];
58
59fn estimate_tokens(text: &str) -> usize {
63 text.len().div_ceil(4)
64}
65
66fn tokens_from_words(words: usize) -> usize {
68 ((words as f64) * 1.3).ceil() as usize
69}
70
71fn simulate_query(graph: &KnowledgeGraph, question: &str) -> usize {
76 let terms: Vec<String> = question
77 .to_lowercase()
78 .split_whitespace()
79 .filter(|w| w.len() > 3) .map(String::from)
81 .collect();
82
83 let mut matched_nodes: Vec<(f64, String)> = Vec::new();
85 for node_id in graph.node_ids() {
86 if let Some(node) = graph.get_node(&node_id) {
87 let label_lower = node.label.to_lowercase();
88 let score: f64 = terms
89 .iter()
90 .map(|t| {
91 if label_lower.contains(t.as_str()) {
92 1.0
93 } else {
94 0.0
95 }
96 })
97 .sum();
98 if score > 0.0 {
99 matched_nodes.push((score, node_id));
100 }
101 }
102 }
103
104 matched_nodes.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
105
106 let top_nodes: Vec<String> = matched_nodes
108 .into_iter()
109 .take(5)
110 .map(|(_, id)| id)
111 .collect();
112
113 let mut context_parts: Vec<String> = Vec::new();
114 let mut seen = HashSet::new();
115
116 for node_id in &top_nodes {
117 if seen.insert(node_id.clone())
118 && let Some(node) = graph.get_node(node_id)
119 {
120 context_parts.push(format!(
121 "{} [{}] (type: {:?}, file: {})",
122 node.label, node.id, node.node_type, node.source_file
123 ));
124 }
125
126 for neighbor in graph.neighbor_ids(node_id) {
128 if seen.insert(neighbor.clone())
129 && let Some(node) = graph.get_node(&neighbor)
130 {
131 context_parts.push(format!(
132 " -> {} [{}] (type: {:?})",
133 node.label, node.id, node.node_type
134 ));
135 }
136 }
137 }
138
139 if context_parts.is_empty() {
141 let json = graph.to_node_link_json();
142 let total = estimate_tokens(&json.to_string());
143 return total / 5; }
145
146 let context_text = context_parts.join("\n");
147 estimate_tokens(&context_text)
148}
149
150pub fn run_benchmark(
156 graph_path: &Path,
157 corpus_words: Option<usize>,
158) -> Result<BenchmarkResult, BenchmarkError> {
159 let content = std::fs::read_to_string(graph_path)?;
161 let value: serde_json::Value = serde_json::from_str(&content)?;
162 let graph = KnowledgeGraph::from_node_link_json(&value)
163 .map_err(|e| BenchmarkError::GraphLoad(e.to_string()))?;
164
165 let graph_tokens = estimate_tokens(&content);
166 let corpus_tokens = corpus_words.map(tokens_from_words);
167
168 let compression_ratio = corpus_tokens.map(|ct| {
169 if graph_tokens > 0 {
170 ct as f64 / graph_tokens as f64
171 } else {
172 0.0
173 }
174 });
175
176 let full_corpus_tokens = corpus_tokens.unwrap_or(graph_tokens);
178 let sample_queries: Vec<QuerySample> = SAMPLE_QUESTIONS
179 .iter()
180 .map(|q| {
181 let context_tokens = simulate_query(&graph, q);
182 let reduction = if context_tokens > 0 {
183 full_corpus_tokens as f64 / context_tokens as f64
184 } else {
185 0.0
186 };
187 QuerySample {
188 question: q.to_string(),
189 context_tokens,
190 full_corpus_tokens,
191 reduction,
192 }
193 })
194 .collect();
195
196 let result = BenchmarkResult {
197 graph_nodes: graph.node_count(),
198 graph_edges: graph.edge_count(),
199 graph_tokens,
200 corpus_words,
201 corpus_tokens,
202 compression_ratio,
203 community_count: graph.communities.len(),
204 sample_queries,
205 };
206
207 info!(
208 "Benchmark complete: {} nodes, {} edges, {} tokens",
209 result.graph_nodes, result.graph_edges, result.graph_tokens
210 );
211
212 Ok(result)
213}
214
215pub fn print_benchmark(result: &BenchmarkResult) {
217 println!("=== graphify Benchmark ===");
218 println!();
219 println!(
220 "Graph: {} nodes, {} edges, {} communities",
221 result.graph_nodes, result.graph_edges, result.community_count
222 );
223 println!("Graph tokens: {}", result.graph_tokens);
224
225 if let Some(words) = result.corpus_words {
226 println!("Corpus words: {}", words);
227 }
228 if let Some(tokens) = result.corpus_tokens {
229 println!("Corpus tokens (est.): {}", tokens);
230 }
231 if let Some(ratio) = result.compression_ratio {
232 println!("Compression: {:.1}x", ratio);
233 }
234
235 println!();
236 println!("Sample queries:");
237 for q in &result.sample_queries {
238 println!(" Q: {}", q.question);
239 println!(
240 " Context: {} tokens (vs {} full) = {:.1}x reduction",
241 q.context_tokens, q.full_corpus_tokens, q.reduction
242 );
243 }
244}
245
246#[cfg(test)]
251mod tests {
252 use super::*;
253 use graphify_core::confidence::Confidence;
254 use graphify_core::model::{GraphEdge, GraphNode, NodeType};
255 use std::collections::HashMap;
256
257 fn make_node(id: &str, label: &str) -> GraphNode {
258 GraphNode {
259 id: id.into(),
260 label: label.into(),
261 source_file: "test.rs".into(),
262 source_location: None,
263 node_type: NodeType::Class,
264 community: None,
265 extra: HashMap::new(),
266 }
267 }
268
269 fn make_edge(src: &str, tgt: &str) -> GraphEdge {
270 GraphEdge {
271 source: src.into(),
272 target: tgt.into(),
273 relation: "calls".into(),
274 confidence: Confidence::Extracted,
275 confidence_score: 1.0,
276 source_file: "test.rs".into(),
277 source_location: None,
278 weight: 1.0,
279 extra: HashMap::new(),
280 }
281 }
282
283 #[test]
284 fn test_estimate_tokens() {
285 assert_eq!(estimate_tokens(""), 0);
286 assert_eq!(estimate_tokens("hello world"), 3); assert!(estimate_tokens(&"a".repeat(100)) >= 25);
288 }
289
290 #[test]
291 fn test_tokens_from_words() {
292 assert_eq!(tokens_from_words(100), 130);
293 assert_eq!(tokens_from_words(0), 0);
294 assert_eq!(tokens_from_words(1), 2); }
296
297 #[test]
298 fn test_simulate_query() {
299 let mut g = KnowledgeGraph::new();
300 g.add_node(make_node("auth", "AuthService")).unwrap();
301 g.add_node(make_node("user", "UserManager")).unwrap();
302 g.add_node(make_node("db", "Database")).unwrap();
303 g.add_edge(make_edge("auth", "user")).unwrap();
304 g.add_edge(make_edge("auth", "db")).unwrap();
305
306 let tokens = simulate_query(&g, "How does authentication work?");
307 assert!(tokens > 0, "Query should produce some context tokens");
308 }
309
310 #[test]
311 fn test_simulate_query_no_match() {
312 let mut g = KnowledgeGraph::new();
313 g.add_node(make_node("auth", "AuthService")).unwrap();
314
315 let tokens = simulate_query(&g, "zzzzz qqqqq");
316 assert!(
317 tokens > 0,
318 "Even with no matches, should return fallback tokens"
319 );
320 }
321
322 #[test]
323 fn test_run_benchmark_from_file() {
324 let mut g = KnowledgeGraph::new();
325 g.add_node(make_node("auth", "AuthService")).unwrap();
326 g.add_node(make_node("user", "UserManager")).unwrap();
327 g.add_node(make_node("db", "Database")).unwrap();
328 g.add_edge(make_edge("auth", "user")).unwrap();
329 g.add_edge(make_edge("user", "db")).unwrap();
330
331 let json = g.to_node_link_json();
332 let tmp = tempfile::NamedTempFile::new().unwrap();
333 std::fs::write(tmp.path(), serde_json::to_string_pretty(&json).unwrap()).unwrap();
334
335 let result = run_benchmark(tmp.path(), Some(10000)).unwrap();
336 assert_eq!(result.graph_nodes, 3);
337 assert_eq!(result.graph_edges, 2);
338 assert!(result.graph_tokens > 0);
339 assert_eq!(result.corpus_words, Some(10000));
340 assert_eq!(result.corpus_tokens, Some(13000));
341 assert!(result.compression_ratio.unwrap() > 0.0);
342 assert_eq!(result.sample_queries.len(), SAMPLE_QUESTIONS.len());
343 }
344
345 #[test]
346 fn test_run_benchmark_no_corpus() {
347 let mut g = KnowledgeGraph::new();
348 g.add_node(make_node("a", "Alpha")).unwrap();
349 let json = g.to_node_link_json();
350 let tmp = tempfile::NamedTempFile::new().unwrap();
351 std::fs::write(tmp.path(), serde_json::to_string(&json).unwrap()).unwrap();
352
353 let result = run_benchmark(tmp.path(), None).unwrap();
354 assert!(result.compression_ratio.is_none());
355 assert!(result.corpus_words.is_none());
356 }
357
358 #[test]
359 fn test_print_benchmark_no_panic() {
360 let result = BenchmarkResult {
361 graph_nodes: 10,
362 graph_edges: 15,
363 graph_tokens: 500,
364 corpus_words: Some(5000),
365 corpus_tokens: Some(6500),
366 compression_ratio: Some(13.0),
367 community_count: 3,
368 sample_queries: vec![QuerySample {
369 question: "Test?".to_string(),
370 context_tokens: 50,
371 full_corpus_tokens: 6500,
372 reduction: 130.0,
373 }],
374 };
375 print_benchmark(&result);
377 }
378}