agent_chain_core/
chat_history.rs1use async_trait::async_trait;
7use std::fmt::Display;
8
9use crate::messages::{AIMessage, BaseMessage, HumanMessage, get_buffer_string};
10
11#[async_trait]
59pub trait BaseChatMessageHistory: Send + Sync {
60 fn messages(&self) -> Vec<BaseMessage>;
66
67 async fn get_messages_async(&self) -> Vec<BaseMessage> {
74 self.messages()
75 }
76
77 fn add_user_message(&mut self, message: HumanMessageInput) {
84 let human_message = match message {
85 HumanMessageInput::Message(m) => m,
86 HumanMessageInput::Text(text) => HumanMessage::new(&text),
87 };
88 self.add_message(BaseMessage::Human(human_message));
89 }
90
91 fn add_ai_message(&mut self, message: AIMessageInput) {
98 let ai_message = match message {
99 AIMessageInput::Message(m) => *m,
100 AIMessageInput::Text(text) => AIMessage::new(&text),
101 };
102 self.add_message(BaseMessage::AI(ai_message));
103 }
104
105 fn add_message(&mut self, message: BaseMessage) {
111 self.add_messages(&[message]);
112 }
113
114 fn add_messages(&mut self, messages: &[BaseMessage]);
119
120 async fn add_messages_async(&mut self, messages: Vec<BaseMessage>) {
125 self.add_messages(&messages);
126 }
127
128 fn clear(&mut self);
130
131 async fn clear_async(&mut self) {
136 self.clear();
137 }
138
139 fn to_buffer_string(&self) -> String {
141 get_buffer_string(&self.messages(), "Human", "AI")
142 }
143}
144
145pub enum HumanMessageInput {
149 Message(HumanMessage),
151 Text(String),
153}
154
155impl From<HumanMessage> for HumanMessageInput {
156 fn from(message: HumanMessage) -> Self {
157 HumanMessageInput::Message(message)
158 }
159}
160
161impl From<String> for HumanMessageInput {
162 fn from(text: String) -> Self {
163 HumanMessageInput::Text(text)
164 }
165}
166
167impl From<&str> for HumanMessageInput {
168 fn from(text: &str) -> Self {
169 HumanMessageInput::Text(text.to_string())
170 }
171}
172
173pub enum AIMessageInput {
177 Message(Box<AIMessage>),
179 Text(String),
181}
182
183impl From<AIMessage> for AIMessageInput {
184 fn from(message: AIMessage) -> Self {
185 AIMessageInput::Message(Box::new(message))
186 }
187}
188
189impl From<String> for AIMessageInput {
190 fn from(text: String) -> Self {
191 AIMessageInput::Text(text)
192 }
193}
194
195impl From<&str> for AIMessageInput {
196 fn from(text: &str) -> Self {
197 AIMessageInput::Text(text.to_string())
198 }
199}
200
201#[derive(Debug, Clone, Default)]
205pub struct InMemoryChatMessageHistory {
206 messages: Vec<BaseMessage>,
208}
209
210impl InMemoryChatMessageHistory {
211 pub fn new() -> Self {
213 Self {
214 messages: Vec::new(),
215 }
216 }
217
218 pub fn with_messages(messages: Vec<BaseMessage>) -> Self {
220 Self { messages }
221 }
222}
223
224impl Display for InMemoryChatMessageHistory {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 write!(f, "{}", self.to_buffer_string())
227 }
228}
229
230#[async_trait]
231impl BaseChatMessageHistory for InMemoryChatMessageHistory {
232 fn messages(&self) -> Vec<BaseMessage> {
233 self.messages.clone()
234 }
235
236 async fn get_messages_async(&self) -> Vec<BaseMessage> {
237 self.messages.clone()
238 }
239
240 fn add_messages(&mut self, messages: &[BaseMessage]) {
241 self.messages.extend(messages.iter().cloned());
242 }
243
244 async fn add_messages_async(&mut self, messages: Vec<BaseMessage>) {
245 self.add_messages(&messages);
246 }
247
248 fn clear(&mut self) {
249 self.messages.clear();
250 }
251
252 async fn clear_async(&mut self) {
253 self.clear();
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_in_memory_chat_history_new() {
263 let history = InMemoryChatMessageHistory::new();
264 assert!(history.messages().is_empty());
265 }
266
267 #[test]
268 fn test_in_memory_chat_history_with_messages() {
269 let messages = vec![
270 BaseMessage::Human(HumanMessage::new("Hello")),
271 BaseMessage::AI(AIMessage::new("Hi there!")),
272 ];
273 let history = InMemoryChatMessageHistory::with_messages(messages.clone());
274 assert_eq!(history.messages().len(), 2);
275 }
276
277 #[test]
278 fn test_add_user_message_string() {
279 let mut history = InMemoryChatMessageHistory::new();
280 history.add_user_message("Hello!".into());
281
282 let messages = history.messages();
283 assert_eq!(messages.len(), 1);
284 assert!(matches!(&messages[0], BaseMessage::Human(_)));
285 assert_eq!(messages[0].content(), "Hello!");
286 }
287
288 #[test]
289 fn test_add_user_message_human_message() {
290 let mut history = InMemoryChatMessageHistory::new();
291 let human_msg = HumanMessage::new("Hello!");
292 history.add_user_message(human_msg.into());
293
294 let messages = history.messages();
295 assert_eq!(messages.len(), 1);
296 assert!(matches!(&messages[0], BaseMessage::Human(_)));
297 assert_eq!(messages[0].content(), "Hello!");
298 }
299
300 #[test]
301 fn test_add_ai_message_string() {
302 let mut history = InMemoryChatMessageHistory::new();
303 history.add_ai_message("Hi there!".into());
304
305 let messages = history.messages();
306 assert_eq!(messages.len(), 1);
307 assert!(matches!(&messages[0], BaseMessage::AI(_)));
308 assert_eq!(messages[0].content(), "Hi there!");
309 }
310
311 #[test]
312 fn test_add_ai_message_ai_message() {
313 let mut history = InMemoryChatMessageHistory::new();
314 let ai_msg = AIMessage::new("Hi there!");
315 history.add_ai_message(ai_msg.into());
316
317 let messages = history.messages();
318 assert_eq!(messages.len(), 1);
319 assert!(matches!(&messages[0], BaseMessage::AI(_)));
320 assert_eq!(messages[0].content(), "Hi there!");
321 }
322
323 #[test]
324 fn test_add_message() {
325 let mut history = InMemoryChatMessageHistory::new();
326 history.add_message(BaseMessage::Human(HumanMessage::new("Hello")));
327 history.add_message(BaseMessage::AI(AIMessage::new("Hi")));
328
329 let messages = history.messages();
330 assert_eq!(messages.len(), 2);
331 }
332
333 #[test]
334 fn test_add_messages() {
335 let mut history = InMemoryChatMessageHistory::new();
336 let new_messages = vec![
337 BaseMessage::Human(HumanMessage::new("Hello")),
338 BaseMessage::AI(AIMessage::new("Hi")),
339 BaseMessage::Human(HumanMessage::new("How are you?")),
340 ];
341 history.add_messages(&new_messages);
342
343 let messages = history.messages();
344 assert_eq!(messages.len(), 3);
345 }
346
347 #[test]
348 fn test_clear() {
349 let mut history = InMemoryChatMessageHistory::new();
350 history.add_user_message("Hello!".into());
351 history.add_ai_message("Hi!".into());
352
353 assert_eq!(history.messages().len(), 2);
354
355 history.clear();
356 assert!(history.messages().is_empty());
357 }
358
359 #[test]
360 fn test_to_buffer_string() {
361 let mut history = InMemoryChatMessageHistory::new();
362 history.add_user_message("Hello!".into());
363 history.add_ai_message("Hi there!".into());
364
365 let buffer = history.to_buffer_string();
366 assert!(buffer.contains("Human: Hello!"));
367 assert!(buffer.contains("AI: Hi there!"));
368 }
369
370 #[test]
371 fn test_display() {
372 let mut history = InMemoryChatMessageHistory::new();
373 history.add_user_message("Hello!".into());
374 history.add_ai_message("Hi there!".into());
375
376 let display = format!("{}", history);
377 assert!(display.contains("Human: Hello!"));
378 assert!(display.contains("AI: Hi there!"));
379 }
380
381 #[tokio::test]
382 async fn test_get_messages_async() {
383 let mut history = InMemoryChatMessageHistory::new();
384 history.add_user_message("Hello!".into());
385
386 let messages = history.get_messages_async().await;
387 assert_eq!(messages.len(), 1);
388 }
389
390 #[tokio::test]
391 async fn test_add_messages_async() {
392 let mut history = InMemoryChatMessageHistory::new();
393 let new_messages = vec![
394 BaseMessage::Human(HumanMessage::new("Hello")),
395 BaseMessage::AI(AIMessage::new("Hi")),
396 ];
397 history.add_messages_async(new_messages).await;
398
399 let messages = history.messages();
400 assert_eq!(messages.len(), 2);
401 }
402
403 #[tokio::test]
404 async fn test_clear_async() {
405 let mut history = InMemoryChatMessageHistory::new();
406 history.add_user_message("Hello!".into());
407
408 assert_eq!(history.messages().len(), 1);
409
410 history.clear_async().await;
411 assert!(history.messages().is_empty());
412 }
413}