rexis_rag/agent/memory/
semantic.rs

1//! Semantic memory - facts and knowledge storage
2//!
3//! Semantic memory stores facts, preferences, and learned information about users,
4//! entities, and concepts. It's agent-scoped and persists across sessions.
5//!
6//! Supports optional vector embeddings for semantic similarity search.
7
8use crate::error::RragResult;
9use crate::storage::{Memory, MemoryValue};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13#[cfg(feature = "vector-search")]
14use super::vector::{Embedding, EmbeddingProvider, SearchResult};
15
16/// A semantic fact stored in memory
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Fact {
19    /// Unique identifier for the fact
20    pub id: String,
21
22    /// The subject of the fact (e.g., "user:123", "product:456")
23    pub subject: String,
24
25    /// The predicate/relation (e.g., "prefers", "is_located_in", "purchased")
26    pub predicate: String,
27
28    /// The object/value (can be any type)
29    pub object: MemoryValue,
30
31    /// Confidence score (0.0 to 1.0)
32    pub confidence: f64,
33
34    /// When the fact was created
35    pub created_at: chrono::DateTime<chrono::Utc>,
36
37    /// When the fact was last updated
38    pub updated_at: chrono::DateTime<chrono::Utc>,
39
40    /// Optional metadata
41    pub metadata: std::collections::HashMap<String, String>,
42
43    /// Optional vector embedding for similarity search
44    #[cfg(feature = "vector-search")]
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub embedding: Option<Embedding>,
47}
48
49impl Fact {
50    /// Create a new fact
51    pub fn new(
52        subject: impl Into<String>,
53        predicate: impl Into<String>,
54        object: impl Into<MemoryValue>,
55    ) -> Self {
56        let now = chrono::Utc::now();
57        Self {
58            id: uuid::Uuid::new_v4().to_string(),
59            subject: subject.into(),
60            predicate: predicate.into(),
61            object: object.into(),
62            confidence: 1.0,
63            created_at: now,
64            updated_at: now,
65            metadata: std::collections::HashMap::new(),
66            #[cfg(feature = "vector-search")]
67            embedding: None,
68        }
69    }
70
71    /// Set the embedding for this fact
72    #[cfg(feature = "vector-search")]
73    pub fn with_embedding(mut self, embedding: Embedding) -> Self {
74        self.embedding = Some(embedding);
75        self
76    }
77
78    /// Set confidence score
79    pub fn with_confidence(mut self, confidence: f64) -> Self {
80        self.confidence = confidence.clamp(0.0, 1.0);
81        self
82    }
83
84    /// Add metadata
85    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
86        self.metadata.insert(key.into(), value.into());
87        self
88    }
89}
90
91/// Semantic memory for agent knowledge
92pub struct SemanticMemory {
93    /// Storage backend
94    storage: Arc<dyn Memory>,
95
96    /// Namespace for this semantic memory (agent::{agent_id}::semantic)
97    namespace: String,
98}
99
100impl SemanticMemory {
101    /// Create a new semantic memory
102    pub fn new(storage: Arc<dyn Memory>, agent_id: String) -> Self {
103        let namespace = format!("agent::{}::semantic", agent_id);
104
105        Self { storage, namespace }
106    }
107
108    /// Store a fact
109    pub async fn store_fact(&self, fact: Fact) -> RragResult<()> {
110        let key = self.fact_key(&fact.id);
111        let value = serde_json::to_value(&fact).map_err(|e| {
112            crate::error::RragError::storage(
113                "serialize_fact",
114                std::io::Error::new(std::io::ErrorKind::Other, e),
115            )
116        })?;
117
118        self.storage.set(&key, MemoryValue::Json(value)).await
119    }
120
121    /// Retrieve a fact by ID
122    pub async fn get_fact(&self, fact_id: &str) -> RragResult<Option<Fact>> {
123        let key = self.fact_key(fact_id);
124        if let Some(value) = self.storage.get(&key).await? {
125            if let Some(json) = value.as_json() {
126                let fact = serde_json::from_value(json.clone()).map_err(|e| {
127                    crate::error::RragError::storage(
128                        "deserialize_fact",
129                        std::io::Error::new(std::io::ErrorKind::Other, e),
130                    )
131                })?;
132                return Ok(Some(fact));
133            }
134        }
135        Ok(None)
136    }
137
138    /// Delete a fact
139    pub async fn delete_fact(&self, fact_id: &str) -> RragResult<bool> {
140        let key = self.fact_key(fact_id);
141        self.storage.delete(&key).await
142    }
143
144    /// Find facts by subject
145    pub async fn find_by_subject(&self, subject: &str) -> RragResult<Vec<Fact>> {
146        // This is a simplified implementation
147        // In a production system, you'd want indexing or vector search
148        let all_keys = self.list_fact_keys().await?;
149        let mut matching_facts = Vec::new();
150
151        for key in all_keys {
152            if let Some(fact) = self.get_fact(&key).await? {
153                if fact.subject == subject {
154                    matching_facts.push(fact);
155                }
156            }
157        }
158
159        Ok(matching_facts)
160    }
161
162    /// Find facts by predicate
163    pub async fn find_by_predicate(&self, predicate: &str) -> RragResult<Vec<Fact>> {
164        let all_keys = self.list_fact_keys().await?;
165        let mut matching_facts = Vec::new();
166
167        for key in all_keys {
168            if let Some(fact) = self.get_fact(&key).await? {
169                if fact.predicate == predicate {
170                    matching_facts.push(fact);
171                }
172            }
173        }
174
175        Ok(matching_facts)
176    }
177
178    /// Find facts by subject and predicate
179    pub async fn find_by_subject_and_predicate(
180        &self,
181        subject: &str,
182        predicate: &str,
183    ) -> RragResult<Vec<Fact>> {
184        let all_keys = self.list_fact_keys().await?;
185        let mut matching_facts = Vec::new();
186
187        for key in all_keys {
188            if let Some(fact) = self.get_fact(&key).await? {
189                if fact.subject == subject && fact.predicate == predicate {
190                    matching_facts.push(fact);
191                }
192            }
193        }
194
195        Ok(matching_facts)
196    }
197
198    /// Get all facts
199    pub async fn get_all_facts(&self) -> RragResult<Vec<Fact>> {
200        let all_keys = self.list_fact_keys().await?;
201        let mut facts = Vec::new();
202
203        for key in all_keys {
204            if let Some(fact) = self.get_fact(&key).await? {
205                facts.push(fact);
206            }
207        }
208
209        Ok(facts)
210    }
211
212    /// Count facts
213    pub async fn count(&self) -> RragResult<usize> {
214        self.storage.count(Some(&self.namespace)).await
215    }
216
217    /// Clear all facts
218    pub async fn clear(&self) -> RragResult<()> {
219        self.storage.clear(Some(&self.namespace)).await
220    }
221
222    /// Generate fact key
223    fn fact_key(&self, fact_id: &str) -> String {
224        format!("{}::fact::{}", self.namespace, fact_id)
225    }
226
227    /// List all fact keys (IDs)
228    async fn list_fact_keys(&self) -> RragResult<Vec<String>> {
229        use crate::storage::MemoryQuery;
230
231        let query = MemoryQuery::new().with_namespace(self.namespace.clone());
232        let all_keys = self.storage.keys(&query).await?;
233
234        // Extract fact IDs from keys
235        let prefix = format!("{}::fact::", self.namespace);
236        let ids = all_keys
237            .into_iter()
238            .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
239            .collect();
240
241        Ok(ids)
242    }
243
244    /// Search for facts using vector similarity (requires 'vector-search' feature)
245    #[cfg(feature = "vector-search")]
246    pub async fn vector_search(
247        &self,
248        query_embedding: &Embedding,
249        limit: usize,
250        min_similarity: f32,
251    ) -> RragResult<Vec<SearchResult<Fact>>> {
252        let all_facts = self.get_all_facts().await?;
253        let mut results = Vec::new();
254
255        for fact in all_facts {
256            if let Some(fact_embedding) = &fact.embedding {
257                match query_embedding.cosine_similarity(fact_embedding) {
258                    Ok(similarity) => {
259                        if similarity >= min_similarity {
260                            results.push(SearchResult::new(fact, similarity));
261                        }
262                    }
263                    Err(_) => continue, // Skip facts with incompatible embeddings
264                }
265            }
266        }
267
268        // Sort by similarity (highest first)
269        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
270
271        // Take top N results
272        results.truncate(limit);
273
274        Ok(results)
275    }
276
277    /// Store a fact with automatic embedding generation (requires 'vector-search' feature)
278    #[cfg(feature = "vector-search")]
279    pub async fn store_fact_with_embedding<P>(&self, mut fact: Fact, provider: &P) -> RragResult<()>
280    where
281        P: EmbeddingProvider,
282    {
283        // Generate text representation for embedding
284        let text = format!(
285            "{} {} {}",
286            fact.subject,
287            fact.predicate,
288            fact.object.as_string().unwrap_or_default()
289        );
290
291        // Generate embedding
292        let embedding = provider.embed(&text).await?;
293        fact.embedding = Some(embedding);
294
295        // Store the fact
296        self.store_fact(fact).await
297    }
298
299    /// Find similar facts to a query text (requires 'vector-search' feature)
300    #[cfg(feature = "vector-search")]
301    pub async fn find_similar<P>(
302        &self,
303        query: &str,
304        provider: &P,
305        limit: usize,
306        min_similarity: f32,
307    ) -> RragResult<Vec<SearchResult<Fact>>>
308    where
309        P: EmbeddingProvider,
310    {
311        // Generate embedding for query
312        let query_embedding = provider.embed(query).await?;
313
314        // Search using embedding
315        self.vector_search(&query_embedding, limit, min_similarity)
316            .await
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::storage::InMemoryStorage;
324
325    #[tokio::test]
326    async fn test_semantic_memory_store_and_retrieve() {
327        let storage = Arc::new(InMemoryStorage::new());
328        let semantic = SemanticMemory::new(storage, "test-agent".to_string());
329
330        // Store a fact
331        let fact =
332            Fact::new("user:alice", "prefers", MemoryValue::from("dark_mode")).with_confidence(0.9);
333
334        let fact_id = fact.id.clone();
335        semantic.store_fact(fact).await.unwrap();
336
337        // Retrieve it
338        let retrieved = semantic.get_fact(&fact_id).await.unwrap().unwrap();
339        assert_eq!(retrieved.subject, "user:alice");
340        assert_eq!(retrieved.predicate, "prefers");
341        assert_eq!(retrieved.object.as_string(), Some("dark_mode"));
342        assert_eq!(retrieved.confidence, 0.9);
343    }
344
345    #[tokio::test]
346    async fn test_semantic_memory_find_by_subject() {
347        let storage = Arc::new(InMemoryStorage::new());
348        let semantic = SemanticMemory::new(storage, "test-agent".to_string());
349
350        // Store multiple facts
351        semantic
352            .store_fact(Fact::new(
353                "user:alice",
354                "prefers",
355                MemoryValue::from("dark_mode"),
356            ))
357            .await
358            .unwrap();
359        semantic
360            .store_fact(Fact::new(
361                "user:alice",
362                "likes",
363                MemoryValue::from("coffee"),
364            ))
365            .await
366            .unwrap();
367        semantic
368            .store_fact(Fact::new(
369                "user:bob",
370                "prefers",
371                MemoryValue::from("light_mode"),
372            ))
373            .await
374            .unwrap();
375
376        // Find by subject
377        let alice_facts = semantic.find_by_subject("user:alice").await.unwrap();
378        assert_eq!(alice_facts.len(), 2);
379
380        let bob_facts = semantic.find_by_subject("user:bob").await.unwrap();
381        assert_eq!(bob_facts.len(), 1);
382    }
383
384    #[tokio::test]
385    async fn test_semantic_memory_find_by_predicate() {
386        let storage = Arc::new(InMemoryStorage::new());
387        let semantic = SemanticMemory::new(storage, "test-agent".to_string());
388
389        // Store facts
390        semantic
391            .store_fact(Fact::new(
392                "user:alice",
393                "prefers",
394                MemoryValue::from("dark_mode"),
395            ))
396            .await
397            .unwrap();
398        semantic
399            .store_fact(Fact::new(
400                "user:bob",
401                "prefers",
402                MemoryValue::from("light_mode"),
403            ))
404            .await
405            .unwrap();
406        semantic
407            .store_fact(Fact::new(
408                "user:alice",
409                "likes",
410                MemoryValue::from("coffee"),
411            ))
412            .await
413            .unwrap();
414
415        // Find by predicate
416        let prefer_facts = semantic.find_by_predicate("prefers").await.unwrap();
417        assert_eq!(prefer_facts.len(), 2);
418
419        let like_facts = semantic.find_by_predicate("likes").await.unwrap();
420        assert_eq!(like_facts.len(), 1);
421    }
422
423    #[tokio::test]
424    async fn test_semantic_memory_delete() {
425        let storage = Arc::new(InMemoryStorage::new());
426        let semantic = SemanticMemory::new(storage, "test-agent".to_string());
427
428        // Store and delete
429        let fact = Fact::new("user:alice", "prefers", MemoryValue::from("dark_mode"));
430        let fact_id = fact.id.clone();
431        semantic.store_fact(fact).await.unwrap();
432
433        assert_eq!(semantic.count().await.unwrap(), 1);
434
435        semantic.delete_fact(&fact_id).await.unwrap();
436        assert_eq!(semantic.count().await.unwrap(), 0);
437    }
438}