Skip to main content

agentrs_memory/
in_memory.rs

1use async_trait::async_trait;
2
3use agentrs_core::{Memory, Message, Result};
4
5use crate::SearchableMemory;
6
7/// Default in-process memory backend.
8#[derive(Debug, Clone, Default)]
9pub struct InMemoryMemory {
10    messages: Vec<Message>,
11    max_messages: Option<usize>,
12}
13
14impl InMemoryMemory {
15    /// Creates an empty memory backend.
16    pub fn new() -> Self {
17        Self::default()
18    }
19
20    /// Creates an in-memory backend capped to the most recent messages.
21    pub fn with_max_messages(max_messages: usize) -> Self {
22        Self {
23            messages: Vec::new(),
24            max_messages: Some(max_messages),
25        }
26    }
27
28    fn trim(&mut self) {
29        let Some(max_messages) = self.max_messages else {
30            return;
31        };
32
33        if self.messages.len() <= max_messages {
34            return;
35        }
36
37        let overflow = self.messages.len() - max_messages;
38        self.messages.drain(0..overflow);
39    }
40}
41
42#[async_trait]
43impl Memory for InMemoryMemory {
44    async fn store(&mut self, _key: &str, value: Message) -> Result<()> {
45        self.messages.push(value);
46        self.trim();
47        Ok(())
48    }
49
50    async fn retrieve(&self, query: &str, limit: usize) -> Result<Vec<Message>> {
51        let query = query.to_lowercase();
52        Ok(self
53            .messages
54            .iter()
55            .filter(|message| message.text_content().to_lowercase().contains(&query))
56            .take(limit)
57            .cloned()
58            .collect())
59    }
60
61    async fn history(&self) -> Result<Vec<Message>> {
62        Ok(self.messages.clone())
63    }
64
65    async fn clear(&mut self) -> Result<()> {
66        self.messages.clear();
67        Ok(())
68    }
69}
70
71#[async_trait]
72impl SearchableMemory for InMemoryMemory {
73    async fn token_count(&self) -> Result<usize> {
74        Ok(self
75            .messages
76            .iter()
77            .map(|message| message.text_content().chars().count() / 4)
78            .sum())
79    }
80}