1use std::fs::OpenOptions;
2use std::io::{BufRead, BufReader, Write};
3use std::path::{Path, PathBuf};
4
5pub trait MessageRole: Clone + PartialEq {
7 fn system() -> Self;
8 fn user() -> Self;
9 fn assistant() -> Self;
10 fn tool() -> Self;
11 fn as_str(&self) -> &str;
12 fn from_str(s: &str) -> Option<Self>;
13 fn is_system(&self) -> bool {
14 self.as_str() == "system"
15 }
16}
17
18pub trait AgentMessage: Clone {
20 type Role: MessageRole;
21 fn new(role: Self::Role, content: String) -> Self;
22 fn role(&self) -> &Self::Role;
23 fn content(&self) -> &str;
24}
25
26#[derive(serde::Serialize, serde::Deserialize)]
27struct PersistedMessage {
28 role: String,
29 content: String,
30}
31
32pub struct Session<M: AgentMessage> {
34 messages: Vec<M>,
35 session_file: PathBuf,
36 max_history: usize,
37}
38
39impl<M: AgentMessage> Session<M> {
40 pub fn new(session_dir: &str, max_history: usize) -> Self {
42 let _ = std::fs::create_dir_all(session_dir);
43 let ts = std::time::SystemTime::now()
44 .duration_since(std::time::UNIX_EPOCH)
45 .unwrap_or_default()
46 .as_secs();
47 let session_file = PathBuf::from(format!("{}/session_{}.jsonl", session_dir, ts));
48 Self {
49 messages: Vec::new(),
50 session_file,
51 max_history,
52 }
53 }
54
55 pub fn resume(path: &Path, _session_dir: &str, max_history: usize) -> Self {
57 let messages = Self::load_file(path);
58 Self {
59 messages,
60 session_file: path.to_path_buf(),
61 max_history,
62 }
63 }
64
65 pub fn resume_last(session_dir: &str, max_history: usize) -> Option<Self> {
67 let last = Self::find_last_session(session_dir)?;
68 Some(Self::resume(&last, session_dir, max_history))
69 }
70
71 pub fn push(&mut self, role: <M as AgentMessage>::Role, content: String) -> &M {
73 let msg = M::new(role, content);
74 self.messages.push(msg);
75 self.persist_last();
76 self.messages.last().expect("just pushed")
77 }
78
79 pub fn push_msg(&mut self, msg: M) {
81 self.messages.push(msg);
82 self.persist_last();
83 }
84
85 pub fn messages(&self) -> &[M] {
87 &self.messages
88 }
89
90 pub fn messages_mut(&mut self) -> &mut Vec<M> {
92 &mut self.messages
93 }
94
95 pub fn is_empty(&self) -> bool {
96 self.messages.is_empty()
97 }
98
99 pub fn len(&self) -> usize {
100 self.messages.len()
101 }
102
103 pub fn session_file(&self) -> &Path {
104 &self.session_file
105 }
106
107 pub fn trim(&mut self) -> usize {
113 if self.messages.len() <= self.max_history {
114 return 0;
115 }
116
117 let system_msgs: Vec<M> = self.messages
118 .iter()
119 .filter(|m| m.role().is_system())
120 .cloned()
121 .collect();
122
123 let non_system: Vec<M> = self.messages
124 .iter()
125 .filter(|m| !m.role().is_system())
126 .cloned()
127 .collect();
128
129 let keep = self.max_history.saturating_sub(system_msgs.len());
130 let skip = non_system.len().saturating_sub(keep);
131
132 if skip == 0 {
133 return 0;
134 }
135
136 let mut trimmed = system_msgs;
137 trimmed.push(M::new(
138 <M as AgentMessage>::Role::system(),
139 format!("[{} earlier messages trimmed]", skip),
140 ));
141 trimmed.extend(non_system.into_iter().skip(skip));
142 self.messages = trimmed;
143 skip
144 }
145
146 fn persist_last(&self) {
149 let Some(msg) = self.messages.last() else { return };
150 let persisted = PersistedMessage {
151 role: msg.role().as_str().into(),
152 content: msg.content().into(),
153 };
154 let Ok(json) = serde_json::to_string(&persisted) else { return };
155 let Ok(mut f) = OpenOptions::new().create(true).append(true).open(&self.session_file) else { return };
156 let _ = writeln!(f, "{}", json);
157 }
158
159 fn load_file(path: &Path) -> Vec<M> {
160 let Ok(file) = std::fs::File::open(path) else { return vec![] };
161 BufReader::new(file)
162 .lines()
163 .map_while(Result::ok)
164 .filter_map(|line| serde_json::from_str::<PersistedMessage>(&line).ok())
165 .filter_map(|p| {
166 let role = <M as AgentMessage>::Role::from_str(&p.role)?;
167 Some(M::new(role, p.content))
168 })
169 .collect()
170 }
171
172 fn find_last_session(dir: &str) -> Option<PathBuf> {
173 let mut entries: Vec<_> = std::fs::read_dir(dir)
174 .ok()?
175 .filter_map(|e| e.ok())
176 .filter(|e| e.path().extension().is_some_and(|ext| ext == "jsonl"))
177 .collect();
178 entries.sort_by_key(|e| e.file_name());
179 entries.last().map(|e| e.path())
180 }
181}
182
183#[cfg(test)]
184pub(crate) mod tests {
185 use super::*;
186
187 #[derive(Clone, Debug, PartialEq)]
188 pub(crate) enum TestRole { System, User, Assistant, Tool }
189
190 impl MessageRole for TestRole {
191 fn system() -> Self { Self::System }
192 fn user() -> Self { Self::User }
193 fn assistant() -> Self { Self::Assistant }
194 fn tool() -> Self { Self::Tool }
195 fn as_str(&self) -> &str {
196 match self {
197 Self::System => "system",
198 Self::User => "user",
199 Self::Assistant => "assistant",
200 Self::Tool => "tool",
201 }
202 }
203 fn from_str(s: &str) -> Option<Self> {
204 match s {
205 "system" => Some(Self::System),
206 "user" => Some(Self::User),
207 "assistant" => Some(Self::Assistant),
208 "tool" => Some(Self::Tool),
209 _ => None,
210 }
211 }
212 }
213
214 #[derive(Clone)]
215 pub(crate) struct TestMsg { pub role: TestRole, pub content: String }
216
217 impl AgentMessage for TestMsg {
218 type Role = TestRole;
219 fn new(role: TestRole, content: String) -> Self { Self { role, content } }
220 fn role(&self) -> &TestRole { &self.role }
221 fn content(&self) -> &str { &self.content }
222 }
223
224 #[test]
225 fn trim_preserves_system_and_recent() {
226 let dir = std::env::temp_dir().join("baml_rt_test_trim");
227 let _ = std::fs::remove_dir_all(&dir);
228 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 10);
229
230 session.push(TestRole::System, "sys prompt".into());
232 for i in 0..20 {
233 let role = if i % 2 == 0 { TestRole::User } else { TestRole::Assistant };
234 session.push(role, format!("msg {}", i));
235 }
236 assert_eq!(session.len(), 21);
237
238 let trimmed = session.trim();
239 assert!(trimmed > 0);
240 assert!(session.len() <= 12); assert_eq!(session.messages()[0].role(), &TestRole::System);
242 assert_eq!(session.messages()[0].content(), "sys prompt");
243 assert!(session.messages()[1].content().contains("trimmed"));
244 assert_eq!(session.messages().last().unwrap().content(), "msg 19");
245
246 let _ = std::fs::remove_dir_all(&dir);
247 }
248
249 #[test]
250 fn trim_noop_small_history() {
251 let dir = std::env::temp_dir().join("baml_rt_test_noop");
252 let _ = std::fs::remove_dir_all(&dir);
253 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
254 session.push(TestRole::User, "hello".into());
255 assert_eq!(session.trim(), 0);
256 let _ = std::fs::remove_dir_all(&dir);
257 }
258
259 #[test]
260 fn persist_and_reload() {
261 let dir = std::env::temp_dir().join("baml_rt_test_persist");
262 let _ = std::fs::remove_dir_all(&dir);
263 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
264 session.push(TestRole::User, "hello world".into());
265 session.push(TestRole::Assistant, "hi there".into());
266
267 let path = session.session_file().to_path_buf();
268 let loaded = Session::<TestMsg>::resume(&path, dir.to_str().unwrap(), 60);
269 assert_eq!(loaded.len(), 2);
270 assert_eq!(loaded.messages()[0].content(), "hello world");
271 assert_eq!(loaded.messages()[1].role(), &TestRole::Assistant);
272
273 let _ = std::fs::remove_dir_all(&dir);
274 }
275
276 #[test]
277 fn resume_last_finds_latest() {
278 let dir = std::env::temp_dir().join("baml_rt_test_resume");
279 let _ = std::fs::remove_dir_all(&dir);
280
281 let mut s1 = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
282 s1.push(TestRole::User, "first".into());
283
284 std::thread::sleep(std::time::Duration::from_millis(1100));
286
287 let mut s2 = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
288 s2.push(TestRole::User, "second".into());
289
290 let resumed = Session::<TestMsg>::resume_last(dir.to_str().unwrap(), 60).unwrap();
291 assert_eq!(resumed.messages()[0].content(), "second");
292
293 let _ = std::fs::remove_dir_all(&dir);
294 }
295}