1use crate::types::{EdgeType, MemoryType};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum CommandType {
14 Create,
16 Search,
18 Update,
20 Delete,
22 Link,
24 List,
26 Stats,
28 Help,
30 Unknown,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct ParsedCommand {
37 pub command_type: CommandType,
39 pub content: Option<String>,
41 pub target_id: Option<i64>,
43 pub memory_type: Option<MemoryType>,
45 pub tags: Vec<String>,
47 pub edge_type: Option<EdgeType>,
49 pub related_id: Option<i64>,
51 pub date_filter: Option<DateFilter>,
53 pub limit: Option<i64>,
55 pub original_input: String,
57 pub confidence: f32,
59 pub params: HashMap<String, String>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct DateFilter {
66 pub after: Option<DateTime<Utc>>,
67 pub before: Option<DateTime<Utc>>,
68}
69
70pub struct NaturalLanguageParser {
72 create_keywords: Vec<&'static str>,
74 search_keywords: Vec<&'static str>,
76 delete_keywords: Vec<&'static str>,
78 link_keywords: Vec<&'static str>,
80 list_keywords: Vec<&'static str>,
82}
83
84impl Default for NaturalLanguageParser {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl NaturalLanguageParser {
91 pub fn new() -> Self {
93 Self {
94 create_keywords: vec![
95 "remember",
96 "save",
97 "store",
98 "create",
99 "add",
100 "note",
101 "record",
102 "keep",
103 "memorize",
104 "write down",
105 "jot down",
106 "make a note",
107 ],
108 search_keywords: vec![
109 "find", "search", "look for", "what", "where", "when", "show me", "get",
110 "retrieve", "recall", "fetch", "query", "lookup",
111 ],
112 delete_keywords: vec!["delete", "remove", "forget", "erase", "discard", "drop"],
113 link_keywords: vec!["link", "connect", "relate", "associate", "reference"],
114 list_keywords: vec!["list", "show all", "display", "enumerate", "browse"],
115 }
116 }
117
118 pub fn parse(&self, input: &str) -> ParsedCommand {
120 let input_lower = input.to_lowercase();
121 let input_trimmed = input.trim();
122
123 let (command_type, confidence) = self.detect_command_type(&input_lower);
125
126 let content = self.extract_content(input_trimmed, &command_type);
128
129 let tags = self.extract_tags(&input_lower);
131
132 let memory_type = self.extract_memory_type(&input_lower);
134
135 let (target_id, related_id) = self.extract_ids(&input_lower);
137
138 let edge_type = self.extract_edge_type(&input_lower);
140
141 let date_filter = self.extract_date_filter(&input_lower);
143
144 let limit = self.extract_limit(&input_lower);
146
147 ParsedCommand {
148 command_type,
149 content,
150 target_id,
151 memory_type,
152 tags,
153 edge_type,
154 related_id,
155 date_filter,
156 limit,
157 original_input: input.to_string(),
158 confidence,
159 params: HashMap::new(),
160 }
161 }
162
163 fn detect_command_type(&self, input: &str) -> (CommandType, f32) {
165 for keyword in &self.create_keywords {
167 if input.contains(keyword) {
168 return (CommandType::Create, 0.9);
169 }
170 }
171
172 for keyword in &self.search_keywords {
174 if input.contains(keyword) {
175 return (CommandType::Search, 0.85);
176 }
177 }
178
179 for keyword in &self.delete_keywords {
181 if input.contains(keyword) {
182 return (CommandType::Delete, 0.9);
183 }
184 }
185
186 for keyword in &self.link_keywords {
188 if input.contains(keyword) {
189 return (CommandType::Link, 0.85);
190 }
191 }
192
193 for keyword in &self.list_keywords {
195 if input.contains(keyword) {
196 return (CommandType::List, 0.85);
197 }
198 }
199
200 if input.contains("stat") || input.contains("count") || input.contains("how many") {
202 return (CommandType::Stats, 0.8);
203 }
204
205 if input.contains("help") || input.contains("how to") || input.contains("usage") {
207 return (CommandType::Help, 0.9);
208 }
209
210 if input.ends_with('?') || input.starts_with("what") || input.starts_with("how") {
212 return (CommandType::Search, 0.6);
213 }
214
215 (CommandType::Unknown, 0.3)
217 }
218
219 fn extract_content(&self, input: &str, command_type: &CommandType) -> Option<String> {
221 let patterns_to_remove: &[&str] = match command_type {
223 CommandType::Create => &[
224 "remember that",
225 "remember:",
226 "save:",
227 "note:",
228 "add:",
229 "create:",
230 "remember",
231 "save",
232 "note",
233 "add",
234 "create",
235 "please",
236 "can you",
237 ],
238 CommandType::Search => &[
239 "find",
240 "search for",
241 "search",
242 "look for",
243 "show me",
244 "get",
245 "what is",
246 "what are",
247 "where is",
248 "when did",
249 "please",
250 "can you",
251 ],
252 CommandType::Delete => &["delete", "remove", "forget", "erase", "please", "can you"],
253 _ => &["please", "can you"],
254 };
255
256 let mut content = input.to_string();
257 for pattern in patterns_to_remove {
258 content = content.replace(pattern, "");
259 let capitalized = pattern
261 .chars()
262 .next()
263 .map(|c| c.to_uppercase().to_string() + &pattern[1..])
264 .unwrap_or_default();
265 content = content.replace(&capitalized, "");
266 }
267
268 let content = content.trim().to_string();
269 if content.is_empty() {
270 None
271 } else {
272 Some(content)
273 }
274 }
275
276 fn extract_tags(&self, input: &str) -> Vec<String> {
278 let mut tags = Vec::new();
279
280 for word in input.split_whitespace() {
282 if word.starts_with('#') {
283 let tag = word
284 .trim_start_matches('#')
285 .trim_matches(|c: char| !c.is_alphanumeric());
286 if !tag.is_empty() {
287 tags.push(tag.to_string());
288 }
289 }
290 }
291
292 if let Some(pos) = input.find("tag:") {
294 let rest = &input[pos + 4..];
295 for word in rest.split_whitespace() {
296 if word.chars().all(|c| c.is_alphanumeric() || c == ',') {
297 for tag in word.split(',') {
298 let tag = tag.trim();
299 if !tag.is_empty() {
300 tags.push(tag.to_string());
301 }
302 }
303 break;
304 }
305 }
306 }
307
308 tags
309 }
310
311 fn extract_memory_type(&self, input: &str) -> Option<MemoryType> {
313 if input.contains("todo") || input.contains("task") {
314 Some(MemoryType::Todo)
315 } else if input.contains("decision") || input.contains("decided") {
316 Some(MemoryType::Decision)
317 } else if input.contains("issue") || input.contains("bug") || input.contains("problem") {
318 Some(MemoryType::Issue)
319 } else if input.contains("preference") || input.contains("prefer") {
320 Some(MemoryType::Preference)
321 } else if input.contains("learn") || input.contains("til") {
322 Some(MemoryType::Learning)
323 } else if input.contains("context") || input.contains("background") {
324 Some(MemoryType::Context)
325 } else {
326 None
327 }
328 }
329
330 fn extract_ids(&self, input: &str) -> (Option<i64>, Option<i64>) {
332 let mut ids: Vec<i64> = Vec::new();
333
334 let patterns = ["memory ", "id ", "id:", "#"];
336
337 for pattern in patterns {
338 if let Some(pos) = input.find(pattern) {
339 let rest = &input[pos + pattern.len()..];
340 let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
341 if let Ok(id) = num_str.parse::<i64>() {
342 ids.push(id);
343 }
344 }
345 }
346
347 for word in input.split_whitespace() {
349 if let Ok(id) = word.parse::<i64>() {
350 if id > 0 && !ids.contains(&id) {
351 ids.push(id);
352 }
353 }
354 }
355
356 match ids.len() {
357 0 => (None, None),
358 1 => (Some(ids[0]), None),
359 _ => (Some(ids[0]), Some(ids[1])),
360 }
361 }
362
363 fn extract_edge_type(&self, input: &str) -> Option<EdgeType> {
365 if input.contains("supersede") || input.contains("replace") {
366 Some(EdgeType::Supersedes)
367 } else if input.contains("contradict") || input.contains("conflict") {
368 Some(EdgeType::Contradicts)
369 } else if input.contains("implement") {
370 Some(EdgeType::Implements)
371 } else if input.contains("extend") {
372 Some(EdgeType::Extends)
373 } else if input.contains("reference") || input.contains("refer") {
374 Some(EdgeType::References)
375 } else if input.contains("depend") || input.contains("require") {
376 Some(EdgeType::DependsOn)
377 } else if input.contains("block") {
378 Some(EdgeType::Blocks)
379 } else if input.contains("follow") {
380 Some(EdgeType::FollowsUp)
381 } else if input.contains("relate") || input.contains("link") {
382 Some(EdgeType::RelatedTo)
383 } else {
384 None
385 }
386 }
387
388 fn extract_date_filter(&self, input: &str) -> Option<DateFilter> {
390 let mut after = None;
391 let mut before = None;
392
393 if input.contains("last") {
395 if let Some(days) = self.extract_duration_days(input) {
396 after = Some(Utc::now() - chrono::Duration::days(days));
397 }
398 }
399
400 if input.contains("today") {
402 let today = Utc::now().date_naive();
403 after = Some(today.and_hms_opt(0, 0, 0).unwrap().and_utc());
404 } else if input.contains("yesterday") {
405 let yesterday = Utc::now().date_naive() - chrono::Duration::days(1);
406 after = Some(yesterday.and_hms_opt(0, 0, 0).unwrap().and_utc());
407 before = Some(
408 Utc::now()
409 .date_naive()
410 .and_hms_opt(0, 0, 0)
411 .unwrap()
412 .and_utc(),
413 );
414 } else if input.contains("this week") {
415 after = Some(Utc::now() - chrono::Duration::days(7));
416 } else if input.contains("this month") {
417 after = Some(Utc::now() - chrono::Duration::days(30));
418 }
419
420 if after.is_some() || before.is_some() {
421 Some(DateFilter { after, before })
422 } else {
423 None
424 }
425 }
426
427 fn extract_duration_days(&self, input: &str) -> Option<i64> {
429 for word in input.split_whitespace() {
431 if let Ok(num) = word.parse::<i64>() {
432 if input.contains("day") {
433 return Some(num);
434 } else if input.contains("week") {
435 return Some(num * 7);
436 } else if input.contains("month") {
437 return Some(num * 30);
438 }
439 }
440 }
441
442 if input.contains("last week") {
444 Some(7)
445 } else if input.contains("last month") {
446 Some(30)
447 } else if input.contains("last year") {
448 Some(365)
449 } else {
450 None
451 }
452 }
453
454 fn extract_limit(&self, input: &str) -> Option<i64> {
456 let patterns = ["top ", "first ", "limit "];
458
459 for pattern in patterns {
460 if let Some(pos) = input.find(pattern) {
461 let rest = &input[pos + pattern.len()..];
462 let num_str: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect();
463 if let Ok(limit) = num_str.parse::<i64>() {
464 return Some(limit);
465 }
466 }
467 }
468
469 None
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_detect_create() {
479 let parser = NaturalLanguageParser::new();
480
481 let cmd = parser.parse("Remember that the API key is abc123");
482 assert_eq!(cmd.command_type, CommandType::Create);
483 assert!(cmd.content.is_some());
484 assert!(cmd.confidence > 0.8);
485 }
486
487 #[test]
488 fn test_detect_search() {
489 let parser = NaturalLanguageParser::new();
490
491 let cmd = parser.parse("Find all memories about authentication");
492 assert_eq!(cmd.command_type, CommandType::Search);
493 assert!(cmd.content.unwrap().contains("authentication"));
494 }
495
496 #[test]
497 fn test_extract_tags() {
498 let parser = NaturalLanguageParser::new();
499
500 let cmd = parser.parse("Save this note #important #work");
501 assert!(cmd.tags.contains(&"important".to_string()));
502 assert!(cmd.tags.contains(&"work".to_string()));
503 }
504
505 #[test]
506 fn test_extract_memory_type() {
507 let parser = NaturalLanguageParser::new();
508
509 let cmd = parser.parse("Add a todo: fix the bug");
510 assert_eq!(cmd.memory_type, Some(MemoryType::Todo));
511
512 let cmd = parser.parse("Record this decision: use JWT");
513 assert_eq!(cmd.memory_type, Some(MemoryType::Decision));
514 }
515
516 #[test]
517 fn test_extract_ids() {
518 let parser = NaturalLanguageParser::new();
519
520 let cmd = parser.parse("Link memory 123 to memory 456");
521 assert_eq!(cmd.target_id, Some(123));
522 assert_eq!(cmd.related_id, Some(456));
523 }
524
525 #[test]
526 fn test_extract_date_filter() {
527 let parser = NaturalLanguageParser::new();
528
529 let cmd = parser.parse("Find memories from last week");
530 assert!(cmd.date_filter.is_some());
531 assert!(cmd.date_filter.unwrap().after.is_some());
532 }
533
534 #[test]
535 fn test_extract_limit() {
536 let parser = NaturalLanguageParser::new();
537
538 let cmd = parser.parse("Show top 10 recent memories");
539 assert_eq!(cmd.limit, Some(10));
540 }
541
542 #[test]
543 fn test_question_as_search() {
544 let parser = NaturalLanguageParser::new();
545
546 let cmd = parser.parse("What is the database password?");
547 assert_eq!(cmd.command_type, CommandType::Search);
548 }
549}