1use std::collections::HashMap;
13
14use chrono::{DateTime, Utc};
15use rusqlite::{params, Connection};
16use serde::{Deserialize, Serialize};
17
18use crate::error::{EngramError, Result};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
22#[serde(rename_all = "lowercase")]
23pub enum IdentityType {
24 #[default]
25 Person,
26 Organization,
27 Project,
28 Tool,
29 Concept,
30 Other,
31}
32
33impl IdentityType {
34 pub fn as_str(&self) -> &'static str {
35 match self {
36 IdentityType::Person => "person",
37 IdentityType::Organization => "organization",
38 IdentityType::Project => "project",
39 IdentityType::Tool => "tool",
40 IdentityType::Concept => "concept",
41 IdentityType::Other => "other",
42 }
43 }
44}
45
46impl std::str::FromStr for IdentityType {
47 type Err = String;
48
49 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
50 match s.to_lowercase().as_str() {
51 "person" => Ok(IdentityType::Person),
52 "organization" | "org" => Ok(IdentityType::Organization),
53 "project" => Ok(IdentityType::Project),
54 "tool" => Ok(IdentityType::Tool),
55 "concept" => Ok(IdentityType::Concept),
56 "other" => Ok(IdentityType::Other),
57 _ => Err(format!("Unknown identity type: {}", s)),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct Identity {
65 pub id: i64,
66 pub canonical_id: String,
67 pub display_name: String,
68 pub entity_type: IdentityType,
69 pub description: Option<String>,
70 pub metadata: HashMap<String, serde_json::Value>,
71 pub created_at: DateTime<Utc>,
72 pub updated_at: DateTime<Utc>,
73 #[serde(default)]
74 pub aliases: Vec<IdentityAlias>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct IdentityAlias {
80 pub id: i64,
81 pub canonical_id: String,
82 pub alias: String,
83 pub alias_normalized: String,
84 pub source: Option<String>,
85 pub confidence: f32,
86 pub created_at: DateTime<Utc>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MemoryIdentityLink {
92 pub id: i64,
93 pub memory_id: i64,
94 pub canonical_id: String,
95 pub mention_text: Option<String>,
96 pub mention_count: i32,
97 pub created_at: DateTime<Utc>,
98}
99
100#[derive(Debug, Clone)]
102pub struct CreateIdentityInput {
103 pub canonical_id: String,
104 pub display_name: String,
105 pub entity_type: IdentityType,
106 pub description: Option<String>,
107 pub metadata: HashMap<String, serde_json::Value>,
108 pub aliases: Vec<String>,
109}
110
111pub fn normalize_alias(s: &str) -> String {
119 s.trim()
120 .to_lowercase()
121 .split_whitespace()
122 .collect::<Vec<_>>()
123 .join(" ")
124 .trim_start_matches(|c: char| !c.is_alphanumeric())
125 .trim_end_matches(|c: char| !c.is_alphanumeric())
126 .to_string()
127}
128
129pub fn create_identity(conn: &Connection, input: &CreateIdentityInput) -> Result<Identity> {
131 let now = Utc::now();
132 let now_str = now.to_rfc3339();
133 let metadata_json = serde_json::to_string(&input.metadata)?;
134
135 conn.execute(
136 r#"
137 INSERT INTO identities (canonical_id, display_name, entity_type, description, metadata, created_at, updated_at)
138 VALUES (?, ?, ?, ?, ?, ?, ?)
139 "#,
140 params![
141 input.canonical_id,
142 input.display_name,
143 input.entity_type.as_str(),
144 input.description,
145 metadata_json,
146 now_str,
147 now_str,
148 ],
149 )?;
150
151 let _id = conn.last_insert_rowid();
152
153 for alias in &input.aliases {
155 add_alias_internal(conn, &input.canonical_id, alias, None)?;
156 }
157
158 let _ = add_alias_internal(
160 conn,
161 &input.canonical_id,
162 &input.display_name,
163 Some("display_name"),
164 );
165
166 get_identity(conn, &input.canonical_id)
167}
168
169pub fn get_identity(conn: &Connection, canonical_id: &str) -> Result<Identity> {
171 let identity = conn.query_row(
172 r#"
173 SELECT id, canonical_id, display_name, entity_type, description, metadata, created_at, updated_at
174 FROM identities WHERE canonical_id = ?
175 "#,
176 params![canonical_id],
177 |row| {
178 let entity_type_str: String = row.get(3)?;
179 let metadata_str: String = row.get(5)?;
180 let created_at: String = row.get(6)?;
181 let updated_at: String = row.get(7)?;
182
183 Ok(Identity {
184 id: row.get(0)?,
185 canonical_id: row.get(1)?,
186 display_name: row.get(2)?,
187 entity_type: entity_type_str.parse().unwrap_or_default(),
188 description: row.get(4)?,
189 metadata: serde_json::from_str(&metadata_str).unwrap_or_default(),
190 created_at: DateTime::parse_from_rfc3339(&created_at)
191 .map(|dt| dt.with_timezone(&Utc))
192 .unwrap_or_else(|_| Utc::now()),
193 updated_at: DateTime::parse_from_rfc3339(&updated_at)
194 .map(|dt| dt.with_timezone(&Utc))
195 .unwrap_or_else(|_| Utc::now()),
196 aliases: vec![],
197 })
198 },
199 ).map_err(|_| EngramError::NotFound(0))?;
200
201 let mut identity = identity;
203 identity.aliases = get_aliases(conn, canonical_id)?;
204
205 Ok(identity)
206}
207
208pub fn update_identity(
210 conn: &Connection,
211 canonical_id: &str,
212 display_name: Option<&str>,
213 description: Option<&str>,
214 entity_type: Option<IdentityType>,
215) -> Result<Identity> {
216 let now = Utc::now().to_rfc3339();
217
218 let mut updates = vec!["updated_at = ?".to_string()];
220 let mut params: Vec<Box<dyn rusqlite::ToSql>> = vec![Box::new(now)];
221
222 if let Some(name) = display_name {
223 updates.push("display_name = ?".to_string());
224 params.push(Box::new(name.to_string()));
225 }
226
227 if let Some(desc) = description {
228 updates.push("description = ?".to_string());
229 params.push(Box::new(desc.to_string()));
230 }
231
232 if let Some(et) = entity_type {
233 updates.push("entity_type = ?".to_string());
234 params.push(Box::new(et.as_str().to_string()));
235 }
236
237 params.push(Box::new(canonical_id.to_string()));
238
239 let sql = format!(
240 "UPDATE identities SET {} WHERE canonical_id = ?",
241 updates.join(", ")
242 );
243
244 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|b| b.as_ref()).collect();
245 let affected = conn.execute(&sql, param_refs.as_slice())?;
246
247 if affected == 0 {
248 return Err(EngramError::NotFound(0));
249 }
250
251 get_identity(conn, canonical_id)
252}
253
254pub fn delete_identity(conn: &Connection, canonical_id: &str) -> Result<()> {
256 let affected = conn.execute(
257 "DELETE FROM identities WHERE canonical_id = ?",
258 params![canonical_id],
259 )?;
260
261 if affected == 0 {
262 return Err(EngramError::NotFound(0));
263 }
264
265 Ok(())
266}
267
268fn add_alias_internal(
274 conn: &Connection,
275 canonical_id: &str,
276 alias: &str,
277 source: Option<&str>,
278) -> Result<IdentityAlias> {
279 let normalized = normalize_alias(alias);
280
281 if normalized.is_empty() {
282 return Err(EngramError::InvalidInput(
283 "Alias cannot be empty".to_string(),
284 ));
285 }
286
287 let now = Utc::now();
288 let now_str = now.to_rfc3339();
289
290 let existing: Option<String> = conn
292 .query_row(
293 "SELECT canonical_id FROM identity_aliases WHERE alias_normalized = ?",
294 params![normalized],
295 |row| row.get(0),
296 )
297 .ok();
298
299 if let Some(existing_canonical) = existing {
300 if existing_canonical != canonical_id {
301 return Err(EngramError::Conflict(format!(
302 "Alias '{}' already belongs to identity '{}'",
303 alias, existing_canonical
304 )));
305 }
306 if let Some(src) = source {
308 conn.execute(
309 "UPDATE identity_aliases SET source = ? WHERE alias_normalized = ?",
310 params![src, normalized],
311 )?;
312 }
313 } else {
314 conn.execute(
316 r#"
317 INSERT INTO identity_aliases (canonical_id, alias, alias_normalized, source, created_at)
318 VALUES (?, ?, ?, ?, ?)
319 "#,
320 params![canonical_id, alias, normalized, source, now_str],
321 )?;
322 }
323
324 conn.query_row(
326 r#"
327 SELECT id, canonical_id, alias, alias_normalized, source, confidence, created_at
328 FROM identity_aliases WHERE alias_normalized = ?
329 "#,
330 params![normalized],
331 |row| {
332 let created_at: String = row.get(6)?;
333 Ok(IdentityAlias {
334 id: row.get(0)?,
335 canonical_id: row.get(1)?,
336 alias: row.get(2)?,
337 alias_normalized: row.get(3)?,
338 source: row.get(4)?,
339 confidence: row.get(5)?,
340 created_at: DateTime::parse_from_rfc3339(&created_at)
341 .map(|dt| dt.with_timezone(&Utc))
342 .unwrap_or_else(|_| Utc::now()),
343 })
344 },
345 )
346 .map_err(EngramError::Database)
347}
348
349pub fn add_alias(
351 conn: &Connection,
352 canonical_id: &str,
353 alias: &str,
354 source: Option<&str>,
355) -> Result<IdentityAlias> {
356 let _ = get_identity(conn, canonical_id)?;
358 add_alias_internal(conn, canonical_id, alias, source)
359}
360
361pub fn remove_alias(conn: &Connection, alias: &str) -> Result<()> {
363 let normalized = normalize_alias(alias);
364
365 let affected = conn.execute(
366 "DELETE FROM identity_aliases WHERE alias_normalized = ?",
367 params![normalized],
368 )?;
369
370 if affected == 0 {
371 return Err(EngramError::NotFound(0));
372 }
373
374 Ok(())
375}
376
377pub fn get_aliases(conn: &Connection, canonical_id: &str) -> Result<Vec<IdentityAlias>> {
379 let mut stmt = conn.prepare(
380 r#"
381 SELECT id, canonical_id, alias, alias_normalized, source, confidence, created_at
382 FROM identity_aliases WHERE canonical_id = ?
383 ORDER BY created_at
384 "#,
385 )?;
386
387 let aliases = stmt
388 .query_map(params![canonical_id], |row| {
389 let created_at: String = row.get(6)?;
390 Ok(IdentityAlias {
391 id: row.get(0)?,
392 canonical_id: row.get(1)?,
393 alias: row.get(2)?,
394 alias_normalized: row.get(3)?,
395 source: row.get(4)?,
396 confidence: row.get(5)?,
397 created_at: DateTime::parse_from_rfc3339(&created_at)
398 .map(|dt| dt.with_timezone(&Utc))
399 .unwrap_or_else(|_| Utc::now()),
400 })
401 })?
402 .filter_map(|r| r.ok())
403 .collect();
404
405 Ok(aliases)
406}
407
408pub fn resolve_alias(conn: &Connection, alias: &str) -> Result<Option<Identity>> {
410 let normalized = normalize_alias(alias);
411
412 let canonical_id: Option<String> = conn
413 .query_row(
414 "SELECT canonical_id FROM identity_aliases WHERE alias_normalized = ?",
415 params![normalized],
416 |row| row.get(0),
417 )
418 .ok();
419
420 match canonical_id {
421 Some(cid) => Ok(Some(get_identity(conn, &cid)?)),
422 None => Ok(None),
423 }
424}
425
426pub fn link_identity_to_memory(
428 conn: &Connection,
429 memory_id: i64,
430 canonical_id: &str,
431 mention_text: Option<&str>,
432) -> Result<MemoryIdentityLink> {
433 let _ = get_identity(conn, canonical_id)?;
435
436 let now = Utc::now().to_rfc3339();
437
438 conn.execute(
439 r#"
440 INSERT INTO memory_identity_links (memory_id, canonical_id, mention_text, mention_count, created_at)
441 VALUES (?, ?, ?, 1, ?)
442 ON CONFLICT(memory_id, canonical_id) DO UPDATE SET
443 mention_count = memory_identity_links.mention_count + 1,
444 mention_text = COALESCE(excluded.mention_text, memory_identity_links.mention_text)
445 "#,
446 params![memory_id, canonical_id, mention_text, now],
447 )?;
448
449 conn.query_row(
450 r#"
451 SELECT id, memory_id, canonical_id, mention_text, mention_count, created_at
452 FROM memory_identity_links WHERE memory_id = ? AND canonical_id = ?
453 "#,
454 params![memory_id, canonical_id],
455 |row| {
456 let created_at: String = row.get(5)?;
457 Ok(MemoryIdentityLink {
458 id: row.get(0)?,
459 memory_id: row.get(1)?,
460 canonical_id: row.get(2)?,
461 mention_text: row.get(3)?,
462 mention_count: row.get(4)?,
463 created_at: DateTime::parse_from_rfc3339(&created_at)
464 .map(|dt| dt.with_timezone(&Utc))
465 .unwrap_or_else(|_| Utc::now()),
466 })
467 },
468 )
469 .map_err(EngramError::Database)
470}
471
472pub fn unlink_identity_from_memory(
474 conn: &Connection,
475 memory_id: i64,
476 canonical_id: &str,
477) -> Result<()> {
478 let affected = conn.execute(
479 "DELETE FROM memory_identity_links WHERE memory_id = ? AND canonical_id = ?",
480 params![memory_id, canonical_id],
481 )?;
482
483 if affected == 0 {
484 return Err(EngramError::NotFound(0));
485 }
486
487 Ok(())
488}
489
490pub fn get_memory_identities(conn: &Connection, memory_id: i64) -> Result<Vec<Identity>> {
492 let mut stmt = conn.prepare(
493 r#"
494 SELECT DISTINCT i.canonical_id
495 FROM identities i
496 JOIN memory_identity_links mil ON i.canonical_id = mil.canonical_id
497 WHERE mil.memory_id = ?
498 "#,
499 )?;
500
501 let canonical_ids: Vec<String> = stmt
502 .query_map(params![memory_id], |row| row.get(0))?
503 .filter_map(|r| r.ok())
504 .collect();
505
506 let mut identities = Vec::new();
507 for cid in canonical_ids {
508 if let Ok(identity) = get_identity(conn, &cid) {
509 identities.push(identity);
510 }
511 }
512
513 Ok(identities)
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct IdentityWithMention {
519 #[serde(flatten)]
520 pub identity: Identity,
521 pub mention_text: Option<String>,
522 pub mention_count: i32,
523}
524
525pub fn get_memory_identities_with_mentions(
528 conn: &Connection,
529 memory_id: i64,
530) -> Result<Vec<IdentityWithMention>> {
531 let mut stmt = conn.prepare(
532 r#"
533 SELECT i.canonical_id, i.display_name, i.entity_type, i.description,
534 i.metadata, i.created_at, i.updated_at,
535 mil.mention_text, mil.mention_count
536 FROM identities i
537 JOIN memory_identity_links mil ON i.canonical_id = mil.canonical_id
538 WHERE mil.memory_id = ?
539 "#,
540 )?;
541
542 let results: Vec<IdentityWithMention> = stmt
543 .query_map(params![memory_id], |row| {
544 let canonical_id: String = row.get(0)?;
545 let display_name: String = row.get(1)?;
546 let entity_type: String = row.get(2)?;
547 let description: Option<String> = row.get(3)?;
548 let metadata_str: String = row.get(4)?;
549 let created_at: String = row.get(5)?;
550 let updated_at: String = row.get(6)?;
551 let mention_text: Option<String> = row.get(7)?;
552 let mention_count: i32 = row.get(8)?;
553
554 let metadata: std::collections::HashMap<String, serde_json::Value> =
555 serde_json::from_str(&metadata_str).unwrap_or_default();
556
557 Ok(IdentityWithMention {
558 identity: Identity {
559 id: 0, canonical_id,
561 display_name,
562 entity_type: entity_type.parse().unwrap_or(IdentityType::Other),
563 description,
564 metadata,
565 created_at: chrono::DateTime::parse_from_rfc3339(&created_at)
566 .map(|dt| dt.with_timezone(&chrono::Utc))
567 .unwrap_or_else(|_| chrono::Utc::now()),
568 updated_at: chrono::DateTime::parse_from_rfc3339(&updated_at)
569 .map(|dt| dt.with_timezone(&chrono::Utc))
570 .unwrap_or_else(|_| chrono::Utc::now()),
571 aliases: vec![], },
573 mention_text,
574 mention_count,
575 })
576 })?
577 .filter_map(|r| r.ok())
578 .collect();
579
580 Ok(results)
581}
582
583pub fn get_identity_memories(conn: &Connection, canonical_id: &str) -> Result<Vec<i64>> {
585 let mut stmt =
586 conn.prepare("SELECT memory_id FROM memory_identity_links WHERE canonical_id = ?")?;
587
588 let memory_ids = stmt
589 .query_map(params![canonical_id], |row| row.get(0))?
590 .filter_map(|r| r.ok())
591 .collect();
592
593 Ok(memory_ids)
594}
595
596pub fn list_identities(
598 conn: &Connection,
599 entity_type: Option<IdentityType>,
600 limit: i64,
601) -> Result<Vec<Identity>> {
602 let mut sql = String::from("SELECT canonical_id FROM identities");
603
604 let mut params: Vec<Box<dyn rusqlite::ToSql>> = vec![];
605
606 if let Some(et) = entity_type {
607 sql.push_str(" WHERE entity_type = ?");
608 params.push(Box::new(et.as_str().to_string()));
609 }
610
611 sql.push_str(" ORDER BY display_name LIMIT ?");
612 params.push(Box::new(limit));
613
614 let param_refs: Vec<&dyn rusqlite::ToSql> = params.iter().map(|b| b.as_ref()).collect();
615 let mut stmt = conn.prepare(&sql)?;
616
617 let canonical_ids: Vec<String> = stmt
618 .query_map(param_refs.as_slice(), |row| row.get(0))?
619 .filter_map(|r| r.ok())
620 .collect();
621
622 let mut identities = Vec::new();
623 for cid in canonical_ids {
624 if let Ok(identity) = get_identity(conn, &cid) {
625 identities.push(identity);
626 }
627 }
628
629 Ok(identities)
630}
631
632pub fn search_identities_by_alias(
634 conn: &Connection,
635 query: &str,
636 limit: i64,
637) -> Result<Vec<Identity>> {
638 let normalized = normalize_alias(query);
639 let pattern = format!("%{}%", normalized);
640
641 let mut stmt = conn.prepare(
642 r#"
643 SELECT DISTINCT i.canonical_id
644 FROM identities i
645 LEFT JOIN identity_aliases ia ON i.canonical_id = ia.canonical_id
646 WHERE ia.alias_normalized LIKE ? OR i.display_name LIKE ?
647 LIMIT ?
648 "#,
649 )?;
650
651 let canonical_ids: Vec<String> = stmt
652 .query_map(params![pattern, pattern, limit], |row| row.get(0))?
653 .filter_map(|r| r.ok())
654 .collect();
655
656 let mut identities = Vec::new();
657 for cid in canonical_ids {
658 if let Ok(identity) = get_identity(conn, &cid) {
659 identities.push(identity);
660 }
661 }
662
663 Ok(identities)
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669 use crate::storage::Storage;
670
671 #[test]
672 fn test_normalize_alias() {
673 assert_eq!(normalize_alias(" Ronaldo "), "ronaldo");
674 assert_eq!(normalize_alias("@ronaldo"), "ronaldo");
675 assert_eq!(normalize_alias("Lima Ronaldo"), "lima ronaldo");
676 assert_eq!(normalize_alias("#project-x"), "project-x");
677 assert_eq!(normalize_alias(" UPPER CASE "), "upper case");
678 }
679
680 #[test]
681 fn test_create_identity() {
682 let storage = Storage::open_in_memory().unwrap();
683
684 storage
685 .with_connection(|conn| {
686 let input = CreateIdentityInput {
687 canonical_id: "user:ronaldo".to_string(),
688 display_name: "Ronaldo".to_string(),
689 entity_type: IdentityType::Person,
690 description: Some("A developer".to_string()),
691 metadata: HashMap::new(),
692 aliases: vec!["@ronaldo".to_string(), "limaronaldo".to_string()],
693 };
694
695 let identity = create_identity(conn, &input)?;
696
697 assert_eq!(identity.canonical_id, "user:ronaldo");
698 assert_eq!(identity.display_name, "Ronaldo");
699 assert_eq!(identity.entity_type, IdentityType::Person);
700 assert!(identity.aliases.len() >= 2);
702
703 Ok(())
704 })
705 .unwrap();
706 }
707
708 #[test]
709 fn test_alias_conflict() {
710 let storage = Storage::open_in_memory().unwrap();
711
712 storage
713 .with_connection(|conn| {
714 let input1 = CreateIdentityInput {
716 canonical_id: "user:alice".to_string(),
717 display_name: "Alice".to_string(),
718 entity_type: IdentityType::Person,
719 description: None,
720 metadata: HashMap::new(),
721 aliases: vec!["ally".to_string()],
722 };
723 create_identity(conn, &input1)?;
724
725 let input2 = CreateIdentityInput {
727 canonical_id: "user:bob".to_string(),
728 display_name: "Bob".to_string(),
729 entity_type: IdentityType::Person,
730 description: None,
731 metadata: HashMap::new(),
732 aliases: vec![],
733 };
734 create_identity(conn, &input2)?;
735
736 let result = add_alias(conn, "user:bob", "ALLY", None); assert!(result.is_err());
739
740 Ok(())
741 })
742 .unwrap();
743 }
744
745 #[test]
746 fn test_resolve_alias() {
747 let storage = Storage::open_in_memory().unwrap();
748
749 storage
750 .with_connection(|conn| {
751 let input = CreateIdentityInput {
752 canonical_id: "user:charlie".to_string(),
753 display_name: "Charlie".to_string(),
754 entity_type: IdentityType::Person,
755 description: None,
756 metadata: HashMap::new(),
757 aliases: vec!["chuck".to_string(), "@charlie".to_string()],
758 };
759 create_identity(conn, &input)?;
760
761 let resolved = resolve_alias(conn, "CHUCK")?;
763 assert!(resolved.is_some());
764 assert_eq!(resolved.unwrap().canonical_id, "user:charlie");
765
766 let resolved = resolve_alias(conn, "@Charlie")?;
767 assert!(resolved.is_some());
768
769 let resolved = resolve_alias(conn, "unknown")?;
770 assert!(resolved.is_none());
771
772 Ok(())
773 })
774 .unwrap();
775 }
776}