1use std::collections::HashMap;
4use std::sync::Arc;
5use tracing::{debug, info, warn};
6use uuid::Uuid;
7use chrono::Utc;
8
9use crate::config::MemoryConfig;
10use crate::embeddings::{create_embedder, Embedder};
11use crate::errors::{LLMError, MemoryError};
12use crate::history::HistoryManager;
13use crate::llms::{create_llm, generate_json, GenerateOptions, LLM};
14use crate::models::{
15 AddOptions, AddResult, EventType, Filters, GetAllOptions, HistoryEntry, MemoryEvent,
16 MemoryRecord, Message, Messages, Payload, ResetOptions, Role, ScoredMemory, SearchOptions,
17 SearchResult,
18};
19use crate::vector_stores::{create_vector_store, VectorStore};
20use crate::rerankers::{create_reranker, Reranker};
21
22use super::prompts::{
23 format_fact_extraction_input, format_memory_update_input, FACT_EXTRACTION_PROMPT,
24 MEMORY_UPDATE_PROMPT,
25};
26
27pub struct Memory {
29 embedder: Arc<dyn Embedder>,
30 vector_store: Arc<dyn VectorStore>,
31 llm: Option<Arc<dyn LLM>>,
32 history: Option<Arc<HistoryManager>>,
33 reranker: Option<Arc<dyn Reranker>>,
34 #[allow(dead_code)]
35 config: MemoryConfig,
36}
37
38impl Memory {
39 pub async fn new(config: MemoryConfig) -> Result<Self, MemoryError> {
41 let embedder = create_embedder(&config.embedder)?;
42 let dimensions = embedder.dimensions();
43
44 let vector_store =
45 create_vector_store(&config.vector_store, &config.collection_name, dimensions).await?;
46
47 let llm = if let Some(llm_config) = &config.llm {
48 Some(create_llm(llm_config)?)
49 } else {
50 None
51 };
52
53 let history = if let Some(path) = &config.history_db_path {
54 Some(Arc::new(HistoryManager::new(path)?))
55 } else {
56 None
57 };
58
59 let reranker = if let Some(reranker_config) = &config.reranker {
60 Some(create_reranker(reranker_config)?)
61 } else {
62 None
63 };
64
65 info!(
66 "Initialized Memory with {} embedder, {} dimensions",
67 embedder.model_name(),
68 dimensions
69 );
70
71 Ok(Self {
72 embedder,
73 vector_store,
74 llm,
75 history,
76 reranker,
77 config,
78 })
79 }
80
81 pub async fn add(
83 &self,
84 messages: impl Into<Messages>,
85 options: AddOptions,
86 ) -> Result<AddResult, MemoryError> {
87 let messages = messages.into().into_messages();
88 if options.user_id.is_none() && options.agent_id.is_none() && options.run_id.is_none() {
90 return Err(MemoryError::InvalidInput(
91 "At least one of user_id, agent_id, or run_id is required".to_string(),
92 ));
93 }
94
95 let results = if options.infer && self.llm.is_some() {
96 self.add_with_inference(&messages, &options).await?
98 } else {
99 self.add_raw(&messages, &options).await?
101 };
102
103 Ok(AddResult { results })
104 }
105
106 async fn add_raw(
108 &self,
109 messages: &[Message],
110 options: &AddOptions,
111 ) -> Result<Vec<MemoryEvent>, MemoryError> {
112 let mut results = Vec::new();
113
114 for msg in messages {
115 if msg.role == Role::System {
116 continue;
117 }
118
119 let record = MemoryRecord::with_scoping(
120 msg.content.clone(),
121 options
122 .metadata
123 .as_ref()
124 .map(|m| serde_json::to_value(m).unwrap_or_default())
125 .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
126 options.user_id.clone(),
127 options.agent_id.clone(),
128 options.run_id.clone(),
129 );
130
131 let embedding = self.embedder.embed(&record.content).await?;
132 let payload = Payload::from(&record);
133
134 self.vector_store
135 .insert(&record.id.to_string(), embedding, payload)
136 .await?;
137
138 if let Some(history) = &self.history {
139 let _ = history.add_history(
140 record.id,
141 None,
142 record.content.clone(),
143 EventType::Add,
144 record.created_at,
145 record.user_id.clone(),
146 record.agent_id.clone(),
147 record.run_id.clone(),
148 );
149 }
150
151 results.push(MemoryEvent {
152 id: record.id,
153 memory: record.content,
154 event: EventType::Add,
155 });
156 }
157
158 Ok(results)
159 }
160
161 async fn add_with_inference(
163 &self,
164 messages: &[Message],
165 options: &AddOptions,
166 ) -> Result<Vec<MemoryEvent>, MemoryError> {
167 let llm = self.llm.as_ref().ok_or(LLMError::NotConfigured)?;
168
169 let messages_text = messages
171 .iter()
172 .map(|m| format!("{:?}: {}", m.role, m.content))
173 .collect::<Vec<_>>()
174 .join("\n");
175
176 let extraction_messages = vec![
178 Message::system(FACT_EXTRACTION_PROMPT),
179 Message::user(format_fact_extraction_input(&messages_text)),
180 ];
181
182 #[derive(serde::Deserialize)]
183 struct FactsResponse {
184 facts: Vec<String>,
185 }
186
187 let facts: FactsResponse = generate_json(
188 llm.as_ref(),
189 &extraction_messages,
190 GenerateOptions::default(),
191 )
192 .await?;
193
194 if facts.facts.is_empty() {
195 debug!("No facts extracted from messages");
196 return Ok(Vec::new());
197 }
198
199 info!("Extracted {} facts", facts.facts.len());
200
201 let mut existing_memories: Vec<(String, String)> = Vec::new(); let mut memory_map: HashMap<String, String> = HashMap::new(); let search_filters = Filters {
206 conditions: vec![],
207 logic: crate::models::FilterLogic::And,
208 };
209
210 for fact in &facts.facts {
211 let embedding = self.embedder.embed(fact).await?;
212
213 let similar = self
214 .vector_store
215 .search(&embedding, 5, Some(&search_filters))
216 .await?;
217
218 for result in similar {
219 let real_id = result.id.clone();
221 if !memory_map.values().any(|rid| rid == &real_id) {
222 let index = memory_map.len().to_string();
223 memory_map.insert(index.clone(), real_id);
224 existing_memories.push((index, result.payload.data));
225 }
226 }
227 }
228
229 let update_messages = vec![
231 Message::system(MEMORY_UPDATE_PROMPT),
232 Message::user(format_memory_update_input(&existing_memories, &facts.facts)),
233 ];
234
235 #[derive(serde::Deserialize)]
236 struct MemoryAction {
237 event: String,
238 text: Option<String>,
239 id: Option<String>,
240 }
241
242 #[derive(serde::Deserialize)]
243 struct MemoryActionsResponse {
244 memory: Vec<MemoryAction>,
245 }
246
247 let actions: MemoryActionsResponse = generate_json(
248 llm.as_ref(),
249 &update_messages,
250 GenerateOptions::default(),
251 )
252 .await?;
253
254 let mut results = Vec::new();
255
256 for action in actions.memory {
257 match action.event.to_uppercase().as_str() {
258 "ADD" => {
259 if let Some(text) = action.text {
260 let record = MemoryRecord::with_scoping(
261 &text,
262 options
263 .metadata
264 .as_ref()
265 .map(|m| serde_json::to_value(m).unwrap_or_default())
266 .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
267 options.user_id.clone(),
268 options.agent_id.clone(),
269 options.run_id.clone(),
270 );
271
272 let embedding = self.embedder.embed(&text).await?;
273 let payload = Payload::from(&record);
274
275 self.vector_store
276 .insert(&record.id.to_string(), embedding, payload)
277 .await?;
278
279 if let Some(history) = &self.history {
280 let _ = history.add_history(
281 record.id,
282 None,
283 record.content.clone(),
284 EventType::Add,
285 record.created_at,
286 record.user_id.clone(),
287 record.agent_id.clone(),
288 record.run_id.clone(),
289 );
290 }
291
292 results.push(MemoryEvent {
293 id: record.id,
294 memory: text,
295 event: EventType::Add,
296 });
297 }
298 }
299 "UPDATE" => {
300 if let (Some(index_id), Some(text)) = (action.id, action.text) {
301 if let Some(real_id) = memory_map.get(&index_id) {
302 debug!("Updating memory {} (index {}) with: {}", real_id, index_id, text);
303
304 match self.update(real_id, &text).await {
306 Ok(record) => {
307 results.push(MemoryEvent {
308 id: record.id,
309 memory: text,
310 event: EventType::Update,
311 });
312 },
313 Err(e) => {
314 warn!("Failed to update memory {}: {}", real_id, e);
315 }
316 }
317 } else {
318 warn!("LLM tried to update unknown memory index: {}", index_id);
319 }
320 }
321 }
322 "DELETE" => {
323 if let Some(index_id) = action.id {
324 if let Some(real_id) = memory_map.get(&index_id) {
325 debug!("Deleting memory {} (index {})", real_id, index_id);
326
327 match self.delete(real_id).await {
329 Ok(_) => {
330 if let Ok(uuid) = Uuid::parse_str(real_id) {
333 results.push(MemoryEvent {
334 id: uuid,
335 memory: String::new(), event: EventType::Delete,
337 });
338 }
339 },
340 Err(e) => {
341 warn!("Failed to delete memory {}: {}", real_id, e);
342 }
343 }
344 } else {
345 warn!("LLM tried to delete unknown memory index: {}", index_id);
346 }
347 }
348 }
349 "NOOP" => {
350 debug!("No action needed");
351 }
352 _ => {
353 warn!("Unknown memory action: {}", action.event);
354 }
355 }
356 }
357
358 Ok(results)
359 }
360
361 pub async fn search(
363 &self,
364 query: &str,
365 options: SearchOptions,
366 ) -> Result<SearchResult, MemoryError> {
367 let embedding = self.embedder.embed(query).await?;
368 let limit = options.limit.unwrap_or(10);
369 let threshold = options.threshold.unwrap_or(0.0);
370
371 let search_limit = if options.rerank { limit * 10 } else { limit * 2 };
373
374 let results = self
375 .vector_store
376 .search(&embedding, search_limit, options.filters.as_ref())
377 .await?;
378
379 let mut scored: Vec<ScoredMemory> = results
380 .into_iter()
381 .map(|r| r.to_scored_memory())
382 .collect();
383
384 scored.retain(|m| {
386 if let Some(ref user_id) = options.user_id {
387 if m.record.user_id.as_ref() != Some(user_id) {
388 return false;
389 }
390 }
391 if let Some(ref agent_id) = options.agent_id {
392 if m.record.agent_id.as_ref() != Some(agent_id) {
393 return false;
394 }
395 }
396 if let Some(ref run_id) = options.run_id {
397 if m.record.run_id.as_ref() != Some(run_id) {
398 return false;
399 }
400 }
401 true
402 });
403
404 scored.retain(|m| m.score >= threshold);
406
407 if options.rerank {
409 if let Some(reranker) = &self.reranker {
410 scored = reranker.rerank(query, scored).await?;
411 } else {
412 warn!("Reranking requested but no reranker configured");
413 }
414 }
415
416 scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
418 scored.truncate(limit);
419
420 Ok(SearchResult { results: scored })
421 }
422
423 pub async fn get(&self, id: &str) -> Result<Option<MemoryRecord>, MemoryError> {
425 let result = self.vector_store.get(id).await?;
426 Ok(result.map(|r| r.to_memory_record()))
427 }
428
429 pub async fn get_all(&self, options: GetAllOptions) -> Result<Vec<MemoryRecord>, MemoryError> {
431 let limit = options.limit.unwrap_or(100);
432 let results = self.vector_store.list(None, limit).await?;
433
434 let mut records: Vec<MemoryRecord> =
435 results.into_iter().map(|r| r.to_memory_record()).collect();
436
437 records.retain(|m| {
439 if let Some(ref user_id) = options.user_id {
440 if m.user_id.as_ref() != Some(user_id) {
441 return false;
442 }
443 }
444 if let Some(ref agent_id) = options.agent_id {
445 if m.agent_id.as_ref() != Some(agent_id) {
446 return false;
447 }
448 }
449 if let Some(ref run_id) = options.run_id {
450 if m.run_id.as_ref() != Some(run_id) {
451 return false;
452 }
453 }
454 true
455 });
456
457 Ok(records)
458 }
459
460 pub async fn update(&self, id: &str, content: &str) -> Result<MemoryRecord, MemoryError> {
462 let existing = self
464 .vector_store
465 .get(id)
466 .await?
467 .ok_or_else(|| MemoryError::NotFound(id.to_string()))?;
468
469 let mut record = existing.to_memory_record();
470 let previous_content = record.content.clone();
471 record.update_content(content);
472
473 let embedding = self.embedder.embed(content).await?;
474 let payload = Payload::from(&record);
475
476 self.vector_store
477 .update(id, Some(embedding), payload)
478 .await?;
479
480 if let Some(history) = &self.history {
481 let _ = history.add_history(
482 record.id,
483 Some(previous_content),
484 record.content.clone(),
485 EventType::Update,
486 Utc::now(),
487 record.user_id.clone(),
488 record.agent_id.clone(),
489 record.run_id.clone(),
490 );
491 }
492
493 Ok(record)
494 }
495
496 pub async fn delete(&self, id: &str) -> Result<(), MemoryError> {
498 let record = self.get(id).await?;
500
501 self.vector_store.delete(id).await?;
502
503 if let Some(record) = record {
504 if let Some(history) = &self.history {
505 let _ = history.add_history(
506 record.id,
507 Some(record.content),
508 "DELETED".to_string(),
509 EventType::Delete,
510 Utc::now(),
511 record.user_id,
512 record.agent_id,
513 record.run_id,
514 );
515 }
516 }
517
518 Ok(())
519 }
520
521 pub async fn history(&self, id: &str) -> Result<Vec<HistoryEntry>, MemoryError> {
523 if let Some(history) = &self.history {
524 let memory_id = Uuid::parse_str(id).map_err(|e| MemoryError::InvalidInput(e.to_string()))?;
525 history.get_history(memory_id)
526 } else {
527 Ok(Vec::new())
528 }
529 }
530
531 pub async fn reset(&self, options: ResetOptions) -> Result<(), MemoryError> {
533 let filters = if options.user_id.is_some() || options.agent_id.is_some() {
535 None
537 } else {
538 None
539 };
540
541 self.vector_store.delete_all(filters.as_ref()).await?;
542
543 if let Some(history) = &self.history {
544 if filters.is_none() {
546 history.reset()?;
547 }
548 }
549
550 Ok(())
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[tokio::test]
559 async fn test_memory_creation() {
560 let config = MemoryConfig::default();
561 let memory = Memory::new(config).await;
562 assert!(memory.is_ok());
563 }
564
565 #[tokio::test]
566 async fn test_add_raw() {
567 let config = MemoryConfig::default();
568 let memory = Memory::new(config).await.unwrap();
569
570 let result = memory
571 .add(
572 "Test memory content",
573 AddOptions {
574 user_id: Some("test_user".to_string()),
575 infer: false,
576 ..Default::default()
577 },
578 )
579 .await;
580
581 assert!(result.is_ok());
582 assert_eq!(result.unwrap().results.len(), 1);
583 }
584
585 #[tokio::test]
586 async fn test_search() {
587 let config = MemoryConfig::default();
588 let memory = Memory::new(config).await.unwrap();
589
590 memory
592 .add(
593 "I love programming in Rust",
594 AddOptions {
595 user_id: Some("test_user".to_string()),
596 infer: false,
597 ..Default::default()
598 },
599 )
600 .await
601 .unwrap();
602
603 let results = memory
605 .search(
606 "Rust programming",
607 SearchOptions {
608 user_id: Some("test_user".to_string()),
609 ..Default::default()
610 },
611 )
612 .await
613 .unwrap();
614
615 assert!(!results.results.is_empty());
616 }
617}