Skip to main content

engram/intelligence/
natural_language.rs

1//! Natural Language Commands (RML-893)
2//!
3//! Parses natural language input into structured commands.
4
5use crate::types::{EdgeType, MemoryType};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Types of commands that can be parsed
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum CommandType {
14    /// Create a new memory
15    Create,
16    /// Search for memories
17    Search,
18    /// Update a memory
19    Update,
20    /// Delete a memory
21    Delete,
22    /// Link two memories
23    Link,
24    /// List memories
25    List,
26    /// Show statistics
27    Stats,
28    /// Get help
29    Help,
30    /// Unknown command
31    Unknown,
32}
33
34/// A parsed command with extracted parameters
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ParsedCommand {
37    /// Type of command
38    pub command_type: CommandType,
39    /// Main content or query
40    pub content: Option<String>,
41    /// Target memory ID (for update/delete)
42    pub target_id: Option<i64>,
43    /// Memory type (for create/search)
44    pub memory_type: Option<MemoryType>,
45    /// Tags extracted
46    pub tags: Vec<String>,
47    /// Edge type (for link)
48    pub edge_type: Option<EdgeType>,
49    /// Related memory ID (for link)
50    pub related_id: Option<i64>,
51    /// Date/time filter
52    pub date_filter: Option<DateFilter>,
53    /// Limit for results
54    pub limit: Option<i64>,
55    /// Original input
56    pub original_input: String,
57    /// Confidence in parsing (0.0 - 1.0)
58    pub confidence: f32,
59    /// Additional parameters
60    pub params: HashMap<String, String>,
61}
62
63/// Date filter for queries
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct DateFilter {
66    pub after: Option<DateTime<Utc>>,
67    pub before: Option<DateTime<Utc>>,
68}
69
70/// Natural language command parser
71pub struct NaturalLanguageParser {
72    /// Keywords that indicate create intent
73    create_keywords: Vec<&'static str>,
74    /// Keywords that indicate search intent
75    search_keywords: Vec<&'static str>,
76    /// Keywords that indicate delete intent
77    delete_keywords: Vec<&'static str>,
78    /// Keywords that indicate link intent
79    link_keywords: Vec<&'static str>,
80    /// Keywords that indicate list intent
81    list_keywords: Vec<&'static str>,
82}
83
84impl Default for NaturalLanguageParser {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl NaturalLanguageParser {
91    /// Create a new parser
92    pub fn new() -> Self {
93        Self {
94            create_keywords: vec![
95                "remember",
96                "save",
97                "store",
98                "create",
99                "add",
100                "note",
101                "record",
102                "keep",
103                "memorize",
104                "write down",
105                "jot down",
106                "make a note",
107            ],
108            search_keywords: vec![
109                "find", "search", "look for", "what", "where", "when", "show me", "get",
110                "retrieve", "recall", "fetch", "query", "lookup",
111            ],
112            delete_keywords: vec!["delete", "remove", "forget", "erase", "discard", "drop"],
113            link_keywords: vec!["link", "connect", "relate", "associate", "reference"],
114            list_keywords: vec!["list", "show all", "display", "enumerate", "browse"],
115        }
116    }
117
118    /// Parse a natural language input into a command
119    pub fn parse(&self, input: &str) -> ParsedCommand {
120        let input_lower = input.to_lowercase();
121        let input_trimmed = input.trim();
122
123        // Detect command type
124        let (command_type, confidence) = self.detect_command_type(&input_lower);
125
126        // Extract content
127        let content = self.extract_content(input_trimmed, &command_type);
128
129        // Extract tags
130        let tags = self.extract_tags(&input_lower);
131
132        // Extract memory type
133        let memory_type = self.extract_memory_type(&input_lower);
134
135        // Extract IDs
136        let (target_id, related_id) = self.extract_ids(&input_lower);
137
138        // Extract edge type
139        let edge_type = self.extract_edge_type(&input_lower);
140
141        // Extract date filter
142        let date_filter = self.extract_date_filter(&input_lower);
143
144        // Extract limit
145        let limit = self.extract_limit(&input_lower);
146
147        ParsedCommand {
148            command_type,
149            content,
150            target_id,
151            memory_type,
152            tags,
153            edge_type,
154            related_id,
155            date_filter,
156            limit,
157            original_input: input.to_string(),
158            confidence,
159            params: HashMap::new(),
160        }
161    }
162
163    /// Detect the type of command from input
164    fn detect_command_type(&self, input: &str) -> (CommandType, f32) {
165        // Check for create intent
166        for keyword in &self.create_keywords {
167            if input.contains(keyword) {
168                return (CommandType::Create, 0.9);
169            }
170        }
171
172        // Check for search intent
173        for keyword in &self.search_keywords {
174            if input.contains(keyword) {
175                return (CommandType::Search, 0.85);
176            }
177        }
178
179        // Check for delete intent
180        for keyword in &self.delete_keywords {
181            if input.contains(keyword) {
182                return (CommandType::Delete, 0.9);
183            }
184        }
185
186        // Check for link intent
187        for keyword in &self.link_keywords {
188            if input.contains(keyword) {
189                return (CommandType::Link, 0.85);
190            }
191        }
192
193        // Check for list intent
194        for keyword in &self.list_keywords {
195            if input.contains(keyword) {
196                return (CommandType::List, 0.85);
197            }
198        }
199
200        // Check for stats
201        if input.contains("stat") || input.contains("count") || input.contains("how many") {
202            return (CommandType::Stats, 0.8);
203        }
204
205        // Check for help
206        if input.contains("help") || input.contains("how to") || input.contains("usage") {
207            return (CommandType::Help, 0.9);
208        }
209
210        // Default to search if it looks like a question
211        if input.ends_with('?') || input.starts_with("what") || input.starts_with("how") {
212            return (CommandType::Search, 0.6);
213        }
214
215        // Unknown
216        (CommandType::Unknown, 0.3)
217    }
218
219    /// Extract main content from input
220    fn extract_content(&self, input: &str, command_type: &CommandType) -> Option<String> {
221        // Remove command keywords to get content
222        let patterns_to_remove: &[&str] = match command_type {
223            CommandType::Create => &[
224                "remember that",
225                "remember:",
226                "save:",
227                "note:",
228                "add:",
229                "create:",
230                "remember",
231                "save",
232                "note",
233                "add",
234                "create",
235                "please",
236                "can you",
237            ],
238            CommandType::Search => &[
239                "find",
240                "search for",
241                "search",
242                "look for",
243                "show me",
244                "get",
245                "what is",
246                "what are",
247                "where is",
248                "when did",
249                "please",
250                "can you",
251            ],
252            CommandType::Delete => &["delete", "remove", "forget", "erase", "please", "can you"],
253            _ => &["please", "can you"],
254        };
255
256        let mut content = input.to_string();
257        for pattern in patterns_to_remove {
258            content = content.replace(pattern, "");
259            // Also try with capital first letter
260            let capitalized = pattern
261                .chars()
262                .next()
263                .map(|c| c.to_uppercase().to_string() + &pattern[1..])
264                .unwrap_or_default();
265            content = content.replace(&capitalized, "");
266        }
267
268        let content = content.trim().to_string();
269        if content.is_empty() {
270            None
271        } else {
272            Some(content)
273        }
274    }
275
276    /// Extract tags from input
277    fn extract_tags(&self, input: &str) -> Vec<String> {
278        let mut tags = Vec::new();
279
280        // Look for #hashtags
281        for word in input.split_whitespace() {
282            if word.starts_with('#') {
283                let tag = word
284                    .trim_start_matches('#')
285                    .trim_matches(|c: char| !c.is_alphanumeric());
286                if !tag.is_empty() {
287                    tags.push(tag.to_string());
288                }
289            }
290        }
291
292        // Look for "tag:" or "tags:" pattern
293        if let Some(pos) = input.find("tag:") {
294            let rest = &input[pos + 4..];
295            for word in rest.split_whitespace() {
296                if word.chars().all(|c| c.is_alphanumeric() || c == ',') {
297                    for tag in word.split(',') {
298                        let tag = tag.trim();
299                        if !tag.is_empty() {
300                            tags.push(tag.to_string());
301                        }
302                    }
303                    break;
304                }
305            }
306        }
307
308        tags
309    }
310
311    /// Extract memory type from input
312    fn extract_memory_type(&self, input: &str) -> Option<MemoryType> {
313        if input.contains("todo") || input.contains("task") {
314            Some(MemoryType::Todo)
315        } else if input.contains("decision") || input.contains("decided") {
316            Some(MemoryType::Decision)
317        } else if input.contains("issue") || input.contains("bug") || input.contains("problem") {
318            Some(MemoryType::Issue)
319        } else if input.contains("preference") || input.contains("prefer") {
320            Some(MemoryType::Preference)
321        } else if input.contains("learn") || input.contains("til") {
322            Some(MemoryType::Learning)
323        } else if input.contains("context") || input.contains("background") {
324            Some(MemoryType::Context)
325        } else {
326            None
327        }
328    }
329
330    /// Extract memory IDs from input
331    fn extract_ids(&self, input: &str) -> (Option<i64>, Option<i64>) {
332        let mut ids: Vec<i64> = Vec::new();
333
334        // Look for patterns like "memory 123", "#123", "id 123", or "id:123"
335        let patterns = ["memory ", "id ", "id:", "#"];
336
337        for pattern in patterns {
338            if let Some(pos) = input.find(pattern) {
339                let rest = &input[pos + pattern.len()..];
340                let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
341                if let Ok(id) = num_str.parse::<i64>() {
342                    ids.push(id);
343                }
344            }
345        }
346
347        // Also look for standalone numbers that might be IDs
348        for word in input.split_whitespace() {
349            if let Ok(id) = word.parse::<i64>() {
350                if id > 0 && !ids.contains(&id) {
351                    ids.push(id);
352                }
353            }
354        }
355
356        match ids.len() {
357            0 => (None, None),
358            1 => (Some(ids[0]), None),
359            _ => (Some(ids[0]), Some(ids[1])),
360        }
361    }
362
363    /// Extract edge type from input
364    fn extract_edge_type(&self, input: &str) -> Option<EdgeType> {
365        if input.contains("supersede") || input.contains("replace") {
366            Some(EdgeType::Supersedes)
367        } else if input.contains("contradict") || input.contains("conflict") {
368            Some(EdgeType::Contradicts)
369        } else if input.contains("implement") {
370            Some(EdgeType::Implements)
371        } else if input.contains("extend") {
372            Some(EdgeType::Extends)
373        } else if input.contains("reference") || input.contains("refer") {
374            Some(EdgeType::References)
375        } else if input.contains("depend") || input.contains("require") {
376            Some(EdgeType::DependsOn)
377        } else if input.contains("block") {
378            Some(EdgeType::Blocks)
379        } else if input.contains("follow") {
380            Some(EdgeType::FollowsUp)
381        } else if input.contains("relate") || input.contains("link") {
382            Some(EdgeType::RelatedTo)
383        } else {
384            None
385        }
386    }
387
388    /// Extract date filter from input
389    fn extract_date_filter(&self, input: &str) -> Option<DateFilter> {
390        let mut after = None;
391        let mut before = None;
392
393        // Look for "last X days/weeks"
394        if input.contains("last") {
395            if let Some(days) = self.extract_duration_days(input) {
396                after = Some(Utc::now() - chrono::Duration::days(days));
397            }
398        }
399
400        // Look for "today", "yesterday", "this week"
401        if input.contains("today") {
402            let today = Utc::now().date_naive();
403            after = Some(today.and_hms_opt(0, 0, 0).unwrap().and_utc());
404        } else if input.contains("yesterday") {
405            let yesterday = Utc::now().date_naive() - chrono::Duration::days(1);
406            after = Some(yesterday.and_hms_opt(0, 0, 0).unwrap().and_utc());
407            before = Some(
408                Utc::now()
409                    .date_naive()
410                    .and_hms_opt(0, 0, 0)
411                    .unwrap()
412                    .and_utc(),
413            );
414        } else if input.contains("this week") {
415            after = Some(Utc::now() - chrono::Duration::days(7));
416        } else if input.contains("this month") {
417            after = Some(Utc::now() - chrono::Duration::days(30));
418        }
419
420        if after.is_some() || before.is_some() {
421            Some(DateFilter { after, before })
422        } else {
423            None
424        }
425    }
426
427    /// Extract duration in days from phrases like "last 7 days"
428    fn extract_duration_days(&self, input: &str) -> Option<i64> {
429        // Look for patterns like "last 7 days", "last week", "last month"
430        for word in input.split_whitespace() {
431            if let Ok(num) = word.parse::<i64>() {
432                if input.contains("day") {
433                    return Some(num);
434                } else if input.contains("week") {
435                    return Some(num * 7);
436                } else if input.contains("month") {
437                    return Some(num * 30);
438                }
439            }
440        }
441
442        // Handle special cases
443        if input.contains("last week") {
444            Some(7)
445        } else if input.contains("last month") {
446            Some(30)
447        } else if input.contains("last year") {
448            Some(365)
449        } else {
450            None
451        }
452    }
453
454    /// Extract result limit from input
455    fn extract_limit(&self, input: &str) -> Option<i64> {
456        // Look for patterns like "top 10", "first 5", "limit 20"
457        let patterns = ["top ", "first ", "limit "];
458
459        for pattern in patterns {
460            if let Some(pos) = input.find(pattern) {
461                let rest = &input[pos + pattern.len()..];
462                let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
463                if let Ok(limit) = num_str.parse::<i64>() {
464                    return Some(limit);
465                }
466            }
467        }
468
469        None
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_detect_create() {
479        let parser = NaturalLanguageParser::new();
480
481        let cmd = parser.parse("Remember that the API key is abc123");
482        assert_eq!(cmd.command_type, CommandType::Create);
483        assert!(cmd.content.is_some());
484        assert!(cmd.confidence > 0.8);
485    }
486
487    #[test]
488    fn test_detect_search() {
489        let parser = NaturalLanguageParser::new();
490
491        let cmd = parser.parse("Find all memories about authentication");
492        assert_eq!(cmd.command_type, CommandType::Search);
493        assert!(cmd.content.unwrap().contains("authentication"));
494    }
495
496    #[test]
497    fn test_extract_tags() {
498        let parser = NaturalLanguageParser::new();
499
500        let cmd = parser.parse("Save this note #important #work");
501        assert!(cmd.tags.contains(&"important".to_string()));
502        assert!(cmd.tags.contains(&"work".to_string()));
503    }
504
505    #[test]
506    fn test_extract_memory_type() {
507        let parser = NaturalLanguageParser::new();
508
509        let cmd = parser.parse("Add a todo: fix the bug");
510        assert_eq!(cmd.memory_type, Some(MemoryType::Todo));
511
512        let cmd = parser.parse("Record this decision: use JWT");
513        assert_eq!(cmd.memory_type, Some(MemoryType::Decision));
514    }
515
516    #[test]
517    fn test_extract_ids() {
518        let parser = NaturalLanguageParser::new();
519
520        let cmd = parser.parse("Link memory 123 to memory 456");
521        assert_eq!(cmd.target_id, Some(123));
522        assert_eq!(cmd.related_id, Some(456));
523    }
524
525    #[test]
526    fn test_extract_date_filter() {
527        let parser = NaturalLanguageParser::new();
528
529        let cmd = parser.parse("Find memories from last week");
530        assert!(cmd.date_filter.is_some());
531        assert!(cmd.date_filter.unwrap().after.is_some());
532    }
533
534    #[test]
535    fn test_extract_limit() {
536        let parser = NaturalLanguageParser::new();
537
538        let cmd = parser.parse("Show top 10 recent memories");
539        assert_eq!(cmd.limit, Some(10));
540    }
541
542    #[test]
543    fn test_question_as_search() {
544        let parser = NaturalLanguageParser::new();
545
546        let cmd = parser.parse("What is the database password?");
547        assert_eq!(cmd.command_type, CommandType::Search);
548    }
549}