use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info};
use super::types::SPARQLGenerationResult;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextAwareConfig {
pub max_history: usize,
pub track_entities: bool,
pub reuse_variables: bool,
pub learn_schema: bool,
pub decay_factor: f32,
}
impl Default for ContextAwareConfig {
fn default() -> Self {
Self {
max_history: 10,
track_entities: true,
reuse_variables: true,
learn_schema: true,
decay_factor: 0.9,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ConversationContext {
pub session_id: String,
pub history: Vec<ContextMessage>,
pub tracked_entities: HashMap<String, TrackedEntity>,
pub variable_bindings: HashMap<String, String>,
pub discovered_schema: DiscoveredSchema,
pub current_topic: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextMessage {
pub id: String,
pub content: String,
pub sparql: Option<String>,
pub entities: Vec<String>,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub relevance: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrackedEntity {
pub entity: String,
pub entity_type: String,
pub first_mention: String,
pub last_mention: String,
pub mention_count: usize,
pub resolved_uri: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DiscoveredSchema {
pub classes: Vec<String>,
pub properties: Vec<String>,
pub prefixes: HashMap<String, String>,
pub patterns: Vec<String>,
}
pub struct ContextAwareGenerator {
config: ContextAwareConfig,
}
impl ContextAwareGenerator {
pub fn new(config: ContextAwareConfig) -> Self {
info!("Initialized context-aware query generator");
Self { config }
}
pub fn generate_with_context(
&self,
query: &str,
context: &mut ConversationContext,
) -> Result<SPARQLGenerationResult> {
debug!("Generating context-aware query for: {}", query);
self.update_relevance_scores(context);
let current_entities = self.extract_entities(query)?;
if self.config.track_entities {
self.update_tracked_entities(context, ¤t_entities, query);
}
let resolved_query = self.resolve_anaphora(query, context)?;
let variable_hints = if self.config.reuse_variables {
self.get_variable_hints(context)
} else {
HashMap::new()
};
let mut sparql = self.generate_base_sparql(&resolved_query, context)?;
sparql = self.enhance_with_context(sparql, context, &variable_hints)?;
if self.config.learn_schema {
self.learn_from_query(&sparql, context);
}
self.add_to_history(context, query, &sparql, current_entities);
Ok(SPARQLGenerationResult {
query: sparql.clone(),
confidence: 0.85,
generation_method: crate::nl2sparql::types::GenerationMethod::RuleBased,
parameters: HashMap::new(),
explanation: Some(crate::nl2sparql::types::QueryExplanation {
natural_language: "Generated based on conversation context".to_string(),
reasoning_steps: vec![],
parameter_mapping: HashMap::new(),
alternatives: Vec::new(),
}),
validation_result: crate::nl2sparql::types::ValidationResult {
is_valid: true,
syntax_errors: Vec::new(),
semantic_warnings: Vec::new(),
schema_issues: Vec::new(),
suggestions: Vec::new(),
},
optimization_hints: Vec::new(),
metadata: crate::nl2sparql::types::GenerationMetadata {
generation_time_ms: 0,
template_used: None,
llm_model_used: None,
iterations: 1,
fallback_used: false,
},
})
}
fn update_relevance_scores(&self, context: &mut ConversationContext) {
let now = chrono::Utc::now();
for message in &mut context.history {
let age_seconds = (now - message.timestamp).num_seconds() as f32;
let decay = self.config.decay_factor.powf(age_seconds / 60.0); message.relevance *= decay;
}
context.history.retain(|m| m.relevance > 0.1);
}
fn extract_entities(&self, query: &str) -> Result<Vec<String>> {
let question_words = [
"How", "What", "Where", "When", "Who", "Which", "Why", "Is", "Are", "Do", "Does",
"Did", "Can", "Could", "Would", "Should", "Will", "The", "A", "An", "Of", "In", "On",
];
let mut entities: Vec<String> = query
.split_whitespace()
.filter(|w| {
let cleaned = w.trim_end_matches(|c: char| !c.is_alphanumeric());
let is_capitalized = cleaned
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false);
is_capitalized && !question_words.contains(&cleaned) && cleaned.len() > 1
})
.map(|w| {
w.trim_end_matches(|c: char| !c.is_alphanumeric())
.to_string()
})
.collect();
let uri_pattern = regex::Regex::new(r"<([^>]+)>").expect("regex pattern should be valid");
for capture in uri_pattern.captures_iter(query) {
if let Some(uri) = capture.get(1) {
entities.push(uri.as_str().to_string());
}
}
let prefixed_pattern = regex::Regex::new(r"\b([a-z]+):([A-Za-z0-9_-]+)\b")
.expect("regex pattern should be valid");
for capture in prefixed_pattern.captures_iter(query) {
if let Some(full_match) = capture.get(0) {
entities.push(full_match.as_str().to_string());
}
}
entities.sort();
entities.dedup();
Ok(entities)
}
fn update_tracked_entities(
&self,
context: &mut ConversationContext,
entities: &[String],
message_id: &str,
) {
for entity in entities {
context
.tracked_entities
.entry(entity.clone())
.and_modify(|e| {
e.mention_count += 1;
e.last_mention = message_id.to_string();
})
.or_insert(TrackedEntity {
entity: entity.clone(),
entity_type: "Unknown".to_string(),
first_mention: message_id.to_string(),
last_mention: message_id.to_string(),
mention_count: 1,
resolved_uri: None,
});
}
}
fn resolve_anaphora(&self, query: &str, context: &ConversationContext) -> Result<String> {
let mut resolved = query.to_string();
if resolved.to_lowercase().contains(" it ") || resolved.to_lowercase().ends_with(" it") {
if let Some(last_entity) = self.get_most_recent_entity(context) {
resolved = resolved.replace(" it ", &format!(" {} ", last_entity));
resolved = resolved.replace(" it", &format!(" {}", last_entity));
}
}
if resolved.to_lowercase().contains(" them ") {
if let Some(recent_entities) = self.get_recent_entities(context, 3) {
let entities_str = recent_entities.join(" and ");
resolved = resolved.replace(" them ", &format!(" {} ", entities_str));
}
}
if resolved.to_lowercase().contains(" that ") {
if let Some(ref topic) = context.current_topic {
resolved = resolved.replace(" that ", &format!(" {} ", topic));
}
}
debug!("Resolved query: {} -> {}", query, resolved);
Ok(resolved)
}
fn get_variable_hints(&self, context: &ConversationContext) -> HashMap<String, String> {
context.variable_bindings.clone()
}
fn generate_base_sparql(&self, query: &str, context: &ConversationContext) -> Result<String> {
let lowercase = query.to_lowercase();
let sparql = if lowercase.contains("count") || lowercase.contains("how many") {
self.generate_count_query(query, context)?
} else if lowercase.contains("list")
|| lowercase.contains("show")
|| lowercase.contains("find")
{
self.generate_select_query(query, context)?
} else if lowercase.contains("describe") {
self.generate_describe_query(query, context)?
} else {
self.generate_select_query(query, context)?
};
Ok(sparql)
}
fn generate_count_query(&self, query: &str, _context: &ConversationContext) -> Result<String> {
let entities = self.extract_entities(query)?;
let primary_entity = entities
.first()
.cloned()
.unwrap_or_else(|| "thing".to_string());
let mut sparql =
String::from("PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>\n");
sparql.push_str("PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n\n");
sparql.push_str("SELECT (COUNT(?s) AS ?count) WHERE {\n");
sparql.push_str(" ?s rdf:type ?type .\n");
sparql.push_str(&format!(
" FILTER (contains(str(?type), \"{}\"))\n",
primary_entity
));
sparql.push_str("}\n");
Ok(sparql)
}
fn generate_select_query(&self, query: &str, _context: &ConversationContext) -> Result<String> {
let entities = self.extract_entities(query)?;
let mut sparql =
String::from("PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>\n");
sparql.push_str("PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n\n");
sparql.push_str("SELECT DISTINCT ?s ?p ?o WHERE {\n");
if let Some(entity) = entities.first() {
sparql.push_str(" ?s ?p ?o .\n");
sparql.push_str(&format!(
" FILTER (contains(str(?s), \"{}\") || contains(str(?o), \"{}\"))\n",
entity, entity
));
} else {
sparql.push_str(" ?s ?p ?o .\n");
}
sparql.push_str("}\n");
sparql.push_str("LIMIT 100\n");
Ok(sparql)
}
fn generate_describe_query(
&self,
query: &str,
_context: &ConversationContext,
) -> Result<String> {
let entities = self.extract_entities(query)?;
let mut sparql =
String::from("PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>\n");
sparql.push_str("PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>\n\n");
if let Some(entity) = entities.first() {
sparql.push_str(&format!("DESCRIBE <http://example.org/{}>\n", entity));
} else {
sparql.push_str("DESCRIBE ?s WHERE { ?s ?p ?o } LIMIT 1\n");
}
Ok(sparql)
}
fn enhance_with_context(
&self,
mut sparql: String,
context: &ConversationContext,
_variable_hints: &HashMap<String, String>,
) -> Result<String> {
for (prefix, uri) in &context.discovered_schema.prefixes {
if !sparql.contains(&format!("PREFIX {}", prefix)) {
sparql = format!("PREFIX {}: <{}>\n{}", prefix, uri, sparql);
}
}
if context
.current_topic
.as_ref()
.map(|t| t.contains("sorted") || t.contains("ordered"))
.unwrap_or(false)
&& !sparql.contains("ORDER BY")
&& sparql.contains("SELECT")
{
if let Some(limit_pos) = sparql.find("LIMIT") {
sparql.insert_str(limit_pos, "ORDER BY ?s\n");
} else {
sparql.push_str("ORDER BY ?s\n");
}
}
Ok(sparql)
}
fn learn_from_query(&self, sparql: &str, context: &mut ConversationContext) {
if let Some(_class_match) = sparql.find("rdf:type") {
context
.discovered_schema
.classes
.push("NewClass".to_string());
}
let property_pattern = "?s ?p ?o";
if sparql.contains(property_pattern) {
}
context.discovered_schema.classes.sort();
context.discovered_schema.classes.dedup();
context.discovered_schema.properties.sort();
context.discovered_schema.properties.dedup();
}
fn add_to_history(
&self,
context: &mut ConversationContext,
query: &str,
sparql: &str,
entities: Vec<String>,
) {
let message = ContextMessage {
id: uuid::Uuid::new_v4().to_string(),
content: query.to_string(),
sparql: Some(sparql.to_string()),
entities,
timestamp: chrono::Utc::now(),
relevance: 1.0,
};
context.history.push(message);
if context.history.len() > self.config.max_history {
context.history.remove(0);
}
context.current_topic = Some(query.to_string());
}
fn get_most_recent_entity(&self, context: &ConversationContext) -> Option<String> {
context
.tracked_entities
.values()
.max_by_key(|e| &e.last_mention)
.map(|e| e.entity.clone())
}
fn get_recent_entities(
&self,
context: &ConversationContext,
count: usize,
) -> Option<Vec<String>> {
let mut entities: Vec<_> = context.tracked_entities.values().collect();
entities.sort_by_key(|e| &e.last_mention);
entities.reverse();
let recent: Vec<String> = entities
.iter()
.take(count)
.map(|e| e.entity.clone())
.collect();
if recent.is_empty() {
None
} else {
Some(recent)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_anaphora_resolution() {
let generator = ContextAwareGenerator::new(ContextAwareConfig::default());
let mut context = ConversationContext::default();
context.tracked_entities.insert(
"Inception".to_string(),
TrackedEntity {
entity: "Inception".to_string(),
entity_type: "Movie".to_string(),
first_mention: "msg1".to_string(),
last_mention: "msg1".to_string(),
mention_count: 1,
resolved_uri: None,
},
);
let resolved = generator
.resolve_anaphora("Tell me more about it", &context)
.expect("should succeed");
assert!(resolved.contains("Inception"));
}
#[test]
fn test_entity_tracking() {
let generator = ContextAwareGenerator::new(ContextAwareConfig::default());
let mut context = ConversationContext::default();
let entities = vec!["Movie".to_string(), "Director".to_string()];
generator.update_tracked_entities(&mut context, &entities, "msg1");
assert_eq!(context.tracked_entities.len(), 2);
assert_eq!(
context
.tracked_entities
.get("Movie")
.expect("should succeed")
.mention_count,
1
);
}
#[test]
fn test_count_query_generation() {
let generator = ContextAwareGenerator::new(ContextAwareConfig::default());
let context = ConversationContext::default();
let sparql = generator
.generate_count_query("How many Movies are there?", &context)
.expect("should succeed");
println!("Generated SPARQL: {}", sparql);
assert!(sparql.contains("COUNT"));
assert!(sparql.to_lowercase().contains("movie"));
}
}