1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::{num::NonZeroUsize, sync::Arc};
4use tokio::sync::Mutex;
5
6#[derive(PartialEq, Eq, Serialize, Deserialize, Debug, Clone)]
8pub enum MessageType {
9 #[serde(rename = "system")]
10 System,
11 #[serde(rename = "human")]
12 Human,
13 #[serde(rename = "ai")]
14 AI,
15 #[serde(rename = "tool")]
16 Tool,
17}
18
19impl Default for MessageType {
20 fn default() -> Self {
22 Self::System
23 }
24}
25
26impl MessageType {
27 pub fn type_string(&self) -> String {
29 match self {
30 MessageType::System => "system".to_owned(),
31 MessageType::Human => "user".to_owned(),
32 MessageType::AI => "assistant".to_owned(),
33 MessageType::Tool => "tool".to_owned(),
34 }
35 }
36}
37
38#[derive(Serialize, Deserialize, Debug, Default, Clone)]
40pub struct Message {
41 pub content: String,
42 pub message_type: MessageType,
43 pub id: Option<String>,
44 pub tool_calls: Option<Value>,
45}
46
47impl Message {
48 pub fn new_human_message<T: std::fmt::Display>(content: T) -> Self {
50 Message {
51 content: content.to_string(),
52 message_type: MessageType::Human,
53 id: None,
54 tool_calls: None,
55 }
56 }
57
58 pub fn new_system_message<T: std::fmt::Display>(content: T) -> Self {
60 Message {
61 content: content.to_string(),
62 message_type: MessageType::System,
63 id: None,
64 tool_calls: None,
65 }
66 }
67
68 pub fn new_tool_message<T: std::fmt::Display, S: Into<String>>(content: T, id: S) -> Self {
70 Message {
71 content: content.to_string(),
72 message_type: MessageType::Tool,
73 id: Some(id.into()),
74 tool_calls: None,
75 }
76 }
77
78 pub fn new_ai_message<T: std::fmt::Display>(content: T) -> Self {
80 Message {
81 content: content.to_string(),
82 message_type: MessageType::AI,
83 id: None,
84 tool_calls: None,
85 }
86 }
87
88 pub fn with_tool_calls(mut self, tool_calls: Value) -> Self {
90 self.tool_calls = Some(tool_calls);
91 self
92 }
93
94 pub fn messages_from_value(value: &Value) -> Result<Vec<Message>, serde_json::error::Error> {
96 serde_json::from_value(value.clone())
97 }
98
99 pub fn messages_to_string(messages: &[Message]) -> String {
101 messages
102 .iter()
103 .map(|m| format!("{:?}: {}", m.message_type, m.content))
104 .collect::<Vec<String>>()
105 .join("\n")
106 }
107}
108
109pub trait Memory: Send + Sync {
111 fn messages(&self) -> Vec<Message>;
113
114 fn add_user_message(&mut self, message: &dyn std::fmt::Display) {
116 self.add_message(Message::new_human_message(message.to_string()));
117 }
118
119 fn add_ai_message(&mut self, message: &dyn std::fmt::Display) {
121 self.add_message(Message::new_ai_message(message.to_string()));
122 }
123
124 fn add_message(&mut self, message: Message);
126
127 fn clear(&mut self);
129
130 fn to_string(&self) -> String {
132 self.messages()
133 .iter()
134 .map(|msg| format!("{}: {}", msg.message_type.type_string(), msg.content))
135 .collect::<Vec<String>>()
136 .join("\n")
137 }
138}
139
140impl<M> From<M> for Box<dyn Memory>
142where
143 M: Memory + 'static,
144{
145 fn from(memory: M) -> Self {
146 Box::new(memory)
147 }
148}
149
150pub struct WindowBufferMemory {
152 window_size: usize,
153 messages: Vec<Message>,
154}
155
156impl Default for WindowBufferMemory {
157 fn default() -> Self {
159 Self::new(10)
160 }
161}
162
163impl WindowBufferMemory {
164 pub fn new(window_size: usize) -> Self {
166 Self {
167 messages: Vec::new(),
168 window_size,
169 }
170 }
171
172 #[inline]
174 pub fn window_size(&self) -> usize {
175 self.window_size
176 }
177}
178
179impl From<WindowBufferMemory> for Arc<dyn Memory> {
181 fn from(val: WindowBufferMemory) -> Self {
182 Arc::new(val)
183 }
184}
185
186impl From<WindowBufferMemory> for Arc<Mutex<dyn Memory>> {
188 fn from(val: WindowBufferMemory) -> Self {
189 Arc::new(Mutex::new(val))
190 }
191}
192
193impl Memory for WindowBufferMemory {
194 fn messages(&self) -> Vec<Message> {
196 self.messages.clone()
197 }
198
199 fn add_message(&mut self, message: Message) {
201 if self.messages.len() >= self.window_size {
202 self.messages.remove(0);
203 }
204 self.messages.push(message);
205 }
206
207 fn clear(&mut self) {
209 self.messages.clear();
210 }
211}
212
213pub struct RLUCacheMemory {
215 cache: lru::LruCache<String, Message>,
216 capacity: usize,
217}
218
219impl RLUCacheMemory {
220 pub fn new(capacity: usize) -> Self {
222 Self {
223 cache: lru::LruCache::new(NonZeroUsize::new(capacity).unwrap()),
224 capacity,
225 }
226 }
227
228 #[inline]
230 pub fn capacity(&self) -> usize {
231 self.capacity
232 }
233}
234
235impl Memory for RLUCacheMemory {
236 fn messages(&self) -> Vec<Message> {
238 self.cache.iter().map(|(_, msg)| msg.clone()).collect()
239 }
240
241 fn add_message(&mut self, message: Message) {
243 let id = message
244 .id
245 .clone()
246 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
247 self.cache.put(id, message);
248 }
249
250 fn clear(&mut self) {
252 self.cache.clear();
253 }
254}
255
256impl From<RLUCacheMemory> for Arc<dyn Memory> {
258 fn from(val: RLUCacheMemory) -> Self {
259 Arc::new(val)
260 }
261}
262
263impl From<RLUCacheMemory> for Arc<Mutex<dyn Memory>> {
265 fn from(val: RLUCacheMemory) -> Self {
266 Arc::new(Mutex::new(val))
267 }
268}