1use chrono::{DateTime, Utc};
15use once_cell::sync::Lazy;
16use regex::Regex;
17use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet};
19
20use crate::types::MemoryId;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum EntityType {
30 Person,
32 Organization,
34 Project,
36 Concept,
38 Location,
40 DateTime,
42 Reference,
44 Other,
46}
47
48impl EntityType {
49 pub fn as_str(&self) -> &'static str {
50 match self {
51 EntityType::Person => "person",
52 EntityType::Organization => "organization",
53 EntityType::Project => "project",
54 EntityType::Concept => "concept",
55 EntityType::Location => "location",
56 EntityType::DateTime => "datetime",
57 EntityType::Reference => "reference",
58 EntityType::Other => "other",
59 }
60 }
61}
62
63impl std::str::FromStr for EntityType {
64 type Err = String;
65
66 fn from_str(s: &str) -> Result<Self, Self::Err> {
67 match s.to_lowercase().as_str() {
68 "person" => Ok(EntityType::Person),
69 "organization" | "org" | "company" => Ok(EntityType::Organization),
70 "project" | "repo" | "repository" => Ok(EntityType::Project),
71 "concept" | "term" | "topic" => Ok(EntityType::Concept),
72 "location" | "place" | "geo" => Ok(EntityType::Location),
73 "datetime" | "date" | "time" => Ok(EntityType::DateTime),
74 "reference" | "url" | "path" => Ok(EntityType::Reference),
75 "other" => Ok(EntityType::Other),
76 _ => Err(format!("Unknown entity type: {}", s)),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Entity {
84 pub id: i64,
86 pub name: String,
88 pub normalized_name: String,
90 pub entity_type: EntityType,
92 #[serde(default)]
94 pub aliases: Vec<String>,
95 #[serde(default)]
97 pub metadata: HashMap<String, serde_json::Value>,
98 pub created_at: DateTime<Utc>,
100 pub updated_at: DateTime<Utc>,
102 #[serde(default)]
104 pub mention_count: i32,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct MemoryEntity {
110 pub memory_id: MemoryId,
112 pub entity_id: i64,
114 pub relation: EntityRelation,
116 pub confidence: f32,
118 pub offset: Option<usize>,
120 pub created_at: DateTime<Utc>,
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126#[serde(rename_all = "lowercase")]
127pub enum EntityRelation {
128 Mentions,
130 Defines,
132 References,
134 About,
136 CreatedBy,
138}
139
140impl EntityRelation {
141 pub fn as_str(&self) -> &'static str {
142 match self {
143 EntityRelation::Mentions => "mentions",
144 EntityRelation::Defines => "defines",
145 EntityRelation::References => "references",
146 EntityRelation::About => "about",
147 EntityRelation::CreatedBy => "created_by",
148 }
149 }
150}
151
152impl std::str::FromStr for EntityRelation {
153 type Err = String;
154
155 fn from_str(s: &str) -> Result<Self, Self::Err> {
156 match s.to_lowercase().as_str() {
157 "mentions" => Ok(EntityRelation::Mentions),
158 "defines" => Ok(EntityRelation::Defines),
159 "references" => Ok(EntityRelation::References),
160 "about" => Ok(EntityRelation::About),
161 "created_by" | "createdby" => Ok(EntityRelation::CreatedBy),
162 _ => Err(format!("Unknown entity relation: {}", s)),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct ExtractionResult {
170 pub entities: Vec<ExtractedEntity>,
172 pub extraction_time_ms: u64,
174}
175
176#[derive(Debug, Clone)]
178pub struct ExtractedEntity {
179 pub text: String,
181 pub normalized: String,
183 pub entity_type: EntityType,
185 pub confidence: f32,
187 pub offset: usize,
189 pub length: usize,
191 pub suggested_relation: EntityRelation,
193}
194
195#[derive(Debug, Clone)]
201pub struct EntityExtractionConfig {
202 pub min_confidence: f32,
204 pub extract_people: bool,
206 pub extract_organizations: bool,
208 pub extract_projects: bool,
210 pub extract_concepts: bool,
212 pub extract_locations: bool,
214 pub extract_datetime: bool,
216 pub extract_references: bool,
218 pub custom_patterns: HashMap<String, EntityType>,
220}
221
222impl Default for EntityExtractionConfig {
223 fn default() -> Self {
224 Self {
225 min_confidence: 0.5,
226 extract_people: true,
227 extract_organizations: true,
228 extract_projects: true,
229 extract_concepts: true,
230 extract_locations: true,
231 extract_datetime: true,
232 extract_references: true,
233 custom_patterns: HashMap::new(),
234 }
235 }
236}
237
238pub struct EntityExtractor {
240 config: EntityExtractionConfig,
241 person_pattern: Regex,
243 org_pattern: Regex,
244 project_pattern: Regex,
245 url_pattern: Regex,
246 path_pattern: Regex,
247 datetime_pattern: Regex,
248 mention_pattern: Regex,
249 known_organizations: HashSet<String>,
251 known_concepts: HashSet<String>,
252}
253
254static PERSON_PATTERN: Lazy<Regex> = Lazy::new(|| {
256 Regex::new(
257 r"(?x)
258 @[\w-]+ # @username mentions
259 |(?:Mr\.|Mrs\.|Ms\.|Dr\.|Prof\.)\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)? # Title + name
260 |[A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)? # First Last (Middle)
261 ",
262 )
263 .unwrap()
264});
265
266static ORG_PATTERN: Lazy<Regex> = Lazy::new(|| {
267 Regex::new(
268 r"(?x)
269 [A-Z][A-Za-z]*(?:\s+[A-Z][A-Za-z]*)*\s+(?:Inc\.?|Corp\.?|LLC|Ltd\.?|Co\.?|Team|Group|Labs?)
270 |(?:The\s+)?[A-Z][A-Za-z]+(?:\s+[A-Z][A-Za-z]+)*\s+(?:Company|Organization|Foundation|Institute)
271 ",
272 )
273 .unwrap()
274});
275
276static PROJECT_PATTERN: Lazy<Regex> = Lazy::new(|| {
277 Regex::new(
278 r"(?x)
279 [a-z][a-z0-9]*(?:-[a-z0-9]+)+ # kebab-case project names
280 |[a-z][a-z0-9]*(?:_[a-z0-9]+)+ # snake_case project names
281 |[A-Z][a-z]+(?:[A-Z][a-z]+)+ # PascalCase project names
282 |v?\d+\.\d+(?:\.\d+)?(?:-[a-z]+)? # version numbers
283 ",
284 )
285 .unwrap()
286});
287
288static URL_PATTERN: Lazy<Regex> =
289 Lazy::new(|| Regex::new(r"https?://[^\s<>\[\]()]+|www\.[^\s<>\[\]]+").unwrap());
290
291static PATH_PATTERN: Lazy<Regex> = Lazy::new(|| {
292 Regex::new(
293 r"(?x)
294 (?:/[\w.-]+)+ # Unix paths
295 |[A-Z]:\\(?:[\w.-]+\\)+[\w.-]* # Windows paths
296 |\.{1,2}/[\w.-/]+ # Relative paths
297 ",
298 )
299 .unwrap()
300});
301
302static DATETIME_PATTERN: Lazy<Regex> = Lazy::new(|| {
303 Regex::new(
304 r"(?x)
305 \d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}(?::\d{2})?)? # ISO dates
306 |\d{1,2}/\d{1,2}/\d{2,4} # MM/DD/YYYY
307 |(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\.?\s+\d{1,2}(?:,?\s+\d{4})?
308 |Q[1-4]\s+\d{4} # Quarters
309 |(?:yesterday|today|tomorrow|last\s+week|next\s+month)
310 ",
311 )
312 .unwrap()
313});
314
315static MENTION_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"@[\w-]+").unwrap());
316
317static KNOWN_ORGANIZATIONS: Lazy<HashSet<String>> = Lazy::new(|| {
318 [
319 "Anthropic",
320 "OpenAI",
321 "Google",
322 "Microsoft",
323 "Meta",
324 "Amazon",
325 "Apple",
326 "GitHub",
327 "GitLab",
328 "Vercel",
329 "Cloudflare",
330 "AWS",
331 "Azure",
332 "GCP",
333 "Stripe",
334 "Supabase",
335 "Neon",
336 "PlanetScale",
337 "MongoDB",
338 "Redis",
339 ]
340 .iter()
341 .map(|s| s.to_lowercase())
342 .collect()
343});
344
345static KNOWN_CONCEPTS: Lazy<HashSet<String>> = Lazy::new(|| {
346 [
347 "machine learning",
348 "deep learning",
349 "neural network",
350 "transformer",
351 "embedding",
352 "vector database",
353 "semantic search",
354 "rag",
355 "llm",
356 "api",
357 "rest",
358 "graphql",
359 "grpc",
360 "websocket",
361 "microservices",
362 "kubernetes",
363 "docker",
364 "ci/cd",
365 "devops",
366 "serverless",
367 "authentication",
368 "authorization",
369 "oauth",
370 "jwt",
371 "session",
372 "database",
373 "sql",
374 "nosql",
375 "postgresql",
376 "sqlite",
377 "redis",
378 "rust",
379 "python",
380 "typescript",
381 "javascript",
382 "go",
383 "java",
384 ]
385 .iter()
386 .map(|s| s.to_string())
387 .collect()
388});
389
390impl EntityExtractor {
391 pub fn new(config: EntityExtractionConfig) -> Self {
392 Self {
393 config,
394 person_pattern: PERSON_PATTERN.clone(),
395 org_pattern: ORG_PATTERN.clone(),
396 project_pattern: PROJECT_PATTERN.clone(),
397 url_pattern: URL_PATTERN.clone(),
398 path_pattern: PATH_PATTERN.clone(),
399 datetime_pattern: DATETIME_PATTERN.clone(),
400 mention_pattern: MENTION_PATTERN.clone(),
401 known_organizations: KNOWN_ORGANIZATIONS.clone(),
402 known_concepts: KNOWN_CONCEPTS.clone(),
403 }
404 }
405
406 pub fn extract(&self, text: &str) -> ExtractionResult {
408 let start = std::time::Instant::now();
409 let mut entities = Vec::new();
410 let text_lower = text.to_lowercase();
411
412 if self.config.extract_people {
414 for cap in self.mention_pattern.find_iter(text) {
415 entities.push(ExtractedEntity {
416 text: cap.as_str().to_string(),
417 normalized: cap.as_str().to_lowercase(),
418 entity_type: EntityType::Person,
419 confidence: 0.95,
420 offset: cap.start(),
421 length: cap.len(),
422 suggested_relation: EntityRelation::Mentions,
423 });
424 }
425
426 for cap in self.person_pattern.find_iter(text) {
428 if cap.as_str().starts_with('@') {
430 continue;
431 }
432 entities.push(ExtractedEntity {
433 text: cap.as_str().to_string(),
434 normalized: normalize_name(cap.as_str()),
435 entity_type: EntityType::Person,
436 confidence: 0.7,
437 offset: cap.start(),
438 length: cap.len(),
439 suggested_relation: EntityRelation::Mentions,
440 });
441 }
442 }
443
444 if self.config.extract_organizations {
446 for cap in self.org_pattern.find_iter(text) {
447 entities.push(ExtractedEntity {
448 text: cap.as_str().to_string(),
449 normalized: normalize_name(cap.as_str()),
450 entity_type: EntityType::Organization,
451 confidence: 0.8,
452 offset: cap.start(),
453 length: cap.len(),
454 suggested_relation: EntityRelation::Mentions,
455 });
456 }
457
458 for org in &self.known_organizations {
460 if let Some(pos) = text_lower.find(org) {
461 let original = &text[pos..pos + org.len()];
463 if !entities.iter().any(|e| e.offset == pos) {
465 entities.push(ExtractedEntity {
466 text: original.to_string(),
467 normalized: org.clone(),
468 entity_type: EntityType::Organization,
469 confidence: 0.9,
470 offset: pos,
471 length: org.len(),
472 suggested_relation: EntityRelation::Mentions,
473 });
474 }
475 }
476 }
477 }
478
479 if self.config.extract_references {
481 for cap in self.url_pattern.find_iter(text) {
482 entities.push(ExtractedEntity {
483 text: cap.as_str().to_string(),
484 normalized: cap.as_str().to_lowercase(),
485 entity_type: EntityType::Reference,
486 confidence: 0.99,
487 offset: cap.start(),
488 length: cap.len(),
489 suggested_relation: EntityRelation::References,
490 });
491 }
492
493 for cap in self.path_pattern.find_iter(text) {
494 entities.push(ExtractedEntity {
495 text: cap.as_str().to_string(),
496 normalized: cap.as_str().to_string(),
497 entity_type: EntityType::Reference,
498 confidence: 0.85,
499 offset: cap.start(),
500 length: cap.len(),
501 suggested_relation: EntityRelation::References,
502 });
503 }
504 }
505
506 if self.config.extract_datetime {
508 for cap in self.datetime_pattern.find_iter(text) {
509 entities.push(ExtractedEntity {
510 text: cap.as_str().to_string(),
511 normalized: cap.as_str().to_lowercase(),
512 entity_type: EntityType::DateTime,
513 confidence: 0.9,
514 offset: cap.start(),
515 length: cap.len(),
516 suggested_relation: EntityRelation::Mentions,
517 });
518 }
519 }
520
521 if self.config.extract_concepts {
523 for concept in &self.known_concepts {
524 if let Some(pos) = text_lower.find(concept) {
525 let original = &text[pos..pos + concept.len()];
526 entities.push(ExtractedEntity {
527 text: original.to_string(),
528 normalized: concept.clone(),
529 entity_type: EntityType::Concept,
530 confidence: 0.85,
531 offset: pos,
532 length: concept.len(),
533 suggested_relation: EntityRelation::About,
534 });
535 }
536 }
537 }
538
539 if self.config.extract_projects {
541 for cap in self.project_pattern.find_iter(text) {
542 let matched = cap.as_str();
543 if matched.len() < 3
545 || matched
546 .chars()
547 .all(|c| c.is_numeric() || c == '.' || c == '-' || c == 'v')
548 {
549 continue;
550 }
551 entities.push(ExtractedEntity {
552 text: matched.to_string(),
553 normalized: matched.to_lowercase(),
554 entity_type: EntityType::Project,
555 confidence: 0.6,
556 offset: cap.start(),
557 length: cap.len(),
558 suggested_relation: EntityRelation::Mentions,
559 });
560 }
561 }
562
563 entities.retain(|e| e.confidence >= self.config.min_confidence);
565 deduplicate_entities(&mut entities);
566
567 let extraction_time_ms = start.elapsed().as_millis() as u64;
568
569 ExtractionResult {
570 entities,
571 extraction_time_ms,
572 }
573 }
574
575 pub fn add_custom_pattern(&mut self, pattern: &str, entity_type: EntityType) {
577 self.config
578 .custom_patterns
579 .insert(pattern.to_string(), entity_type);
580 }
581
582 pub fn config(&self) -> &EntityExtractionConfig {
584 &self.config
585 }
586}
587
588impl Default for EntityExtractor {
589 fn default() -> Self {
590 Self::new(EntityExtractionConfig::default())
591 }
592}
593
594fn normalize_name(name: &str) -> String {
600 name.trim()
601 .to_lowercase()
602 .split_whitespace()
603 .collect::<Vec<_>>()
604 .join(" ")
605}
606
607fn deduplicate_entities(entities: &mut Vec<ExtractedEntity>) {
609 entities.sort_by(|a, b| {
611 a.offset.cmp(&b.offset).then(
612 b.confidence
613 .partial_cmp(&a.confidence)
614 .unwrap_or(std::cmp::Ordering::Equal),
615 )
616 });
617
618 let mut i = 0;
620 while i < entities.len() {
621 let current_end = entities[i].offset + entities[i].length;
622 let mut j = i + 1;
623 while j < entities.len() {
624 if entities[j].offset < current_end {
625 if entities[j].confidence > entities[i].confidence {
627 entities.remove(i);
628 continue;
630 } else {
631 entities.remove(j);
632 continue;
634 }
635 }
636 j += 1;
637 }
638 i += 1;
639 }
640}
641
642#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_extract_mentions() {
652 let extractor = EntityExtractor::default();
653 let result = extractor.extract("Hey @john-doe, can you review this with @alice?");
654
655 let people: Vec<_> = result
656 .entities
657 .iter()
658 .filter(|e| e.entity_type == EntityType::Person)
659 .collect();
660
661 assert_eq!(people.len(), 2);
662 assert!(people.iter().any(|e| e.text == "@john-doe"));
663 assert!(people.iter().any(|e| e.text == "@alice"));
664 }
665
666 #[test]
667 fn test_extract_urls() {
668 let extractor = EntityExtractor::default();
669 let result = extractor.extract("Check out https://github.com/engram/engram for more info.");
670
671 let refs: Vec<_> = result
672 .entities
673 .iter()
674 .filter(|e| e.entity_type == EntityType::Reference)
675 .collect();
676
677 assert_eq!(refs.len(), 1);
678 assert!(refs[0].text.contains("github.com"));
679 }
680
681 #[test]
682 fn test_extract_organizations() {
683 let extractor = EntityExtractor::default();
684 let result = extractor.extract("We're using Anthropic's Claude and OpenAI's GPT-4.");
685
686 let orgs: Vec<_> = result
687 .entities
688 .iter()
689 .filter(|e| e.entity_type == EntityType::Organization)
690 .collect();
691
692 assert!(orgs.len() >= 2);
693 }
694
695 #[test]
696 fn test_extract_concepts() {
697 let extractor = EntityExtractor::default();
698 let result = extractor.extract("We need to implement semantic search with embeddings.");
699
700 let concepts: Vec<_> = result
701 .entities
702 .iter()
703 .filter(|e| e.entity_type == EntityType::Concept)
704 .collect();
705
706 assert!(concepts
707 .iter()
708 .any(|e| e.normalized.contains("semantic search")));
709 assert!(concepts.iter().any(|e| e.normalized.contains("embedding")));
710 }
711
712 #[test]
713 fn test_extract_dates() {
714 let extractor = EntityExtractor::default();
715 let result = extractor
716 .extract("Meeting scheduled for 2024-01-15. Let's discuss yesterday's issues.");
717
718 let dates: Vec<_> = result
719 .entities
720 .iter()
721 .filter(|e| e.entity_type == EntityType::DateTime)
722 .collect();
723
724 assert!(dates.len() >= 2);
725 }
726
727 #[test]
728 fn test_entity_type_parsing() {
729 assert_eq!("person".parse::<EntityType>().unwrap(), EntityType::Person);
730 assert_eq!(
731 "org".parse::<EntityType>().unwrap(),
732 EntityType::Organization
733 );
734 assert_eq!("repo".parse::<EntityType>().unwrap(), EntityType::Project);
735 }
736
737 #[test]
738 fn test_confidence_threshold() {
739 let config = EntityExtractionConfig {
740 min_confidence: 0.9,
741 ..Default::default()
742 };
743 let extractor = EntityExtractor::new(config);
744
745 let result = extractor.extract("Some random text with John Smith mentioned.");
747
748 let people: Vec<_> = result
750 .entities
751 .iter()
752 .filter(|e| e.entity_type == EntityType::Person && !e.text.starts_with('@'))
753 .collect();
754
755 assert!(people.is_empty());
756 }
757}