use crate::GraphRAGResult;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParsedQuery {
pub original: String,
pub keywords: Vec<String>,
pub intent: QueryIntent,
pub entities: Vec<ExtractedEntity>,
pub temporal: Option<TemporalConstraint>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum QueryIntent {
Factual,
Explanation,
Comparison,
List,
Definition,
Relationship,
Unknown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtractedEntity {
pub text: String,
pub entity_type: String,
pub start: usize,
pub end: usize,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalConstraint {
pub constraint_type: TemporalType,
pub start: Option<String>,
pub end: Option<String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum TemporalType {
Before,
After,
During,
Between,
}
pub struct QueryParser {
stop_words: std::collections::HashSet<String>,
}
impl Default for QueryParser {
fn default() -> Self {
Self::new()
}
}
impl QueryParser {
pub fn new() -> Self {
let stop_words: std::collections::HashSet<String> = [
"a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
"had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
"shall", "can", "of", "at", "by", "for", "with", "about", "against", "between", "into",
"through", "during", "before", "after", "above", "below", "to", "from", "up", "down",
"in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here",
"there", "when", "where", "why", "how", "all", "each", "few", "more", "most", "other",
"some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very",
"just", "and", "but", "if", "or", "because", "as", "until", "while", "what", "which",
"who", "whom", "this", "that", "these", "those", "am", "it",
]
.iter()
.map(|s| s.to_string())
.collect();
Self { stop_words }
}
pub fn parse(&self, query: &str) -> GraphRAGResult<ParsedQuery> {
let keywords = self.extract_keywords(query);
let intent = self.detect_intent(query);
let entities = self.extract_entities(query);
let temporal = self.extract_temporal(query);
Ok(ParsedQuery {
original: query.to_string(),
keywords,
intent,
entities,
temporal,
})
}
fn extract_keywords(&self, query: &str) -> Vec<String> {
query
.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|word| !word.is_empty() && word.len() > 2 && !self.stop_words.contains(*word))
.map(String::from)
.collect()
}
fn detect_intent(&self, query: &str) -> QueryIntent {
let lower = query.to_lowercase();
if lower.starts_with("what is") || lower.starts_with("define") {
QueryIntent::Definition
} else if lower.starts_with("why") || lower.starts_with("how does") {
QueryIntent::Explanation
} else if lower.contains("compare") || lower.contains("difference between") {
QueryIntent::Comparison
} else if lower.starts_with("list") || lower.contains("what are") {
QueryIntent::List
} else if lower.contains("related to")
|| lower.contains("connected to")
|| lower.contains("relationship")
{
QueryIntent::Relationship
} else if lower.starts_with("what")
|| lower.starts_with("who")
|| lower.starts_with("when")
|| lower.starts_with("where")
{
QueryIntent::Factual
} else {
QueryIntent::Unknown
}
}
fn extract_entities(&self, query: &str) -> Vec<ExtractedEntity> {
let mut entities = Vec::new();
for word in query.split_whitespace() {
if word.len() > 1
&& word
.chars()
.next()
.map(|c| c.is_uppercase())
.unwrap_or(false)
&& ![
"What", "Who", "When", "Where", "Why", "How", "Is", "Are", "The", "A",
]
.contains(&word)
{
if let Some(start) = query.find(word) {
entities.push(ExtractedEntity {
text: word.to_string(),
entity_type: "Unknown".to_string(),
start,
end: start + word.len(),
confidence: 0.5,
});
}
}
}
entities
}
fn extract_temporal(&self, query: &str) -> Option<TemporalConstraint> {
let lower = query.to_lowercase();
if lower.contains("before") {
Some(TemporalConstraint {
constraint_type: TemporalType::Before,
start: None,
end: None,
})
} else if lower.contains("after") {
Some(TemporalConstraint {
constraint_type: TemporalType::After,
start: None,
end: None,
})
} else if lower.contains("during") || lower.contains("in ") {
Some(TemporalConstraint {
constraint_type: TemporalType::During,
start: None,
end: None,
})
} else if lower.contains("between") {
Some(TemporalConstraint {
constraint_type: TemporalType::Between,
start: None,
end: None,
})
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keyword_extraction() {
let parser = QueryParser::new();
let keywords = parser.extract_keywords("What are the safety issues with battery cells?");
assert!(keywords.contains(&"safety".to_string()));
assert!(keywords.contains(&"issues".to_string()));
assert!(keywords.contains(&"battery".to_string()));
assert!(keywords.contains(&"cells".to_string()));
}
#[test]
fn test_intent_detection() {
let parser = QueryParser::new();
assert_eq!(
parser.detect_intent("What is a battery?"),
QueryIntent::Definition
);
assert_eq!(
parser.detect_intent("Why does the battery overheat?"),
QueryIntent::Explanation
);
assert_eq!(
parser.detect_intent("Compare lithium and nickel batteries"),
QueryIntent::Comparison
);
assert_eq!(
parser.detect_intent("List all safety hazards"),
QueryIntent::List
);
}
}