1use std::collections::{HashMap, HashSet};
12use serde::{Deserialize, Serialize};
13
14pub struct QueryIntelligence {
16 synonyms: HashMap<String, Vec<String>>,
17 templates: Vec<QueryTemplate>,
18 stop_words: HashSet<String>,
19 relevance_scores: HashMap<String, f32>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct QueryTemplate {
25 pub pattern: String,
27 pub query_type: QueryType,
29 pub rewrite: String,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub enum QueryType {
36 EntityLookup,
38 Relationship,
40 Aggregation,
42 Comparison,
44 Temporal,
46 Causal,
48 General,
50}
51
52#[derive(Debug, Clone)]
54pub struct RewrittenQuery {
55 pub original: String,
57 pub rewritten: String,
59 pub query_type: QueryType,
61 pub expanded_terms: Vec<String>,
63 pub confidence: f32,
65}
66
67impl QueryIntelligence {
68 pub fn new() -> Self {
70 let mut engine = Self {
71 synonyms: HashMap::new(),
72 templates: Vec::new(),
73 stop_words: HashSet::new(),
74 relevance_scores: HashMap::new(),
75 };
76
77 engine.load_default_synonyms();
79 engine.load_default_templates();
80 engine.load_default_stop_words();
81
82 engine
83 }
84
85 pub fn rewrite_query(&self, query: &str) -> RewrittenQuery {
93 let normalized = self.normalize_query(query);
95
96 let query_type = self.detect_query_type(&normalized);
98
99 let template_rewritten = self.apply_templates(&normalized, &query_type);
101
102 let expanded = self.expand_synonyms(&template_rewritten);
104
105 let expanded_terms = self.extract_key_terms(&expanded);
107
108 let confidence = self.calculate_confidence(&normalized, &expanded_terms);
110
111 RewrittenQuery {
112 original: query.to_string(),
113 rewritten: expanded,
114 query_type,
115 expanded_terms,
116 confidence,
117 }
118 }
119
120 pub fn add_synonym(&mut self, term: impl Into<String>, synonyms: Vec<String>) {
126 self.synonyms.insert(term.into().to_lowercase(), synonyms);
128 }
129
130 pub fn add_template(&mut self, template: QueryTemplate) {
135 self.templates.push(template);
136 }
137
138 pub fn record_feedback(&mut self, term: impl Into<String>, score: f32) {
144 let term = term.into();
145 let current_score = self.relevance_scores.get(&term).unwrap_or(&0.5);
146 let new_score = current_score * 0.5 + score * 0.5;
148 self.relevance_scores.insert(term, new_score);
149 }
150
151 pub fn get_relevance(&self, term: &str) -> f32 {
159 *self.relevance_scores.get(term).unwrap_or(&0.5)
160 }
161
162 fn normalize_query(&self, query: &str) -> String {
166 query.trim().to_lowercase()
167 }
168
169 fn detect_query_type(&self, query: &str) -> QueryType {
171 let query_lower = query.to_lowercase();
172
173 if query_lower.contains("relationship between")
175 || query_lower.contains("how does")
176 || query_lower.contains("related to")
177 || query_lower.contains("connection between")
178 {
179 return QueryType::Relationship;
180 }
181
182 if query_lower.starts_with("who is")
184 || query_lower.starts_with("what is")
185 || query_lower.starts_with("define")
186 {
187 return QueryType::EntityLookup;
188 }
189
190 if query_lower.starts_with("how many")
192 || query_lower.starts_with("count")
193 || query_lower.contains("total")
194 || query_lower.contains("sum")
195 || query_lower.contains("average")
196 {
197 return QueryType::Aggregation;
198 }
199
200 if query_lower.contains("compare")
202 || query_lower.contains("difference between")
203 || query_lower.contains("versus")
204 || query_lower.contains("vs")
205 {
206 return QueryType::Comparison;
207 }
208
209 if query_lower.contains("when")
211 || query_lower.contains("before")
212 || query_lower.contains("after")
213 || query_lower.contains("during")
214 || query_lower.contains("timeline")
215 {
216 return QueryType::Temporal;
217 }
218
219 if query_lower.contains("why")
221 || query_lower.contains("because")
222 || query_lower.contains("cause")
223 || query_lower.contains("reason")
224 || query_lower.contains("led to")
225 {
226 return QueryType::Causal;
227 }
228
229 QueryType::General
230 }
231
232 fn apply_templates(&self, query: &str, query_type: &QueryType) -> String {
234 for template in &self.templates {
235 if &template.query_type == query_type && query.contains(&template.pattern) {
236 return query.replace(&template.pattern, &template.rewrite);
237 }
238 }
239 query.to_string()
240 }
241
242 fn expand_synonyms(&self, query: &str) -> String {
244 let words: Vec<&str> = query.split_whitespace().collect();
245 let mut expanded_words = Vec::new();
246
247 for word in words {
248 expanded_words.push(word.to_string());
249
250 if let Some(synonyms) = self.synonyms.get(word) {
252 for synonym in synonyms {
253 if !expanded_words.contains(synonym) {
254 expanded_words.push(synonym.clone());
255 }
256 }
257 }
258 }
259
260 expanded_words.join(" ")
261 }
262
263 fn extract_key_terms(&self, query: &str) -> Vec<String> {
265 query
266 .split_whitespace()
267 .filter(|word| !self.stop_words.contains(*word))
268 .map(|s| s.to_string())
269 .collect()
270 }
271
272 fn calculate_confidence(&self, query: &str, expanded_terms: &[String]) -> f32 {
274 if expanded_terms.is_empty() {
275 return 0.5;
276 }
277
278 let word_count = query.split_whitespace().count() as f32;
280 let term_count = expanded_terms.len() as f32;
281
282 let specificity_score = (term_count / (word_count + 1.0)).min(1.0);
284
285 let relevance_score: f32 = expanded_terms
287 .iter()
288 .map(|t| self.get_relevance(t))
289 .sum::<f32>()
290 / term_count;
291
292 specificity_score * 0.6 + relevance_score * 0.4
294 }
295
296 fn load_default_synonyms(&mut self) {
298 self.add_synonym("find", vec!["search".to_string(), "locate".to_string()]);
300 self.add_synonym("person", vec!["individual".to_string(), "people".to_string()]);
301 self.add_synonym("company", vec!["organization".to_string(), "business".to_string(), "firm".to_string()]);
302 self.add_synonym("show", vec!["display".to_string(), "present".to_string()]);
303 self.add_synonym("get", vec!["retrieve".to_string(), "fetch".to_string()]);
304 self.add_synonym("large", vec!["big".to_string(), "huge".to_string(), "significant".to_string()]);
305 self.add_synonym("small", vec!["tiny".to_string(), "minor".to_string()]);
306 self.add_synonym("important", vec!["significant".to_string(), "critical".to_string(), "key".to_string()]);
307 }
308
309 fn load_default_templates(&mut self) {
311 self.add_template(QueryTemplate {
312 pattern: "who is".to_string(),
313 query_type: QueryType::EntityLookup,
314 rewrite: "entity:".to_string(),
315 });
316
317 self.add_template(QueryTemplate {
318 pattern: "what is".to_string(),
319 query_type: QueryType::EntityLookup,
320 rewrite: "define:".to_string(),
321 });
322
323 self.add_template(QueryTemplate {
324 pattern: "how many".to_string(),
325 query_type: QueryType::Aggregation,
326 rewrite: "count:".to_string(),
327 });
328
329 self.add_template(QueryTemplate {
330 pattern: "compare".to_string(),
331 query_type: QueryType::Comparison,
332 rewrite: "compare:".to_string(),
333 });
334 }
335
336 fn load_default_stop_words(&mut self) {
338 let stop_words = vec![
339 "a", "an", "and", "are", "as", "at", "be", "by", "for",
340 "from", "has", "he", "in", "is", "it", "its", "of", "on",
341 "that", "the", "to", "was", "will", "with",
342 ];
343
344 for word in stop_words {
345 self.stop_words.insert(word.to_string());
346 }
347 }
348}
349
350impl Default for QueryIntelligence {
351 fn default() -> Self {
352 Self::new()
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_query_type_detection() {
362 let engine = QueryIntelligence::new();
363
364 let query = "who is the CEO of OpenAI?";
365 let result = engine.rewrite_query(query);
366 assert_eq!(result.query_type, QueryType::EntityLookup);
367
368 let query = "how many employees work at Google?";
369 let result = engine.rewrite_query(query);
370 assert_eq!(result.query_type, QueryType::Aggregation);
371
372 let query = "what is the relationship between Apple and Microsoft?";
373 let result = engine.rewrite_query(query);
374 assert_eq!(result.query_type, QueryType::Relationship);
375 }
376
377 #[test]
378 fn test_synonym_expansion() {
379 let engine = QueryIntelligence::new();
380
381 let query = "find large companies";
382 let result = engine.rewrite_query(query);
383
384 assert!(result.expanded_terms.contains(&"search".to_string()) ||
386 result.expanded_terms.contains(&"big".to_string()));
387 }
388
389 #[test]
390 fn test_stop_word_removal() {
391 let engine = QueryIntelligence::new();
392
393 let query = "what is the best approach";
394 let result = engine.rewrite_query(query);
395
396 assert!(!result.expanded_terms.contains(&"the".to_string()));
398 assert!(!result.expanded_terms.contains(&"is".to_string()));
399 }
400
401 #[test]
402 fn test_relevance_feedback() {
403 let mut engine = QueryIntelligence::new();
404
405 engine.record_feedback("artificial_intelligence", 0.9);
406 engine.record_feedback("artificial_intelligence", 0.8);
407
408 let score = engine.get_relevance("artificial_intelligence");
409 assert!(score > 0.7);
410 }
411
412 #[test]
413 fn test_custom_synonyms() {
414 let mut engine = QueryIntelligence::new();
415 engine.add_synonym("AI", vec!["artificial intelligence".to_string(), "machine learning".to_string()]);
416
417 let query = "AI applications";
418 let result = engine.rewrite_query(query);
419
420 assert!(result.rewritten.contains("artificial") || result.rewritten.contains("machine"));
421 }
422}