1use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
20pub enum EntityType {
21 Person,
22 Place,
23 Organization,
24 Concept,
25 Product,
26 Event,
27 Other(String),
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct Entity {
33 pub id: String,
35 pub name: String,
37 pub entity_type: EntityType,
39 pub attributes: HashMap<String, String>,
41 pub first_mentioned: u64,
43 pub last_mentioned: u64,
45 pub mention_count: u32,
47 pub source_memories: Vec<String>,
49}
50
51impl Entity {
52 pub fn new(name: String, entity_type: EntityType, memory_id: String) -> Self {
53 let timestamp = Self::current_timestamp();
54 Self {
55 id: uuid::Uuid::new_v4().to_string(),
56 name,
57 entity_type,
58 attributes: HashMap::new(),
59 first_mentioned: timestamp,
60 last_mentioned: timestamp,
61 mention_count: 1,
62 source_memories: vec![memory_id],
63 }
64 }
65
66 pub fn update_mention(&mut self, memory_id: String) {
67 self.last_mentioned = Self::current_timestamp();
68 self.mention_count += 1;
69 if !self.source_memories.contains(&memory_id) {
70 self.source_memories.push(memory_id);
71 }
72 }
73
74 fn current_timestamp() -> u64 {
75 std::time::SystemTime::now()
76 .duration_since(std::time::UNIX_EPOCH)
77 .unwrap()
78 .as_secs()
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Fact {
85 pub id: String,
87 pub subject: String,
89 pub predicate: String,
91 pub object: String,
93 pub confidence: f32,
95 pub learned_at: u64,
97 pub last_confirmed: u64,
99 pub confirmation_count: u32,
101 pub source_memories: Vec<String>,
103}
104
105impl Fact {
106 pub fn new(subject: String, predicate: String, object: String, memory_id: String) -> Self {
107 let timestamp = Self::current_timestamp();
108 Self {
109 id: uuid::Uuid::new_v4().to_string(),
110 subject,
111 predicate,
112 object,
113 confidence: 0.7, learned_at: timestamp,
115 last_confirmed: timestamp,
116 confirmation_count: 1,
117 source_memories: vec![memory_id],
118 }
119 }
120
121 pub fn confirm(&mut self, memory_id: String) {
122 self.last_confirmed = Self::current_timestamp();
123 self.confirmation_count += 1;
124 self.confidence = (self.confidence + 0.1).min(1.0);
126 if !self.source_memories.contains(&memory_id) {
127 self.source_memories.push(memory_id);
128 }
129 }
130
131 fn current_timestamp() -> u64 {
132 std::time::SystemTime::now()
133 .duration_since(std::time::UNIX_EPOCH)
134 .unwrap()
135 .as_secs()
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct Relationship {
142 pub id: String,
143 pub from_entity: String,
144 pub to_entity: String,
145 pub relationship_type: String,
146 pub strength: f32,
147 pub created_at: u64,
148}
149
150pub struct SemanticMemory {
152 entities: Arc<RwLock<HashMap<String, Entity>>>,
154 entity_names: Arc<RwLock<HashMap<String, String>>>,
156 facts: Arc<RwLock<HashMap<String, Fact>>>,
158 fact_index: Arc<RwLock<HashMap<String, Vec<String>>>>,
160 relationships: Arc<RwLock<HashMap<String, Relationship>>>,
162}
163
164impl SemanticMemory {
165 pub fn new() -> Self {
166 Self {
167 entities: Arc::new(RwLock::new(HashMap::new())),
168 entity_names: Arc::new(RwLock::new(HashMap::new())),
169 facts: Arc::new(RwLock::new(HashMap::new())),
170 fact_index: Arc::new(RwLock::new(HashMap::new())),
171 relationships: Arc::new(RwLock::new(HashMap::new())),
172 }
173 }
174
175 pub async fn add_entity(&self, name: String, entity_type: EntityType, memory_id: String) -> String {
181 let mut entities = self.entities.write().await;
182 let mut entity_names = self.entity_names.write().await;
183
184 if let Some(entity_id) = entity_names.get(&name) {
186 if let Some(entity) = entities.get_mut(entity_id) {
187 entity.update_mention(memory_id);
188 return entity_id.clone();
189 }
190 }
191
192 let entity = Entity::new(name.clone(), entity_type, memory_id);
194 let id = entity.id.clone();
195 entity_names.insert(name, id.clone());
196 entities.insert(id.clone(), entity);
197
198 id
199 }
200
201 pub async fn get_entity(&self, id: &str) -> Option<Entity> {
203 let entities = self.entities.read().await;
204 entities.get(id).cloned()
205 }
206
207 pub async fn get_entity_by_name(&self, name: &str) -> Option<Entity> {
209 let entity_names = self.entity_names.read().await;
210 if let Some(id) = entity_names.get(name) {
211 let entities = self.entities.read().await;
212 entities.get(id).cloned()
213 } else {
214 None
215 }
216 }
217
218 pub async fn get_entities_by_type(&self, entity_type: EntityType) -> Vec<Entity> {
220 let entities = self.entities.read().await;
221 entities
222 .values()
223 .filter(|e| e.entity_type == entity_type)
224 .cloned()
225 .collect()
226 }
227
228 pub async fn search_entities(&self, query: &str) -> Vec<Entity> {
230 let entities = self.entities.read().await;
231 let query_lower = query.to_lowercase();
232 entities
233 .values()
234 .filter(|e| e.name.to_lowercase().contains(&query_lower))
235 .cloned()
236 .collect()
237 }
238
239 pub async fn add_fact(
245 &self,
246 subject: String,
247 predicate: String,
248 object: String,
249 memory_id: String,
250 ) -> String {
251 let mut facts = self.facts.write().await;
252 let mut fact_index = self.fact_index.write().await;
253
254 if let Some(fact_ids) = fact_index.get(&subject) {
256 for fact_id in fact_ids {
257 if let Some(fact) = facts.get_mut(fact_id) {
258 if fact.predicate == predicate && fact.object == object {
259 fact.confirm(memory_id);
260 return fact_id.clone();
261 }
262 }
263 }
264 }
265
266 let fact = Fact::new(subject.clone(), predicate, object, memory_id);
268 let id = fact.id.clone();
269
270 fact_index
272 .entry(subject)
273 .or_insert_with(Vec::new)
274 .push(id.clone());
275
276 facts.insert(id.clone(), fact);
277
278 id
279 }
280
281 pub async fn get_fact(&self, id: &str) -> Option<Fact> {
283 let facts = self.facts.read().await;
284 facts.get(id).cloned()
285 }
286
287 pub async fn get_facts_about(&self, subject: &str) -> Vec<Fact> {
289 let facts = self.facts.read().await;
290 let fact_index = self.fact_index.read().await;
291
292 if let Some(fact_ids) = fact_index.get(subject) {
293 fact_ids
294 .iter()
295 .filter_map(|id| facts.get(id).cloned())
296 .collect()
297 } else {
298 Vec::new()
299 }
300 }
301
302 pub async fn get_facts_by_predicate(&self, predicate: &str) -> Vec<Fact> {
304 let facts = self.facts.read().await;
305 facts
306 .values()
307 .filter(|f| f.predicate == predicate)
308 .cloned()
309 .collect()
310 }
311
312 pub async fn get_high_confidence_facts(&self, threshold: f32) -> Vec<Fact> {
314 let facts = self.facts.read().await;
315 facts
316 .values()
317 .filter(|f| f.confidence >= threshold)
318 .cloned()
319 .collect()
320 }
321
322 pub async fn add_relationship(
328 &self,
329 from_entity: String,
330 to_entity: String,
331 relationship_type: String,
332 ) -> String {
333 let mut relationships = self.relationships.write().await;
334
335 let relationship = Relationship {
336 id: uuid::Uuid::new_v4().to_string(),
337 from_entity,
338 to_entity,
339 relationship_type,
340 strength: 1.0,
341 created_at: Self::current_timestamp(),
342 };
343
344 let id = relationship.id.clone();
345 relationships.insert(id.clone(), relationship);
346 id
347 }
348
349 pub async fn get_relationships(&self, entity_id: &str) -> Vec<Relationship> {
351 let relationships = self.relationships.read().await;
352 relationships
353 .values()
354 .filter(|r| r.from_entity == entity_id || r.to_entity == entity_id)
355 .cloned()
356 .collect()
357 }
358
359 pub async fn get_knowledge_summary(&self, entity_name: &str) -> Option<KnowledgeSummary> {
365 let entity = self.get_entity_by_name(entity_name).await?;
366 let facts = self.get_facts_about(entity_name).await;
367 let relationships = self.get_relationships(&entity.id).await;
368
369 Some(KnowledgeSummary {
370 entity: entity.clone(),
371 fact_count: facts.len(),
372 relationship_count: relationships.len(),
373 high_confidence_facts: facts
374 .iter()
375 .filter(|f| f.confidence >= 0.8)
376 .cloned()
377 .collect(),
378 })
379 }
380
381 pub async fn get_all_entity_names(&self) -> Vec<String> {
383 let entity_names = self.entity_names.read().await;
384 entity_names.keys().cloned().collect()
385 }
386
387 pub async fn get_statistics(&self) -> SemanticMemoryStats {
389 let entities = self.entities.read().await;
390 let facts = self.facts.read().await;
391 let relationships = self.relationships.read().await;
392
393 SemanticMemoryStats {
394 total_entities: entities.len(),
395 total_facts: facts.len(),
396 total_relationships: relationships.len(),
397 high_confidence_facts: facts.values().filter(|f| f.confidence >= 0.8).count(),
398 }
399 }
400
401 fn current_timestamp() -> u64 {
402 std::time::SystemTime::now()
403 .duration_since(std::time::UNIX_EPOCH)
404 .unwrap()
405 .as_secs()
406 }
407}
408
409impl Default for SemanticMemory {
410 fn default() -> Self {
411 Self::new()
412 }
413}
414
415#[derive(Debug, Clone)]
417pub struct KnowledgeSummary {
418 pub entity: Entity,
419 pub fact_count: usize,
420 pub relationship_count: usize,
421 pub high_confidence_facts: Vec<Fact>,
422}
423
424#[derive(Debug, Clone)]
426pub struct SemanticMemoryStats {
427 pub total_entities: usize,
428 pub total_facts: usize,
429 pub total_relationships: usize,
430 pub high_confidence_facts: usize,
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[tokio::test]
438 async fn test_semantic_memory_entities() {
439 let sm = SemanticMemory::new();
440
441 let id1 = sm
442 .add_entity("Alice".to_string(), EntityType::Person, "mem-1".to_string())
443 .await;
444 let id2 = sm
445 .add_entity("Alice".to_string(), EntityType::Person, "mem-2".to_string())
446 .await;
447
448 assert_eq!(id1, id2);
450
451 let entity = sm.get_entity_by_name("Alice").await.unwrap();
452 assert_eq!(entity.mention_count, 2);
453 }
454
455 #[tokio::test]
456 async fn test_semantic_memory_facts() {
457 let sm = SemanticMemory::new();
458
459 let id1 = sm
460 .add_fact(
461 "Alice".to_string(),
462 "likes".to_string(),
463 "programming".to_string(),
464 "mem-1".to_string(),
465 )
466 .await;
467
468 let id2 = sm
469 .add_fact(
470 "Alice".to_string(),
471 "likes".to_string(),
472 "programming".to_string(),
473 "mem-2".to_string(),
474 )
475 .await;
476
477 assert_eq!(id1, id2);
479
480 let fact = sm.get_fact(&id1).await.unwrap();
481 assert_eq!(fact.confirmation_count, 2);
482 assert!(fact.confidence > 0.7); }
484
485 #[tokio::test]
486 async fn test_semantic_memory_relationships() {
487 let sm = SemanticMemory::new();
488
489 let entity1 = sm
490 .add_entity("Alice".to_string(), EntityType::Person, "mem-1".to_string())
491 .await;
492 let entity2 = sm
493 .add_entity("Bob".to_string(), EntityType::Person, "mem-1".to_string())
494 .await;
495
496 sm.add_relationship(entity1.clone(), entity2.clone(), "friend".to_string())
497 .await;
498
499 let rels = sm.get_relationships(&entity1).await;
500 assert_eq!(rels.len(), 1);
501 assert_eq!(rels[0].relationship_type, "friend");
502 }
503}