Skip to main content

engram/intelligence/
entities.rs

1//! Entity Extraction for Engram (RML-925)
2//!
3//! Provides automatic Named Entity Recognition (NER) to extract:
4//! - People (names, roles, mentions)
5//! - Organizations (companies, teams)
6//! - Projects (repos, products)
7//! - Concepts (technical terms, patterns)
8//! - Locations (places, regions)
9//! - Dates/Times (temporal references)
10//!
11//! Uses pattern-based extraction (fast, no dependencies) with optional
12//! LLM-enhanced extraction for higher quality.
13
14use chrono::{DateTime, Utc};
15use once_cell::sync::Lazy;
16use regex::Regex;
17use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet};
19
20use crate::types::MemoryId;
21
22// =============================================================================
23// Types
24// =============================================================================
25
26/// Type of entity extracted from text
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum EntityType {
30    /// Person name (e.g., "John Smith", "@username")
31    Person,
32    /// Organization or company (e.g., "Anthropic", "OpenAI")
33    Organization,
34    /// Project or repository (e.g., "engram", "rust-analyzer")
35    Project,
36    /// Technical concept or term (e.g., "vector database", "embeddings")
37    Concept,
38    /// Geographic location (e.g., "San Francisco", "AWS us-east-1")
39    Location,
40    /// Date or time reference (e.g., "yesterday", "Q4 2024")
41    DateTime,
42    /// URL or file path
43    Reference,
44    /// Generic/unknown entity type
45    Other,
46}
47
48impl EntityType {
49    pub fn as_str(&self) -> &'static str {
50        match self {
51            EntityType::Person => "person",
52            EntityType::Organization => "organization",
53            EntityType::Project => "project",
54            EntityType::Concept => "concept",
55            EntityType::Location => "location",
56            EntityType::DateTime => "datetime",
57            EntityType::Reference => "reference",
58            EntityType::Other => "other",
59        }
60    }
61}
62
63impl std::str::FromStr for EntityType {
64    type Err = String;
65
66    fn from_str(s: &str) -> Result<Self, Self::Err> {
67        match s.to_lowercase().as_str() {
68            "person" => Ok(EntityType::Person),
69            "organization" | "org" | "company" => Ok(EntityType::Organization),
70            "project" | "repo" | "repository" => Ok(EntityType::Project),
71            "concept" | "term" | "topic" => Ok(EntityType::Concept),
72            "location" | "place" | "geo" => Ok(EntityType::Location),
73            "datetime" | "date" | "time" => Ok(EntityType::DateTime),
74            "reference" | "url" | "path" => Ok(EntityType::Reference),
75            "other" => Ok(EntityType::Other),
76            _ => Err(format!("Unknown entity type: {}", s)),
77        }
78    }
79}
80
81/// An extracted entity
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Entity {
84    /// Unique identifier
85    pub id: i64,
86    /// Canonical name of the entity
87    pub name: String,
88    /// Normalized name for matching (lowercase, trimmed)
89    pub normalized_name: String,
90    /// Type of entity
91    pub entity_type: EntityType,
92    /// Aliases (other names this entity is known by)
93    #[serde(default)]
94    pub aliases: Vec<String>,
95    /// Additional metadata
96    #[serde(default)]
97    pub metadata: HashMap<String, serde_json::Value>,
98    /// When first seen
99    pub created_at: DateTime<Utc>,
100    /// When last referenced
101    pub updated_at: DateTime<Utc>,
102    /// Number of times referenced
103    #[serde(default)]
104    pub mention_count: i32,
105}
106
107/// Relationship between a memory and an entity
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct MemoryEntity {
110    /// Memory ID
111    pub memory_id: MemoryId,
112    /// Entity ID
113    pub entity_id: i64,
114    /// Type of relation (mentions, defines, references, etc.)
115    pub relation: EntityRelation,
116    /// Confidence score (0.0 - 1.0)
117    pub confidence: f32,
118    /// Character offset where entity appears in content
119    pub offset: Option<usize>,
120    /// When the link was created
121    pub created_at: DateTime<Utc>,
122}
123
124/// Type of relationship between memory and entity
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126#[serde(rename_all = "lowercase")]
127pub enum EntityRelation {
128    /// Entity is mentioned in the memory
129    Mentions,
130    /// Memory defines or describes the entity
131    Defines,
132    /// Memory references the entity (e.g., link, citation)
133    References,
134    /// Memory is about/focuses on the entity
135    About,
136    /// Memory was created by the entity (for Person type)
137    CreatedBy,
138}
139
140impl EntityRelation {
141    pub fn as_str(&self) -> &'static str {
142        match self {
143            EntityRelation::Mentions => "mentions",
144            EntityRelation::Defines => "defines",
145            EntityRelation::References => "references",
146            EntityRelation::About => "about",
147            EntityRelation::CreatedBy => "created_by",
148        }
149    }
150}
151
152impl std::str::FromStr for EntityRelation {
153    type Err = String;
154
155    fn from_str(s: &str) -> Result<Self, Self::Err> {
156        match s.to_lowercase().as_str() {
157            "mentions" => Ok(EntityRelation::Mentions),
158            "defines" => Ok(EntityRelation::Defines),
159            "references" => Ok(EntityRelation::References),
160            "about" => Ok(EntityRelation::About),
161            "created_by" | "createdby" => Ok(EntityRelation::CreatedBy),
162            _ => Err(format!("Unknown entity relation: {}", s)),
163        }
164    }
165}
166
167/// Result of entity extraction from text
168#[derive(Debug, Clone)]
169pub struct ExtractionResult {
170    /// Extracted entities with their positions
171    pub entities: Vec<ExtractedEntity>,
172    /// Total extraction time in milliseconds
173    pub extraction_time_ms: u64,
174}
175
176/// A single extracted entity from text
177#[derive(Debug, Clone)]
178pub struct ExtractedEntity {
179    /// The extracted text
180    pub text: String,
181    /// Normalized form
182    pub normalized: String,
183    /// Entity type
184    pub entity_type: EntityType,
185    /// Confidence score (0.0 - 1.0)
186    pub confidence: f32,
187    /// Character offset in source text
188    pub offset: usize,
189    /// Length of the match
190    pub length: usize,
191    /// Suggested relation type
192    pub suggested_relation: EntityRelation,
193}
194
195// =============================================================================
196// Entity Extraction Engine
197// =============================================================================
198
199/// Configuration for entity extraction
200#[derive(Debug, Clone)]
201pub struct EntityExtractionConfig {
202    /// Minimum confidence threshold for extraction
203    pub min_confidence: f32,
204    /// Extract people names
205    pub extract_people: bool,
206    /// Extract organizations
207    pub extract_organizations: bool,
208    /// Extract projects
209    pub extract_projects: bool,
210    /// Extract concepts
211    pub extract_concepts: bool,
212    /// Extract locations
213    pub extract_locations: bool,
214    /// Extract datetime references
215    pub extract_datetime: bool,
216    /// Extract URLs and paths
217    pub extract_references: bool,
218    /// Custom patterns to match (name -> entity_type)
219    pub custom_patterns: HashMap<String, EntityType>,
220}
221
222impl Default for EntityExtractionConfig {
223    fn default() -> Self {
224        Self {
225            min_confidence: 0.5,
226            extract_people: true,
227            extract_organizations: true,
228            extract_projects: true,
229            extract_concepts: true,
230            extract_locations: true,
231            extract_datetime: true,
232            extract_references: true,
233            custom_patterns: HashMap::new(),
234        }
235    }
236}
237
238/// Entity extraction engine using pattern matching
239pub struct EntityExtractor {
240    config: EntityExtractionConfig,
241    // Compiled regex patterns
242    person_pattern: Regex,
243    org_pattern: Regex,
244    project_pattern: Regex,
245    url_pattern: Regex,
246    path_pattern: Regex,
247    datetime_pattern: Regex,
248    mention_pattern: Regex,
249    // Known entities for matching
250    known_organizations: HashSet<String>,
251    known_concepts: HashSet<String>,
252}
253
254// Compiled regex patterns
255static PERSON_PATTERN: Lazy<Regex> = Lazy::new(|| {
256    Regex::new(
257        r"(?x)
258        @[\w-]+                           # @username mentions
259        |(?:Mr\.|Mrs\.|Ms\.|Dr\.|Prof\.)\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?  # Title + name
260        |[A-Z][a-z]+\s+[A-Z][a-z]+(?:\s+[A-Z][a-z]+)?  # First Last (Middle)
261        ",
262    )
263    .unwrap()
264});
265
266static ORG_PATTERN: Lazy<Regex> = Lazy::new(|| {
267    Regex::new(
268        r"(?x)
269        [A-Z][A-Za-z]*(?:\s+[A-Z][A-Za-z]*)*\s+(?:Inc\.?|Corp\.?|LLC|Ltd\.?|Co\.?|Team|Group|Labs?)
270        |(?:The\s+)?[A-Z][A-Za-z]+(?:\s+[A-Z][A-Za-z]+)*\s+(?:Company|Organization|Foundation|Institute)
271        ",
272    )
273    .unwrap()
274});
275
276static PROJECT_PATTERN: Lazy<Regex> = Lazy::new(|| {
277    Regex::new(
278        r"(?x)
279        [a-z][a-z0-9]*(?:-[a-z0-9]+)+     # kebab-case project names
280        |[a-z][a-z0-9]*(?:_[a-z0-9]+)+    # snake_case project names
281        |[A-Z][a-z]+(?:[A-Z][a-z]+)+      # PascalCase project names
282        |v?\d+\.\d+(?:\.\d+)?(?:-[a-z]+)? # version numbers
283        ",
284    )
285    .unwrap()
286});
287
288static URL_PATTERN: Lazy<Regex> =
289    Lazy::new(|| Regex::new(r"https?://[^\s<>\[\]()]+|www\.[^\s<>\[\]]+").unwrap());
290
291static PATH_PATTERN: Lazy<Regex> = Lazy::new(|| {
292    Regex::new(
293        r"(?x)
294        (?:/[\w.-]+)+                     # Unix paths
295        |[A-Z]:\\(?:[\w.-]+\\)+[\w.-]*    # Windows paths
296        |\.{1,2}/[\w.-/]+                 # Relative paths
297        ",
298    )
299    .unwrap()
300});
301
302static DATETIME_PATTERN: Lazy<Regex> = Lazy::new(|| {
303    Regex::new(
304        r"(?x)
305        \d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}(?::\d{2})?)?  # ISO dates
306        |\d{1,2}/\d{1,2}/\d{2,4}          # MM/DD/YYYY
307        |(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\.?\s+\d{1,2}(?:,?\s+\d{4})?
308        |Q[1-4]\s+\d{4}                   # Quarters
309        |(?:yesterday|today|tomorrow|last\s+week|next\s+month)
310        ",
311    )
312    .unwrap()
313});
314
315static MENTION_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"@[\w-]+").unwrap());
316
317static KNOWN_ORGANIZATIONS: Lazy<HashSet<String>> = Lazy::new(|| {
318    [
319        "Anthropic",
320        "OpenAI",
321        "Google",
322        "Microsoft",
323        "Meta",
324        "Amazon",
325        "Apple",
326        "GitHub",
327        "GitLab",
328        "Vercel",
329        "Cloudflare",
330        "AWS",
331        "Azure",
332        "GCP",
333        "Stripe",
334        "Supabase",
335        "Neon",
336        "PlanetScale",
337        "MongoDB",
338        "Redis",
339    ]
340    .iter()
341    .map(|s| s.to_lowercase())
342    .collect()
343});
344
345static KNOWN_CONCEPTS: Lazy<HashSet<String>> = Lazy::new(|| {
346    [
347        "machine learning",
348        "deep learning",
349        "neural network",
350        "transformer",
351        "embedding",
352        "vector database",
353        "semantic search",
354        "rag",
355        "llm",
356        "api",
357        "rest",
358        "graphql",
359        "grpc",
360        "websocket",
361        "microservices",
362        "kubernetes",
363        "docker",
364        "ci/cd",
365        "devops",
366        "serverless",
367        "authentication",
368        "authorization",
369        "oauth",
370        "jwt",
371        "session",
372        "database",
373        "sql",
374        "nosql",
375        "postgresql",
376        "sqlite",
377        "redis",
378        "rust",
379        "python",
380        "typescript",
381        "javascript",
382        "go",
383        "java",
384    ]
385    .iter()
386    .map(|s| s.to_string())
387    .collect()
388});
389
390impl EntityExtractor {
391    pub fn new(config: EntityExtractionConfig) -> Self {
392        Self {
393            config,
394            person_pattern: PERSON_PATTERN.clone(),
395            org_pattern: ORG_PATTERN.clone(),
396            project_pattern: PROJECT_PATTERN.clone(),
397            url_pattern: URL_PATTERN.clone(),
398            path_pattern: PATH_PATTERN.clone(),
399            datetime_pattern: DATETIME_PATTERN.clone(),
400            mention_pattern: MENTION_PATTERN.clone(),
401            known_organizations: KNOWN_ORGANIZATIONS.clone(),
402            known_concepts: KNOWN_CONCEPTS.clone(),
403        }
404    }
405
406    /// Extract entities from text
407    pub fn extract(&self, text: &str) -> ExtractionResult {
408        let start = std::time::Instant::now();
409        let mut entities = Vec::new();
410        let text_lower = text.to_lowercase();
411
412        // Extract @mentions (high confidence)
413        if self.config.extract_people {
414            for cap in self.mention_pattern.find_iter(text) {
415                entities.push(ExtractedEntity {
416                    text: cap.as_str().to_string(),
417                    normalized: cap.as_str().to_lowercase(),
418                    entity_type: EntityType::Person,
419                    confidence: 0.95,
420                    offset: cap.start(),
421                    length: cap.len(),
422                    suggested_relation: EntityRelation::Mentions,
423                });
424            }
425
426            // Extract person names
427            for cap in self.person_pattern.find_iter(text) {
428                // Skip if already captured as @mention
429                if cap.as_str().starts_with('@') {
430                    continue;
431                }
432                entities.push(ExtractedEntity {
433                    text: cap.as_str().to_string(),
434                    normalized: normalize_name(cap.as_str()),
435                    entity_type: EntityType::Person,
436                    confidence: 0.7,
437                    offset: cap.start(),
438                    length: cap.len(),
439                    suggested_relation: EntityRelation::Mentions,
440                });
441            }
442        }
443
444        // Extract organizations
445        if self.config.extract_organizations {
446            for cap in self.org_pattern.find_iter(text) {
447                entities.push(ExtractedEntity {
448                    text: cap.as_str().to_string(),
449                    normalized: normalize_name(cap.as_str()),
450                    entity_type: EntityType::Organization,
451                    confidence: 0.8,
452                    offset: cap.start(),
453                    length: cap.len(),
454                    suggested_relation: EntityRelation::Mentions,
455                });
456            }
457
458            // Check for known organizations
459            for org in &self.known_organizations {
460                if let Some(pos) = text_lower.find(org) {
461                    // Get the original case version
462                    let original = &text[pos..pos + org.len()];
463                    // Avoid duplicates
464                    if !entities.iter().any(|e| e.offset == pos) {
465                        entities.push(ExtractedEntity {
466                            text: original.to_string(),
467                            normalized: org.clone(),
468                            entity_type: EntityType::Organization,
469                            confidence: 0.9,
470                            offset: pos,
471                            length: org.len(),
472                            suggested_relation: EntityRelation::Mentions,
473                        });
474                    }
475                }
476            }
477        }
478
479        // Extract URLs
480        if self.config.extract_references {
481            for cap in self.url_pattern.find_iter(text) {
482                entities.push(ExtractedEntity {
483                    text: cap.as_str().to_string(),
484                    normalized: cap.as_str().to_lowercase(),
485                    entity_type: EntityType::Reference,
486                    confidence: 0.99,
487                    offset: cap.start(),
488                    length: cap.len(),
489                    suggested_relation: EntityRelation::References,
490                });
491            }
492
493            for cap in self.path_pattern.find_iter(text) {
494                entities.push(ExtractedEntity {
495                    text: cap.as_str().to_string(),
496                    normalized: cap.as_str().to_string(),
497                    entity_type: EntityType::Reference,
498                    confidence: 0.85,
499                    offset: cap.start(),
500                    length: cap.len(),
501                    suggested_relation: EntityRelation::References,
502                });
503            }
504        }
505
506        // Extract datetime
507        if self.config.extract_datetime {
508            for cap in self.datetime_pattern.find_iter(text) {
509                entities.push(ExtractedEntity {
510                    text: cap.as_str().to_string(),
511                    normalized: cap.as_str().to_lowercase(),
512                    entity_type: EntityType::DateTime,
513                    confidence: 0.9,
514                    offset: cap.start(),
515                    length: cap.len(),
516                    suggested_relation: EntityRelation::Mentions,
517                });
518            }
519        }
520
521        // Extract concepts
522        if self.config.extract_concepts {
523            for concept in &self.known_concepts {
524                if let Some(pos) = text_lower.find(concept) {
525                    let original = &text[pos..pos + concept.len()];
526                    entities.push(ExtractedEntity {
527                        text: original.to_string(),
528                        normalized: concept.clone(),
529                        entity_type: EntityType::Concept,
530                        confidence: 0.85,
531                        offset: pos,
532                        length: concept.len(),
533                        suggested_relation: EntityRelation::About,
534                    });
535                }
536            }
537        }
538
539        // Extract project names
540        if self.config.extract_projects {
541            for cap in self.project_pattern.find_iter(text) {
542                let matched = cap.as_str();
543                // Skip very short matches and pure version numbers
544                if matched.len() < 3
545                    || matched
546                        .chars()
547                        .all(|c| c.is_numeric() || c == '.' || c == '-' || c == 'v')
548                {
549                    continue;
550                }
551                entities.push(ExtractedEntity {
552                    text: matched.to_string(),
553                    normalized: matched.to_lowercase(),
554                    entity_type: EntityType::Project,
555                    confidence: 0.6,
556                    offset: cap.start(),
557                    length: cap.len(),
558                    suggested_relation: EntityRelation::Mentions,
559                });
560            }
561        }
562
563        // Filter by confidence threshold and deduplicate
564        entities.retain(|e| e.confidence >= self.config.min_confidence);
565        deduplicate_entities(&mut entities);
566
567        let extraction_time_ms = start.elapsed().as_millis() as u64;
568
569        ExtractionResult {
570            entities,
571            extraction_time_ms,
572        }
573    }
574
575    /// Add a custom pattern for entity extraction
576    pub fn add_custom_pattern(&mut self, pattern: &str, entity_type: EntityType) {
577        self.config
578            .custom_patterns
579            .insert(pattern.to_string(), entity_type);
580    }
581
582    /// Get configuration
583    pub fn config(&self) -> &EntityExtractionConfig {
584        &self.config
585    }
586}
587
588impl Default for EntityExtractor {
589    fn default() -> Self {
590        Self::new(EntityExtractionConfig::default())
591    }
592}
593
594// =============================================================================
595// Helper Functions
596// =============================================================================
597
598/// Normalize a name for matching
599fn normalize_name(name: &str) -> String {
600    name.trim()
601        .to_lowercase()
602        .split_whitespace()
603        .collect::<Vec<_>>()
604        .join(" ")
605}
606
607/// Deduplicate entities, keeping the highest confidence match
608fn deduplicate_entities(entities: &mut Vec<ExtractedEntity>) {
609    // Sort by offset, then by confidence (descending)
610    entities.sort_by(|a, b| {
611        a.offset.cmp(&b.offset).then(
612            b.confidence
613                .partial_cmp(&a.confidence)
614                .unwrap_or(std::cmp::Ordering::Equal),
615        )
616    });
617
618    // Remove overlapping entities, keeping higher confidence
619    let mut i = 0;
620    while i < entities.len() {
621        let current_end = entities[i].offset + entities[i].length;
622        let mut j = i + 1;
623        while j < entities.len() {
624            if entities[j].offset < current_end {
625                // Overlapping - remove the lower confidence one
626                if entities[j].confidence > entities[i].confidence {
627                    entities.remove(i);
628                    // Don't increment i, check the new element at position i
629                    continue;
630                } else {
631                    entities.remove(j);
632                    // Don't increment j, check the new element at position j
633                    continue;
634                }
635            }
636            j += 1;
637        }
638        i += 1;
639    }
640}
641
642// =============================================================================
643// Tests
644// =============================================================================
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649
650    #[test]
651    fn test_extract_mentions() {
652        let extractor = EntityExtractor::default();
653        let result = extractor.extract("Hey @john-doe, can you review this with @alice?");
654
655        let people: Vec<_> = result
656            .entities
657            .iter()
658            .filter(|e| e.entity_type == EntityType::Person)
659            .collect();
660
661        assert_eq!(people.len(), 2);
662        assert!(people.iter().any(|e| e.text == "@john-doe"));
663        assert!(people.iter().any(|e| e.text == "@alice"));
664    }
665
666    #[test]
667    fn test_extract_urls() {
668        let extractor = EntityExtractor::default();
669        let result = extractor.extract("Check out https://github.com/engram/engram for more info.");
670
671        let refs: Vec<_> = result
672            .entities
673            .iter()
674            .filter(|e| e.entity_type == EntityType::Reference)
675            .collect();
676
677        assert_eq!(refs.len(), 1);
678        assert!(refs[0].text.contains("github.com"));
679    }
680
681    #[test]
682    fn test_extract_organizations() {
683        let extractor = EntityExtractor::default();
684        let result = extractor.extract("We're using Anthropic's Claude and OpenAI's GPT-4.");
685
686        let orgs: Vec<_> = result
687            .entities
688            .iter()
689            .filter(|e| e.entity_type == EntityType::Organization)
690            .collect();
691
692        assert!(orgs.len() >= 2);
693    }
694
695    #[test]
696    fn test_extract_concepts() {
697        let extractor = EntityExtractor::default();
698        let result = extractor.extract("We need to implement semantic search with embeddings.");
699
700        let concepts: Vec<_> = result
701            .entities
702            .iter()
703            .filter(|e| e.entity_type == EntityType::Concept)
704            .collect();
705
706        assert!(concepts
707            .iter()
708            .any(|e| e.normalized.contains("semantic search")));
709        assert!(concepts.iter().any(|e| e.normalized.contains("embedding")));
710    }
711
712    #[test]
713    fn test_extract_dates() {
714        let extractor = EntityExtractor::default();
715        let result = extractor
716            .extract("Meeting scheduled for 2024-01-15. Let's discuss yesterday's issues.");
717
718        let dates: Vec<_> = result
719            .entities
720            .iter()
721            .filter(|e| e.entity_type == EntityType::DateTime)
722            .collect();
723
724        assert!(dates.len() >= 2);
725    }
726
727    #[test]
728    fn test_entity_type_parsing() {
729        assert_eq!("person".parse::<EntityType>().unwrap(), EntityType::Person);
730        assert_eq!(
731            "org".parse::<EntityType>().unwrap(),
732            EntityType::Organization
733        );
734        assert_eq!("repo".parse::<EntityType>().unwrap(), EntityType::Project);
735    }
736
737    #[test]
738    fn test_confidence_threshold() {
739        let config = EntityExtractionConfig {
740            min_confidence: 0.9,
741            ..Default::default()
742        };
743        let extractor = EntityExtractor::new(config);
744
745        // Low confidence matches should be filtered out
746        let result = extractor.extract("Some random text with John Smith mentioned.");
747
748        // Person names have 0.7 confidence, should be filtered
749        let people: Vec<_> = result
750            .entities
751            .iter()
752            .filter(|e| e.entity_type == EntityType::Person && !e.text.starts_with('@'))
753            .collect();
754
755        assert!(people.is_empty());
756    }
757}