1use crate::GraphRAGResult;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct ParsedQuery {
9 pub original: String,
11 pub keywords: Vec<String>,
13 pub intent: QueryIntent,
15 pub entities: Vec<ExtractedEntity>,
17 pub temporal: Option<TemporalConstraint>,
19}
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
23pub enum QueryIntent {
24 Factual,
26 Explanation,
28 Comparison,
30 List,
32 Definition,
34 Relationship,
36 Unknown,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ExtractedEntity {
43 pub text: String,
45 pub entity_type: String,
47 pub start: usize,
49 pub end: usize,
51 pub confidence: f32,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct TemporalConstraint {
58 pub constraint_type: TemporalType,
60 pub start: Option<String>,
62 pub end: Option<String>,
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
67pub enum TemporalType {
68 Before,
69 After,
70 During,
71 Between,
72}
73
74pub struct QueryParser {
76 stop_words: std::collections::HashSet<String>,
78}
79
80impl Default for QueryParser {
81 fn default() -> Self {
82 Self::new()
83 }
84}
85
86impl QueryParser {
87 pub fn new() -> Self {
88 let stop_words: std::collections::HashSet<String> = [
89 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
90 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
91 "shall", "can", "of", "at", "by", "for", "with", "about", "against", "between", "into",
92 "through", "during", "before", "after", "above", "below", "to", "from", "up", "down",
93 "in", "out", "on", "off", "over", "under", "again", "further", "then", "once", "here",
94 "there", "when", "where", "why", "how", "all", "each", "few", "more", "most", "other",
95 "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very",
96 "just", "and", "but", "if", "or", "because", "as", "until", "while", "what", "which",
97 "who", "whom", "this", "that", "these", "those", "am", "it",
98 ]
99 .iter()
100 .map(|s| s.to_string())
101 .collect();
102
103 Self { stop_words }
104 }
105
106 pub fn parse(&self, query: &str) -> GraphRAGResult<ParsedQuery> {
108 let keywords = self.extract_keywords(query);
109 let intent = self.detect_intent(query);
110 let entities = self.extract_entities(query);
111 let temporal = self.extract_temporal(query);
112
113 Ok(ParsedQuery {
114 original: query.to_string(),
115 keywords,
116 intent,
117 entities,
118 temporal,
119 })
120 }
121
122 fn extract_keywords(&self, query: &str) -> Vec<String> {
124 query
125 .to_lowercase()
126 .split(|c: char| !c.is_alphanumeric())
127 .filter(|word| !word.is_empty() && word.len() > 2 && !self.stop_words.contains(*word))
128 .map(String::from)
129 .collect()
130 }
131
132 fn detect_intent(&self, query: &str) -> QueryIntent {
134 let lower = query.to_lowercase();
135
136 if lower.starts_with("what is") || lower.starts_with("define") {
137 QueryIntent::Definition
138 } else if lower.starts_with("why") || lower.starts_with("how does") {
139 QueryIntent::Explanation
140 } else if lower.contains("compare") || lower.contains("difference between") {
141 QueryIntent::Comparison
142 } else if lower.starts_with("list") || lower.contains("what are") {
143 QueryIntent::List
144 } else if lower.contains("related to")
145 || lower.contains("connected to")
146 || lower.contains("relationship")
147 {
148 QueryIntent::Relationship
149 } else if lower.starts_with("what")
150 || lower.starts_with("who")
151 || lower.starts_with("when")
152 || lower.starts_with("where")
153 {
154 QueryIntent::Factual
155 } else {
156 QueryIntent::Unknown
157 }
158 }
159
160 fn extract_entities(&self, query: &str) -> Vec<ExtractedEntity> {
162 let mut entities = Vec::new();
163
164 for word in query.split_whitespace() {
166 if word.len() > 1
167 && word
168 .chars()
169 .next()
170 .map(|c| c.is_uppercase())
171 .unwrap_or(false)
172 && ![
173 "What", "Who", "When", "Where", "Why", "How", "Is", "Are", "The", "A",
174 ]
175 .contains(&word)
176 {
177 if let Some(start) = query.find(word) {
178 entities.push(ExtractedEntity {
179 text: word.to_string(),
180 entity_type: "Unknown".to_string(),
181 start,
182 end: start + word.len(),
183 confidence: 0.5,
184 });
185 }
186 }
187 }
188
189 entities
190 }
191
192 fn extract_temporal(&self, query: &str) -> Option<TemporalConstraint> {
194 let lower = query.to_lowercase();
195
196 if lower.contains("before") {
197 Some(TemporalConstraint {
198 constraint_type: TemporalType::Before,
199 start: None,
200 end: None,
201 })
202 } else if lower.contains("after") {
203 Some(TemporalConstraint {
204 constraint_type: TemporalType::After,
205 start: None,
206 end: None,
207 })
208 } else if lower.contains("during") || lower.contains("in ") {
209 Some(TemporalConstraint {
210 constraint_type: TemporalType::During,
211 start: None,
212 end: None,
213 })
214 } else if lower.contains("between") {
215 Some(TemporalConstraint {
216 constraint_type: TemporalType::Between,
217 start: None,
218 end: None,
219 })
220 } else {
221 None
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn test_keyword_extraction() {
232 let parser = QueryParser::new();
233 let keywords = parser.extract_keywords("What are the safety issues with battery cells?");
234 assert!(keywords.contains(&"safety".to_string()));
235 assert!(keywords.contains(&"issues".to_string()));
236 assert!(keywords.contains(&"battery".to_string()));
237 assert!(keywords.contains(&"cells".to_string()));
238 }
239
240 #[test]
241 fn test_intent_detection() {
242 let parser = QueryParser::new();
243
244 assert_eq!(
245 parser.detect_intent("What is a battery?"),
246 QueryIntent::Definition
247 );
248 assert_eq!(
249 parser.detect_intent("Why does the battery overheat?"),
250 QueryIntent::Explanation
251 );
252 assert_eq!(
253 parser.detect_intent("Compare lithium and nickel batteries"),
254 QueryIntent::Comparison
255 );
256 assert_eq!(
257 parser.detect_intent("List all safety hazards"),
258 QueryIntent::List
259 );
260 }
261}