harn_vm/
agent_sessions.rs1use std::cell::{Cell, RefCell};
21use std::collections::{BTreeMap, HashMap};
22use std::rc::Rc;
23use std::time::Instant;
24
25use crate::value::VmValue;
26
27pub const DEFAULT_SESSION_CAP: usize = 128;
30
31pub struct SessionState {
32 pub id: String,
33 pub transcript: VmValue,
34 pub subscribers: Vec<VmValue>,
35 pub created_at: Instant,
36 pub last_accessed: Instant,
37}
38
39impl SessionState {
40 fn new(id: String) -> Self {
41 let now = Instant::now();
42 let transcript = empty_transcript(&id);
43 Self {
44 id,
45 transcript,
46 subscribers: Vec::new(),
47 created_at: now,
48 last_accessed: now,
49 }
50 }
51}
52
53thread_local! {
54 static SESSIONS: RefCell<HashMap<String, SessionState>> = RefCell::new(HashMap::new());
55 static SESSION_CAP: Cell<usize> = const { Cell::new(DEFAULT_SESSION_CAP) };
56}
57
58pub fn set_session_cap(cap: usize) {
61 SESSION_CAP.with(|c| c.set(cap.max(1)));
62}
63
64pub fn session_cap() -> usize {
65 SESSION_CAP.with(|c| c.get())
66}
67
68pub fn reset_session_store() {
70 SESSIONS.with(|s| s.borrow_mut().clear());
71}
72
73pub fn exists(id: &str) -> bool {
74 SESSIONS.with(|s| s.borrow().contains_key(id))
75}
76
77pub fn length(id: &str) -> Option<usize> {
78 SESSIONS.with(|s| {
79 s.borrow().get(id).map(|state| {
80 state
81 .transcript
82 .as_dict()
83 .and_then(|d| d.get("messages"))
84 .and_then(|v| match v {
85 VmValue::List(list) => Some(list.len()),
86 _ => None,
87 })
88 .unwrap_or(0)
89 })
90 })
91}
92
93pub fn snapshot(id: &str) -> Option<VmValue> {
94 SESSIONS.with(|s| s.borrow().get(id).map(|state| state.transcript.clone()))
95}
96
97pub fn open_or_create(id: Option<String>) -> String {
99 let resolved = id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string());
100 SESSIONS.with(|s| {
101 let mut map = s.borrow_mut();
102 if let Some(state) = map.get_mut(&resolved) {
103 state.last_accessed = Instant::now();
104 return;
105 }
106 let cap = SESSION_CAP.with(|c| c.get());
107 if map.len() >= cap {
108 if let Some(victim) = map
109 .iter()
110 .min_by_key(|(_, state)| state.last_accessed)
111 .map(|(id, _)| id.clone())
112 {
113 map.remove(&victim);
114 }
115 }
116 map.insert(resolved.clone(), SessionState::new(resolved.clone()));
117 });
118 resolved
119}
120
121pub fn close(id: &str) {
122 SESSIONS.with(|s| {
123 s.borrow_mut().remove(id);
124 });
125}
126
127pub fn reset_transcript(id: &str) -> bool {
128 SESSIONS.with(|s| {
129 let mut map = s.borrow_mut();
130 let Some(state) = map.get_mut(id) else {
131 return false;
132 };
133 state.transcript = empty_transcript(id);
134 state.last_accessed = Instant::now();
135 true
136 })
137}
138
139pub fn fork(src_id: &str, dst_id: Option<String>) -> Option<String> {
146 let (src_transcript, dst) = SESSIONS.with(|s| {
147 let mut map = s.borrow_mut();
148 let src = map.get_mut(src_id)?;
149 src.last_accessed = Instant::now();
150 let dst = dst_id.unwrap_or_else(|| uuid::Uuid::now_v7().to_string());
151 let forked_transcript = clone_transcript_with_id(&src.transcript, &dst);
152 Some((forked_transcript, dst))
153 })?;
154 open_or_create(Some(dst.clone()));
156 SESSIONS.with(|s| {
157 if let Some(state) = s.borrow_mut().get_mut(&dst) {
158 state.transcript = src_transcript;
159 state.last_accessed = Instant::now();
160 }
161 });
162 if exists(&dst) {
166 Some(dst)
167 } else {
168 None
169 }
170}
171
172pub fn trim(id: &str, keep_last: usize) -> Option<usize> {
175 SESSIONS.with(|s| {
176 let mut map = s.borrow_mut();
177 let state = map.get_mut(id)?;
178 let dict = state.transcript.as_dict()?.clone();
179 let messages: Vec<VmValue> = match dict.get("messages") {
180 Some(VmValue::List(list)) => list.iter().cloned().collect(),
181 _ => Vec::new(),
182 };
183 let start = messages.len().saturating_sub(keep_last);
184 let retained: Vec<VmValue> = messages.into_iter().skip(start).collect();
185 let kept = retained.len();
186 let mut next = dict;
187 next.insert(
188 "events".to_string(),
189 VmValue::List(Rc::new(
190 crate::llm::helpers::transcript_events_from_messages(&retained),
191 )),
192 );
193 next.insert("messages".to_string(), VmValue::List(Rc::new(retained)));
194 state.transcript = VmValue::Dict(Rc::new(next));
195 state.last_accessed = Instant::now();
196 Some(kept)
197 })
198}
199
200pub fn inject_message(id: &str, message: VmValue) -> Result<(), String> {
203 let Some(msg_dict) = message.as_dict().cloned() else {
204 return Err("agent_session_inject: message must be a dict".into());
205 };
206 let role_ok = matches!(msg_dict.get("role"), Some(VmValue::String(_)));
207 if !role_ok {
208 return Err(
209 "agent_session_inject: message must have a string `role` (user|assistant|tool_result|system)"
210 .into(),
211 );
212 }
213 SESSIONS.with(|s| {
214 let mut map = s.borrow_mut();
215 let Some(state) = map.get_mut(id) else {
216 return Err(format!("agent_session_inject: unknown session id '{id}'"));
217 };
218 let dict = state
219 .transcript
220 .as_dict()
221 .cloned()
222 .unwrap_or_else(BTreeMap::new);
223 let mut messages: Vec<VmValue> = match dict.get("messages") {
224 Some(VmValue::List(list)) => list.iter().cloned().collect(),
225 _ => Vec::new(),
226 };
227 messages.push(VmValue::Dict(Rc::new(msg_dict)));
228 let mut next = dict;
229 next.insert(
230 "events".to_string(),
231 VmValue::List(Rc::new(
232 crate::llm::helpers::transcript_events_from_messages(&messages),
233 )),
234 );
235 next.insert("messages".to_string(), VmValue::List(Rc::new(messages)));
236 state.transcript = VmValue::Dict(Rc::new(next));
237 state.last_accessed = Instant::now();
238 Ok(())
239 })
240}
241
242pub fn messages_json(id: &str) -> Vec<serde_json::Value> {
246 SESSIONS.with(|s| {
247 let map = s.borrow();
248 let Some(state) = map.get(id) else {
249 return Vec::new();
250 };
251 let Some(dict) = state.transcript.as_dict() else {
252 return Vec::new();
253 };
254 match dict.get("messages") {
255 Some(VmValue::List(list)) => list
256 .iter()
257 .map(crate::llm::helpers::vm_value_to_json)
258 .collect(),
259 _ => Vec::new(),
260 }
261 })
262}
263
264pub fn store_transcript(id: &str, transcript: VmValue) {
267 SESSIONS.with(|s| {
268 if let Some(state) = s.borrow_mut().get_mut(id) {
269 state.transcript = transcript;
270 state.last_accessed = Instant::now();
271 }
272 });
273}
274
275pub fn replace_messages(id: &str, messages: &[serde_json::Value]) {
278 SESSIONS.with(|s| {
279 let mut map = s.borrow_mut();
280 let Some(state) = map.get_mut(id) else {
281 return;
282 };
283 let dict = state
284 .transcript
285 .as_dict()
286 .cloned()
287 .unwrap_or_else(BTreeMap::new);
288 let vm_messages: Vec<VmValue> = messages
289 .iter()
290 .map(crate::stdlib::json_to_vm_value)
291 .collect();
292 let mut next = dict;
293 next.insert(
294 "events".to_string(),
295 VmValue::List(Rc::new(
296 crate::llm::helpers::transcript_events_from_messages(&vm_messages),
297 )),
298 );
299 next.insert("messages".to_string(), VmValue::List(Rc::new(vm_messages)));
300 state.transcript = VmValue::Dict(Rc::new(next));
301 state.last_accessed = Instant::now();
302 });
303}
304
305pub fn append_subscriber(id: &str, callback: VmValue) {
306 open_or_create(Some(id.to_string()));
307 SESSIONS.with(|s| {
308 if let Some(state) = s.borrow_mut().get_mut(id) {
309 state.subscribers.push(callback);
310 state.last_accessed = Instant::now();
311 }
312 });
313}
314
315pub fn subscribers_for(id: &str) -> Vec<VmValue> {
316 SESSIONS.with(|s| {
317 s.borrow()
318 .get(id)
319 .map(|state| state.subscribers.clone())
320 .unwrap_or_default()
321 })
322}
323
324pub fn subscriber_count(id: &str) -> usize {
325 SESSIONS.with(|s| {
326 s.borrow()
327 .get(id)
328 .map(|state| state.subscribers.len())
329 .unwrap_or(0)
330 })
331}
332
333fn empty_transcript(id: &str) -> VmValue {
334 use crate::llm::helpers::new_transcript_with;
335 new_transcript_with(Some(id.to_string()), Vec::new(), None, None)
336}
337
338fn clone_transcript_with_id(transcript: &VmValue, new_id: &str) -> VmValue {
339 let Some(dict) = transcript.as_dict() else {
340 return empty_transcript(new_id);
341 };
342 let mut next = dict.clone();
343 next.insert(
344 "id".to_string(),
345 VmValue::String(Rc::from(new_id.to_string())),
346 );
347 VmValue::Dict(Rc::new(next))
348}