#![allow(unused_imports)]
#[cfg(feature = "rograg")]
use crate::core::{Entity, KnowledgeGraph};
#[cfg(feature = "rograg")]
use crate::retrieval::causal_analysis::CausalAnalyzer;
#[cfg(feature = "rograg")]
use crate::Result;
#[cfg(feature = "rograg")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "rograg")]
use std::collections::HashSet;
#[cfg(feature = "rograg")]
use std::sync::Arc;
#[cfg(feature = "rograg")]
use strum::{Display as StrumDisplay, EnumString};
#[cfg(feature = "rograg")]
use thiserror::Error;
#[cfg(feature = "rograg")]
use super::*;
#[cfg(feature = "rograg")]
pub struct LogicFormRetriever {
parsers: Vec<Box<dyn LogicFormParser>>,
executor: LogicFormExecutor,
}
#[cfg(feature = "rograg")]
impl Default for LogicFormRetriever {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "rograg")]
impl LogicFormRetriever {
pub fn new() -> Self {
let parsers: Vec<Box<dyn LogicFormParser>> = vec![Box::new(
PatternBasedParser::new().expect("static pattern set"),
)];
Self {
parsers,
executor: LogicFormExecutor::new(),
}
}
pub async fn retrieve(&self, query: &str, graph: &KnowledgeGraph) -> Result<LogicFormResult> {
let start_time = std::time::Instant::now();
let mut logic_form = None;
for parser in &self.parsers {
if let Some(parsed) = parser.parse(query)? {
logic_form = Some(parsed);
break;
}
}
let parsing_time = start_time.elapsed().as_millis() as u64;
let logic_form = logic_form.ok_or_else(|| LogicFormError::ParseError {
query: query.to_string(),
})?;
let execution_start = std::time::Instant::now();
let bindings = self.executor.execute(&logic_form, graph)?;
let execution_time = execution_start.elapsed().as_millis() as u64;
if bindings.is_empty() {
return Err(LogicFormError::NoResults.into());
}
let answer = self.generate_answer(&logic_form, &bindings);
let confidence = self.calculate_overall_confidence(&bindings);
let sources = self.extract_sources(&bindings);
let relationships_examined = match logic_form.predicate {
Predicate::Related | Predicate::Caused | Predicate::Compare => {
graph.relationships().count()
},
Predicate::Is
| Predicate::Has
| Predicate::Happened
| Predicate::Exists
| Predicate::Similar
| Predicate::Located => 0,
_ => 0,
};
Ok(LogicFormResult {
query: query.to_string(),
logic_form,
bindings: bindings.clone(),
answer,
confidence,
sources,
execution_stats: LogicExecutionStats {
parsing_time_ms: parsing_time,
execution_time_ms: execution_time,
entities_examined: graph.entities().count(),
relationships_examined,
bindings_found: bindings.len(),
},
})
}
fn generate_answer(&self, logic_form: &LogicFormQuery, bindings: &[VariableBinding]) -> String {
match logic_form.predicate {
Predicate::Is => {
if let Some(binding) = bindings.first() {
binding.value.clone()
} else {
"No information found.".to_string()
}
},
Predicate::Related => {
if let Some(binding) = bindings.first() {
binding.value.clone()
} else {
"No relationship found.".to_string()
}
},
Predicate::Compare => {
if let Some(binding) = bindings.first() {
binding.value.clone()
} else {
"Cannot compare the specified entities.".to_string()
}
},
_ => {
let values: Vec<String> = bindings.iter().map(|b| b.value.clone()).collect();
values.join("; ")
},
}
}
fn calculate_overall_confidence(&self, bindings: &[VariableBinding]) -> f32 {
if bindings.is_empty() {
return 0.0;
}
let sum: f32 = bindings.iter().map(|b| b.confidence).sum();
sum / bindings.len() as f32
}
fn extract_sources(&self, bindings: &[VariableBinding]) -> Vec<String> {
bindings
.iter()
.filter_map(|b| b.entity_id.clone())
.collect()
}
pub fn add_parser(&mut self, parser: Box<dyn LogicFormParser>) {
self.parsers.push(parser);
}
pub fn get_supported_predicates(&self) -> Vec<Predicate> {
vec![
Predicate::Is,
Predicate::Related,
Predicate::Has,
Predicate::Compare,
Predicate::Happened,
Predicate::Caused,
]
}
}