1use std::collections::HashMap;
24
25use once_cell::sync::Lazy;
26use regex::Regex;
27use rusqlite::Connection;
28use serde::{Deserialize, Serialize};
29use tracing::{debug, instrument, warn};
30
31use crate::error::Result;
32use crate::storage::identity_links::{normalize_alias, resolve_alias};
33
34const MAX_ENTITIES_PER_TEXT: usize = 100;
36
37const MIN_CONFIDENCE: f32 = 0.3;
39
40static MENTION_PATTERN: Lazy<Regex> =
42 Lazy::new(|| Regex::new(r"@([a-zA-Z][a-zA-Z0-9_-]{1,30})").expect("valid regex"));
43
44static EMAIL_PATTERN: Lazy<Regex> = Lazy::new(|| {
45 Regex::new(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}").expect("valid regex")
46});
47
48static URL_PATTERN: Lazy<Regex> =
49 Lazy::new(|| Regex::new(r"https?://([a-zA-Z0-9.-]+)(?:/[^\s]*)?").expect("valid regex"));
50
51static NAME_PATTERN: Lazy<Regex> =
53 Lazy::new(|| Regex::new(r"\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b").expect("valid regex"));
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ExtractedEntity {
58 pub mention_text: String,
60 pub normalized: String,
62 pub entity_type: ExtractedEntityType,
64 pub confidence: f32,
66 pub position: usize,
68 pub count: usize,
70 pub resolved_id: Option<String>,
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76#[serde(rename_all = "snake_case")]
77pub enum ExtractedEntityType {
78 Mention,
80 Email,
82 Url,
84 Name,
86 KnownAlias,
88}
89
90impl ExtractedEntityType {
91 fn default_confidence(&self) -> f32 {
93 match self {
94 ExtractedEntityType::Mention => 0.9,
95 ExtractedEntityType::Email => 0.95,
96 ExtractedEntityType::Url => 0.7,
97 ExtractedEntityType::Name => 0.5,
98 ExtractedEntityType::KnownAlias => 1.0,
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct ExtractionConfig {
106 pub extract_mentions: bool,
108 pub extract_emails: bool,
110 pub extract_urls: bool,
112 pub extract_names: bool,
114 pub lookup_aliases: bool,
116 pub min_confidence: f32,
118 pub max_entities: usize,
120}
121
122impl Default for ExtractionConfig {
123 fn default() -> Self {
124 Self {
125 extract_mentions: true,
126 extract_emails: true,
127 extract_urls: true,
128 extract_names: true,
129 lookup_aliases: true,
130 min_confidence: MIN_CONFIDENCE,
131 max_entities: MAX_ENTITIES_PER_TEXT,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ExtractionResult {
139 pub entities: Vec<ExtractedEntity>,
141 pub total_mentions: usize,
143 pub resolved_count: usize,
145}
146
147#[instrument(skip(content, config, conn), fields(content_len = content.len()))]
159pub fn extract_entities(
160 content: &str,
161 config: &ExtractionConfig,
162 conn: Option<&Connection>,
163) -> ExtractionResult {
164 let content = content.trim();
166 if content.is_empty() {
167 return ExtractionResult {
168 entities: vec![],
169 total_mentions: 0,
170 resolved_count: 0,
171 };
172 }
173
174 let mut entities_map: HashMap<String, ExtractedEntity> = HashMap::new();
176 let mut total_mentions = 0;
177
178 if config.extract_mentions {
180 for cap in MENTION_PATTERN.captures_iter(content) {
181 if let Some(m) = cap.get(1) {
182 let mention_text = format!("@{}", m.as_str());
183 let normalized = normalize_alias(&mention_text);
184 let position = cap.get(0).map(|c| c.start()).unwrap_or(0);
185
186 add_or_increment(
187 &mut entities_map,
188 mention_text,
189 normalized,
190 ExtractedEntityType::Mention,
191 position,
192 );
193 total_mentions += 1;
194 }
195
196 if entities_map.len() >= config.max_entities {
198 break;
199 }
200 }
201 }
202
203 if config.extract_emails && entities_map.len() < config.max_entities {
205 for cap in EMAIL_PATTERN.find_iter(content) {
206 let email = cap.as_str();
207 let normalized = normalize_alias(email);
208
209 add_or_increment(
210 &mut entities_map,
211 email.to_string(),
212 normalized,
213 ExtractedEntityType::Email,
214 cap.start(),
215 );
216 total_mentions += 1;
217
218 if entities_map.len() >= config.max_entities {
219 break;
220 }
221 }
222 }
223
224 if config.extract_urls && entities_map.len() < config.max_entities {
226 for cap in URL_PATTERN.captures_iter(content) {
227 if let Some(domain) = cap.get(1) {
228 let domain_str = domain.as_str();
229 if !is_common_domain(domain_str) {
231 let normalized = normalize_alias(domain_str);
232
233 add_or_increment(
234 &mut entities_map,
235 domain_str.to_string(),
236 normalized,
237 ExtractedEntityType::Url,
238 cap.get(0).map(|c| c.start()).unwrap_or(0),
239 );
240 total_mentions += 1;
241 }
242 }
243
244 if entities_map.len() >= config.max_entities {
245 break;
246 }
247 }
248 }
249
250 if config.extract_names && entities_map.len() < config.max_entities {
252 for cap in NAME_PATTERN.find_iter(content) {
253 let name = cap.as_str();
254 if !is_common_phrase(name) {
256 let normalized = normalize_alias(name);
257
258 add_or_increment(
259 &mut entities_map,
260 name.to_string(),
261 normalized,
262 ExtractedEntityType::Name,
263 cap.start(),
264 );
265 total_mentions += 1;
266 }
267
268 if entities_map.len() >= config.max_entities {
269 break;
270 }
271 }
272 }
273
274 let mut resolved_count = 0;
276 if config.lookup_aliases {
277 if let Some(conn) = conn {
278 for entity in entities_map.values_mut() {
279 if let Ok(Some(identity)) = resolve_alias(conn, &entity.normalized) {
280 entity.resolved_id = Some(identity.canonical_id);
281 entity.entity_type = ExtractedEntityType::KnownAlias;
282 entity.confidence = 1.0;
283 resolved_count += 1;
284 }
285 }
286 }
287 }
288
289 let mut entities: Vec<ExtractedEntity> = entities_map
291 .into_values()
292 .filter(|e| e.confidence >= config.min_confidence)
293 .collect();
294
295 entities.sort_by_key(|e| e.position);
297
298 entities.truncate(config.max_entities);
300
301 debug!(
302 entity_count = entities.len(),
303 total_mentions, resolved_count, "Entity extraction complete"
304 );
305
306 ExtractionResult {
307 entities,
308 total_mentions,
309 resolved_count,
310 }
311}
312
313fn add_or_increment(
315 map: &mut HashMap<String, ExtractedEntity>,
316 mention_text: String,
317 normalized: String,
318 entity_type: ExtractedEntityType,
319 position: usize,
320) {
321 if let Some(existing) = map.get_mut(&normalized) {
322 existing.count += 1;
323 } else {
324 map.insert(
325 normalized.clone(),
326 ExtractedEntity {
327 mention_text,
328 normalized,
329 entity_type,
330 confidence: entity_type.default_confidence(),
331 position,
332 count: 1,
333 resolved_id: None,
334 },
335 );
336 }
337}
338
339fn is_common_domain(domain: &str) -> bool {
341 const COMMON: &[&str] = &[
342 "google.com",
343 "github.com",
344 "stackoverflow.com",
345 "wikipedia.org",
346 "twitter.com",
347 "x.com",
348 "facebook.com",
349 "youtube.com",
350 "linkedin.com",
351 "medium.com",
352 "docs.rs",
353 "crates.io",
354 "rust-lang.org",
355 ];
356 COMMON.iter().any(|c| domain.eq_ignore_ascii_case(c))
357}
358
359fn is_common_phrase(phrase: &str) -> bool {
361 const COMMON: &[&str] = &[
362 "New York",
363 "Los Angeles",
364 "San Francisco",
365 "United States",
366 "Open Source",
367 "Machine Learning",
368 "Artificial Intelligence",
369 "The End",
370 "The Start",
371 ];
372 COMMON.iter().any(|c| phrase.eq_ignore_ascii_case(c))
373}
374
375#[instrument(skip(conn, content), fields(memory_id, auto_create, content_len = content.len()))]
391pub fn auto_link_memory(
392 conn: &Connection,
393 memory_id: i64,
394 content: &str,
395 auto_create: bool,
396) -> Result<usize> {
397 use crate::storage::identity_links::{
398 create_identity, link_identity_to_memory, CreateIdentityInput, IdentityType,
399 };
400
401 let config = ExtractionConfig::default();
402 let result = extract_entities(content, &config, Some(conn));
403
404 let mut linked_count = 0;
405
406 for entity in result.entities {
407 let canonical_id = if let Some(id) = entity.resolved_id {
408 id
410 } else if auto_create {
411 let entity_type = match entity.entity_type {
413 ExtractedEntityType::Email => IdentityType::Person,
414 ExtractedEntityType::Mention => IdentityType::Person,
415 ExtractedEntityType::Url => IdentityType::Organization,
416 ExtractedEntityType::Name => IdentityType::Person,
417 ExtractedEntityType::KnownAlias => IdentityType::Other,
418 };
419
420 let input = CreateIdentityInput {
421 canonical_id: format!("auto:{}", entity.normalized),
422 display_name: entity.mention_text.clone(),
423 entity_type,
424 description: Some("Auto-created from entity extraction".to_string()),
425 metadata: HashMap::new(),
426 aliases: vec![entity.mention_text.clone()],
427 };
428
429 match create_identity(conn, &input) {
430 Ok(identity) => identity.canonical_id,
431 Err(_) => continue, }
433 } else {
434 continue; };
436
437 if link_identity_to_memory(conn, memory_id, &canonical_id, Some(&entity.mention_text))
439 .is_ok()
440 {
441 linked_count += 1;
442 }
443 }
444
445 Ok(linked_count)
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_extract_mentions() {
454 let config = ExtractionConfig {
455 lookup_aliases: false,
456 ..Default::default()
457 };
458
459 let result = extract_entities("Hello @alice and @bob-smith!", &config, None);
460
461 assert_eq!(result.entities.len(), 2);
462 assert_eq!(result.entities[0].mention_text, "@alice");
463 assert_eq!(result.entities[1].mention_text, "@bob-smith");
464 }
465
466 #[test]
467 fn test_extract_emails() {
468 let config = ExtractionConfig {
469 lookup_aliases: false,
470 extract_names: false,
471 extract_mentions: false,
472 extract_urls: false,
473 extract_emails: true,
474 ..Default::default()
475 };
476
477 let result = extract_entities("Contact us at hello@example.com", &config, None);
478
479 assert_eq!(result.entities.len(), 1);
480 assert_eq!(result.entities[0].mention_text, "hello@example.com");
481 assert_eq!(result.entities[0].entity_type, ExtractedEntityType::Email);
482 }
483
484 #[test]
485 fn test_extract_names() {
486 let config = ExtractionConfig {
487 lookup_aliases: false,
488 ..Default::default()
489 };
490
491 let result = extract_entities("I met John Smith yesterday", &config, None);
492
493 assert_eq!(result.entities.len(), 1);
494 assert_eq!(result.entities[0].mention_text, "John Smith");
495 assert_eq!(result.entities[0].entity_type, ExtractedEntityType::Name);
496 }
497
498 #[test]
499 fn test_empty_input() {
500 let config = ExtractionConfig::default();
501 let result = extract_entities("", &config, None);
502 assert!(result.entities.is_empty());
503
504 let result = extract_entities(" ", &config, None);
505 assert!(result.entities.is_empty());
506 }
507
508 #[test]
509 fn test_deduplication() {
510 let config = ExtractionConfig {
511 lookup_aliases: false,
512 ..Default::default()
513 };
514
515 let result = extract_entities("@alice said hello. @alice waved.", &config, None);
516
517 assert_eq!(result.entities.len(), 1);
518 assert_eq!(result.entities[0].count, 2);
519 assert_eq!(result.total_mentions, 2);
520 }
521
522 #[test]
523 fn test_max_entities_bound() {
524 let config = ExtractionConfig {
525 lookup_aliases: false,
526 max_entities: 2,
527 ..Default::default()
528 };
529
530 let result = extract_entities("@a @b @c @d @e", &config, None);
531
532 assert!(result.entities.len() <= 2);
533 }
534
535 #[test]
536 fn test_normalization_invariant() {
537 let inputs = vec![
539 "@Alice",
540 " bob ",
541 "@CHARLIE",
542 "user@email.com",
543 " @mixed CASE ",
544 ];
545
546 for input in inputs {
547 let once = normalize_alias(input);
548 let twice = normalize_alias(&once);
549 assert_eq!(
550 once, twice,
551 "Normalization should be idempotent for: {}",
552 input
553 );
554 }
555 }
556
557 #[test]
558 fn test_never_panics_on_bad_input() {
559 let config = ExtractionConfig {
560 lookup_aliases: false,
561 ..Default::default()
562 };
563
564 let long_a = "a".repeat(10000);
566 let long_at = "@".repeat(1000);
567
568 let inputs: Vec<&str> = vec![
570 "",
571 " ",
572 "@",
573 "@@@@",
574 "@a",
575 "a@",
576 "http://",
577 "https://",
578 &long_a,
579 &long_at,
580 "\0\0\0",
581 "emoji: πππ",
582 "unicode: ζ₯ζ¬θͺ δΈζ νκ΅μ΄",
583 ];
584
585 for input in inputs {
586 let result = extract_entities(input, &config, None);
587 let _ = result.entities.len();
589 }
590 }
591}