1use regex::Regex;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum QueryIntent {
16 Temporal,
18 Causal,
20 Entity,
22 Factual,
24}
25
26#[derive(Debug, Clone)]
28pub struct IntentClassification {
29 pub intent: QueryIntent,
31 pub confidence: f32,
33 pub secondary: Vec<(QueryIntent, f32)>,
35 pub entity_focus: Option<String>,
37}
38
39pub struct IntentClassifier {
44 temporal_patterns: Vec<Regex>,
45 causal_patterns: Vec<Regex>,
46 entity_patterns: Vec<Regex>,
47 entity_list_patterns: Vec<Regex>,
49}
50
51impl IntentClassifier {
52 pub fn new() -> Self {
57 Self {
58 temporal_patterns: vec![
59 Regex::new(r"(?i)\b(yesterday|today|tomorrow)\b").unwrap(),
61 Regex::new(r"(?i)\b(recent|recently|latest|newest|oldest|earlier)\b").unwrap(),
62 Regex::new(r"(?i)\b(when|since|until|before|after)\b").unwrap(),
63
64 Regex::new(r"(?i)\b(last\s+week|next\s+week|this\s+week)\b").unwrap(),
66 Regex::new(r"(?i)\b(last\s+month|next\s+month|this\s+month)\b").unwrap(),
67 Regex::new(r"(?i)\b(last\s+year|next\s+year|this\s+year)\b").unwrap(),
68
69 Regex::new(r"(?i)\b(january|february|march|april|may|june|july|august|september|october|november|december)\b").unwrap(),
71
72 Regex::new(r"(?i)\b(monday|tuesday|wednesday|thursday|friday|saturday|sunday)\b").unwrap(),
74
75 Regex::new(r"(?i)\b(this\s+morning|this\s+afternoon|this\s+evening|tonight)\b").unwrap(),
77 Regex::new(r"(?i)\b(last\s+night|earlier\s+today)\b").unwrap(),
78
79 Regex::new(r"(?i)\b(\d+\s+(hours?|days?|weeks?|months?|years?)\s+ago)\b").unwrap(),
81
82 Regex::new(r"(?i)\b(past\s+(few|couple|several))\b").unwrap(),
84
85 Regex::new(r"(?i)^(show|list|get)\s+(recent|latest|newest)").unwrap(),
87 ],
88 causal_patterns: vec![
89 Regex::new(r"(?i)\b(why|because|cause[ds]?|reason|led\s+to|result\s+in)\b").unwrap(),
91 Regex::new(r"(?i)\b(what\s+caused|what\s+led\s+to|what.*resulted\s+in)\b").unwrap(),
92 Regex::new(r"(?i)\b(consequences?|impacts?|effects?|outcomes?)\b").unwrap(),
93 Regex::new(r"(?i)^why\s+").unwrap(),
94 ],
95 entity_patterns: vec![
96 Regex::new(r"(?i)\b(about|regarding|concerning|related\s+to)\s+[A-Z]").unwrap(),
98 Regex::new(r"(?i)\b(with|involving|mention|mentioning)\s+[A-Z]").unwrap(),
99 Regex::new(r"\b[A-Z][a-z]+\b").unwrap(), ],
101 entity_list_patterns: vec![
102 Regex::new(r"(?i)^what\s+does\s+(\w+)\s+(like|enjoy|prefer|want|need|love|hate|dislike)").unwrap(),
104 Regex::new(r"(?i)^what\s+(are|were)\s+(\w+)'?s\s+(hobbies|interests|activities|preferences|habits|routines)").unwrap(),
106 Regex::new(r"(?i)^(list|show|get|find)\s+(all|everything)\s+(about|for|regarding)\s+(\w+)").unwrap(),
108 Regex::new(r"(?i)^tell\s+me\s+(everything\s+)?(about|regarding)\s+(\w+)").unwrap(),
110 Regex::new(r"(?i)^what\s+do\s+(i|we)\s+know\s+about\s+(\w+)").unwrap(),
112 Regex::new(r"(?i)^(\w+)'?s\s+(hobbies|interests|activities|preferences|habits)").unwrap(),
114 ],
115 }
116 }
117
118 pub fn classify(&self, query: &str) -> IntentClassification {
138 let mut scores: HashMap<QueryIntent, f32> = HashMap::new();
139 scores.insert(QueryIntent::Temporal, 0.0);
140 scores.insert(QueryIntent::Causal, 0.0);
141 scores.insert(QueryIntent::Entity, 0.0);
142 scores.insert(QueryIntent::Factual, 0.3); let temporal_matches = self.count_matches(&self.temporal_patterns, query);
146 let causal_matches = self.count_matches(&self.causal_patterns, query);
147 let entity_matches = self.count_matches(&self.entity_patterns, query);
148 let entity_list_matches = self.count_matches(&self.entity_list_patterns, query);
149
150 if temporal_matches > 0 {
152 *scores.get_mut(&QueryIntent::Temporal).unwrap() =
153 (temporal_matches as f32 * 0.4).min(1.0);
154 }
155
156 if causal_matches > 0 {
157 *scores.get_mut(&QueryIntent::Causal).unwrap() = (causal_matches as f32 * 0.5).min(1.0);
158 }
159
160 if entity_matches > 0 {
161 *scores.get_mut(&QueryIntent::Entity).unwrap() = (entity_matches as f32 * 0.2).min(0.8);
163 }
164
165 if entity_list_matches > 0 {
167 let current_score = *scores.get(&QueryIntent::Entity).unwrap();
169 *scores.get_mut(&QueryIntent::Entity).unwrap() =
170 (current_score + entity_list_matches as f32 * 0.6).min(1.0);
171 }
172
173 let mut intent_vec: Vec<_> = scores.iter().collect();
175 intent_vec.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
176
177 let primary_intent = *intent_vec[0].0;
178 let primary_confidence = *intent_vec[0].1;
179
180 let mut secondary = Vec::new();
182 for (intent, score) in intent_vec.iter().skip(1) {
183 if **score > 0.3 {
184 secondary.push((**intent, **score));
185 }
186 }
187
188 let entity_focus = if primary_intent == QueryIntent::Entity {
190 self.extract_entity_from_query(query)
191 } else {
192 None
193 };
194
195 IntentClassification {
196 intent: primary_intent,
197 confidence: primary_confidence,
198 secondary,
199 entity_focus,
200 }
201 }
202
203 fn count_matches(&self, patterns: &[Regex], query: &str) -> usize {
205 patterns.iter().filter(|p| p.is_match(query)).count()
206 }
207
208 pub fn extract_entity_from_query(&self, query: &str) -> Option<String> {
231 for pattern in &self.entity_list_patterns {
232 if let Some(captures) = pattern.captures(query) {
233 for i in 1..captures.len() {
236 if let Some(capture) = captures.get(i) {
237 let text = capture.as_str().trim();
238 if !text.is_empty() && !Self::is_common_word(text) && text.len() > 1 {
240 let mut chars = text.chars();
242 if let Some(first) = chars.next() {
243 let capitalized =
244 first.to_uppercase().collect::<String>() + chars.as_str();
245 return Some(capitalized);
246 }
247 }
248 }
249 }
250 }
251 }
252 None
253 }
254
255 fn is_common_word(word: &str) -> bool {
257 const COMMON_WORDS: &[&str] = &[
258 "what",
259 "does",
260 "do",
261 "are",
262 "were",
263 "is",
264 "was",
265 "the",
266 "a",
267 "an",
268 "like",
269 "enjoy",
270 "prefer",
271 "want",
272 "need",
273 "love",
274 "hate",
275 "dislike",
276 "all",
277 "everything",
278 "about",
279 "for",
280 "regarding",
281 "list",
282 "show",
283 "get",
284 "find",
285 "tell",
286 "me",
287 "i",
288 "we",
289 "know",
290 "hobbies",
291 "interests",
292 "activities",
293 "preferences",
294 "habits",
295 "routines",
296 "everything ",
297 ];
298 COMMON_WORDS.contains(&word.to_lowercase().as_str())
299 }
300}
301
302impl Default for IntentClassifier {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_temporal_intent() {
314 let classifier = IntentClassifier::new();
315
316 let result = classifier.classify("What happened yesterday?");
317 assert_eq!(result.intent, QueryIntent::Temporal);
318 assert!(result.confidence > 0.3);
319
320 let result = classifier.classify("Show me recent memories");
321 assert_eq!(result.intent, QueryIntent::Temporal);
322
323 let result = classifier.classify("What did I do last week?");
324 assert_eq!(result.intent, QueryIntent::Temporal);
325 }
326
327 #[test]
328 fn test_causal_intent() {
329 let classifier = IntentClassifier::new();
330
331 let result = classifier.classify("Why was the meeting cancelled?");
332 assert_eq!(result.intent, QueryIntent::Causal);
333 assert!(result.confidence > 0.4);
334
335 let result = classifier.classify("What caused the server crash?");
336 assert_eq!(result.intent, QueryIntent::Causal);
337
338 let result = classifier.classify("What led to the project delay?");
339 assert_eq!(result.intent, QueryIntent::Causal);
340 }
341
342 #[test]
343 fn test_entity_intent() {
344 let classifier = IntentClassifier::new();
345
346 let result = classifier.classify("Show me memories about Alice");
347 assert_eq!(result.intent, QueryIntent::Entity);
348 assert!(result.confidence > 0.1);
349
350 let result = classifier.classify("What do I know about Project Alpha?");
351 assert_eq!(result.intent, QueryIntent::Entity);
352 }
353
354 #[test]
355 fn test_factual_intent() {
356 let classifier = IntentClassifier::new();
357
358 let result = classifier.classify("machine learning techniques");
359 assert_eq!(result.intent, QueryIntent::Factual);
360
361 let result = classifier.classify("database optimization");
362 assert_eq!(result.intent, QueryIntent::Factual);
363 }
364
365 #[test]
366 fn test_mixed_intent() {
367 let classifier = IntentClassifier::new();
368
369 let result = classifier.classify("What did Alice do yesterday?");
371 assert!(result.secondary.len() > 0 || result.intent == QueryIntent::Temporal);
373
374 let result = classifier.classify("Why did the meeting get cancelled last week?");
376 assert_eq!(result.intent, QueryIntent::Causal);
378 }
379
380 #[test]
381 fn test_confidence_scores() {
382 let classifier = IntentClassifier::new();
383
384 let result = classifier.classify("Why why why");
385 assert!(result.confidence > 0.5);
387
388 let result = classifier.classify("yesterday recent latest");
389 assert_eq!(result.intent, QueryIntent::Temporal);
391 assert!(result.confidence > 0.4);
392 }
393
394 #[test]
396 fn test_entity_extraction_what_does_pattern() {
397 let classifier = IntentClassifier::new();
398
399 let entity = classifier.extract_entity_from_query("What does Alice like?");
400 assert_eq!(entity, Some("Alice".to_string()));
401
402 let entity = classifier.extract_entity_from_query("what does bob enjoy doing");
403 assert_eq!(entity, Some("Bob".to_string()));
404
405 let entity = classifier.extract_entity_from_query("What does Charlie prefer?");
406 assert_eq!(entity, Some("Charlie".to_string()));
407 }
408
409 #[test]
410 fn test_entity_extraction_possessive_pattern() {
411 let classifier = IntentClassifier::new();
412
413 let entity = classifier.extract_entity_from_query("What are Alice's hobbies?");
414 assert_eq!(entity, Some("Alice".to_string()));
415
416 let entity = classifier.extract_entity_from_query("what were bob's interests");
417 assert_eq!(entity, Some("Bob".to_string()));
418
419 let entity = classifier.extract_entity_from_query("Charlie's activities");
420 assert_eq!(entity, Some("Charlie".to_string()));
421 }
422
423 #[test]
424 fn test_entity_extraction_list_pattern() {
425 let classifier = IntentClassifier::new();
426
427 let entity = classifier.extract_entity_from_query("List all about Alice");
428 assert_eq!(entity, Some("Alice".to_string()));
429
430 let entity = classifier.extract_entity_from_query("show everything about Project");
431 assert_eq!(entity, Some("Project".to_string()));
432
433 let entity = classifier.extract_entity_from_query("Tell me everything about Bob");
434 assert_eq!(entity, Some("Bob".to_string()));
435 }
436
437 #[test]
438 fn test_entity_extraction_know_pattern() {
439 let classifier = IntentClassifier::new();
440
441 let entity = classifier.extract_entity_from_query("What do I know about Alice?");
442 assert_eq!(entity, Some("Alice".to_string()));
443
444 let entity = classifier.extract_entity_from_query("What do we know about system");
445 assert_eq!(entity, Some("System".to_string()));
446 }
447
448 #[test]
449 fn test_entity_extraction_no_match() {
450 let classifier = IntentClassifier::new();
451
452 let entity = classifier.extract_entity_from_query("What happened yesterday?");
454 assert_eq!(entity, None);
455
456 let entity = classifier.extract_entity_from_query("Why was it cancelled?");
457 assert_eq!(entity, None);
458
459 let entity = classifier.extract_entity_from_query("machine learning techniques");
460 assert_eq!(entity, None);
461 }
462
463 #[test]
464 fn test_entity_focus_in_classification() {
465 let classifier = IntentClassifier::new();
466
467 let result = classifier.classify("What does Alice like?");
469 assert_eq!(result.intent, QueryIntent::Entity);
470 assert_eq!(result.entity_focus, Some("Alice".to_string()));
471
472 let result = classifier.classify("What happened yesterday?");
474 assert_eq!(result.intent, QueryIntent::Temporal);
475 assert_eq!(result.entity_focus, None);
476 }
477}