1use crate::message::Message;
2use crate::storage::ConversationStore;
3
4#[derive(Default, Clone, Debug)]
6pub struct ConversationMemory {
7 messages: Vec<Message>,
8}
9
10impl ConversationMemory {
11 pub fn with_messages(messages: Vec<Message>) -> Self {
12 Self { messages }
13 }
14
15 pub fn push(&mut self, message: Message) {
16 self.messages.push(message);
17 }
18
19 pub fn iter(&self) -> impl DoubleEndedIterator<Item = &Message> + '_ {
20 self.messages.iter()
21 }
22
23 pub fn len(&self) -> usize {
24 self.messages.len()
25 }
26
27 pub fn is_empty(&self) -> bool {
28 self.messages.is_empty()
29 }
30}
31
32#[derive(Clone, Debug)]
34pub struct PersistentConversationMemory<S: ConversationStore> {
35 store: S,
36 inner: ConversationMemory,
37}
38
39impl<S: ConversationStore> PersistentConversationMemory<S> {
40 pub fn new(store: S) -> Self {
41 Self {
42 store,
43 inner: ConversationMemory::default(),
44 }
45 }
46
47 pub async fn load(mut self) -> crate::Result<Self> {
48 let stored = self.store.load().await?;
49 self.inner = ConversationMemory::with_messages(stored);
50 Ok(self)
51 }
52
53 pub fn as_memory(&self) -> &ConversationMemory {
54 &self.inner
55 }
56
57 pub async fn push(&mut self, message: Message) -> crate::Result<()> {
58 self.store.append(&message).await?;
59 self.inner.push(message);
60 Ok(())
61 }
62
63 pub async fn clear(&mut self) -> crate::Result<()> {
64 self.store.clear().await?;
65 self.inner = ConversationMemory::default();
66 Ok(())
67 }
68}
69
70pub trait MemoryStrategy: Send + Sync {
76 fn get_context_messages(&self, messages: &[Message]) -> Vec<Message>;
78
79 fn name(&self) -> &str;
81}
82
83#[derive(Clone, Default)]
85pub struct FullMemoryStrategy;
86
87impl MemoryStrategy for FullMemoryStrategy {
88 fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
89 messages.to_vec()
90 }
91
92 fn name(&self) -> &str {
93 "full"
94 }
95}
96
97#[derive(Clone)]
99pub struct WindowedMemoryStrategy {
100 window_size: usize,
101 keep_system: bool,
102}
103
104impl WindowedMemoryStrategy {
105 pub fn new(window_size: usize) -> Self {
106 Self {
107 window_size,
108 keep_system: true,
109 }
110 }
111
112 pub fn without_system(mut self) -> Self {
113 self.keep_system = false;
114 self
115 }
116}
117
118impl MemoryStrategy for WindowedMemoryStrategy {
119 fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
120 use crate::message::Role;
121
122 if messages.len() <= self.window_size {
123 return messages.to_vec();
124 }
125
126 let mut result = Vec::new();
127
128 if self.keep_system {
130 for msg in messages {
131 if msg.role == Role::System {
132 result.push(msg.clone());
133 }
134 }
135 }
136
137 let non_system: Vec<&Message> = messages
139 .iter()
140 .filter(|m| m.role != Role::System)
141 .collect();
142
143 let start = if non_system.len() > self.window_size {
144 non_system.len() - self.window_size
145 } else {
146 0
147 };
148
149 for msg in &non_system[start..] {
150 result.push((*msg).clone());
151 }
152
153 result
154 }
155
156 fn name(&self) -> &str {
157 "windowed"
158 }
159}
160
161#[derive(Clone)]
163pub struct SummarizedMemoryStrategy {
164 keep_first: usize,
166 keep_last: usize,
168 summary: Option<String>,
170}
171
172impl SummarizedMemoryStrategy {
173 pub fn new(keep_first: usize, keep_last: usize) -> Self {
174 Self {
175 keep_first,
176 keep_last,
177 summary: None,
178 }
179 }
180
181 pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
182 self.summary = Some(summary.into());
183 self
184 }
185
186 pub fn needs_summary(&self, messages: &[Message]) -> bool {
188 messages.len() > self.keep_first + self.keep_last
189 }
190
191 pub fn messages_to_summarize<'a>(&self, messages: &'a [Message]) -> &'a [Message] {
193 if messages.len() <= self.keep_first + self.keep_last {
194 return &[];
195 }
196 let end = messages.len() - self.keep_last;
197 &messages[self.keep_first..end]
198 }
199}
200
201impl MemoryStrategy for SummarizedMemoryStrategy {
202 fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
203 if messages.len() <= self.keep_first + self.keep_last {
204 return messages.to_vec();
205 }
206
207 let mut result = Vec::new();
208
209 for msg in messages.iter().take(self.keep_first) {
211 result.push(msg.clone());
212 }
213
214 if let Some(ref summary) = self.summary {
216 result.push(Message::system(format!(
217 "[Summary of {} messages]: {}",
218 messages.len() - self.keep_first - self.keep_last,
219 summary
220 )));
221 }
222
223 let start = messages.len() - self.keep_last;
225 for msg in &messages[start..] {
226 result.push(msg.clone());
227 }
228
229 result
230 }
231
232 fn name(&self) -> &str {
233 "summarized"
234 }
235}
236
237#[derive(Clone)]
239pub struct TokenLimitedMemoryStrategy {
240 max_tokens: usize,
241 chars_per_token: usize,
243}
244
245impl TokenLimitedMemoryStrategy {
246 pub fn new(max_tokens: usize) -> Self {
247 Self {
248 max_tokens,
249 chars_per_token: 4,
250 }
251 }
252
253 pub fn with_chars_per_token(mut self, chars: usize) -> Self {
254 self.chars_per_token = chars;
255 self
256 }
257
258 fn estimate_tokens(&self, content: &str) -> usize {
259 content.len() / self.chars_per_token
260 }
261}
262
263impl MemoryStrategy for TokenLimitedMemoryStrategy {
264 fn get_context_messages(&self, messages: &[Message]) -> Vec<Message> {
265 use crate::message::Role;
266
267 let mut result = Vec::new();
268 let mut total_tokens = 0;
269
270 for msg in messages {
272 if msg.role == Role::System {
273 let tokens = self.estimate_tokens(&msg.content);
274 total_tokens += tokens;
275 result.push(msg.clone());
276 }
277 }
278
279 let non_system: Vec<&Message> = messages
281 .iter()
282 .filter(|m| m.role != Role::System)
283 .collect();
284
285 let mut temp = Vec::new();
286 for msg in non_system.iter().rev() {
287 let tokens = self.estimate_tokens(&msg.content);
288 if total_tokens + tokens > self.max_tokens {
289 break;
290 }
291 total_tokens += tokens;
292 temp.push((*msg).clone());
293 }
294
295 temp.reverse();
297 result.extend(temp);
298
299 result
300 }
301
302 fn name(&self) -> &str {
303 "token_limited"
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_windowed_strategy() {
313 let messages = vec![
314 Message::system("You are a helpful assistant"),
315 Message::user("Hello"),
316 Message::assistant("Hi there!"),
317 Message::user("How are you?"),
318 Message::assistant("I'm doing well!"),
319 Message::user("What's 2+2?"),
320 Message::assistant("4"),
321 ];
322
323 let strategy = WindowedMemoryStrategy::new(4);
324 let context = strategy.get_context_messages(&messages);
325
326 assert_eq!(context.len(), 5); assert_eq!(context[0].content, "You are a helpful assistant");
329 }
330
331 #[test]
332 fn test_token_limited_strategy() {
333 let messages = vec![
334 Message::system("System"),
335 Message::user("A".repeat(100)),
336 Message::assistant("B".repeat(100)),
337 Message::user("C".repeat(100)),
338 ];
339
340 let strategy = TokenLimitedMemoryStrategy::new(50); let context = strategy.get_context_messages(&messages);
342
343 assert!(context.len() <= messages.len());
345 }
346}
347