1use hehe_core::{Id, Message, Metadata, Timestamp};
2use serde::{Deserialize, Serialize};
3use std::sync::{Arc, RwLock};
4
5#[derive(Clone, Debug, Serialize, Deserialize)]
6pub struct SessionStats {
7 pub message_count: usize,
8 pub tool_call_count: usize,
9 pub iteration_count: usize,
10}
11
12impl Default for SessionStats {
13 fn default() -> Self {
14 Self {
15 message_count: 0,
16 tool_call_count: 0,
17 iteration_count: 0,
18 }
19 }
20}
21
22#[derive(Debug)]
23struct SessionInner {
24 messages: Vec<Message>,
25 stats: SessionStats,
26}
27
28#[derive(Clone, Debug)]
29pub struct Session {
30 id: Id,
31 created_at: Timestamp,
32 metadata: Metadata,
33 inner: Arc<RwLock<SessionInner>>,
34}
35
36impl Session {
37 pub fn new() -> Self {
38 Self {
39 id: Id::new(),
40 created_at: Timestamp::now(),
41 metadata: Metadata::new(),
42 inner: Arc::new(RwLock::new(SessionInner {
43 messages: Vec::new(),
44 stats: SessionStats::default(),
45 })),
46 }
47 }
48
49 pub fn with_id(id: Id) -> Self {
50 Self {
51 id,
52 created_at: Timestamp::now(),
53 metadata: Metadata::new(),
54 inner: Arc::new(RwLock::new(SessionInner {
55 messages: Vec::new(),
56 stats: SessionStats::default(),
57 })),
58 }
59 }
60
61 pub fn id(&self) -> &Id {
62 &self.id
63 }
64
65 pub fn created_at(&self) -> &Timestamp {
66 &self.created_at
67 }
68
69 pub fn metadata(&self) -> &Metadata {
70 &self.metadata
71 }
72
73 pub fn add_message(&self, message: Message) {
74 let mut inner = self.inner.write().unwrap();
75 inner.stats.message_count += 1;
76 inner.messages.push(message);
77 }
78
79 pub fn add_messages(&self, messages: impl IntoIterator<Item = Message>) {
80 let mut inner = self.inner.write().unwrap();
81 for message in messages {
82 inner.stats.message_count += 1;
83 inner.messages.push(message);
84 }
85 }
86
87 pub fn messages(&self) -> Vec<Message> {
88 self.inner.read().unwrap().messages.clone()
89 }
90
91 pub fn message_count(&self) -> usize {
92 self.inner.read().unwrap().messages.len()
93 }
94
95 pub fn last_messages(&self, n: usize) -> Vec<Message> {
96 let inner = self.inner.read().unwrap();
97 let len = inner.messages.len();
98 if n >= len {
99 inner.messages.clone()
100 } else {
101 inner.messages[len - n..].to_vec()
102 }
103 }
104
105 pub fn clear(&self) {
106 let mut inner = self.inner.write().unwrap();
107 inner.messages.clear();
108 }
109
110 pub fn stats(&self) -> SessionStats {
111 self.inner.read().unwrap().stats.clone()
112 }
113
114 pub fn increment_tool_calls(&self, count: usize) {
115 let mut inner = self.inner.write().unwrap();
116 inner.stats.tool_call_count += count;
117 }
118
119 pub fn increment_iterations(&self) {
120 let mut inner = self.inner.write().unwrap();
121 inner.stats.iteration_count += 1;
122 }
123
124 pub fn truncate_messages(&self, max_messages: usize) {
125 let mut inner = self.inner.write().unwrap();
126 if inner.messages.len() > max_messages {
127 let remove_count = inner.messages.len() - max_messages;
128 inner.messages.drain(0..remove_count);
129 }
130 }
131}
132
133impl Default for Session {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use hehe_core::Role;
143
144 #[test]
145 fn test_session_new() {
146 let session = Session::new();
147 assert_eq!(session.message_count(), 0);
148 }
149
150 #[test]
151 fn test_session_add_message() {
152 let session = Session::new();
153 session.add_message(Message::user("Hello"));
154 session.add_message(Message::assistant("Hi there!"));
155
156 assert_eq!(session.message_count(), 2);
157
158 let messages = session.messages();
159 assert_eq!(messages[0].role, Role::User);
160 assert_eq!(messages[1].role, Role::Assistant);
161 }
162
163 #[test]
164 fn test_session_last_messages() {
165 let session = Session::new();
166 for i in 0..10 {
167 session.add_message(Message::user(format!("Message {}", i)));
168 }
169
170 let last_3 = session.last_messages(3);
171 assert_eq!(last_3.len(), 3);
172 }
173
174 #[test]
175 fn test_session_truncate() {
176 let session = Session::new();
177 for i in 0..10 {
178 session.add_message(Message::user(format!("Message {}", i)));
179 }
180
181 session.truncate_messages(5);
182 assert_eq!(session.message_count(), 5);
183 }
184
185 #[test]
186 fn test_session_stats() {
187 let session = Session::new();
188 session.add_message(Message::user("Hello"));
189 session.increment_tool_calls(2);
190 session.increment_iterations();
191
192 let stats = session.stats();
193 assert_eq!(stats.message_count, 1);
194 assert_eq!(stats.tool_call_count, 2);
195 assert_eq!(stats.iteration_count, 1);
196 }
197
198 #[test]
199 fn test_session_clone_shares_state() {
200 let session1 = Session::new();
201 let session2 = session1.clone();
202
203 session1.add_message(Message::user("Hello"));
204
205 assert_eq!(session2.message_count(), 1);
206 }
207}