Skip to main content

fluers_runtime/
session.rs

1//! Sessions and the event store.
2//!
3//! Mirrors Flue's session machinery (`SessionStore`, `event-stream-store`,
4//! `dispatch`/`invoke`). MVP holds sessions in memory while exposing a typed
5//! persistence envelope for durable resume.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11use serde_json::Value;
12use uuid::Uuid;
13
14use fluers_core::AgentMessage;
15
16use crate::error::{RuntimeError, RuntimeResult};
17use crate::persistence::PersistenceAdapter;
18
19/// Current on-disk session envelope schema version.
20pub const SCHEMA_VERSION: u32 = 1;
21
22/// A unique session id.
23pub type SessionId = Uuid;
24
25/// On-disk envelope for a resumable session. Carries everything the
26/// coordinator needs to reconstruct a run after a process restart.
27#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct SessionState {
29    /// Envelope schema version. Bumped on breaking changes to this struct.
30    pub schema_version: u32,
31    /// The model id used for this session.
32    pub model: String,
33    /// Max turns configured for this session.
34    pub max_turns: usize,
35    /// The system message (instructions) for this session.
36    pub system_message: Option<String>,
37    /// The full conversation log.
38    pub messages: Vec<AgentMessage>,
39    /// Arbitrary metadata (e.g. created_at, tags).
40    pub metadata: HashMap<String, String>,
41}
42
43/// One session: its id, configuration, message log, and metadata.
44#[derive(Debug, Clone)]
45pub struct Session {
46    /// The id.
47    pub id: SessionId,
48    /// The model id used for this session.
49    pub model: String,
50    /// Max turns configured for this session.
51    pub max_turns: usize,
52    /// The system message (instructions) for this session.
53    pub system_message: Option<String>,
54    /// Conversation so far.
55    pub messages: Vec<AgentMessage>,
56    /// Arbitrary metadata.
57    pub metadata: HashMap<String, String>,
58}
59
60impl Session {
61    fn to_state(&self) -> SessionState {
62        SessionState {
63            schema_version: SCHEMA_VERSION,
64            model: self.model.clone(),
65            max_turns: self.max_turns,
66            system_message: self.system_message.clone(),
67            messages: self.messages.clone(),
68            metadata: self.metadata.clone(),
69        }
70    }
71
72    fn from_state(id: SessionId, state: SessionState) -> Self {
73        Self {
74            id,
75            model: state.model,
76            max_turns: state.max_turns,
77            system_message: state.system_message,
78            messages: state.messages,
79            metadata: state.metadata,
80        }
81    }
82}
83
84/// An in-memory session store.
85///
86/// The store remains synchronous for in-process mutation (`append`), while
87/// explicit `save`/`load` methods bridge to async persistence adapters.
88#[derive(Default)]
89pub struct SessionStore {
90    inner: RwLock<HashMap<SessionId, Session>>,
91}
92
93impl SessionStore {
94    /// Create an empty store.
95    #[must_use]
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    /// Create a new session with default, unspecified configuration and return
101    /// its id.
102    pub fn create(&self) -> SessionId {
103        self.create_with_config(String::new(), 0, None)
104    }
105
106    /// Create a new session with explicit run configuration and return its id.
107    pub fn create_with_config(
108        &self,
109        model: impl Into<String>,
110        max_turns: usize,
111        system_message: Option<String>,
112    ) -> SessionId {
113        let id = Uuid::new_v4();
114        let session = Session {
115            id,
116            model: model.into(),
117            max_turns,
118            system_message,
119            messages: Vec::new(),
120            metadata: HashMap::new(),
121        };
122        self.inner.write().insert(id, session);
123        id
124    }
125
126    /// Append a message to a session.
127    pub fn append(&self, id: SessionId, message: AgentMessage) -> RuntimeResult<()> {
128        let mut guard = self.inner.write();
129        let session = guard
130            .get_mut(&id)
131            .ok_or_else(|| RuntimeError::SessionNotFound(id.to_string()))?;
132        session.messages.push(message);
133        Ok(())
134    }
135
136    /// Snapshot a session's messages.
137    pub fn messages(&self, id: SessionId) -> RuntimeResult<Vec<AgentMessage>> {
138        let guard = self.inner.read();
139        guard
140            .get(&id)
141            .map(|s| s.messages.clone())
142            .ok_or_else(|| RuntimeError::SessionNotFound(id.to_string()))
143    }
144
145    /// Persist a session through the provided adapter.
146    pub async fn save(&self, adapter: &dyn PersistenceAdapter, id: SessionId) -> RuntimeResult<()> {
147        let session = {
148            let guard = self.inner.read();
149            guard
150                .get(&id)
151                .cloned()
152                .ok_or_else(|| RuntimeError::SessionNotFound(id.to_string()))?
153        };
154        let value = state_to_value(&session.to_state(), id)?;
155        adapter
156            .save_session(&id.to_string(), &value)
157            .await
158            .map_err(RuntimeError::from)
159    }
160
161    /// Load a session from the provided adapter.
162    pub async fn load(
163        adapter: &dyn PersistenceAdapter,
164        id: SessionId,
165    ) -> RuntimeResult<Option<Session>> {
166        let Some(value) = adapter
167            .load_session(&id.to_string())
168            .await
169            .map_err(RuntimeError::from)?
170        else {
171            return Ok(None);
172        };
173        let state = value_to_state(value, id)?;
174        Ok(Some(Session::from_state(id, state)))
175    }
176
177    /// List all persisted sessions from the provided adapter.
178    pub async fn list(adapter: &dyn PersistenceAdapter) -> RuntimeResult<Vec<SessionId>> {
179        adapter
180            .list_sessions()
181            .await
182            .map_err(RuntimeError::from)?
183            .into_iter()
184            .map(|raw| {
185                Uuid::parse_str(&raw).map_err(|err| {
186                    RuntimeError::Persistence(format!(
187                        "invalid persisted session id `{raw}`: {err}"
188                    ))
189                })
190            })
191            .collect()
192    }
193}
194
195fn state_to_value(state: &SessionState, id: SessionId) -> RuntimeResult<Value> {
196    serde_json::to_value(state).map_err(|err| {
197        RuntimeError::Persistence(format!("failed to serialize session `{id}`: {err}"))
198    })
199}
200
201fn value_to_state(value: Value, id: SessionId) -> RuntimeResult<SessionState> {
202    let state: SessionState = serde_json::from_value(value).map_err(|err| {
203        RuntimeError::Persistence(format!("failed to deserialize session `{id}`: {err}"))
204    })?;
205    if state.schema_version != SCHEMA_VERSION {
206        return Err(RuntimeError::Persistence(format!(
207            "unsupported session schema version {} for `{id}` (expected {SCHEMA_VERSION})",
208            state.schema_version
209        )));
210    }
211    Ok(state)
212}
213
214/// Shared pointer to a session store.
215pub type SharedSessionStore = Arc<SessionStore>;
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use async_trait::async_trait;
221    use fluers_core::{ContentBlock, Role};
222    use serde_json::{json, Value};
223    use tokio::sync::Mutex;
224
225    use crate::persistence::{PersistenceAdapter, Result as PersistenceResult};
226
227    type TestResult<T = ()> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
228
229    #[derive(Default)]
230    struct MockAdapter {
231        sessions: Mutex<HashMap<String, Value>>,
232    }
233
234    impl MockAdapter {
235        async fn put(&self, id: SessionId, value: Value) {
236            self.sessions.lock().await.insert(id.to_string(), value);
237        }
238    }
239
240    #[async_trait]
241    impl PersistenceAdapter for MockAdapter {
242        async fn save_session(&self, id: &str, data: &Value) -> PersistenceResult<()> {
243            self.sessions
244                .lock()
245                .await
246                .insert(id.to_string(), data.clone());
247            Ok(())
248        }
249
250        async fn load_session(&self, id: &str) -> PersistenceResult<Option<Value>> {
251            Ok(self.sessions.lock().await.get(id).cloned())
252        }
253
254        async fn list_sessions(&self) -> PersistenceResult<Vec<String>> {
255            Ok(self.sessions.lock().await.keys().cloned().collect())
256        }
257    }
258
259    fn text_message(role: Role, text: &str) -> AgentMessage {
260        AgentMessage {
261            role,
262            content: vec![ContentBlock::Text { text: text.into() }],
263        }
264    }
265
266    fn first_text(messages: &[AgentMessage]) -> Option<&str> {
267        messages
268            .first()
269            .and_then(|message| message.content.first())
270            .and_then(|block| match block {
271                ContentBlock::Text { text } => Some(text.as_str()),
272                _ => None,
273            })
274    }
275
276    #[test]
277    fn session_state_roundtrips() -> TestResult {
278        let state = SessionState {
279            schema_version: SCHEMA_VERSION,
280            model: "mock/model".into(),
281            max_turns: 4,
282            system_message: Some("be useful".into()),
283            messages: vec![text_message(Role::User, "hello")],
284            metadata: HashMap::from([("tag".into(), "test".into())]),
285        };
286
287        let value = serde_json::to_value(&state)?;
288        let roundtrip: SessionState = serde_json::from_value(value)?;
289
290        assert_eq!(roundtrip.schema_version, SCHEMA_VERSION);
291        assert_eq!(roundtrip.model, "mock/model");
292        assert_eq!(roundtrip.max_turns, 4);
293        assert_eq!(roundtrip.system_message.as_deref(), Some("be useful"));
294        assert_eq!(roundtrip.messages.len(), 1);
295        assert_eq!(first_text(&roundtrip.messages), Some("hello"));
296        assert_eq!(
297            roundtrip.metadata.get("tag").map(String::as_str),
298            Some("test")
299        );
300        Ok(())
301    }
302
303    #[tokio::test]
304    async fn session_save_then_load() -> TestResult {
305        let adapter = MockAdapter::default();
306        let store = SessionStore::new();
307        let id = store.create_with_config("mock/model", 8, Some("system".into()));
308        store.append(id, text_message(Role::User, "persist me"))?;
309
310        store.save(&adapter, id).await?;
311        let loaded = SessionStore::load(&adapter, id).await?;
312        let Some(session) = loaded else {
313            return Err(std::io::Error::other("session was not loaded").into());
314        };
315
316        assert_eq!(session.id, id);
317        assert_eq!(session.model, "mock/model");
318        assert_eq!(session.max_turns, 8);
319        assert_eq!(session.system_message.as_deref(), Some("system"));
320        assert_eq!(session.messages.len(), 1);
321        assert_eq!(first_text(&session.messages), Some("persist me"));
322        Ok(())
323    }
324
325    #[tokio::test]
326    async fn schema_version_mismatch_errors() -> TestResult {
327        let adapter = MockAdapter::default();
328        let id = Uuid::new_v4();
329        adapter
330            .put(
331                id,
332                json!({
333                    "schema_version": SCHEMA_VERSION + 1,
334                    "model": "mock/model",
335                    "max_turns": 4,
336                    "system_message": null,
337                    "messages": [],
338                    "metadata": {}
339                }),
340            )
341            .await;
342
343        let result = SessionStore::load(&adapter, id).await;
344
345        assert!(
346            matches!(result, Err(RuntimeError::Persistence(ref message)) if message.contains("unsupported session schema version")),
347            "expected schema version error, got {result:?}"
348        );
349        Ok(())
350    }
351}