1use std::collections::HashMap;
8use std::sync::Arc;
9
10use parking_lot::RwLock;
11use serde_json::Value;
12use uuid::Uuid;
13
14use fluers_core::AgentMessage;
15
16use crate::error::{RuntimeError, RuntimeResult};
17use crate::persistence::PersistenceAdapter;
18
19pub const SCHEMA_VERSION: u32 = 1;
21
22pub type SessionId = Uuid;
24
25#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28pub struct SessionState {
29 pub schema_version: u32,
31 pub model: String,
33 pub max_turns: usize,
35 pub system_message: Option<String>,
37 pub messages: Vec<AgentMessage>,
39 pub metadata: HashMap<String, String>,
41}
42
43#[derive(Debug, Clone)]
45pub struct Session {
46 pub id: SessionId,
48 pub model: String,
50 pub max_turns: usize,
52 pub system_message: Option<String>,
54 pub messages: Vec<AgentMessage>,
56 pub metadata: HashMap<String, String>,
58}
59
60impl Session {
61 fn to_state(&self) -> SessionState {
62 SessionState {
63 schema_version: SCHEMA_VERSION,
64 model: self.model.clone(),
65 max_turns: self.max_turns,
66 system_message: self.system_message.clone(),
67 messages: self.messages.clone(),
68 metadata: self.metadata.clone(),
69 }
70 }
71
72 fn from_state(id: SessionId, state: SessionState) -> Self {
73 Self {
74 id,
75 model: state.model,
76 max_turns: state.max_turns,
77 system_message: state.system_message,
78 messages: state.messages,
79 metadata: state.metadata,
80 }
81 }
82}
83
84#[derive(Default)]
89pub struct SessionStore {
90 inner: RwLock<HashMap<SessionId, Session>>,
91}
92
93impl SessionStore {
94 #[must_use]
96 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn create(&self) -> SessionId {
103 self.create_with_config(String::new(), 0, None)
104 }
105
106 pub fn create_with_config(
108 &self,
109 model: impl Into<String>,
110 max_turns: usize,
111 system_message: Option<String>,
112 ) -> SessionId {
113 let id = Uuid::new_v4();
114 let session = Session {
115 id,
116 model: model.into(),
117 max_turns,
118 system_message,
119 messages: Vec::new(),
120 metadata: HashMap::new(),
121 };
122 self.inner.write().insert(id, session);
123 id
124 }
125
126 pub fn append(&self, id: SessionId, message: AgentMessage) -> RuntimeResult<()> {
128 let mut guard = self.inner.write();
129 let session = guard
130 .get_mut(&id)
131 .ok_or_else(|| RuntimeError::SessionNotFound(id.to_string()))?;
132 session.messages.push(message);
133 Ok(())
134 }
135
136 pub fn messages(&self, id: SessionId) -> RuntimeResult<Vec<AgentMessage>> {
138 let guard = self.inner.read();
139 guard
140 .get(&id)
141 .map(|s| s.messages.clone())
142 .ok_or_else(|| RuntimeError::SessionNotFound(id.to_string()))
143 }
144
145 pub async fn save(&self, adapter: &dyn PersistenceAdapter, id: SessionId) -> RuntimeResult<()> {
147 let session = {
148 let guard = self.inner.read();
149 guard
150 .get(&id)
151 .cloned()
152 .ok_or_else(|| RuntimeError::SessionNotFound(id.to_string()))?
153 };
154 let value = state_to_value(&session.to_state(), id)?;
155 adapter
156 .save_session(&id.to_string(), &value)
157 .await
158 .map_err(RuntimeError::from)
159 }
160
161 pub async fn load(
163 adapter: &dyn PersistenceAdapter,
164 id: SessionId,
165 ) -> RuntimeResult<Option<Session>> {
166 let Some(value) = adapter
167 .load_session(&id.to_string())
168 .await
169 .map_err(RuntimeError::from)?
170 else {
171 return Ok(None);
172 };
173 let state = value_to_state(value, id)?;
174 Ok(Some(Session::from_state(id, state)))
175 }
176
177 pub async fn list(adapter: &dyn PersistenceAdapter) -> RuntimeResult<Vec<SessionId>> {
179 adapter
180 .list_sessions()
181 .await
182 .map_err(RuntimeError::from)?
183 .into_iter()
184 .map(|raw| {
185 Uuid::parse_str(&raw).map_err(|err| {
186 RuntimeError::Persistence(format!(
187 "invalid persisted session id `{raw}`: {err}"
188 ))
189 })
190 })
191 .collect()
192 }
193}
194
195fn state_to_value(state: &SessionState, id: SessionId) -> RuntimeResult<Value> {
196 serde_json::to_value(state).map_err(|err| {
197 RuntimeError::Persistence(format!("failed to serialize session `{id}`: {err}"))
198 })
199}
200
201fn value_to_state(value: Value, id: SessionId) -> RuntimeResult<SessionState> {
202 let state: SessionState = serde_json::from_value(value).map_err(|err| {
203 RuntimeError::Persistence(format!("failed to deserialize session `{id}`: {err}"))
204 })?;
205 if state.schema_version != SCHEMA_VERSION {
206 return Err(RuntimeError::Persistence(format!(
207 "unsupported session schema version {} for `{id}` (expected {SCHEMA_VERSION})",
208 state.schema_version
209 )));
210 }
211 Ok(state)
212}
213
214pub type SharedSessionStore = Arc<SessionStore>;
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use async_trait::async_trait;
221 use fluers_core::{ContentBlock, Role};
222 use serde_json::{json, Value};
223 use tokio::sync::Mutex;
224
225 use crate::persistence::{PersistenceAdapter, Result as PersistenceResult};
226
227 type TestResult<T = ()> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
228
229 #[derive(Default)]
230 struct MockAdapter {
231 sessions: Mutex<HashMap<String, Value>>,
232 }
233
234 impl MockAdapter {
235 async fn put(&self, id: SessionId, value: Value) {
236 self.sessions.lock().await.insert(id.to_string(), value);
237 }
238 }
239
240 #[async_trait]
241 impl PersistenceAdapter for MockAdapter {
242 async fn save_session(&self, id: &str, data: &Value) -> PersistenceResult<()> {
243 self.sessions
244 .lock()
245 .await
246 .insert(id.to_string(), data.clone());
247 Ok(())
248 }
249
250 async fn load_session(&self, id: &str) -> PersistenceResult<Option<Value>> {
251 Ok(self.sessions.lock().await.get(id).cloned())
252 }
253
254 async fn list_sessions(&self) -> PersistenceResult<Vec<String>> {
255 Ok(self.sessions.lock().await.keys().cloned().collect())
256 }
257 }
258
259 fn text_message(role: Role, text: &str) -> AgentMessage {
260 AgentMessage {
261 role,
262 content: vec![ContentBlock::Text { text: text.into() }],
263 }
264 }
265
266 fn first_text(messages: &[AgentMessage]) -> Option<&str> {
267 messages
268 .first()
269 .and_then(|message| message.content.first())
270 .and_then(|block| match block {
271 ContentBlock::Text { text } => Some(text.as_str()),
272 _ => None,
273 })
274 }
275
276 #[test]
277 fn session_state_roundtrips() -> TestResult {
278 let state = SessionState {
279 schema_version: SCHEMA_VERSION,
280 model: "mock/model".into(),
281 max_turns: 4,
282 system_message: Some("be useful".into()),
283 messages: vec![text_message(Role::User, "hello")],
284 metadata: HashMap::from([("tag".into(), "test".into())]),
285 };
286
287 let value = serde_json::to_value(&state)?;
288 let roundtrip: SessionState = serde_json::from_value(value)?;
289
290 assert_eq!(roundtrip.schema_version, SCHEMA_VERSION);
291 assert_eq!(roundtrip.model, "mock/model");
292 assert_eq!(roundtrip.max_turns, 4);
293 assert_eq!(roundtrip.system_message.as_deref(), Some("be useful"));
294 assert_eq!(roundtrip.messages.len(), 1);
295 assert_eq!(first_text(&roundtrip.messages), Some("hello"));
296 assert_eq!(
297 roundtrip.metadata.get("tag").map(String::as_str),
298 Some("test")
299 );
300 Ok(())
301 }
302
303 #[tokio::test]
304 async fn session_save_then_load() -> TestResult {
305 let adapter = MockAdapter::default();
306 let store = SessionStore::new();
307 let id = store.create_with_config("mock/model", 8, Some("system".into()));
308 store.append(id, text_message(Role::User, "persist me"))?;
309
310 store.save(&adapter, id).await?;
311 let loaded = SessionStore::load(&adapter, id).await?;
312 let Some(session) = loaded else {
313 return Err(std::io::Error::other("session was not loaded").into());
314 };
315
316 assert_eq!(session.id, id);
317 assert_eq!(session.model, "mock/model");
318 assert_eq!(session.max_turns, 8);
319 assert_eq!(session.system_message.as_deref(), Some("system"));
320 assert_eq!(session.messages.len(), 1);
321 assert_eq!(first_text(&session.messages), Some("persist me"));
322 Ok(())
323 }
324
325 #[tokio::test]
326 async fn schema_version_mismatch_errors() -> TestResult {
327 let adapter = MockAdapter::default();
328 let id = Uuid::new_v4();
329 adapter
330 .put(
331 id,
332 json!({
333 "schema_version": SCHEMA_VERSION + 1,
334 "model": "mock/model",
335 "max_turns": 4,
336 "system_message": null,
337 "messages": [],
338 "metadata": {}
339 }),
340 )
341 .await;
342
343 let result = SessionStore::load(&adapter, id).await;
344
345 assert!(
346 matches!(result, Err(RuntimeError::Persistence(ref message)) if message.contains("unsupported session schema version")),
347 "expected schema version error, got {result:?}"
348 );
349 Ok(())
350 }
351}