1use std::collections::HashMap;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use fluers_core::{AgentMessage, CoreError, Result as CoreResult, TurnSink};
11use parking_lot::RwLock;
12
13use crate::error::RuntimeResult;
14use crate::persistence::PersistenceAdapter;
15use crate::session::{Session, SessionId, SessionState, SessionStore, SCHEMA_VERSION};
16
17pub struct SessionRunner {
20 adapter: Arc<dyn PersistenceAdapter>,
21 session_id: SessionId,
22 model: String,
23 max_turns: usize,
24 system_message: Option<String>,
25 messages: Arc<RwLock<Vec<AgentMessage>>>,
26 metadata: Arc<RwLock<HashMap<String, String>>>,
27}
28
29impl SessionRunner {
30 #[must_use]
32 pub fn new(
33 adapter: Arc<dyn PersistenceAdapter>,
34 session_id: SessionId,
35 model: impl Into<String>,
36 max_turns: usize,
37 system_message: Option<String>,
38 ) -> Self {
39 Self {
40 adapter,
41 session_id,
42 model: model.into(),
43 max_turns,
44 system_message,
45 messages: Arc::new(RwLock::new(Vec::new())),
46 metadata: Arc::new(RwLock::new(HashMap::new())),
47 }
48 }
49
50 pub async fn load(
52 adapter: Arc<dyn PersistenceAdapter>,
53 session_id: SessionId,
54 ) -> RuntimeResult<Option<Self>> {
55 let Some(session) = SessionStore::load(adapter.as_ref(), session_id).await? else {
56 return Ok(None);
57 };
58 Ok(Some(Self::from_session(adapter, session)))
59 }
60
61 #[must_use]
63 pub fn messages(&self) -> Vec<AgentMessage> {
64 self.messages.read().clone()
65 }
66
67 #[must_use]
69 pub fn model_id(&self) -> &str {
70 &self.model
71 }
72
73 #[must_use]
75 pub fn max_turns(&self) -> usize {
76 self.max_turns
77 }
78
79 #[must_use]
81 pub fn system_message(&self) -> Option<String> {
82 self.system_message.clone()
83 }
84
85 fn from_session(adapter: Arc<dyn PersistenceAdapter>, session: Session) -> Self {
86 Self {
87 adapter,
88 session_id: session.id,
89 model: session.model,
90 max_turns: session.max_turns,
91 system_message: session.system_message,
92 messages: Arc::new(RwLock::new(session.messages)),
93 metadata: Arc::new(RwLock::new(session.metadata)),
94 }
95 }
96
97 fn state(&self, messages: Vec<AgentMessage>) -> SessionState {
98 SessionState {
99 schema_version: SCHEMA_VERSION,
100 model: self.model.clone(),
101 max_turns: self.max_turns,
102 system_message: self.system_message.clone(),
103 messages,
104 metadata: self.metadata.read().clone(),
105 }
106 }
107}
108
109#[async_trait]
110impl TurnSink for SessionRunner {
111 async fn after_turn(&self, _turn: usize, messages: &[AgentMessage]) -> CoreResult<()> {
112 let snapshot = messages.to_vec();
113 {
114 let mut current = self.messages.write();
115 *current = snapshot.clone();
116 }
117 let state = self.state(snapshot);
118 let value = serde_json::to_value(&state).map_err(|err| {
119 CoreError::Transport(format!(
120 "failed to serialize session `{}`: {err}",
121 self.session_id
122 ))
123 })?;
124 self.adapter
125 .save_session(&self.session_id.to_string(), &value)
126 .await
127 .map_err(|err| {
128 CoreError::Transport(format!(
129 "failed to save session `{}`: {err}",
130 self.session_id
131 ))
132 })?;
133 Ok(())
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use async_trait::async_trait;
141 use fluers_core::{ContentBlock, Role};
142 use serde_json::Value;
143 use tokio::sync::Mutex;
144 use uuid::Uuid;
145
146 use crate::persistence::{PersistenceAdapter, Result as PersistenceResult};
147
148 type TestResult<T = ()> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
149
150 #[derive(Default)]
151 struct MockAdapter {
152 sessions: Mutex<HashMap<String, Value>>,
153 }
154
155 #[async_trait]
156 impl PersistenceAdapter for MockAdapter {
157 async fn save_session(&self, id: &str, data: &Value) -> PersistenceResult<()> {
158 self.sessions
159 .lock()
160 .await
161 .insert(id.to_string(), data.clone());
162 Ok(())
163 }
164
165 async fn load_session(&self, id: &str) -> PersistenceResult<Option<Value>> {
166 Ok(self.sessions.lock().await.get(id).cloned())
167 }
168
169 async fn list_sessions(&self) -> PersistenceResult<Vec<String>> {
170 Ok(self.sessions.lock().await.keys().cloned().collect())
171 }
172 }
173
174 fn text_message(role: Role, text: &str) -> AgentMessage {
175 AgentMessage {
176 role,
177 content: vec![ContentBlock::Text { text: text.into() }],
178 }
179 }
180
181 fn first_text(messages: &[AgentMessage]) -> Option<&str> {
182 messages
183 .first()
184 .and_then(|message| message.content.first())
185 .and_then(|block| match block {
186 ContentBlock::Text { text } => Some(text.as_str()),
187 _ => None,
188 })
189 }
190
191 #[tokio::test]
192 async fn session_runner_persists_after_turn() -> TestResult {
193 let adapter = Arc::new(MockAdapter::default());
194 let session_id = Uuid::new_v4();
195 let runner = SessionRunner::new(
196 adapter.clone(),
197 session_id,
198 "mock/model",
199 5,
200 Some("be useful".into()),
201 );
202 let messages = vec![text_message(Role::User, "hello")];
203
204 TurnSink::after_turn(&runner, 1, &messages).await?;
205
206 let saved = adapter.load_session(&session_id.to_string()).await?;
207 let Some(value) = saved else {
208 return Err(std::io::Error::other("session was not saved").into());
209 };
210 let state: SessionState = serde_json::from_value(value)?;
211
212 assert_eq!(state.schema_version, SCHEMA_VERSION);
213 assert_eq!(state.model, "mock/model");
214 assert_eq!(state.max_turns, 5);
215 assert_eq!(state.system_message.as_deref(), Some("be useful"));
216 assert_eq!(state.messages.len(), 1);
217 assert_eq!(first_text(&state.messages), Some("hello"));
218 assert_eq!(runner.messages().len(), 1);
219 Ok(())
220 }
221}