use std::collections::HashSet;
use std::path::Path;
use graphify_core::graph::KnowledgeGraph;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::info;
#[derive(Debug, Error)]
pub enum BenchmarkError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("graph load error: {0}")]
GraphLoad(String),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkResult {
pub graph_nodes: usize,
pub graph_edges: usize,
pub graph_tokens: usize,
pub corpus_words: Option<usize>,
pub corpus_tokens: Option<usize>,
pub compression_ratio: Option<f64>,
pub community_count: usize,
pub sample_queries: Vec<QuerySample>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuerySample {
pub question: String,
pub context_tokens: usize,
pub full_corpus_tokens: usize,
pub reduction: f64,
}
const SAMPLE_QUESTIONS: &[&str] = &[
"What are the main components?",
"How does authentication work?",
"What are the key abstractions?",
"How do components communicate?",
"What are the entry points?",
];
fn estimate_tokens(text: &str) -> usize {
text.len().div_ceil(4)
}
fn tokens_from_words(words: usize) -> usize {
((words as f64) * 1.3).ceil() as usize
}
fn simulate_query(graph: &KnowledgeGraph, question: &str) -> usize {
let terms: Vec<String> = question
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3) .map(String::from)
.collect();
let mut matched_nodes: Vec<(f64, String)> = Vec::new();
for node_id in graph.node_ids() {
if let Some(node) = graph.get_node(&node_id) {
let label_lower = node.label.to_lowercase();
let score: f64 = terms
.iter()
.map(|t| {
if label_lower.contains(t.as_str()) {
1.0
} else {
0.0
}
})
.sum();
if score > 0.0 {
matched_nodes.push((score, node_id));
}
}
}
matched_nodes.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let top_nodes: Vec<String> = matched_nodes
.into_iter()
.take(5)
.map(|(_, id)| id)
.collect();
let mut context_parts: Vec<String> = Vec::new();
let mut seen = HashSet::new();
for node_id in &top_nodes {
if seen.insert(node_id.clone())
&& let Some(node) = graph.get_node(node_id)
{
context_parts.push(format!(
"{} [{}] (type: {:?}, file: {})",
node.label, node.id, node.node_type, node.source_file
));
}
for neighbor in graph.neighbor_ids(node_id) {
if seen.insert(neighbor.clone())
&& let Some(node) = graph.get_node(&neighbor)
{
context_parts.push(format!(
" -> {} [{}] (type: {:?})",
node.label, node.id, node.node_type
));
}
}
}
if context_parts.is_empty() {
let json = graph.to_node_link_json();
let total = estimate_tokens(&json.to_string());
return total / 5; }
let context_text = context_parts.join("\n");
estimate_tokens(&context_text)
}
pub fn run_benchmark(
graph_path: &Path,
corpus_words: Option<usize>,
) -> Result<BenchmarkResult, BenchmarkError> {
let content = std::fs::read_to_string(graph_path)?;
let value: serde_json::Value = serde_json::from_str(&content)?;
let graph = KnowledgeGraph::from_node_link_json(&value)
.map_err(|e| BenchmarkError::GraphLoad(e.to_string()))?;
let graph_tokens = estimate_tokens(&content);
let corpus_tokens = corpus_words.map(tokens_from_words);
let compression_ratio = corpus_tokens.map(|ct| {
if graph_tokens > 0 {
ct as f64 / graph_tokens as f64
} else {
0.0
}
});
let full_corpus_tokens = corpus_tokens.unwrap_or(graph_tokens);
let sample_queries: Vec<QuerySample> = SAMPLE_QUESTIONS
.iter()
.map(|q| {
let context_tokens = simulate_query(&graph, q);
let reduction = if context_tokens > 0 {
full_corpus_tokens as f64 / context_tokens as f64
} else {
0.0
};
QuerySample {
question: q.to_string(),
context_tokens,
full_corpus_tokens,
reduction,
}
})
.collect();
let result = BenchmarkResult {
graph_nodes: graph.node_count(),
graph_edges: graph.edge_count(),
graph_tokens,
corpus_words,
corpus_tokens,
compression_ratio,
community_count: graph.communities.len(),
sample_queries,
};
info!(
"Benchmark complete: {} nodes, {} edges, {} tokens",
result.graph_nodes, result.graph_edges, result.graph_tokens
);
Ok(result)
}
pub fn print_benchmark(result: &BenchmarkResult) {
println!("=== graphify-rs Benchmark ===");
println!();
println!(
"Graph: {} nodes, {} edges, {} communities",
result.graph_nodes, result.graph_edges, result.community_count
);
println!("Graph tokens: {}", result.graph_tokens);
if let Some(words) = result.corpus_words {
println!("Corpus words: {}", words);
}
if let Some(tokens) = result.corpus_tokens {
println!("Corpus tokens (est.): {}", tokens);
}
if let Some(ratio) = result.compression_ratio {
println!("Compression: {:.1}x", ratio);
}
println!();
println!("Sample queries:");
for q in &result.sample_queries {
println!(" Q: {}", q.question);
println!(
" Context: {} tokens (vs {} full) = {:.1}x reduction",
q.context_tokens, q.full_corpus_tokens, q.reduction
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use graphify_core::confidence::Confidence;
use graphify_core::model::{GraphEdge, GraphNode, NodeType};
use std::collections::HashMap;
fn make_node(id: &str, label: &str) -> GraphNode {
GraphNode {
id: id.into(),
label: label.into(),
source_file: "test.rs".into(),
source_location: None,
node_type: NodeType::Class,
community: None,
extra: HashMap::new(),
}
}
fn make_edge(src: &str, tgt: &str) -> GraphEdge {
GraphEdge {
source: src.into(),
target: tgt.into(),
relation: "calls".into(),
confidence: Confidence::Extracted,
confidence_score: 1.0,
source_file: "test.rs".into(),
source_location: None,
weight: 1.0,
extra: HashMap::new(),
}
}
#[test]
fn test_estimate_tokens() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("hello world"), 3); assert!(estimate_tokens(&"a".repeat(100)) >= 25);
}
#[test]
fn test_tokens_from_words() {
assert_eq!(tokens_from_words(100), 130);
assert_eq!(tokens_from_words(0), 0);
assert_eq!(tokens_from_words(1), 2); }
#[test]
fn test_simulate_query() {
let mut g = KnowledgeGraph::new();
g.add_node(make_node("auth", "AuthService")).unwrap();
g.add_node(make_node("user", "UserManager")).unwrap();
g.add_node(make_node("db", "Database")).unwrap();
g.add_edge(make_edge("auth", "user")).unwrap();
g.add_edge(make_edge("auth", "db")).unwrap();
let tokens = simulate_query(&g, "How does authentication work?");
assert!(tokens > 0, "Query should produce some context tokens");
}
#[test]
fn test_simulate_query_no_match() {
let mut g = KnowledgeGraph::new();
g.add_node(make_node("auth", "AuthService")).unwrap();
let tokens = simulate_query(&g, "zzzzz qqqqq");
assert!(
tokens > 0,
"Even with no matches, should return fallback tokens"
);
}
#[test]
fn test_run_benchmark_from_file() {
let mut g = KnowledgeGraph::new();
g.add_node(make_node("auth", "AuthService")).unwrap();
g.add_node(make_node("user", "UserManager")).unwrap();
g.add_node(make_node("db", "Database")).unwrap();
g.add_edge(make_edge("auth", "user")).unwrap();
g.add_edge(make_edge("user", "db")).unwrap();
let json = g.to_node_link_json();
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), serde_json::to_string_pretty(&json).unwrap()).unwrap();
let result = run_benchmark(tmp.path(), Some(10000)).unwrap();
assert_eq!(result.graph_nodes, 3);
assert_eq!(result.graph_edges, 2);
assert!(result.graph_tokens > 0);
assert_eq!(result.corpus_words, Some(10000));
assert_eq!(result.corpus_tokens, Some(13000));
assert!(result.compression_ratio.unwrap() > 0.0);
assert_eq!(result.sample_queries.len(), SAMPLE_QUESTIONS.len());
}
#[test]
fn test_run_benchmark_no_corpus() {
let mut g = KnowledgeGraph::new();
g.add_node(make_node("a", "Alpha")).unwrap();
let json = g.to_node_link_json();
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), serde_json::to_string(&json).unwrap()).unwrap();
let result = run_benchmark(tmp.path(), None).unwrap();
assert!(result.compression_ratio.is_none());
assert!(result.corpus_words.is_none());
}
#[test]
fn test_print_benchmark_no_panic() {
let result = BenchmarkResult {
graph_nodes: 10,
graph_edges: 15,
graph_tokens: 500,
corpus_words: Some(5000),
corpus_tokens: Some(6500),
compression_ratio: Some(13.0),
community_count: 3,
sample_queries: vec![QuerySample {
question: "Test?".to_string(),
context_tokens: 50,
full_corpus_tokens: 6500,
reduction: 130.0,
}],
};
print_benchmark(&result);
}
}