1use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13
14pub struct QueryIntelligence {
16 synonyms: HashMap<String, Vec<String>>,
17 templates: Vec<QueryTemplate>,
18 stop_words: HashSet<String>,
19 relevance_scores: HashMap<String, f32>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct QueryTemplate {
25 pub pattern: String,
27 pub query_type: QueryType,
29 pub rewrite: String,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub enum QueryType {
36 EntityLookup,
38 Relationship,
40 Aggregation,
42 Comparison,
44 Temporal,
46 Causal,
48 General,
50}
51
52#[derive(Debug, Clone)]
54pub struct RewrittenQuery {
55 pub original: String,
57 pub rewritten: String,
59 pub query_type: QueryType,
61 pub expanded_terms: Vec<String>,
63 pub confidence: f32,
65}
66
67impl QueryIntelligence {
68 pub fn new() -> Self {
70 let mut engine = Self {
71 synonyms: HashMap::new(),
72 templates: Vec::new(),
73 stop_words: HashSet::new(),
74 relevance_scores: HashMap::new(),
75 };
76
77 engine.load_default_synonyms();
79 engine.load_default_templates();
80 engine.load_default_stop_words();
81
82 engine
83 }
84
85 pub fn rewrite_query(&self, query: &str) -> RewrittenQuery {
93 let normalized = self.normalize_query(query);
95
96 let query_type = self.detect_query_type(&normalized);
98
99 let template_rewritten = self.apply_templates(&normalized, &query_type);
101
102 let expanded = self.expand_synonyms(&template_rewritten);
104
105 let expanded_terms = self.extract_key_terms(&expanded);
107
108 let confidence = self.calculate_confidence(&normalized, &expanded_terms);
110
111 RewrittenQuery {
112 original: query.to_string(),
113 rewritten: expanded,
114 query_type,
115 expanded_terms,
116 confidence,
117 }
118 }
119
120 pub fn add_synonym(&mut self, term: impl Into<String>, synonyms: Vec<String>) {
126 self.synonyms.insert(term.into().to_lowercase(), synonyms);
128 }
129
130 pub fn add_template(&mut self, template: QueryTemplate) {
135 self.templates.push(template);
136 }
137
138 pub fn record_feedback(&mut self, term: impl Into<String>, score: f32) {
144 let term = term.into();
145 let current_score = self.relevance_scores.get(&term).unwrap_or(&0.5);
146 let new_score = current_score * 0.5 + score * 0.5;
148 self.relevance_scores.insert(term, new_score);
149 }
150
151 pub fn get_relevance(&self, term: &str) -> f32 {
159 *self.relevance_scores.get(term).unwrap_or(&0.5)
160 }
161
162 fn normalize_query(&self, query: &str) -> String {
166 query.trim().to_lowercase()
167 }
168
169 fn detect_query_type(&self, query: &str) -> QueryType {
171 let query_lower = query.to_lowercase();
172
173 if query_lower.contains("relationship between")
175 || query_lower.contains("how does")
176 || query_lower.contains("related to")
177 || query_lower.contains("connection between")
178 {
179 return QueryType::Relationship;
180 }
181
182 if query_lower.starts_with("who is")
184 || query_lower.starts_with("what is")
185 || query_lower.starts_with("define")
186 {
187 return QueryType::EntityLookup;
188 }
189
190 if query_lower.starts_with("how many")
192 || query_lower.starts_with("count")
193 || query_lower.contains("total")
194 || query_lower.contains("sum")
195 || query_lower.contains("average")
196 {
197 return QueryType::Aggregation;
198 }
199
200 if query_lower.contains("compare")
202 || query_lower.contains("difference between")
203 || query_lower.contains("versus")
204 || query_lower.contains("vs")
205 {
206 return QueryType::Comparison;
207 }
208
209 if query_lower.contains("when")
211 || query_lower.contains("before")
212 || query_lower.contains("after")
213 || query_lower.contains("during")
214 || query_lower.contains("timeline")
215 {
216 return QueryType::Temporal;
217 }
218
219 if query_lower.contains("why")
221 || query_lower.contains("because")
222 || query_lower.contains("cause")
223 || query_lower.contains("reason")
224 || query_lower.contains("led to")
225 {
226 return QueryType::Causal;
227 }
228
229 QueryType::General
230 }
231
232 fn apply_templates(&self, query: &str, query_type: &QueryType) -> String {
234 for template in &self.templates {
235 if &template.query_type == query_type && query.contains(&template.pattern) {
236 return query.replace(&template.pattern, &template.rewrite);
237 }
238 }
239 query.to_string()
240 }
241
242 fn expand_synonyms(&self, query: &str) -> String {
244 let words: Vec<&str> = query.split_whitespace().collect();
245 let mut expanded_words = Vec::new();
246
247 for word in words {
248 expanded_words.push(word.to_string());
249
250 if let Some(synonyms) = self.synonyms.get(word) {
252 for synonym in synonyms {
253 if !expanded_words.contains(synonym) {
254 expanded_words.push(synonym.clone());
255 }
256 }
257 }
258 }
259
260 expanded_words.join(" ")
261 }
262
263 fn extract_key_terms(&self, query: &str) -> Vec<String> {
265 query
266 .split_whitespace()
267 .filter(|word| !self.stop_words.contains(*word))
268 .map(|s| s.to_string())
269 .collect()
270 }
271
272 fn calculate_confidence(&self, query: &str, expanded_terms: &[String]) -> f32 {
274 if expanded_terms.is_empty() {
275 return 0.5;
276 }
277
278 let word_count = query.split_whitespace().count() as f32;
280 let term_count = expanded_terms.len() as f32;
281
282 let specificity_score = (term_count / (word_count + 1.0)).min(1.0);
284
285 let relevance_score: f32 = expanded_terms
287 .iter()
288 .map(|t| self.get_relevance(t))
289 .sum::<f32>()
290 / term_count;
291
292 specificity_score * 0.6 + relevance_score * 0.4
294 }
295
296 fn load_default_synonyms(&mut self) {
298 self.add_synonym("find", vec!["search".to_string(), "locate".to_string()]);
300 self.add_synonym(
301 "person",
302 vec!["individual".to_string(), "people".to_string()],
303 );
304 self.add_synonym(
305 "company",
306 vec![
307 "organization".to_string(),
308 "business".to_string(),
309 "firm".to_string(),
310 ],
311 );
312 self.add_synonym("show", vec!["display".to_string(), "present".to_string()]);
313 self.add_synonym("get", vec!["retrieve".to_string(), "fetch".to_string()]);
314 self.add_synonym(
315 "large",
316 vec![
317 "big".to_string(),
318 "huge".to_string(),
319 "significant".to_string(),
320 ],
321 );
322 self.add_synonym("small", vec!["tiny".to_string(), "minor".to_string()]);
323 self.add_synonym(
324 "important",
325 vec![
326 "significant".to_string(),
327 "critical".to_string(),
328 "key".to_string(),
329 ],
330 );
331 }
332
333 fn load_default_templates(&mut self) {
335 self.add_template(QueryTemplate {
336 pattern: "who is".to_string(),
337 query_type: QueryType::EntityLookup,
338 rewrite: "entity:".to_string(),
339 });
340
341 self.add_template(QueryTemplate {
342 pattern: "what is".to_string(),
343 query_type: QueryType::EntityLookup,
344 rewrite: "define:".to_string(),
345 });
346
347 self.add_template(QueryTemplate {
348 pattern: "how many".to_string(),
349 query_type: QueryType::Aggregation,
350 rewrite: "count:".to_string(),
351 });
352
353 self.add_template(QueryTemplate {
354 pattern: "compare".to_string(),
355 query_type: QueryType::Comparison,
356 rewrite: "compare:".to_string(),
357 });
358 }
359
360 fn load_default_stop_words(&mut self) {
362 let stop_words = vec![
363 "a", "an", "and", "are", "as", "at", "be", "by", "for", "from", "has", "he", "in",
364 "is", "it", "its", "of", "on", "that", "the", "to", "was", "will", "with",
365 ];
366
367 for word in stop_words {
368 self.stop_words.insert(word.to_string());
369 }
370 }
371}
372
373impl Default for QueryIntelligence {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_query_type_detection() {
385 let engine = QueryIntelligence::new();
386
387 let query = "who is the CEO of OpenAI?";
388 let result = engine.rewrite_query(query);
389 assert_eq!(result.query_type, QueryType::EntityLookup);
390
391 let query = "how many employees work at Google?";
392 let result = engine.rewrite_query(query);
393 assert_eq!(result.query_type, QueryType::Aggregation);
394
395 let query = "what is the relationship between Apple and Microsoft?";
396 let result = engine.rewrite_query(query);
397 assert_eq!(result.query_type, QueryType::Relationship);
398 }
399
400 #[test]
401 fn test_synonym_expansion() {
402 let engine = QueryIntelligence::new();
403
404 let query = "find large companies";
405 let result = engine.rewrite_query(query);
406
407 assert!(
409 result.expanded_terms.contains(&"search".to_string())
410 || result.expanded_terms.contains(&"big".to_string())
411 );
412 }
413
414 #[test]
415 fn test_stop_word_removal() {
416 let engine = QueryIntelligence::new();
417
418 let query = "what is the best approach";
419 let result = engine.rewrite_query(query);
420
421 assert!(!result.expanded_terms.contains(&"the".to_string()));
423 assert!(!result.expanded_terms.contains(&"is".to_string()));
424 }
425
426 #[test]
427 fn test_relevance_feedback() {
428 let mut engine = QueryIntelligence::new();
429
430 engine.record_feedback("artificial_intelligence", 0.9);
431 engine.record_feedback("artificial_intelligence", 0.8);
432
433 let score = engine.get_relevance("artificial_intelligence");
434 assert!(score > 0.7);
435 }
436
437 #[test]
438 fn test_custom_synonyms() {
439 let mut engine = QueryIntelligence::new();
440 engine.add_synonym(
441 "AI",
442 vec![
443 "artificial intelligence".to_string(),
444 "machine learning".to_string(),
445 ],
446 );
447
448 let query = "AI applications";
449 let result = engine.rewrite_query(query);
450
451 assert!(result.rewritten.contains("artificial") || result.rewritten.contains("machine"));
452 }
453}