Skip to main content

baml_agent/
session.rs

1use std::fs::OpenOptions;
2use std::io::{BufRead, BufReader, Write};
3use std::path::{Path, PathBuf};
4
5/// Role of a message in the agent conversation.
6pub trait MessageRole: Clone + PartialEq {
7    fn system() -> Self;
8    fn user() -> Self;
9    fn assistant() -> Self;
10    fn tool() -> Self;
11    fn as_str(&self) -> &str;
12    fn from_str(s: &str) -> Option<Self>;
13    fn is_system(&self) -> bool {
14        self.as_str() == "system"
15    }
16}
17
18/// A message in the agent conversation.
19pub trait AgentMessage: Clone {
20    type Role: MessageRole;
21    fn new(role: Self::Role, content: String) -> Self;
22    fn role(&self) -> &Self::Role;
23    fn content(&self) -> &str;
24}
25
26#[derive(serde::Serialize, serde::Deserialize)]
27struct PersistedMessage {
28    role: String,
29    content: String,
30}
31
32/// Session manager: JSONL persistence, history access, context trimming.
33pub struct Session<M: AgentMessage> {
34    messages: Vec<M>,
35    session_file: PathBuf,
36    max_history: usize,
37}
38
39impl<M: AgentMessage> Session<M> {
40    /// Create a new session with a fresh JSONL file.
41    pub fn new(session_dir: &str, max_history: usize) -> Self {
42        let _ = std::fs::create_dir_all(session_dir);
43        let ts = std::time::SystemTime::now()
44            .duration_since(std::time::UNIX_EPOCH)
45            .unwrap_or_default()
46            .as_secs();
47        let session_file = PathBuf::from(format!("{}/session_{}.jsonl", session_dir, ts));
48        Self {
49            messages: Vec::new(),
50            session_file,
51            max_history,
52        }
53    }
54
55    /// Resume from a specific session file.
56    pub fn resume(path: &Path, _session_dir: &str, max_history: usize) -> Self {
57        let messages = Self::load_file(path);
58        Self {
59            messages,
60            session_file: path.to_path_buf(),
61            max_history,
62        }
63    }
64
65    /// Resume the most recent session in the session directory.
66    pub fn resume_last(session_dir: &str, max_history: usize) -> Option<Self> {
67        let last = Self::find_last_session(session_dir)?;
68        Some(Self::resume(&last, session_dir, max_history))
69    }
70
71    /// Push a message, persist to JSONL, return ref.
72    pub fn push(&mut self, role: <M as AgentMessage>::Role, content: String) -> &M {
73        let msg = M::new(role, content);
74        self.messages.push(msg);
75        self.persist_last();
76        self.messages.last().expect("just pushed")
77    }
78
79    /// Push a pre-built message.
80    pub fn push_msg(&mut self, msg: M) {
81        self.messages.push(msg);
82        self.persist_last();
83    }
84
85    /// Access messages.
86    pub fn messages(&self) -> &[M] {
87        &self.messages
88    }
89
90    /// Mutable access to messages (for external trimming).
91    pub fn messages_mut(&mut self) -> &mut Vec<M> {
92        &mut self.messages
93    }
94
95    pub fn is_empty(&self) -> bool {
96        self.messages.is_empty()
97    }
98
99    pub fn len(&self) -> usize {
100        self.messages.len()
101    }
102
103    pub fn session_file(&self) -> &Path {
104        &self.session_file
105    }
106
107    /// Trim history to fit context window.
108    ///
109    /// Preserves system messages and the most recent non-system messages.
110    /// Inserts a "[N earlier messages trimmed]" system notice.
111    /// Returns the number of trimmed messages (0 if no trimming needed).
112    pub fn trim(&mut self) -> usize {
113        if self.messages.len() <= self.max_history {
114            return 0;
115        }
116
117        let system_msgs: Vec<M> = self.messages
118            .iter()
119            .filter(|m| m.role().is_system())
120            .cloned()
121            .collect();
122
123        let non_system: Vec<M> = self.messages
124            .iter()
125            .filter(|m| !m.role().is_system())
126            .cloned()
127            .collect();
128
129        let keep = self.max_history.saturating_sub(system_msgs.len());
130        let skip = non_system.len().saturating_sub(keep);
131
132        if skip == 0 {
133            return 0;
134        }
135
136        let mut trimmed = system_msgs;
137        trimmed.push(M::new(
138            <M as AgentMessage>::Role::system(),
139            format!("[{} earlier messages trimmed]", skip),
140        ));
141        trimmed.extend(non_system.into_iter().skip(skip));
142        self.messages = trimmed;
143        skip
144    }
145
146    // --- Private ---
147
148    fn persist_last(&self) {
149        let Some(msg) = self.messages.last() else { return };
150        let persisted = PersistedMessage {
151            role: msg.role().as_str().into(),
152            content: msg.content().into(),
153        };
154        let Ok(json) = serde_json::to_string(&persisted) else { return };
155        let Ok(mut f) = OpenOptions::new().create(true).append(true).open(&self.session_file) else { return };
156        let _ = writeln!(f, "{}", json);
157    }
158
159    fn load_file(path: &Path) -> Vec<M> {
160        let Ok(file) = std::fs::File::open(path) else { return vec![] };
161        BufReader::new(file)
162            .lines()
163            .map_while(Result::ok)
164            .filter_map(|line| serde_json::from_str::<PersistedMessage>(&line).ok())
165            .filter_map(|p| {
166                let role = <M as AgentMessage>::Role::from_str(&p.role)?;
167                Some(M::new(role, p.content))
168            })
169            .collect()
170    }
171
172    fn find_last_session(dir: &str) -> Option<PathBuf> {
173        let mut entries: Vec<_> = std::fs::read_dir(dir)
174            .ok()?
175            .filter_map(|e| e.ok())
176            .filter(|e| e.path().extension().is_some_and(|ext| ext == "jsonl"))
177            .collect();
178        entries.sort_by_key(|e| e.file_name());
179        entries.last().map(|e| e.path())
180    }
181}
182
183#[cfg(test)]
184pub(crate) mod tests {
185    use super::*;
186
187    #[derive(Clone, Debug, PartialEq)]
188    pub(crate) enum TestRole { System, User, Assistant, Tool }
189
190    impl MessageRole for TestRole {
191        fn system() -> Self { Self::System }
192        fn user() -> Self { Self::User }
193        fn assistant() -> Self { Self::Assistant }
194        fn tool() -> Self { Self::Tool }
195        fn as_str(&self) -> &str {
196            match self {
197                Self::System => "system",
198                Self::User => "user",
199                Self::Assistant => "assistant",
200                Self::Tool => "tool",
201            }
202        }
203        fn from_str(s: &str) -> Option<Self> {
204            match s {
205                "system" => Some(Self::System),
206                "user" => Some(Self::User),
207                "assistant" => Some(Self::Assistant),
208                "tool" => Some(Self::Tool),
209                _ => None,
210            }
211        }
212    }
213
214    #[derive(Clone)]
215    pub(crate) struct TestMsg { pub role: TestRole, pub content: String }
216
217    impl AgentMessage for TestMsg {
218        type Role = TestRole;
219        fn new(role: TestRole, content: String) -> Self { Self { role, content } }
220        fn role(&self) -> &TestRole { &self.role }
221        fn content(&self) -> &str { &self.content }
222    }
223
224    #[test]
225    fn trim_preserves_system_and_recent() {
226        let dir = std::env::temp_dir().join("baml_rt_test_trim");
227        let _ = std::fs::remove_dir_all(&dir);
228        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 10);
229
230        // 1 system + 20 user/assistant
231        session.push(TestRole::System, "sys prompt".into());
232        for i in 0..20 {
233            let role = if i % 2 == 0 { TestRole::User } else { TestRole::Assistant };
234            session.push(role, format!("msg {}", i));
235        }
236        assert_eq!(session.len(), 21);
237
238        let trimmed = session.trim();
239        assert!(trimmed > 0);
240        assert!(session.len() <= 12); // 10 max + system + trim notice
241        assert_eq!(session.messages()[0].role(), &TestRole::System);
242        assert_eq!(session.messages()[0].content(), "sys prompt");
243        assert!(session.messages()[1].content().contains("trimmed"));
244        assert_eq!(session.messages().last().unwrap().content(), "msg 19");
245
246        let _ = std::fs::remove_dir_all(&dir);
247    }
248
249    #[test]
250    fn trim_noop_small_history() {
251        let dir = std::env::temp_dir().join("baml_rt_test_noop");
252        let _ = std::fs::remove_dir_all(&dir);
253        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
254        session.push(TestRole::User, "hello".into());
255        assert_eq!(session.trim(), 0);
256        let _ = std::fs::remove_dir_all(&dir);
257    }
258
259    #[test]
260    fn persist_and_reload() {
261        let dir = std::env::temp_dir().join("baml_rt_test_persist");
262        let _ = std::fs::remove_dir_all(&dir);
263        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
264        session.push(TestRole::User, "hello world".into());
265        session.push(TestRole::Assistant, "hi there".into());
266
267        let path = session.session_file().to_path_buf();
268        let loaded = Session::<TestMsg>::resume(&path, dir.to_str().unwrap(), 60);
269        assert_eq!(loaded.len(), 2);
270        assert_eq!(loaded.messages()[0].content(), "hello world");
271        assert_eq!(loaded.messages()[1].role(), &TestRole::Assistant);
272
273        let _ = std::fs::remove_dir_all(&dir);
274    }
275
276    #[test]
277    fn resume_last_finds_latest() {
278        let dir = std::env::temp_dir().join("baml_rt_test_resume");
279        let _ = std::fs::remove_dir_all(&dir);
280
281        let mut s1 = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
282        s1.push(TestRole::User, "first".into());
283
284        // Ensure different timestamp
285        std::thread::sleep(std::time::Duration::from_millis(1100));
286
287        let mut s2 = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
288        s2.push(TestRole::User, "second".into());
289
290        let resumed = Session::<TestMsg>::resume_last(dir.to_str().unwrap(), 60).unwrap();
291        assert_eq!(resumed.messages()[0].content(), "second");
292
293        let _ = std::fs::remove_dir_all(&dir);
294    }
295}