Skip to main content

forja_memory/
session.rs

1use forja_core::types::Message;
2
3#[derive(Debug, Clone, Default)]
4pub struct SessionBuffer {
5    messages: Vec<Message>,
6}
7
8impl SessionBuffer {
9    pub fn new() -> Self {
10        Self {
11            messages: Vec::new(),
12        }
13    }
14
15    pub fn add(&mut self, message: Message) {
16        self.messages.push(message);
17    }
18
19    pub fn get_recent(&self, count: usize) -> Vec<Message> {
20        self.messages
21            .iter()
22            .rev()
23            .take(count)
24            .cloned()
25            .collect::<Vec<_>>()
26            .into_iter()
27            .rev()
28            .collect()
29    }
30
31    pub fn get_all(&self) -> Vec<Message> {
32        self.messages.clone()
33    }
34
35    pub fn clear(&mut self) {
36        self.messages.clear();
37    }
38
39    pub fn len(&self) -> usize {
40        self.messages.len()
41    }
42
43    pub fn is_empty(&self) -> bool {
44        self.messages.is_empty()
45    }
46
47    pub fn token_count(&self) -> usize {
48        self.messages
49            .iter()
50            .map(|message| message.content_text_len() / 4)
51            .sum()
52    }
53
54    pub(crate) fn drain_oldest(&mut self, count: usize) -> Vec<Message> {
55        let drain_count = count.min(self.messages.len());
56        self.messages.drain(0..drain_count).collect()
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::SessionBuffer;
63    use forja_core::types::{Message, Role};
64
65    #[test]
66    fn session_buffer_adds_and_returns_recent_messages() {
67        let mut buffer = SessionBuffer::new();
68        buffer.add(Message::text(Role::User, "first", None));
69        buffer.add(Message::text(Role::Assistant, "second", None));
70        buffer.add(Message::text(Role::User, "third", None));
71
72        let recent = buffer.get_recent(2);
73
74        assert_eq!(buffer.len(), 3);
75        assert_eq!(recent.len(), 2);
76        assert_eq!(recent[0].content_text_len(), "second".len());
77        assert_eq!(recent[1].content_text_len(), "third".len());
78    }
79
80    #[test]
81    fn session_buffer_token_count_increases_with_messages() {
82        let mut buffer = SessionBuffer::new();
83        buffer.add(Message::text(
84            Role::User,
85            "This is a moderately sized message for token counting.",
86            None,
87        ));
88        buffer.add(Message::text(
89            Role::Assistant,
90            "Another message with enough text to count.",
91            None,
92        ));
93
94        assert!(buffer.token_count() > 0);
95    }
96}