1#![allow(unused_imports)]
2
3use crate::config::Config;
4use crate::core::{
5 ChunkId, Document, DocumentId, Entity, EntityId, GraphRAGError, KnowledgeGraph, Relationship,
6 Result, TextChunk,
7};
8use crate::{critic, ollama, persistence, query, retrieval};
9
10#[cfg(feature = "parallel-processing")]
11#[allow(unused_imports)]
12use crate::parallel;
13
14use super::GraphRAG;
15
16impl GraphRAG {
17 #[cfg(feature = "async")]
20 pub async fn ask_with_reasoning(&mut self, query: &str) -> Result<String> {
21 if self.query_planner.is_none() {
23 return self.ask(query).await;
24 }
25
26 self.ensure_initialized()?;
27 if self.has_documents() && !self.has_graph() {
28 self.build_graph().await?;
29 }
30
31 let planner = self.query_planner.as_ref().expect("checked above");
32 tracing::info!("Decomposing query: {}", query);
33
34 let sub_queries = match planner.decompose(query).await {
36 Ok(sq) => sq,
37 Err(e) => {
38 tracing::warn!(
39 "Query decomposition failed, falling back to standard query: {}",
40 e
41 );
42 vec![query.to_string()]
43 },
44 };
45
46 tracing::info!("Sub-queries: {:?}", sub_queries);
47
48 let mut all_results = Vec::new();
50 for sub_query in sub_queries {
51 match self.query_internal_with_results(&sub_query).await {
52 Ok(results) => all_results.extend(results),
53 Err(e) => tracing::warn!("Failed to execute sub-query '{}': {}", sub_query, e),
54 }
55 }
56
57 if all_results.is_empty() {
58 return Ok("No relevant information found for the decomposed queries.".to_string());
59 }
60
61 all_results.sort_by(|a, b| {
64 b.score
65 .partial_cmp(&a.score)
66 .unwrap_or(std::cmp::Ordering::Equal)
67 });
68 let mut unique_results = Vec::new();
69 let mut seen_ids = std::collections::HashSet::new();
70
71 for result in all_results {
72 if !seen_ids.contains(&result.id) {
73 seen_ids.insert(result.id.clone());
74 unique_results.push(result);
75 }
76 }
77
78 if self.config.ollama.enabled {
79 let mut current_answer = self
81 .generate_semantic_answer_from_results(query, &unique_results)
82 .await?;
83
84 if let Some(critic) = &self.critic {
86 let mut attempts = 0;
87 let max_retries = 3;
88
89 while attempts < max_retries {
90 let context_strings: Vec<String> =
91 unique_results.iter().map(|r| r.content.clone()).collect();
92
93 let evaluation = match critic
94 .evaluate(query, &context_strings, ¤t_answer)
95 .await
96 {
97 Ok(eval) => eval,
98 Err(e) => {
99 tracing::warn!("Critic evaluation failed: {}", e);
100 break;
101 },
102 };
103
104 tracing::info!(
105 "Critic Evaluation (Attempt {}): Score={:.2}, Grounded={}, Feedback='{}'",
106 attempts + 1,
107 evaluation.score,
108 evaluation.grounded,
109 evaluation.feedback
110 );
111
112 if evaluation.score >= 0.7 && evaluation.grounded {
113 tracing::info!("Answer accepted by critic.");
114 break;
115 }
116
117 tracing::warn!("Answer rejected by critic. Refining...");
118
119 current_answer = critic
121 .refine(query, ¤t_answer, &evaluation.feedback)
122 .await?;
123 attempts += 1;
124 }
125 }
126
127 return Ok(current_answer);
128 }
129
130 let formatted: Vec<String> = unique_results
132 .into_iter()
133 .take(10)
134 .map(|r| format!("{} (score: {:.2})", r.content, r.score))
135 .collect();
136 Ok(formatted.join("\n"))
137 }
138
139 #[cfg(feature = "async")]
141 pub async fn ask(&mut self, query: &str) -> Result<String> {
142 self.ensure_initialized()?;
143
144 if self.has_documents() && !self.has_graph() {
145 self.build_graph().await?;
146 }
147
148 let search_results = self.query_internal_with_results(query).await?;
150
151 if self.config.ollama.enabled {
153 return self
154 .generate_semantic_answer_from_results(query, &search_results)
155 .await;
156 }
157
158 let formatted: Vec<String> = search_results
160 .into_iter()
161 .map(|r| format!("{} (score: {:.2})", r.content, r.score))
162 .collect();
163 Ok(formatted.join("\n"))
164 }
165
166 #[cfg(not(feature = "async"))]
168 pub fn ask(&mut self, query: &str) -> Result<String> {
169 self.ensure_initialized()?;
170
171 if self.has_documents() && !self.has_graph() {
172 self.build_graph()?;
173 }
174
175 let retrieval = self
176 .retrieval_system
177 .as_ref()
178 .ok_or_else(|| GraphRAGError::Config {
179 message: "Retrieval system not initialized".to_string(),
180 })?;
181
182 let results = retrieval.query(query)?;
183 Ok(results.join("\n"))
184 }
185
186 #[cfg(feature = "async")]
217 pub async fn ask_explained(&mut self, query: &str) -> Result<retrieval::ExplainedAnswer> {
218 self.ensure_initialized()?;
219
220 if self.has_documents() && !self.has_graph() {
221 self.build_graph().await?;
222 }
223
224 let search_results = self.query_internal_with_results(query).await?;
226
227 let answer = if self.config.ollama.enabled {
229 self.generate_semantic_answer_from_results(query, &search_results)
230 .await?
231 } else {
232 search_results
234 .iter()
235 .take(3)
236 .map(|r| r.content.clone())
237 .collect::<Vec<_>>()
238 .join(" ")
239 };
240
241 let explained = retrieval::ExplainedAnswer::from_results(answer, &search_results, query);
243
244 Ok(explained)
245 }
246
247 pub async fn query_internal(&mut self, query: &str) -> Result<Vec<String>> {
249 let retrieval = self
250 .retrieval_system
251 .as_mut()
252 .ok_or_else(|| GraphRAGError::Config {
253 message: "Retrieval system not initialized".to_string(),
254 })?;
255
256 let graph = self
257 .knowledge_graph
258 .as_mut()
259 .ok_or_else(|| GraphRAGError::Config {
260 message: "Knowledge graph not initialized".to_string(),
261 })?;
262
263 retrieval.add_embeddings_to_graph(graph).await?;
265
266 let search_results = retrieval.hybrid_query(query, graph).await?;
268
269 let result_strings: Vec<String> = search_results
271 .into_iter()
272 .map(|r| format!("{} (score: {:.2})", r.content, r.score))
273 .collect();
274
275 Ok(result_strings)
276 }
277
278 #[cfg(feature = "async")]
280 async fn query_internal_with_results(
281 &mut self,
282 query: &str,
283 ) -> Result<Vec<retrieval::SearchResult>> {
284 let retrieval = self
285 .retrieval_system
286 .as_mut()
287 .ok_or_else(|| GraphRAGError::Config {
288 message: "Retrieval system not initialized".to_string(),
289 })?;
290
291 let graph = self
292 .knowledge_graph
293 .as_mut()
294 .ok_or_else(|| GraphRAGError::Config {
295 message: "Knowledge graph not initialized".to_string(),
296 })?;
297
298 retrieval.add_embeddings_to_graph(graph).await?;
300
301 retrieval.hybrid_query(query, graph).await
303 }
304
305 #[cfg(feature = "async")]
307 async fn generate_semantic_answer_from_results(
308 &self,
309 query: &str,
310 search_results: &[retrieval::SearchResult],
311 ) -> Result<String> {
312 use crate::ollama::OllamaClient;
313
314 let graph = self
315 .knowledge_graph
316 .as_ref()
317 .ok_or_else(|| GraphRAGError::Config {
318 message: "Knowledge graph not initialized".to_string(),
319 })?;
320
321 let mut context_parts = Vec::new();
324 let mut seen_chunk_ids = std::collections::HashSet::new();
325
326 for result in search_results.iter() {
327 if result.result_type == retrieval::ResultType::Entity
329 && !result.source_chunks.is_empty()
330 {
331 let entity_label = result
332 .content
333 .split(" (score:")
334 .next()
335 .unwrap_or(&result.content);
336 for chunk_id_str in &result.source_chunks {
337 if seen_chunk_ids.contains(chunk_id_str) {
338 continue;
339 }
340 let chunk_id = ChunkId::new(chunk_id_str.clone());
341 if let Some(chunk) = graph.chunks().find(|c| c.id == chunk_id) {
342 seen_chunk_ids.insert(chunk_id_str.clone());
343 context_parts.push((
344 result.score,
345 format!(
346 "[Entity: {} | Relevance: {:.2}]\n{}",
347 entity_label, result.score, chunk.content
348 ),
349 ));
350 }
351 }
352 }
353 else if result.result_type == retrieval::ResultType::Chunk {
355 if !seen_chunk_ids.contains(&result.id) {
356 seen_chunk_ids.insert(result.id.clone());
357 context_parts.push((
358 result.score,
359 format!(
360 "[Chunk | Relevance: {:.2}]\n{}",
361 result.score, result.content
362 ),
363 ));
364 }
365 }
366 else {
368 context_parts.push((
369 result.score,
370 format!(
371 "[{:?} | Relevance: {:.2}]\n{}",
372 result.result_type, result.score, result.content
373 ),
374 ));
375 }
376 }
377
378 context_parts.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
380 let context = context_parts
381 .into_iter()
382 .map(|(_, text)| text)
383 .collect::<Vec<_>>()
384 .join("\n\n---\n\n");
385
386 if context.trim().is_empty() {
387 return Ok("No relevant information found in the knowledge graph.".to_string());
388 }
389
390 let client = OllamaClient::new(self.config.ollama.clone());
392
393 let prompt = format!(
395 "You are a knowledgeable assistant specialized in answering questions based on a knowledge graph.\n\n\
396 IMPORTANT INSTRUCTIONS:\n\
397 - Answer ONLY using information from the provided context below\n\
398 - Synthesize information from ALL context sections to give a comprehensive answer\n\
399 - Provide direct, conversational, and natural responses\n\
400 - Do NOT show your reasoning process or use <think> tags\n\
401 - If the context lacks sufficient information, clearly state: \"I don't have enough information to answer this question.\"\n\
402 - Aim for a complete answer (3-6 sentences) that covers different aspects found across the context\n\
403 - Use a natural, helpful tone as if speaking to a person\n\n\
404 CONTEXT:\n\
405 {}\n\n\
406 QUESTION: {}\n\n\
407 ANSWER (direct response only, no reasoning):",
408 context, query
409 );
410
411 let max_answer_tokens: u32 = 800;
413 let prompt_tokens = (prompt.len() / 4) as u32;
414 let total = prompt_tokens + max_answer_tokens;
415 let with_margin = (total as f32 * 1.20) as u32;
416 let num_ctx = (((with_margin + 1023) / 1024) * 1024).clamp(4096, 131_072);
417
418 let params = crate::ollama::OllamaGenerationParams {
419 num_predict: Some(max_answer_tokens),
420 temperature: self.config.ollama.temperature,
421 num_ctx: Some(num_ctx),
422 keep_alive: self.config.ollama.keep_alive.clone(),
423 ..Default::default()
424 };
425
426 match client.generate_with_params(&prompt, params).await {
428 Ok(answer) => {
429 let cleaned_answer = Self::remove_thinking_tags(&answer);
431 Ok(cleaned_answer.trim().to_string())
432 },
433 Err(e) => {
434 #[cfg(feature = "tracing")]
435 tracing::warn!(
436 "LLM generation failed: {}. Falling back to search results.",
437 e
438 );
439
440 Ok(format!(
442 "Relevant information from knowledge graph:\n\n{}",
443 context
444 ))
445 },
446 }
447 }
448
449 #[cfg(feature = "async")]
454 fn remove_thinking_tags(text: &str) -> String {
455 let mut result = text.to_string();
458
459 while let Some(start) = result.find("<think>") {
460 if let Some(end) = result[start..].find("</think>") {
462 let end_pos = start + end + "</think>".len();
464 result.replace_range(start..end_pos, "");
465 } else {
466 result.replace_range(start..start + "<think>".len(), "");
468 break;
469 }
470 }
471
472 result.trim().to_string()
473 }
474
475 #[cfg(all(feature = "pagerank", feature = "async"))]
477 pub async fn ask_with_pagerank(
478 &mut self,
479 query: &str,
480 ) -> Result<Vec<retrieval::pagerank_retrieval::ScoredResult>> {
481 use crate::retrieval::pagerank_retrieval::PageRankRetrievalSystem;
482
483 self.ensure_initialized()?;
484
485 if self.has_documents() && !self.has_graph() {
486 self.build_graph().await?;
487 }
488
489 let graph = self
490 .knowledge_graph
491 .as_ref()
492 .ok_or_else(|| GraphRAGError::Config {
493 message: "Knowledge graph not initialized".to_string(),
494 })?;
495
496 let pagerank_system = PageRankRetrievalSystem::new(10);
497 pagerank_system.search_with_pagerank(query, graph, Some(5))
498 }
499
500 #[cfg(all(feature = "pagerank", not(feature = "async")))]
502 pub fn ask_with_pagerank(
503 &mut self,
504 query: &str,
505 ) -> Result<Vec<retrieval::pagerank_retrieval::ScoredResult>> {
506 use crate::retrieval::pagerank_retrieval::PageRankRetrievalSystem;
507
508 self.ensure_initialized()?;
509
510 if self.has_documents() && !self.has_graph() {
511 self.build_graph()?;
512 }
513
514 let graph = self
515 .knowledge_graph
516 .as_ref()
517 .ok_or_else(|| GraphRAGError::Config {
518 message: "Knowledge graph not initialized".to_string(),
519 })?;
520
521 let pagerank_system = PageRankRetrievalSystem::new(10);
522 pagerank_system.search_with_pagerank(query, graph, Some(5))
523 }
524}