hehe_agent/
session.rs

1use hehe_core::{Id, Message, Metadata, Timestamp};
2use serde::{Deserialize, Serialize};
3use std::sync::{Arc, RwLock};
4
5#[derive(Clone, Debug, Serialize, Deserialize)]
6pub struct SessionStats {
7    pub message_count: usize,
8    pub tool_call_count: usize,
9    pub iteration_count: usize,
10}
11
12impl Default for SessionStats {
13    fn default() -> Self {
14        Self {
15            message_count: 0,
16            tool_call_count: 0,
17            iteration_count: 0,
18        }
19    }
20}
21
22#[derive(Debug)]
23struct SessionInner {
24    messages: Vec<Message>,
25    stats: SessionStats,
26}
27
28#[derive(Clone, Debug)]
29pub struct Session {
30    id: Id,
31    created_at: Timestamp,
32    metadata: Metadata,
33    inner: Arc<RwLock<SessionInner>>,
34}
35
36impl Session {
37    pub fn new() -> Self {
38        Self {
39            id: Id::new(),
40            created_at: Timestamp::now(),
41            metadata: Metadata::new(),
42            inner: Arc::new(RwLock::new(SessionInner {
43                messages: Vec::new(),
44                stats: SessionStats::default(),
45            })),
46        }
47    }
48
49    pub fn with_id(id: Id) -> Self {
50        Self {
51            id,
52            created_at: Timestamp::now(),
53            metadata: Metadata::new(),
54            inner: Arc::new(RwLock::new(SessionInner {
55                messages: Vec::new(),
56                stats: SessionStats::default(),
57            })),
58        }
59    }
60
61    pub fn id(&self) -> &Id {
62        &self.id
63    }
64
65    pub fn created_at(&self) -> &Timestamp {
66        &self.created_at
67    }
68
69    pub fn metadata(&self) -> &Metadata {
70        &self.metadata
71    }
72
73    pub fn add_message(&self, message: Message) {
74        let mut inner = self.inner.write().unwrap();
75        inner.stats.message_count += 1;
76        inner.messages.push(message);
77    }
78
79    pub fn add_messages(&self, messages: impl IntoIterator<Item = Message>) {
80        let mut inner = self.inner.write().unwrap();
81        for message in messages {
82            inner.stats.message_count += 1;
83            inner.messages.push(message);
84        }
85    }
86
87    pub fn messages(&self) -> Vec<Message> {
88        self.inner.read().unwrap().messages.clone()
89    }
90
91    pub fn message_count(&self) -> usize {
92        self.inner.read().unwrap().messages.len()
93    }
94
95    pub fn last_messages(&self, n: usize) -> Vec<Message> {
96        let inner = self.inner.read().unwrap();
97        let len = inner.messages.len();
98        if n >= len {
99            inner.messages.clone()
100        } else {
101            inner.messages[len - n..].to_vec()
102        }
103    }
104
105    pub fn clear(&self) {
106        let mut inner = self.inner.write().unwrap();
107        inner.messages.clear();
108    }
109
110    pub fn stats(&self) -> SessionStats {
111        self.inner.read().unwrap().stats.clone()
112    }
113
114    pub fn increment_tool_calls(&self, count: usize) {
115        let mut inner = self.inner.write().unwrap();
116        inner.stats.tool_call_count += count;
117    }
118
119    pub fn increment_iterations(&self) {
120        let mut inner = self.inner.write().unwrap();
121        inner.stats.iteration_count += 1;
122    }
123
124    pub fn truncate_messages(&self, max_messages: usize) {
125        let mut inner = self.inner.write().unwrap();
126        if inner.messages.len() > max_messages {
127            let remove_count = inner.messages.len() - max_messages;
128            inner.messages.drain(0..remove_count);
129        }
130    }
131}
132
133impl Default for Session {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use hehe_core::Role;
143
144    #[test]
145    fn test_session_new() {
146        let session = Session::new();
147        assert_eq!(session.message_count(), 0);
148    }
149
150    #[test]
151    fn test_session_add_message() {
152        let session = Session::new();
153        session.add_message(Message::user("Hello"));
154        session.add_message(Message::assistant("Hi there!"));
155
156        assert_eq!(session.message_count(), 2);
157
158        let messages = session.messages();
159        assert_eq!(messages[0].role, Role::User);
160        assert_eq!(messages[1].role, Role::Assistant);
161    }
162
163    #[test]
164    fn test_session_last_messages() {
165        let session = Session::new();
166        for i in 0..10 {
167            session.add_message(Message::user(format!("Message {}", i)));
168        }
169
170        let last_3 = session.last_messages(3);
171        assert_eq!(last_3.len(), 3);
172    }
173
174    #[test]
175    fn test_session_truncate() {
176        let session = Session::new();
177        for i in 0..10 {
178            session.add_message(Message::user(format!("Message {}", i)));
179        }
180
181        session.truncate_messages(5);
182        assert_eq!(session.message_count(), 5);
183    }
184
185    #[test]
186    fn test_session_stats() {
187        let session = Session::new();
188        session.add_message(Message::user("Hello"));
189        session.increment_tool_calls(2);
190        session.increment_iterations();
191
192        let stats = session.stats();
193        assert_eq!(stats.message_count, 1);
194        assert_eq!(stats.tool_call_count, 2);
195        assert_eq!(stats.iteration_count, 1);
196    }
197
198    #[test]
199    fn test_session_clone_shares_state() {
200        let session1 = Session::new();
201        let session2 = session1.clone();
202
203        session1.add_message(Message::user("Hello"));
204
205        assert_eq!(session2.message_count(), 1);
206    }
207}