1use crate::error::RragResult;
9use crate::storage::{Memory, MemoryValue};
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12
13#[cfg(feature = "vector-search")]
14use super::vector::{Embedding, EmbeddingProvider, SearchResult};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Fact {
19 pub id: String,
21
22 pub subject: String,
24
25 pub predicate: String,
27
28 pub object: MemoryValue,
30
31 pub confidence: f64,
33
34 pub created_at: chrono::DateTime<chrono::Utc>,
36
37 pub updated_at: chrono::DateTime<chrono::Utc>,
39
40 pub metadata: std::collections::HashMap<String, String>,
42
43 #[cfg(feature = "vector-search")]
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub embedding: Option<Embedding>,
47}
48
49impl Fact {
50 pub fn new(
52 subject: impl Into<String>,
53 predicate: impl Into<String>,
54 object: impl Into<MemoryValue>,
55 ) -> Self {
56 let now = chrono::Utc::now();
57 Self {
58 id: uuid::Uuid::new_v4().to_string(),
59 subject: subject.into(),
60 predicate: predicate.into(),
61 object: object.into(),
62 confidence: 1.0,
63 created_at: now,
64 updated_at: now,
65 metadata: std::collections::HashMap::new(),
66 #[cfg(feature = "vector-search")]
67 embedding: None,
68 }
69 }
70
71 #[cfg(feature = "vector-search")]
73 pub fn with_embedding(mut self, embedding: Embedding) -> Self {
74 self.embedding = Some(embedding);
75 self
76 }
77
78 pub fn with_confidence(mut self, confidence: f64) -> Self {
80 self.confidence = confidence.clamp(0.0, 1.0);
81 self
82 }
83
84 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
86 self.metadata.insert(key.into(), value.into());
87 self
88 }
89}
90
91pub struct SemanticMemory {
93 storage: Arc<dyn Memory>,
95
96 namespace: String,
98}
99
100impl SemanticMemory {
101 pub fn new(storage: Arc<dyn Memory>, agent_id: String) -> Self {
103 let namespace = format!("agent::{}::semantic", agent_id);
104
105 Self { storage, namespace }
106 }
107
108 pub async fn store_fact(&self, fact: Fact) -> RragResult<()> {
110 let key = self.fact_key(&fact.id);
111 let value = serde_json::to_value(&fact).map_err(|e| {
112 crate::error::RragError::storage(
113 "serialize_fact",
114 std::io::Error::new(std::io::ErrorKind::Other, e),
115 )
116 })?;
117
118 self.storage.set(&key, MemoryValue::Json(value)).await
119 }
120
121 pub async fn get_fact(&self, fact_id: &str) -> RragResult<Option<Fact>> {
123 let key = self.fact_key(fact_id);
124 if let Some(value) = self.storage.get(&key).await? {
125 if let Some(json) = value.as_json() {
126 let fact = serde_json::from_value(json.clone()).map_err(|e| {
127 crate::error::RragError::storage(
128 "deserialize_fact",
129 std::io::Error::new(std::io::ErrorKind::Other, e),
130 )
131 })?;
132 return Ok(Some(fact));
133 }
134 }
135 Ok(None)
136 }
137
138 pub async fn delete_fact(&self, fact_id: &str) -> RragResult<bool> {
140 let key = self.fact_key(fact_id);
141 self.storage.delete(&key).await
142 }
143
144 pub async fn find_by_subject(&self, subject: &str) -> RragResult<Vec<Fact>> {
146 let all_keys = self.list_fact_keys().await?;
149 let mut matching_facts = Vec::new();
150
151 for key in all_keys {
152 if let Some(fact) = self.get_fact(&key).await? {
153 if fact.subject == subject {
154 matching_facts.push(fact);
155 }
156 }
157 }
158
159 Ok(matching_facts)
160 }
161
162 pub async fn find_by_predicate(&self, predicate: &str) -> RragResult<Vec<Fact>> {
164 let all_keys = self.list_fact_keys().await?;
165 let mut matching_facts = Vec::new();
166
167 for key in all_keys {
168 if let Some(fact) = self.get_fact(&key).await? {
169 if fact.predicate == predicate {
170 matching_facts.push(fact);
171 }
172 }
173 }
174
175 Ok(matching_facts)
176 }
177
178 pub async fn find_by_subject_and_predicate(
180 &self,
181 subject: &str,
182 predicate: &str,
183 ) -> RragResult<Vec<Fact>> {
184 let all_keys = self.list_fact_keys().await?;
185 let mut matching_facts = Vec::new();
186
187 for key in all_keys {
188 if let Some(fact) = self.get_fact(&key).await? {
189 if fact.subject == subject && fact.predicate == predicate {
190 matching_facts.push(fact);
191 }
192 }
193 }
194
195 Ok(matching_facts)
196 }
197
198 pub async fn get_all_facts(&self) -> RragResult<Vec<Fact>> {
200 let all_keys = self.list_fact_keys().await?;
201 let mut facts = Vec::new();
202
203 for key in all_keys {
204 if let Some(fact) = self.get_fact(&key).await? {
205 facts.push(fact);
206 }
207 }
208
209 Ok(facts)
210 }
211
212 pub async fn count(&self) -> RragResult<usize> {
214 self.storage.count(Some(&self.namespace)).await
215 }
216
217 pub async fn clear(&self) -> RragResult<()> {
219 self.storage.clear(Some(&self.namespace)).await
220 }
221
222 fn fact_key(&self, fact_id: &str) -> String {
224 format!("{}::fact::{}", self.namespace, fact_id)
225 }
226
227 async fn list_fact_keys(&self) -> RragResult<Vec<String>> {
229 use crate::storage::MemoryQuery;
230
231 let query = MemoryQuery::new().with_namespace(self.namespace.clone());
232 let all_keys = self.storage.keys(&query).await?;
233
234 let prefix = format!("{}::fact::", self.namespace);
236 let ids = all_keys
237 .into_iter()
238 .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
239 .collect();
240
241 Ok(ids)
242 }
243
244 #[cfg(feature = "vector-search")]
246 pub async fn vector_search(
247 &self,
248 query_embedding: &Embedding,
249 limit: usize,
250 min_similarity: f32,
251 ) -> RragResult<Vec<SearchResult<Fact>>> {
252 let all_facts = self.get_all_facts().await?;
253 let mut results = Vec::new();
254
255 for fact in all_facts {
256 if let Some(fact_embedding) = &fact.embedding {
257 match query_embedding.cosine_similarity(fact_embedding) {
258 Ok(similarity) => {
259 if similarity >= min_similarity {
260 results.push(SearchResult::new(fact, similarity));
261 }
262 }
263 Err(_) => continue, }
265 }
266 }
267
268 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
270
271 results.truncate(limit);
273
274 Ok(results)
275 }
276
277 #[cfg(feature = "vector-search")]
279 pub async fn store_fact_with_embedding<P>(&self, mut fact: Fact, provider: &P) -> RragResult<()>
280 where
281 P: EmbeddingProvider,
282 {
283 let text = format!(
285 "{} {} {}",
286 fact.subject,
287 fact.predicate,
288 fact.object.as_string().unwrap_or_default()
289 );
290
291 let embedding = provider.embed(&text).await?;
293 fact.embedding = Some(embedding);
294
295 self.store_fact(fact).await
297 }
298
299 #[cfg(feature = "vector-search")]
301 pub async fn find_similar<P>(
302 &self,
303 query: &str,
304 provider: &P,
305 limit: usize,
306 min_similarity: f32,
307 ) -> RragResult<Vec<SearchResult<Fact>>>
308 where
309 P: EmbeddingProvider,
310 {
311 let query_embedding = provider.embed(query).await?;
313
314 self.vector_search(&query_embedding, limit, min_similarity)
316 .await
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::storage::InMemoryStorage;
324
325 #[tokio::test]
326 async fn test_semantic_memory_store_and_retrieve() {
327 let storage = Arc::new(InMemoryStorage::new());
328 let semantic = SemanticMemory::new(storage, "test-agent".to_string());
329
330 let fact =
332 Fact::new("user:alice", "prefers", MemoryValue::from("dark_mode")).with_confidence(0.9);
333
334 let fact_id = fact.id.clone();
335 semantic.store_fact(fact).await.unwrap();
336
337 let retrieved = semantic.get_fact(&fact_id).await.unwrap().unwrap();
339 assert_eq!(retrieved.subject, "user:alice");
340 assert_eq!(retrieved.predicate, "prefers");
341 assert_eq!(retrieved.object.as_string(), Some("dark_mode"));
342 assert_eq!(retrieved.confidence, 0.9);
343 }
344
345 #[tokio::test]
346 async fn test_semantic_memory_find_by_subject() {
347 let storage = Arc::new(InMemoryStorage::new());
348 let semantic = SemanticMemory::new(storage, "test-agent".to_string());
349
350 semantic
352 .store_fact(Fact::new(
353 "user:alice",
354 "prefers",
355 MemoryValue::from("dark_mode"),
356 ))
357 .await
358 .unwrap();
359 semantic
360 .store_fact(Fact::new(
361 "user:alice",
362 "likes",
363 MemoryValue::from("coffee"),
364 ))
365 .await
366 .unwrap();
367 semantic
368 .store_fact(Fact::new(
369 "user:bob",
370 "prefers",
371 MemoryValue::from("light_mode"),
372 ))
373 .await
374 .unwrap();
375
376 let alice_facts = semantic.find_by_subject("user:alice").await.unwrap();
378 assert_eq!(alice_facts.len(), 2);
379
380 let bob_facts = semantic.find_by_subject("user:bob").await.unwrap();
381 assert_eq!(bob_facts.len(), 1);
382 }
383
384 #[tokio::test]
385 async fn test_semantic_memory_find_by_predicate() {
386 let storage = Arc::new(InMemoryStorage::new());
387 let semantic = SemanticMemory::new(storage, "test-agent".to_string());
388
389 semantic
391 .store_fact(Fact::new(
392 "user:alice",
393 "prefers",
394 MemoryValue::from("dark_mode"),
395 ))
396 .await
397 .unwrap();
398 semantic
399 .store_fact(Fact::new(
400 "user:bob",
401 "prefers",
402 MemoryValue::from("light_mode"),
403 ))
404 .await
405 .unwrap();
406 semantic
407 .store_fact(Fact::new(
408 "user:alice",
409 "likes",
410 MemoryValue::from("coffee"),
411 ))
412 .await
413 .unwrap();
414
415 let prefer_facts = semantic.find_by_predicate("prefers").await.unwrap();
417 assert_eq!(prefer_facts.len(), 2);
418
419 let like_facts = semantic.find_by_predicate("likes").await.unwrap();
420 assert_eq!(like_facts.len(), 1);
421 }
422
423 #[tokio::test]
424 async fn test_semantic_memory_delete() {
425 let storage = Arc::new(InMemoryStorage::new());
426 let semantic = SemanticMemory::new(storage, "test-agent".to_string());
427
428 let fact = Fact::new("user:alice", "prefers", MemoryValue::from("dark_mode"));
430 let fact_id = fact.id.clone();
431 semantic.store_fact(fact).await.unwrap();
432
433 assert_eq!(semantic.count().await.unwrap(), 1);
434
435 semantic.delete_fact(&fact_id).await.unwrap();
436 assert_eq!(semantic.count().await.unwrap(), 0);
437 }
438}