1use forja_core::types::Message;
2
3#[derive(Debug, Clone, Default)]
4pub struct SessionBuffer {
5 messages: Vec<Message>,
6}
7
8impl SessionBuffer {
9 pub fn new() -> Self {
10 Self {
11 messages: Vec::new(),
12 }
13 }
14
15 pub fn add(&mut self, message: Message) {
16 self.messages.push(message);
17 }
18
19 pub fn get_recent(&self, count: usize) -> Vec<Message> {
20 self.messages
21 .iter()
22 .rev()
23 .take(count)
24 .cloned()
25 .collect::<Vec<_>>()
26 .into_iter()
27 .rev()
28 .collect()
29 }
30
31 pub fn get_all(&self) -> Vec<Message> {
32 self.messages.clone()
33 }
34
35 pub fn clear(&mut self) {
36 self.messages.clear();
37 }
38
39 pub fn len(&self) -> usize {
40 self.messages.len()
41 }
42
43 pub fn is_empty(&self) -> bool {
44 self.messages.is_empty()
45 }
46
47 pub fn token_count(&self) -> usize {
48 self.messages
49 .iter()
50 .map(|message| message.content_text_len() / 4)
51 .sum()
52 }
53
54 pub(crate) fn drain_oldest(&mut self, count: usize) -> Vec<Message> {
55 let drain_count = count.min(self.messages.len());
56 self.messages.drain(0..drain_count).collect()
57 }
58}
59
60#[cfg(test)]
61mod tests {
62 use super::SessionBuffer;
63 use forja_core::types::{Message, Role};
64
65 #[test]
66 fn session_buffer_adds_and_returns_recent_messages() {
67 let mut buffer = SessionBuffer::new();
68 buffer.add(Message::text(Role::User, "first", None));
69 buffer.add(Message::text(Role::Assistant, "second", None));
70 buffer.add(Message::text(Role::User, "third", None));
71
72 let recent = buffer.get_recent(2);
73
74 assert_eq!(buffer.len(), 3);
75 assert_eq!(recent.len(), 2);
76 assert_eq!(recent[0].content_text_len(), "second".len());
77 assert_eq!(recent[1].content_text_len(), "third".len());
78 }
79
80 #[test]
81 fn session_buffer_token_count_increases_with_messages() {
82 let mut buffer = SessionBuffer::new();
83 buffer.add(Message::text(
84 Role::User,
85 "This is a moderately sized message for token counting.",
86 None,
87 ));
88 buffer.add(Message::text(
89 Role::Assistant,
90 "Another message with enough text to count.",
91 None,
92 ));
93
94 assert!(buffer.token_count() > 0);
95 }
96}