Skip to main content

graphrag_cli/handlers/
bench.rs

1use color_eyre::eyre::{eyre, Result};
2use graphrag_core::GraphRAG;
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use std::time::Instant;
6use tracing::info;
7
8#[derive(Serialize, Deserialize)]
9struct BenchResult {
10    config_file: String,
11    book_file: String,
12    timing: BenchTiming,
13    stats: BenchStats,
14    questions_and_answers: Vec<QAItem>,
15}
16
17#[derive(Serialize, Deserialize)]
18struct BenchTiming {
19    init_ms: u128,
20    build_ms: u128,
21    total_query_ms: u128,
22    total_ms: u128,
23}
24
25#[derive(Serialize, Deserialize, Default)]
26struct BenchStats {
27    entities: usize,
28    relationships: usize,
29    chunks: usize,
30}
31
32#[derive(Serialize, Deserialize)]
33struct QAItem {
34    index: usize,
35    question: String,
36    answer: String,
37    confidence: Option<f32>,
38    sources: Vec<SourceInfo>,
39    query_time_ms: u128,
40}
41
42#[derive(Serialize, Deserialize)]
43struct SourceInfo {
44    id: String,
45    excerpt: String,
46    relevance_score: f32,
47}
48
49pub async fn run_benchmark(
50    config_path: &Path,
51    book_path: &Path,
52    questions: Vec<String>,
53) -> Result<()> {
54    // 1. Load config
55    let config = crate::config::load_config(config_path).await?;
56    let config_file_str = config_path.to_string_lossy().to_string();
57    let book_file_str = book_path.to_string_lossy().to_string();
58
59    // 2. Init
60    let start_all = Instant::now();
61    let start_init = Instant::now();
62
63    // We instantiate GraphRAG directly, not via the handler, to keep it simple and local
64    let mut graphrag = GraphRAG::new(config)?;
65    graphrag.initialize()?;
66
67    let init_ms = start_init.elapsed().as_millis();
68    info!("Init done in {}ms", init_ms);
69
70    // 3. Load & Build
71    let start_build = Instant::now();
72    let content = tokio::fs::read_to_string(book_path)
73        .await
74        .map_err(|e| eyre!("Failed to read book file: {}", e))?;
75
76    graphrag.add_document_from_text(&content)?;
77    graphrag.build_graph().await?;
78
79    let build_ms = start_build.elapsed().as_millis();
80    info!("Build done in {}ms", build_ms);
81
82    // Get stats
83    let kg = graphrag
84        .knowledge_graph()
85        .ok_or(eyre!("Knowledge graph not initialized"))?;
86    let stats = BenchStats {
87        entities: kg.entities().count(),
88        relationships: kg.relationships().count(),
89        chunks: kg.chunks().count(),
90    };
91
92    // 4. Query
93    let mut qa_results = Vec::new();
94    let start_query_total = Instant::now();
95
96    for (i, q) in questions.iter().enumerate() {
97        let q_start = Instant::now();
98        // Use ask_explained() which returns answer + source references
99        let (answer, confidence, sources) = match graphrag.ask_explained(q).await {
100            Ok(explained) => {
101                let source_infos: Vec<SourceInfo> = explained
102                    .sources
103                    .iter()
104                    .map(|s| SourceInfo {
105                        id: s.id.clone(),
106                        excerpt: s.excerpt.clone(),
107                        relevance_score: s.relevance_score,
108                    })
109                    .collect();
110                (explained.answer, Some(explained.confidence), source_infos)
111            },
112            Err(e) => (format!("Error: {}", e), None, vec![]),
113        };
114        let q_ms = q_start.elapsed().as_millis();
115
116        qa_results.push(QAItem {
117            index: i + 1,
118            question: q.clone(),
119            answer,
120            confidence,
121            sources,
122            query_time_ms: q_ms,
123        });
124        info!("Q{} done in {}ms", i + 1, q_ms);
125    }
126
127    let total_query_ms = start_query_total.elapsed().as_millis();
128    let total_ms = start_all.elapsed().as_millis();
129
130    // Output JSON
131    let result = BenchResult {
132        config_file: config_file_str,
133        book_file: book_file_str,
134        timing: BenchTiming {
135            init_ms,
136            build_ms,
137            total_query_ms,
138            total_ms,
139        },
140        stats,
141        questions_and_answers: qa_results,
142    };
143
144    println!("{}", serde_json::to_string(&result)?);
145
146    Ok(())
147}