Skip to main content

hanzo_engine/engine/
agentic_session.rs

1use std::collections::HashMap;
2use std::io::Cursor;
3use std::time::{Duration, Instant};
4
5use anyhow::{Context, Result};
6use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
7use either::Either;
8use image::{DynamicImage, ImageFormat};
9use indexmap::IndexMap;
10use serde::{Deserialize, Serialize};
11
12use crate::{MessageContent, NormalRequest, RequestMessage, VideoInput};
13
14const MAX_SESSIONS: usize = 128;
15const SESSION_TTL: Duration = Duration::from_secs(30 * 60);
16
17/// A stored agentic conversation, tool call/response messages included.
18#[derive(Clone)]
19pub struct AgenticSessionEntry {
20    pub messages: Vec<IndexMap<String, MessageContent>>,
21    /// Positional with `messages`.
22    pub images: Vec<DynamicImage>,
23    pub videos: Vec<VideoInput>,
24    last_accessed: Instant,
25}
26
27/// Agentic conversation state, keyed by session ID. Also supports content-based matching for clients that don't pass an ID.
28pub struct AgenticSessionStore {
29    sessions: HashMap<String, AgenticSessionEntry>,
30    approved_agent_sessions: HashMap<String, Instant>,
31}
32
33impl Default for AgenticSessionStore {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl AgenticSessionStore {
40    pub fn new() -> Self {
41        Self {
42            sessions: HashMap::new(),
43            approved_agent_sessions: HashMap::new(),
44        }
45    }
46
47    pub fn approve_agent_actions(&mut self, session_id: impl Into<String>) {
48        self.evict();
49        self.approved_agent_sessions
50            .insert(session_id.into(), Instant::now());
51    }
52
53    pub fn agent_actions_approved(&mut self, session_id: &str) -> bool {
54        self.evict();
55        let Some(last_accessed) = self.approved_agent_sessions.get_mut(session_id) else {
56            return false;
57        };
58        *last_accessed = Instant::now();
59        true
60    }
61
62    /// Updates `last_accessed`.
63    pub fn get(&mut self, session_id: &str) -> Option<AgenticSessionEntry> {
64        let entry = self.sessions.get_mut(session_id)?;
65        entry.last_accessed = Instant::now();
66        Some(entry.clone())
67    }
68
69    /// Find a stored session whose user-visible messages (no tool turns) are a prefix of `incoming`.
70    pub fn find_by_messages(
71        &mut self,
72        incoming: &[IndexMap<String, MessageContent>],
73    ) -> Option<(String, AgenticSessionEntry)> {
74        // Need at least 2 messages (system/user + assistant) to match meaningfully.
75        if incoming.len() < 2 {
76            return None;
77        }
78
79        for (id, entry) in &mut self.sessions {
80            let stored_visible = user_visible_messages(&entry.messages);
81            if stored_visible.len() > incoming.len() {
82                continue;
83            }
84
85            let matches = stored_visible
86                .iter()
87                .zip(incoming.iter())
88                .all(|(stored, inc)| messages_match(stored, inc));
89
90            if matches && !stored_visible.is_empty() {
91                entry.last_accessed = Instant::now();
92                return Some((id.clone(), entry.clone()));
93            }
94        }
95
96        None
97    }
98
99    /// Save or update. Evicts stale entries if needed.
100    pub fn save(&mut self, session_id: String, entry: AgenticSessionEntry) {
101        self.evict();
102        self.sessions.insert(session_id, entry);
103    }
104
105    /// Returns whether the session existed.
106    pub fn delete(&mut self, session_id: &str) -> bool {
107        self.approved_agent_sessions.remove(session_id);
108        self.sessions.remove(session_id).is_some()
109    }
110
111    pub fn list_ids(&self) -> Vec<String> {
112        self.sessions.keys().cloned().collect()
113    }
114
115    pub fn export(&mut self, session_id: &str) -> Result<Option<SerializedSession>> {
116        let Some(entry) = self.get(session_id) else {
117            return Ok(None);
118        };
119        Ok(Some(SerializedSession::from_entry(&entry)?))
120    }
121
122    /// Replaces any existing entry with the same ID.
123    pub fn import(&mut self, session_id: String, serialized: SerializedSession) -> Result<()> {
124        let entry = serialized.into_entry()?;
125        self.save(session_id, entry);
126        Ok(())
127    }
128
129    /// Clone the first `num_turns` complete turns of `src` into `dest`. A turn ends at the first
130    /// `role: assistant` message that has no `tool_calls` field. Images and videos are copied as-is.
131    pub fn fork(&mut self, src: &str, dest: String, num_turns: usize) -> Result<()> {
132        let entry = self
133            .get(src)
134            .ok_or_else(|| anyhow::anyhow!("source session {src} not found"))?;
135
136        let mut turns_seen = 0;
137        let mut cutoff: Option<usize> = None;
138        for (i, m) in entry.messages.iter().enumerate() {
139            let role = m
140                .get("role")
141                .and_then(|r| match r {
142                    Either::Left(s) => Some(s.as_str()),
143                    _ => None,
144                })
145                .unwrap_or("");
146            if role == "assistant" && !m.contains_key("tool_calls") {
147                turns_seen += 1;
148                if turns_seen == num_turns {
149                    cutoff = Some(i);
150                    break;
151                }
152            }
153        }
154        let messages = match cutoff {
155            Some(i) => entry.messages[..=i].to_vec(),
156            None => entry.messages.clone(),
157        };
158        let forked = AgenticSessionEntry::new(messages, entry.images.clone(), entry.videos.clone());
159        self.save(dest, forked);
160        Ok(())
161    }
162
163    /// Drop expired and over-limit entries.
164    fn evict(&mut self) {
165        let now = Instant::now();
166
167        self.sessions
168            .retain(|_, entry| now.duration_since(entry.last_accessed) < SESSION_TTL);
169        self.approved_agent_sessions
170            .retain(|_, last_accessed| now.duration_since(*last_accessed) < SESSION_TTL);
171
172        while self.sessions.len() >= MAX_SESSIONS {
173            let oldest = self
174                .sessions
175                .iter()
176                .min_by_key(|(_, e)| e.last_accessed)
177                .map(|(k, _)| k.clone());
178            if let Some(key) = oldest {
179                self.sessions.remove(&key);
180            } else {
181                break;
182            }
183        }
184    }
185}
186
187impl AgenticSessionEntry {
188    pub fn new(
189        messages: Vec<IndexMap<String, MessageContent>>,
190        images: Vec<DynamicImage>,
191        videos: Vec<VideoInput>,
192    ) -> Self {
193        Self {
194            messages,
195            images,
196            videos,
197            last_accessed: Instant::now(),
198        }
199    }
200}
201
202/// User-visible messages only, skipping tool call/response messages.
203fn user_visible_messages(
204    messages: &[IndexMap<String, MessageContent>],
205) -> Vec<&IndexMap<String, MessageContent>> {
206    messages
207        .iter()
208        .filter(|msg| !is_tool_message(msg))
209        .collect()
210}
211
212/// True for tool call / tool response messages.
213fn is_tool_message(msg: &IndexMap<String, MessageContent>) -> bool {
214    let role = msg
215        .get("role")
216        .and_then(|r| match r {
217            Either::Left(s) => Some(s.as_str()),
218            _ => None,
219        })
220        .unwrap_or("");
221
222    if role == "tool" {
223        return true;
224    }
225
226    // Assistant messages with tool_calls.
227    if msg.contains_key("tool_calls") {
228        return true;
229    }
230
231    false
232}
233
234fn messages_match(
235    a: &IndexMap<String, MessageContent>,
236    b: &IndexMap<String, MessageContent>,
237) -> bool {
238    a.get("role") == b.get("role")
239        && a.get("content") == b.get("content")
240        && a.get("tool_calls") == b.get("tool_calls")
241}
242
243/// Splice stored tool call/response messages back into an incoming request between matched user-visible messages.
244pub fn splice_session_into_request(request: &mut NormalRequest, entry: &AgenticSessionEntry) {
245    let incoming = match &mut request.messages {
246        RequestMessage::Chat { messages, .. } | RequestMessage::MultimodalChat { messages, .. } => {
247            messages
248        }
249        _ => return,
250    };
251
252    let stored = &entry.messages;
253
254    let mut result: Vec<IndexMap<String, MessageContent>> = Vec::new();
255    let mut incoming_idx = 0;
256    let mut stored_idx = 0;
257
258    while stored_idx < stored.len() && incoming_idx < incoming.len() {
259        let stored_msg = &stored[stored_idx];
260
261        if is_tool_message(stored_msg) {
262            result.push(stored_msg.clone());
263            stored_idx += 1;
264        } else {
265            let incoming_msg = &incoming[incoming_idx];
266            if messages_match(stored_msg, incoming_msg) {
267                result.push(stored_msg.clone());
268                stored_idx += 1;
269                incoming_idx += 1;
270            } else {
271                // Conversation diverged. Stop splicing.
272                break;
273            }
274        }
275    }
276
277    // Drain trailing tool messages after the last matched user-visible message.
278    while stored_idx < stored.len() && is_tool_message(&stored[stored_idx]) {
279        result.push(stored[stored_idx].clone());
280        stored_idx += 1;
281    }
282
283    while incoming_idx < incoming.len() {
284        result.push(incoming[incoming_idx].clone());
285        incoming_idx += 1;
286    }
287
288    *incoming = result;
289
290    if !entry.images.is_empty() || !entry.videos.is_empty() {
291        super::agentic_loop::upgrade_to_multimodal(request);
292        if !entry.images.is_empty() {
293            let req_images = super::agentic_loop::get_images_mut(request);
294            *req_images = entry.images.clone();
295        }
296        if !entry.videos.is_empty() {
297            let req_videos = super::agentic_loop::get_videos_mut(request);
298            *req_videos = entry.videos.clone();
299        }
300    }
301}
302
303/// Wire format. Images and video frames are base64 PNGs.
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
306pub struct SerializedSession {
307    #[cfg_attr(feature = "utoipa", schema(value_type = Vec<serde_json::Value>))]
308    pub messages: Vec<IndexMap<String, MessageContent>>,
309    #[serde(default)]
310    pub images: Vec<String>,
311    #[serde(default)]
312    pub videos: Vec<SerializedVideo>,
313    #[serde(default)]
314    #[cfg_attr(feature = "utoipa", schema(value_type = Vec<serde_json::Value>))]
315    pub files: Vec<crate::files::File>,
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
320pub struct SerializedVideo {
321    pub fps: f64,
322    pub frames: Vec<String>,
323    pub total_num_frames: usize,
324    pub sampled_indices: Vec<usize>,
325}
326
327impl SerializedSession {
328    pub fn from_entry(entry: &AgenticSessionEntry) -> Result<Self> {
329        let images = entry
330            .images
331            .iter()
332            .map(encode_png_base64)
333            .collect::<Result<Vec<_>>>()?;
334
335        let videos = entry
336            .videos
337            .iter()
338            .map(SerializedVideo::from_video)
339            .collect::<Result<Vec<_>>>()?;
340
341        Ok(Self {
342            messages: entry.messages.clone(),
343            images,
344            videos,
345            files: Vec::new(),
346        })
347    }
348
349    pub fn into_entry(self) -> Result<AgenticSessionEntry> {
350        let images = self
351            .images
352            .iter()
353            .map(|s| decode_png_base64(s))
354            .collect::<Result<Vec<_>>>()?;
355
356        let videos = self
357            .videos
358            .into_iter()
359            .map(SerializedVideo::into_video)
360            .collect::<Result<Vec<_>>>()?;
361
362        Ok(AgenticSessionEntry {
363            messages: self.messages,
364            images,
365            videos,
366            last_accessed: Instant::now(),
367        })
368    }
369}
370
371impl SerializedVideo {
372    fn from_video(video: &VideoInput) -> Result<Self> {
373        let frames = video
374            .frames
375            .iter()
376            .map(encode_png_base64)
377            .collect::<Result<Vec<_>>>()?;
378        Ok(Self {
379            fps: video.fps,
380            frames,
381            total_num_frames: video.total_num_frames,
382            sampled_indices: video.sampled_indices.clone(),
383        })
384    }
385
386    fn into_video(self) -> Result<VideoInput> {
387        let frames = self
388            .frames
389            .iter()
390            .map(|s| decode_png_base64(s))
391            .collect::<Result<Vec<_>>>()?;
392        Ok(VideoInput {
393            frames,
394            fps: self.fps,
395            total_num_frames: self.total_num_frames,
396            sampled_indices: self.sampled_indices,
397        })
398    }
399}
400
401fn encode_png_base64(img: &DynamicImage) -> Result<String> {
402    let mut buf = Vec::new();
403    img.write_to(&mut Cursor::new(&mut buf), ImageFormat::Png)
404        .context("encoding image as PNG")?;
405    Ok(BASE64.encode(&buf))
406}
407
408fn decode_png_base64(s: &str) -> Result<DynamicImage> {
409    let bytes = BASE64.decode(s).context("base64 decoding image")?;
410    image::load_from_memory(&bytes).context("loading image bytes")
411}