alith_core/
memory.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::{num::NonZeroUsize, sync::Arc};
4use tokio::sync::Mutex;
5
6/// Represents the type of a message.
7#[derive(PartialEq, Eq, Serialize, Deserialize, Debug, Clone)]
8pub enum MessageType {
9    #[serde(rename = "system")]
10    System,
11    #[serde(rename = "human")]
12    Human,
13    #[serde(rename = "ai")]
14    AI,
15    #[serde(rename = "tool")]
16    Tool,
17}
18
19impl Default for MessageType {
20    /// Default message type is `SystemMessage`.
21    fn default() -> Self {
22        Self::System
23    }
24}
25
26impl MessageType {
27    /// Converts the `MessageType` to a string representation.
28    pub fn type_string(&self) -> String {
29        match self {
30            MessageType::System => "system".to_owned(),
31            MessageType::Human => "user".to_owned(),
32            MessageType::AI => "assistant".to_owned(),
33            MessageType::Tool => "tool".to_owned(),
34        }
35    }
36}
37
38/// Represents a message with content, type, optional ID, and optional tool calls.
39#[derive(Serialize, Deserialize, Debug, Default, Clone)]
40pub struct Message {
41    pub content: String,
42    pub message_type: MessageType,
43    pub id: Option<String>,
44    pub tool_calls: Option<Value>,
45}
46
47impl Message {
48    /// Creates a new human message with the given content.
49    pub fn new_human_message<T: std::fmt::Display>(content: T) -> Self {
50        Message {
51            content: content.to_string(),
52            message_type: MessageType::Human,
53            id: None,
54            tool_calls: None,
55        }
56    }
57
58    /// Creates a new system message with the given content.
59    pub fn new_system_message<T: std::fmt::Display>(content: T) -> Self {
60        Message {
61            content: content.to_string(),
62            message_type: MessageType::System,
63            id: None,
64            tool_calls: None,
65        }
66    }
67
68    /// Creates a new tool message with the given content and ID.
69    pub fn new_tool_message<T: std::fmt::Display, S: Into<String>>(content: T, id: S) -> Self {
70        Message {
71            content: content.to_string(),
72            message_type: MessageType::Tool,
73            id: Some(id.into()),
74            tool_calls: None,
75        }
76    }
77
78    /// Creates a new =AI message with the given content.
79    pub fn new_ai_message<T: std::fmt::Display>(content: T) -> Self {
80        Message {
81            content: content.to_string(),
82            message_type: MessageType::AI,
83            id: None,
84            tool_calls: None,
85        }
86    }
87
88    /// Adds tool calls to the message.
89    pub fn with_tool_calls(mut self, tool_calls: Value) -> Self {
90        self.tool_calls = Some(tool_calls);
91        self
92    }
93
94    /// Deserializes a `Value` into a vector of `Message` objects.
95    pub fn messages_from_value(value: &Value) -> Result<Vec<Message>, serde_json::error::Error> {
96        serde_json::from_value(value.clone())
97    }
98
99    /// Converts a slice of `Message` objects into a single string representation.
100    pub fn messages_to_string(messages: &[Message]) -> String {
101        messages
102            .iter()
103            .map(|m| format!("{:?}: {}", m.message_type, m.content))
104            .collect::<Vec<String>>()
105            .join("\n")
106    }
107}
108
109/// A trait representing a memory storage for messages.
110pub trait Memory: Send + Sync {
111    /// Returns all messages stored in memory.
112    fn messages(&self) -> Vec<Message>;
113
114    /// Adds a user (human) message to the memory.
115    fn add_user_message(&mut self, message: &dyn std::fmt::Display) {
116        self.add_message(Message::new_human_message(message.to_string()));
117    }
118
119    /// Adds an AI (LLM) message to the memory.
120    fn add_ai_message(&mut self, message: &dyn std::fmt::Display) {
121        self.add_message(Message::new_ai_message(message.to_string()));
122    }
123
124    /// Adds a message to the memory.
125    fn add_message(&mut self, message: Message);
126
127    /// Clears all messages from memory.
128    fn clear(&mut self);
129
130    /// Converts the memory's messages to a string representation.
131    fn to_string(&self) -> String {
132        self.messages()
133            .iter()
134            .map(|msg| format!("{}: {}", msg.message_type.type_string(), msg.content))
135            .collect::<Vec<String>>()
136            .join("\n")
137    }
138}
139
140/// Converts a type implementing `Memory` into a boxed trait object.
141impl<M> From<M> for Box<dyn Memory>
142where
143    M: Memory + 'static,
144{
145    fn from(memory: M) -> Self {
146        Box::new(memory)
147    }
148}
149
150/// A memory structure that stores messages in a sliding window buffer.
151pub struct WindowBufferMemory {
152    window_size: usize,
153    messages: Vec<Message>,
154}
155
156impl Default for WindowBufferMemory {
157    /// Creates a default `WindowBufferMemory` with a window size of 10.
158    fn default() -> Self {
159        Self::new(10)
160    }
161}
162
163impl WindowBufferMemory {
164    /// Creates a new `WindowBufferMemory` with the specified window size.
165    pub fn new(window_size: usize) -> Self {
166        Self {
167            messages: Vec::new(),
168            window_size,
169        }
170    }
171
172    /// Get the window size.
173    #[inline]
174    pub fn window_size(&self) -> usize {
175        self.window_size
176    }
177}
178
179/// Converts `WindowBufferMemory` into an `Arc<dyn Memory>`.
180impl From<WindowBufferMemory> for Arc<dyn Memory> {
181    fn from(val: WindowBufferMemory) -> Self {
182        Arc::new(val)
183    }
184}
185
186/// Converts `WindowBufferMemory` into an `Arc<Mutex<dyn Memory>>`.
187impl From<WindowBufferMemory> for Arc<Mutex<dyn Memory>> {
188    fn from(val: WindowBufferMemory) -> Self {
189        Arc::new(Mutex::new(val))
190    }
191}
192
193impl Memory for WindowBufferMemory {
194    /// Returns all messages in the buffer.
195    fn messages(&self) -> Vec<Message> {
196        self.messages.clone()
197    }
198
199    /// Adds a message to the buffer, removing the oldest message if the buffer is full.
200    fn add_message(&mut self, message: Message) {
201        if self.messages.len() >= self.window_size {
202            self.messages.remove(0);
203        }
204        self.messages.push(message);
205    }
206
207    /// Clears all messages from the buffer.
208    fn clear(&mut self) {
209        self.messages.clear();
210    }
211}
212
213/// A memory structure that stores messages in an LRU (Least Recently Used) cache.
214pub struct RLUCacheMemory {
215    cache: lru::LruCache<String, Message>,
216    capacity: usize,
217}
218
219impl RLUCacheMemory {
220    /// Creates a new `RLUCacheMemory` with the specified capacity.
221    pub fn new(capacity: usize) -> Self {
222        Self {
223            cache: lru::LruCache::new(NonZeroUsize::new(capacity).unwrap()),
224            capacity,
225        }
226    }
227
228    /// Get the capacity.
229    #[inline]
230    pub fn capacity(&self) -> usize {
231        self.capacity
232    }
233}
234
235impl Memory for RLUCacheMemory {
236    /// Returns all messages in the cache.
237    fn messages(&self) -> Vec<Message> {
238        self.cache.iter().map(|(_, msg)| msg.clone()).collect()
239    }
240
241    /// Adds a message to the cache, evicting the least recently used message if the cache is full.
242    fn add_message(&mut self, message: Message) {
243        let id = message
244            .id
245            .clone()
246            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
247        self.cache.put(id, message);
248    }
249
250    /// Clears all messages from the cache.
251    fn clear(&mut self) {
252        self.cache.clear();
253    }
254}
255
256/// Converts `RLUCacheMemory` into an `Arc<dyn Memory>`.
257impl From<RLUCacheMemory> for Arc<dyn Memory> {
258    fn from(val: RLUCacheMemory) -> Self {
259        Arc::new(val)
260    }
261}
262
263/// Converts `RLUCacheMemory` into an `Arc<Mutex<dyn Memory>>`.
264impl From<RLUCacheMemory> for Arc<Mutex<dyn Memory>> {
265    fn from(val: RLUCacheMemory) -> Self {
266        Arc::new(Mutex::new(val))
267    }
268}