Skip to main content

engram/intelligence/
entity_extraction.rs

1//! Entity extraction for automatic identity linking
2//!
3//! Provides lightweight Named Entity Recognition (NER) for:
4//! - @mentions (e.g., @ronaldo, @acme-corp)
5//! - Email addresses
6//! - URLs with domain extraction
7//! - Capitalized names (simple heuristic)
8//! - Known identity aliases (database lookup)
9//!
10//! ## Invariants
11//!
12//! - Extraction never panics on any input
13//! - Empty/whitespace input returns empty results
14//! - Duplicate mentions are deduplicated with count
15//! - Results are sorted by first occurrence position
16//!
17//! ## Performance
18//!
19//! - Regex patterns are compiled once (lazy_static)
20//! - Single pass through text for pattern matching
21//! - Bounded output: max 100 entities per text
22
23use std::collections::HashMap;
24
25use once_cell::sync::Lazy;
26use regex::Regex;
27use rusqlite::Connection;
28use serde::{Deserialize, Serialize};
29use tracing::{debug, instrument, warn};
30
31use crate::error::Result;
32use crate::storage::identity_links::{normalize_alias, resolve_alias};
33
34/// Maximum entities to extract from a single text (prevents DoS)
35const MAX_ENTITIES_PER_TEXT: usize = 100;
36
37/// Minimum confidence threshold for extraction
38const MIN_CONFIDENCE: f32 = 0.3;
39
40/// Compiled regex patterns (compiled once, reused)
41static MENTION_PATTERN: Lazy<Regex> =
42    Lazy::new(|| Regex::new(r"@([a-zA-Z][a-zA-Z0-9_-]{1,30})").expect("valid regex"));
43
44static EMAIL_PATTERN: Lazy<Regex> = Lazy::new(|| {
45    Regex::new(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}").expect("valid regex")
46});
47
48static URL_PATTERN: Lazy<Regex> =
49    Lazy::new(|| Regex::new(r"https?://([a-zA-Z0-9.-]+)(?:/[^\s]*)?").expect("valid regex"));
50
51/// Pattern for capitalized names (2+ words starting with capitals)
52static NAME_PATTERN: Lazy<Regex> =
53    Lazy::new(|| Regex::new(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b").expect("valid regex"));
54
55/// An extracted entity from text
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ExtractedEntity {
58    /// The raw text as found in content
59    pub mention_text: String,
60    /// Normalized form for matching
61    pub normalized: String,
62    /// Type of entity detected
63    pub entity_type: ExtractedEntityType,
64    /// Confidence score (0.0 - 1.0)
65    pub confidence: f32,
66    /// Position in text (byte offset)
67    pub position: usize,
68    /// Number of times this entity appears
69    pub count: usize,
70    /// Resolved canonical ID if matched to existing identity
71    pub resolved_id: Option<String>,
72}
73
74/// Type of extracted entity
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76#[serde(rename_all = "snake_case")]
77pub enum ExtractedEntityType {
78    /// @mention style reference
79    Mention,
80    /// Email address
81    Email,
82    /// URL/domain
83    Url,
84    /// Capitalized name pattern
85    Name,
86    /// Matched existing alias
87    KnownAlias,
88}
89
90impl ExtractedEntityType {
91    /// Default confidence for this entity type
92    fn default_confidence(&self) -> f32 {
93        match self {
94            ExtractedEntityType::Mention => 0.9,
95            ExtractedEntityType::Email => 0.95,
96            ExtractedEntityType::Url => 0.7,
97            ExtractedEntityType::Name => 0.5,
98            ExtractedEntityType::KnownAlias => 1.0,
99        }
100    }
101}
102
103/// Configuration for entity extraction
104#[derive(Debug, Clone)]
105pub struct ExtractionConfig {
106    /// Extract @mentions
107    pub extract_mentions: bool,
108    /// Extract email addresses
109    pub extract_emails: bool,
110    /// Extract URLs/domains
111    pub extract_urls: bool,
112    /// Extract capitalized names
113    pub extract_names: bool,
114    /// Lookup existing aliases in database
115    pub lookup_aliases: bool,
116    /// Minimum confidence to include
117    pub min_confidence: f32,
118    /// Maximum entities to return
119    pub max_entities: usize,
120}
121
122impl Default for ExtractionConfig {
123    fn default() -> Self {
124        Self {
125            extract_mentions: true,
126            extract_emails: true,
127            extract_urls: true,
128            extract_names: true,
129            lookup_aliases: true,
130            min_confidence: MIN_CONFIDENCE,
131            max_entities: MAX_ENTITIES_PER_TEXT,
132        }
133    }
134}
135
136/// Result of entity extraction
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ExtractionResult {
139    /// Extracted entities, deduplicated and sorted by position
140    pub entities: Vec<ExtractedEntity>,
141    /// Total mentions found (before dedup)
142    pub total_mentions: usize,
143    /// Number of entities resolved to existing identities
144    pub resolved_count: usize,
145}
146
147/// Extract entities from text content.
148///
149/// This function never panics. Invalid input returns empty results.
150///
151/// # Arguments
152/// * `content` - Text to extract entities from
153/// * `config` - Extraction configuration
154/// * `conn` - Optional database connection for alias lookup
155///
156/// # Returns
157/// Extraction result with deduplicated, sorted entities
158#[instrument(skip(content, config, conn), fields(content_len = content.len()))]
159pub fn extract_entities(
160    content: &str,
161    config: &ExtractionConfig,
162    conn: Option<&Connection>,
163) -> ExtractionResult {
164    // Handle empty/whitespace input
165    let content = content.trim();
166    if content.is_empty() {
167        return ExtractionResult {
168            entities: vec![],
169            total_mentions: 0,
170            resolved_count: 0,
171        };
172    }
173
174    // Use HashMap for deduplication (normalized -> entity)
175    let mut entities_map: HashMap<String, ExtractedEntity> = HashMap::new();
176    let mut total_mentions = 0;
177
178    // Extract @mentions
179    if config.extract_mentions {
180        for cap in MENTION_PATTERN.captures_iter(content) {
181            if let Some(m) = cap.get(1) {
182                let mention_text = format!("@{}", m.as_str());
183                let normalized = normalize_alias(&mention_text);
184                let position = cap.get(0).map(|c| c.start()).unwrap_or(0);
185
186                add_or_increment(
187                    &mut entities_map,
188                    mention_text,
189                    normalized,
190                    ExtractedEntityType::Mention,
191                    position,
192                );
193                total_mentions += 1;
194            }
195
196            // Bound check
197            if entities_map.len() >= config.max_entities {
198                break;
199            }
200        }
201    }
202
203    // Extract emails
204    if config.extract_emails && entities_map.len() < config.max_entities {
205        for cap in EMAIL_PATTERN.find_iter(content) {
206            let email = cap.as_str();
207            let normalized = normalize_alias(email);
208
209            add_or_increment(
210                &mut entities_map,
211                email.to_string(),
212                normalized,
213                ExtractedEntityType::Email,
214                cap.start(),
215            );
216            total_mentions += 1;
217
218            if entities_map.len() >= config.max_entities {
219                break;
220            }
221        }
222    }
223
224    // Extract URLs (just domain part)
225    if config.extract_urls && entities_map.len() < config.max_entities {
226        for cap in URL_PATTERN.captures_iter(content) {
227            if let Some(domain) = cap.get(1) {
228                let domain_str = domain.as_str();
229                // Skip common domains
230                if !is_common_domain(domain_str) {
231                    let normalized = normalize_alias(domain_str);
232
233                    add_or_increment(
234                        &mut entities_map,
235                        domain_str.to_string(),
236                        normalized,
237                        ExtractedEntityType::Url,
238                        cap.get(0).map(|c| c.start()).unwrap_or(0),
239                    );
240                    total_mentions += 1;
241                }
242            }
243
244            if entities_map.len() >= config.max_entities {
245                break;
246            }
247        }
248    }
249
250    // Extract capitalized names
251    if config.extract_names && entities_map.len() < config.max_entities {
252        for cap in NAME_PATTERN.find_iter(content) {
253            let name = cap.as_str();
254            // Skip common phrases
255            if !is_common_phrase(name) {
256                let normalized = normalize_alias(name);
257
258                add_or_increment(
259                    &mut entities_map,
260                    name.to_string(),
261                    normalized,
262                    ExtractedEntityType::Name,
263                    cap.start(),
264                );
265                total_mentions += 1;
266            }
267
268            if entities_map.len() >= config.max_entities {
269                break;
270            }
271        }
272    }
273
274    // Resolve entities against existing identities
275    let mut resolved_count = 0;
276    if config.lookup_aliases {
277        if let Some(conn) = conn {
278            for entity in entities_map.values_mut() {
279                if let Ok(Some(identity)) = resolve_alias(conn, &entity.normalized) {
280                    entity.resolved_id = Some(identity.canonical_id);
281                    entity.entity_type = ExtractedEntityType::KnownAlias;
282                    entity.confidence = 1.0;
283                    resolved_count += 1;
284                }
285            }
286        }
287    }
288
289    // Filter by confidence and convert to vec
290    let mut entities: Vec<ExtractedEntity> = entities_map
291        .into_values()
292        .filter(|e| e.confidence >= config.min_confidence)
293        .collect();
294
295    // Sort by position for stable output
296    entities.sort_by_key(|e| e.position);
297
298    // Truncate to max
299    entities.truncate(config.max_entities);
300
301    debug!(
302        entity_count = entities.len(),
303        total_mentions, resolved_count, "Entity extraction complete"
304    );
305
306    ExtractionResult {
307        entities,
308        total_mentions,
309        resolved_count,
310    }
311}
312
313/// Add entity or increment count if exists
314fn add_or_increment(
315    map: &mut HashMap<String, ExtractedEntity>,
316    mention_text: String,
317    normalized: String,
318    entity_type: ExtractedEntityType,
319    position: usize,
320) {
321    if let Some(existing) = map.get_mut(&normalized) {
322        existing.count += 1;
323    } else {
324        map.insert(
325            normalized.clone(),
326            ExtractedEntity {
327                mention_text,
328                normalized,
329                entity_type,
330                confidence: entity_type.default_confidence(),
331                position,
332                count: 1,
333                resolved_id: None,
334            },
335        );
336    }
337}
338
339/// Check if domain is too common to be meaningful
340fn is_common_domain(domain: &str) -> bool {
341    const COMMON: &[&str] = &[
342        "google.com",
343        "github.com",
344        "stackoverflow.com",
345        "wikipedia.org",
346        "twitter.com",
347        "x.com",
348        "facebook.com",
349        "youtube.com",
350        "linkedin.com",
351        "medium.com",
352        "docs.rs",
353        "crates.io",
354        "rust-lang.org",
355    ];
356    COMMON.iter().any(|c| domain.eq_ignore_ascii_case(c))
357}
358
359/// Check if phrase is too common to be a name
360fn is_common_phrase(phrase: &str) -> bool {
361    const COMMON: &[&str] = &[
362        "New York",
363        "Los Angeles",
364        "San Francisco",
365        "United States",
366        "Open Source",
367        "Machine Learning",
368        "Artificial Intelligence",
369        "The End",
370        "The Start",
371    ];
372    COMMON.iter().any(|c| phrase.eq_ignore_ascii_case(c))
373}
374
375/// Auto-link entities found in a memory's content to identities.
376///
377/// This function:
378/// 1. Extracts entities from content
379/// 2. Creates new identities for unresolved entities (optional)
380/// 3. Links all resolved entities to the memory
381///
382/// # Arguments
383/// * `conn` - Database connection
384/// * `memory_id` - Memory to link entities to
385/// * `content` - Memory content to extract from
386/// * `auto_create` - Whether to create new identities for unresolved entities
387///
388/// # Returns
389/// Number of entities linked
390#[instrument(skip(conn, content), fields(memory_id, auto_create, content_len = content.len()))]
391pub fn auto_link_memory(
392    conn: &Connection,
393    memory_id: i64,
394    content: &str,
395    auto_create: bool,
396) -> Result<usize> {
397    use crate::storage::identity_links::{
398        create_identity, link_identity_to_memory, CreateIdentityInput, IdentityType,
399    };
400
401    let config = ExtractionConfig::default();
402    let result = extract_entities(content, &config, Some(conn));
403
404    let mut linked_count = 0;
405
406    for entity in result.entities {
407        let canonical_id = if let Some(id) = entity.resolved_id {
408            // Already resolved to existing identity
409            id
410        } else if auto_create {
411            // Create new identity for this entity
412            let entity_type = match entity.entity_type {
413                ExtractedEntityType::Email => IdentityType::Person,
414                ExtractedEntityType::Mention => IdentityType::Person,
415                ExtractedEntityType::Url => IdentityType::Organization,
416                ExtractedEntityType::Name => IdentityType::Person,
417                ExtractedEntityType::KnownAlias => IdentityType::Other,
418            };
419
420            let input = CreateIdentityInput {
421                canonical_id: format!("auto:{}", entity.normalized),
422                display_name: entity.mention_text.clone(),
423                entity_type,
424                description: Some("Auto-created from entity extraction".to_string()),
425                metadata: HashMap::new(),
426                aliases: vec![entity.mention_text.clone()],
427            };
428
429            match create_identity(conn, &input) {
430                Ok(identity) => identity.canonical_id,
431                Err(_) => continue, // Skip if creation fails (e.g., already exists)
432            }
433        } else {
434            continue; // Skip unresolved entities
435        };
436
437        // Link to memory
438        if link_identity_to_memory(conn, memory_id, &canonical_id, Some(&entity.mention_text))
439            .is_ok()
440        {
441            linked_count += 1;
442        }
443    }
444
445    Ok(linked_count)
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_extract_mentions() {
454        let config = ExtractionConfig {
455            lookup_aliases: false,
456            ..Default::default()
457        };
458
459        let result = extract_entities("Hello @alice and @bob-smith!", &config, None);
460
461        assert_eq!(result.entities.len(), 2);
462        assert_eq!(result.entities[0].mention_text, "@alice");
463        assert_eq!(result.entities[1].mention_text, "@bob-smith");
464    }
465
466    #[test]
467    fn test_extract_emails() {
468        let config = ExtractionConfig {
469            lookup_aliases: false,
470            extract_names: false,
471            extract_mentions: false,
472            extract_urls: false,
473            extract_emails: true,
474            ..Default::default()
475        };
476
477        let result = extract_entities("Contact us at hello@example.com", &config, None);
478
479        assert_eq!(result.entities.len(), 1);
480        assert_eq!(result.entities[0].mention_text, "hello@example.com");
481        assert_eq!(result.entities[0].entity_type, ExtractedEntityType::Email);
482    }
483
484    #[test]
485    fn test_extract_names() {
486        let config = ExtractionConfig {
487            lookup_aliases: false,
488            ..Default::default()
489        };
490
491        let result = extract_entities("I met John Smith yesterday", &config, None);
492
493        assert_eq!(result.entities.len(), 1);
494        assert_eq!(result.entities[0].mention_text, "John Smith");
495        assert_eq!(result.entities[0].entity_type, ExtractedEntityType::Name);
496    }
497
498    #[test]
499    fn test_empty_input() {
500        let config = ExtractionConfig::default();
501        let result = extract_entities("", &config, None);
502        assert!(result.entities.is_empty());
503
504        let result = extract_entities("   ", &config, None);
505        assert!(result.entities.is_empty());
506    }
507
508    #[test]
509    fn test_deduplication() {
510        let config = ExtractionConfig {
511            lookup_aliases: false,
512            ..Default::default()
513        };
514
515        let result = extract_entities("@alice said hello. @alice waved.", &config, None);
516
517        assert_eq!(result.entities.len(), 1);
518        assert_eq!(result.entities[0].count, 2);
519        assert_eq!(result.total_mentions, 2);
520    }
521
522    #[test]
523    fn test_max_entities_bound() {
524        let config = ExtractionConfig {
525            lookup_aliases: false,
526            max_entities: 2,
527            ..Default::default()
528        };
529
530        let result = extract_entities("@a @b @c @d @e", &config, None);
531
532        assert!(result.entities.len() <= 2);
533    }
534
535    #[test]
536    fn test_normalization_invariant() {
537        // Invariant: normalize_alias is idempotent
538        let inputs = vec![
539            "@Alice",
540            "  bob  ",
541            "@CHARLIE",
542            "user@email.com",
543            "  @mixed  CASE  ",
544        ];
545
546        for input in inputs {
547            let once = normalize_alias(input);
548            let twice = normalize_alias(&once);
549            assert_eq!(
550                once, twice,
551                "Normalization should be idempotent for: {}",
552                input
553            );
554        }
555    }
556
557    #[test]
558    fn test_never_panics_on_bad_input() {
559        let config = ExtractionConfig {
560            lookup_aliases: false,
561            ..Default::default()
562        };
563
564        // Pre-allocate strings that need longer lifetime
565        let long_a = "a".repeat(10000);
566        let long_at = "@".repeat(1000);
567
568        // Various edge cases that shouldn't panic
569        let inputs: Vec<&str> = vec![
570            "",
571            "   ",
572            "@",
573            "@@@@",
574            "@a",
575            "a@",
576            "http://",
577            "https://",
578            &long_a,
579            &long_at,
580            "\0\0\0",
581            "emoji: πŸŽ‰πŸŽŠπŸŽ",
582            "unicode: ζ—₯本θͺž δΈ­ζ–‡ ν•œκ΅­μ–΄",
583        ];
584
585        for input in inputs {
586            let result = extract_entities(input, &config, None);
587            // Just verify no panic
588            let _ = result.entities.len();
589        }
590    }
591}