1use std::collections::HashMap;
2use std::io::Cursor;
3use std::time::{Duration, Instant};
4
5use anyhow::{Context, Result};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
7use either::Either;
8use image::{DynamicImage, ImageFormat};
9use indexmap::IndexMap;
10use serde::{Deserialize, Serialize};
11
12use crate::{MessageContent, NormalRequest, RequestMessage, VideoInput};
13
14const MAX_SESSIONS: usize = 128;
15const SESSION_TTL: Duration = Duration::from_secs(30 * 60);
16
17#[derive(Clone)]
19pub struct AgenticSessionEntry {
20 pub messages: Vec<IndexMap<String, MessageContent>>,
21 pub images: Vec<DynamicImage>,
23 pub videos: Vec<VideoInput>,
24 last_accessed: Instant,
25}
26
27pub struct AgenticSessionStore {
29 sessions: HashMap<String, AgenticSessionEntry>,
30 approved_agent_sessions: HashMap<String, Instant>,
31}
32
33impl Default for AgenticSessionStore {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl AgenticSessionStore {
40 pub fn new() -> Self {
41 Self {
42 sessions: HashMap::new(),
43 approved_agent_sessions: HashMap::new(),
44 }
45 }
46
47 pub fn approve_agent_actions(&mut self, session_id: impl Into<String>) {
48 self.evict();
49 self.approved_agent_sessions
50 .insert(session_id.into(), Instant::now());
51 }
52
53 pub fn agent_actions_approved(&mut self, session_id: &str) -> bool {
54 self.evict();
55 let Some(last_accessed) = self.approved_agent_sessions.get_mut(session_id) else {
56 return false;
57 };
58 *last_accessed = Instant::now();
59 true
60 }
61
62 pub fn get(&mut self, session_id: &str) -> Option<AgenticSessionEntry> {
64 let entry = self.sessions.get_mut(session_id)?;
65 entry.last_accessed = Instant::now();
66 Some(entry.clone())
67 }
68
69 pub fn find_by_messages(
71 &mut self,
72 incoming: &[IndexMap<String, MessageContent>],
73 ) -> Option<(String, AgenticSessionEntry)> {
74 if incoming.len() < 2 {
76 return None;
77 }
78
79 for (id, entry) in &mut self.sessions {
80 let stored_visible = user_visible_messages(&entry.messages);
81 if stored_visible.len() > incoming.len() {
82 continue;
83 }
84
85 let matches = stored_visible
86 .iter()
87 .zip(incoming.iter())
88 .all(|(stored, inc)| messages_match(stored, inc));
89
90 if matches && !stored_visible.is_empty() {
91 entry.last_accessed = Instant::now();
92 return Some((id.clone(), entry.clone()));
93 }
94 }
95
96 None
97 }
98
99 pub fn save(&mut self, session_id: String, entry: AgenticSessionEntry) {
101 self.evict();
102 self.sessions.insert(session_id, entry);
103 }
104
105 pub fn delete(&mut self, session_id: &str) -> bool {
107 self.approved_agent_sessions.remove(session_id);
108 self.sessions.remove(session_id).is_some()
109 }
110
111 pub fn list_ids(&self) -> Vec<String> {
112 self.sessions.keys().cloned().collect()
113 }
114
115 pub fn export(&mut self, session_id: &str) -> Result<Option<SerializedSession>> {
116 let Some(entry) = self.get(session_id) else {
117 return Ok(None);
118 };
119 Ok(Some(SerializedSession::from_entry(&entry)?))
120 }
121
122 pub fn import(&mut self, session_id: String, serialized: SerializedSession) -> Result<()> {
124 let entry = serialized.into_entry()?;
125 self.save(session_id, entry);
126 Ok(())
127 }
128
129 pub fn fork(&mut self, src: &str, dest: String, num_turns: usize) -> Result<()> {
132 let entry = self
133 .get(src)
134 .ok_or_else(|| anyhow::anyhow!("source session {src} not found"))?;
135
136 let mut turns_seen = 0;
137 let mut cutoff: Option<usize> = None;
138 for (i, m) in entry.messages.iter().enumerate() {
139 let role = m
140 .get("role")
141 .and_then(|r| match r {
142 Either::Left(s) => Some(s.as_str()),
143 _ => None,
144 })
145 .unwrap_or("");
146 if role == "assistant" && !m.contains_key("tool_calls") {
147 turns_seen += 1;
148 if turns_seen == num_turns {
149 cutoff = Some(i);
150 break;
151 }
152 }
153 }
154 let messages = match cutoff {
155 Some(i) => entry.messages[..=i].to_vec(),
156 None => entry.messages.clone(),
157 };
158 let forked = AgenticSessionEntry::new(messages, entry.images.clone(), entry.videos.clone());
159 self.save(dest, forked);
160 Ok(())
161 }
162
163 fn evict(&mut self) {
165 let now = Instant::now();
166
167 self.sessions
168 .retain(|_, entry| now.duration_since(entry.last_accessed) < SESSION_TTL);
169 self.approved_agent_sessions
170 .retain(|_, last_accessed| now.duration_since(*last_accessed) < SESSION_TTL);
171
172 while self.sessions.len() >= MAX_SESSIONS {
173 let oldest = self
174 .sessions
175 .iter()
176 .min_by_key(|(_, e)| e.last_accessed)
177 .map(|(k, _)| k.clone());
178 if let Some(key) = oldest {
179 self.sessions.remove(&key);
180 } else {
181 break;
182 }
183 }
184 }
185}
186
187impl AgenticSessionEntry {
188 pub fn new(
189 messages: Vec<IndexMap<String, MessageContent>>,
190 images: Vec<DynamicImage>,
191 videos: Vec<VideoInput>,
192 ) -> Self {
193 Self {
194 messages,
195 images,
196 videos,
197 last_accessed: Instant::now(),
198 }
199 }
200}
201
202fn user_visible_messages(
204 messages: &[IndexMap<String, MessageContent>],
205) -> Vec<&IndexMap<String, MessageContent>> {
206 messages
207 .iter()
208 .filter(|msg| !is_tool_message(msg))
209 .collect()
210}
211
212fn is_tool_message(msg: &IndexMap<String, MessageContent>) -> bool {
214 let role = msg
215 .get("role")
216 .and_then(|r| match r {
217 Either::Left(s) => Some(s.as_str()),
218 _ => None,
219 })
220 .unwrap_or("");
221
222 if role == "tool" {
223 return true;
224 }
225
226 if msg.contains_key("tool_calls") {
228 return true;
229 }
230
231 false
232}
233
234fn messages_match(
235 a: &IndexMap<String, MessageContent>,
236 b: &IndexMap<String, MessageContent>,
237) -> bool {
238 a.get("role") == b.get("role")
239 && a.get("content") == b.get("content")
240 && a.get("tool_calls") == b.get("tool_calls")
241}
242
243pub fn splice_session_into_request(request: &mut NormalRequest, entry: &AgenticSessionEntry) {
245 let incoming = match &mut request.messages {
246 RequestMessage::Chat { messages, .. } | RequestMessage::MultimodalChat { messages, .. } => {
247 messages
248 }
249 _ => return,
250 };
251
252 let stored = &entry.messages;
253
254 let mut result: Vec<IndexMap<String, MessageContent>> = Vec::new();
255 let mut incoming_idx = 0;
256 let mut stored_idx = 0;
257
258 while stored_idx < stored.len() && incoming_idx < incoming.len() {
259 let stored_msg = &stored[stored_idx];
260
261 if is_tool_message(stored_msg) {
262 result.push(stored_msg.clone());
263 stored_idx += 1;
264 } else {
265 let incoming_msg = &incoming[incoming_idx];
266 if messages_match(stored_msg, incoming_msg) {
267 result.push(stored_msg.clone());
268 stored_idx += 1;
269 incoming_idx += 1;
270 } else {
271 break;
273 }
274 }
275 }
276
277 while stored_idx < stored.len() && is_tool_message(&stored[stored_idx]) {
279 result.push(stored[stored_idx].clone());
280 stored_idx += 1;
281 }
282
283 while incoming_idx < incoming.len() {
284 result.push(incoming[incoming_idx].clone());
285 incoming_idx += 1;
286 }
287
288 *incoming = result;
289
290 if !entry.images.is_empty() || !entry.videos.is_empty() {
291 super::agentic_loop::upgrade_to_multimodal(request);
292 if !entry.images.is_empty() {
293 let req_images = super::agentic_loop::get_images_mut(request);
294 *req_images = entry.images.clone();
295 }
296 if !entry.videos.is_empty() {
297 let req_videos = super::agentic_loop::get_videos_mut(request);
298 *req_videos = entry.videos.clone();
299 }
300 }
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
305#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
306pub struct SerializedSession {
307 #[cfg_attr(feature = "utoipa", schema(value_type = Vec<serde_json::Value>))]
308 pub messages: Vec<IndexMap<String, MessageContent>>,
309 #[serde(default)]
310 pub images: Vec<String>,
311 #[serde(default)]
312 pub videos: Vec<SerializedVideo>,
313 #[serde(default)]
314 #[cfg_attr(feature = "utoipa", schema(value_type = Vec<serde_json::Value>))]
315 pub files: Vec<crate::files::File>,
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
320pub struct SerializedVideo {
321 pub fps: f64,
322 pub frames: Vec<String>,
323 pub total_num_frames: usize,
324 pub sampled_indices: Vec<usize>,
325}
326
327impl SerializedSession {
328 pub fn from_entry(entry: &AgenticSessionEntry) -> Result<Self> {
329 let images = entry
330 .images
331 .iter()
332 .map(encode_png_base64)
333 .collect::<Result<Vec<_>>>()?;
334
335 let videos = entry
336 .videos
337 .iter()
338 .map(SerializedVideo::from_video)
339 .collect::<Result<Vec<_>>>()?;
340
341 Ok(Self {
342 messages: entry.messages.clone(),
343 images,
344 videos,
345 files: Vec::new(),
346 })
347 }
348
349 pub fn into_entry(self) -> Result<AgenticSessionEntry> {
350 let images = self
351 .images
352 .iter()
353 .map(|s| decode_png_base64(s))
354 .collect::<Result<Vec<_>>>()?;
355
356 let videos = self
357 .videos
358 .into_iter()
359 .map(SerializedVideo::into_video)
360 .collect::<Result<Vec<_>>>()?;
361
362 Ok(AgenticSessionEntry {
363 messages: self.messages,
364 images,
365 videos,
366 last_accessed: Instant::now(),
367 })
368 }
369}
370
371impl SerializedVideo {
372 fn from_video(video: &VideoInput) -> Result<Self> {
373 let frames = video
374 .frames
375 .iter()
376 .map(encode_png_base64)
377 .collect::<Result<Vec<_>>>()?;
378 Ok(Self {
379 fps: video.fps,
380 frames,
381 total_num_frames: video.total_num_frames,
382 sampled_indices: video.sampled_indices.clone(),
383 })
384 }
385
386 fn into_video(self) -> Result<VideoInput> {
387 let frames = self
388 .frames
389 .iter()
390 .map(|s| decode_png_base64(s))
391 .collect::<Result<Vec<_>>>()?;
392 Ok(VideoInput {
393 frames,
394 fps: self.fps,
395 total_num_frames: self.total_num_frames,
396 sampled_indices: self.sampled_indices,
397 })
398 }
399}
400
401fn encode_png_base64(img: &DynamicImage) -> Result<String> {
402 let mut buf = Vec::new();
403 img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
404 .context("encoding image as PNG")?;
405 Ok(BASE64.encode(&buf))
406}
407
408fn decode_png_base64(s: &str) -> Result<DynamicImage> {
409 let bytes = BASE64.decode(s).context("base64 decoding image")?;
410 image::load_from_memory(&bytes).context("loading image bytes")
411}