1use async_trait::async_trait;
17use chrono::Utc;
18use everruns_core::error::{AgentLoopError, Result};
19use everruns_core::session_task::{
20 CreateSessionTask, NewTaskMessage, SessionTask, SessionTaskFilter, SessionTaskRegistry,
21 SessionTaskState, SessionTaskUpdate, TaskMessage, TaskMessageDirection, apply_task_update,
22 generate_task_message_id, new_session_task,
23};
24use everruns_core::typed_id::SessionId;
25use rusqlite::OptionalExtension;
26
27use crate::db::SqliteDb;
28use crate::error::LocalError;
29
30#[derive(Clone)]
32pub struct LocalSessionTaskRegistry {
33 db: SqliteDb,
34}
35
36impl LocalSessionTaskRegistry {
37 pub fn new(db: SqliteDb) -> Result<Self> {
39 db.with_conn(|conn| {
40 conn.execute_batch(
41 "CREATE TABLE IF NOT EXISTS local_tasks (
42 id TEXT PRIMARY KEY,
43 session_id TEXT NOT NULL,
44 kind TEXT NOT NULL,
45 state TEXT NOT NULL,
46 snapshot TEXT NOT NULL
47 );
48 CREATE INDEX IF NOT EXISTS idx_local_tasks_session
49 ON local_tasks(session_id);
50 CREATE TABLE IF NOT EXISTS local_task_messages (
51 seq INTEGER PRIMARY KEY AUTOINCREMENT,
52 id TEXT NOT NULL UNIQUE,
53 task_id TEXT NOT NULL,
54 snapshot TEXT NOT NULL,
55 FOREIGN KEY(task_id) REFERENCES local_tasks(id)
56 );
57 CREATE INDEX IF NOT EXISTS idx_local_task_messages_task
58 ON local_task_messages(task_id, seq);",
59 )
60 })
61 .map_err(AgentLoopError::from)?;
62 Ok(Self { db })
63 }
64
65 fn load_task(&self, task_id: &str) -> Result<Option<SessionTask>> {
66 let snapshot: Option<String> = self
67 .db
68 .with_conn(|conn| {
69 conn.query_row(
70 "SELECT snapshot FROM local_tasks WHERE id = ?1",
71 [task_id],
72 |row| row.get(0),
73 )
74 .optional()
75 })
76 .map_err(AgentLoopError::from)?;
77 match snapshot {
78 Some(json) => Ok(Some(
79 serde_json::from_str(&json)
80 .map_err(|e| AgentLoopError::from(LocalError::from(e)))?,
81 )),
82 None => Ok(None),
83 }
84 }
85
86 fn store_task(&self, task: &SessionTask) -> Result<()> {
87 let snapshot =
88 serde_json::to_string(task).map_err(|e| AgentLoopError::from(LocalError::from(e)))?;
89 let id = task.id.clone();
90 let session_id = task.session_id.to_string();
91 let kind = task.kind.clone();
92 let state = task.state.to_string();
93 self.db
94 .with_conn(|conn| {
95 conn.execute(
96 "INSERT INTO local_tasks (id, session_id, kind, state, snapshot)
97 VALUES (?1, ?2, ?3, ?4, ?5)
98 ON CONFLICT(id) DO UPDATE SET
99 session_id = excluded.session_id,
100 kind = excluded.kind,
101 state = excluded.state,
102 snapshot = excluded.snapshot",
103 rusqlite::params![id, session_id, kind, state, snapshot],
104 )
105 })
106 .map_err(AgentLoopError::from)?;
107 Ok(())
108 }
109}
110
111#[async_trait]
112impl SessionTaskRegistry for LocalSessionTaskRegistry {
113 async fn create(&self, input: CreateSessionTask) -> Result<SessionTask> {
114 if let Some(id) = &input.id
118 && let Some(existing) = self.load_task(id)?
119 {
120 if existing.session_id == input.session_id {
121 return Ok(existing);
122 }
123 return Err(AgentLoopError::store(format!(
124 "task id {id} already exists under a different session"
125 )));
126 }
127 let task = new_session_task(input, Utc::now());
128 self.store_task(&task)?;
129 Ok(task)
130 }
131
132 async fn update(
133 &self,
134 session_id: SessionId,
135 task_id: &str,
136 update: SessionTaskUpdate,
137 ) -> Result<Option<SessionTask>> {
138 let Some(mut task) = self.load_task(task_id)? else {
139 return Ok(None);
140 };
141 if task.session_id != session_id {
143 return Ok(None);
144 }
145 apply_task_update(&mut task, update, Utc::now());
146 self.store_task(&task)?;
147 Ok(Some(task))
148 }
149
150 async fn get(&self, session_id: SessionId, task_id: &str) -> Result<Option<SessionTask>> {
151 Ok(self
153 .load_task(task_id)?
154 .filter(|task| task.session_id == session_id))
155 }
156
157 async fn list(
158 &self,
159 session_id: SessionId,
160 filter: Option<&SessionTaskFilter>,
161 ) -> Result<Vec<SessionTask>> {
162 let session = session_id.to_string();
163 let kind = filter.and_then(|f| f.kind.clone());
164 let state = filter.and_then(|f| f.state.map(|s| s.to_string()));
165 let snapshots: Vec<String> = self
166 .db
167 .with_conn(|conn| {
168 let mut sql =
171 String::from("SELECT snapshot FROM local_tasks WHERE session_id = ?1");
172 if kind.is_some() {
173 sql.push_str(" AND kind = ?2");
174 }
175 if state.is_some() {
176 sql.push_str(if kind.is_some() {
180 " AND state = ?3"
181 } else {
182 " AND state = ?2"
183 });
184 }
185 sql.push_str(" ORDER BY rowid ASC");
186
187 let mut stmt = conn.prepare(&sql)?;
188 let rows = match (&kind, &state) {
189 (Some(k), Some(s)) => stmt
190 .query_map(rusqlite::params![session, k, s], |row| row.get(0))?
191 .collect::<rusqlite::Result<Vec<String>>>()?,
192 (Some(k), None) => stmt
193 .query_map(rusqlite::params![session, k], |row| row.get(0))?
194 .collect::<rusqlite::Result<Vec<String>>>()?,
195 (None, Some(s)) => stmt
196 .query_map(rusqlite::params![session, s], |row| row.get(0))?
197 .collect::<rusqlite::Result<Vec<String>>>()?,
198 (None, None) => stmt
199 .query_map(rusqlite::params![session], |row| row.get(0))?
200 .collect::<rusqlite::Result<Vec<String>>>()?,
201 };
202 Ok(rows)
203 })
204 .map_err(AgentLoopError::from)?;
205 snapshots
206 .into_iter()
207 .map(|json| {
208 serde_json::from_str(&json).map_err(|e| AgentLoopError::from(LocalError::from(e)))
209 })
210 .collect()
211 }
212
213 async fn request_cancel(
214 &self,
215 session_id: SessionId,
216 task_id: &str,
217 ) -> Result<Option<SessionTask>> {
218 let Some(mut task) = self.load_task(task_id)? else {
219 return Ok(None);
220 };
221 if task.session_id != session_id {
223 return Ok(None);
224 }
225 task.cancel_requested_at.get_or_insert_with(Utc::now);
227 task.updated_at = Utc::now();
228 self.store_task(&task)?;
229 Ok(Some(task))
230 }
231
232 async fn record_message(
233 &self,
234 session_id: SessionId,
235 task_id: &str,
236 message: NewTaskMessage,
237 ) -> Result<TaskMessage> {
238 let mut task = self
241 .get(session_id, task_id)
242 .await?
243 .ok_or_else(|| AgentLoopError::tool(format!("no task {task_id}")))?;
244 if let Some(expected) = message.expected_attempt
247 && expected != task.attempt
248 {
249 return Err(AgentLoopError::tool(format!(
250 "stale attempt for task {task_id}: expected {expected}, current {}",
251 task.attempt
252 )));
253 }
254
255 let record = TaskMessage {
256 id: generate_task_message_id(),
257 task_id: task_id.to_string(),
258 direction: message.direction,
259 content: message.content,
260 in_reply_to: message.in_reply_to.clone(),
261 created_at: Utc::now(),
262 };
263 let snapshot = serde_json::to_string(&record)
264 .map_err(|e| AgentLoopError::from(LocalError::from(e)))?;
265 let id = record.id.clone();
266 let tid = task_id.to_string();
267 self.db
268 .with_conn(|conn| {
269 conn.execute(
270 "INSERT INTO local_task_messages (id, task_id, snapshot)
271 VALUES (?1, ?2, ?3)",
272 rusqlite::params![id, tid, snapshot],
273 )
274 })
275 .map_err(AgentLoopError::from)?;
276
277 if message.direction == TaskMessageDirection::Inbound
281 && let Some(reply_id) = &message.in_reply_to
282 && task
283 .input_request
284 .as_ref()
285 .is_some_and(|req| &req.id == reply_id)
286 {
287 apply_task_update(
288 &mut task,
289 SessionTaskUpdate {
290 state: Some(SessionTaskState::Running),
291 ..Default::default()
292 },
293 Utc::now(),
294 );
295 self.store_task(&task)?;
296 }
297
298 Ok(record)
299 }
300
301 async fn list_messages(
302 &self,
303 session_id: SessionId,
304 task_id: &str,
305 limit: Option<u32>,
306 after_id: Option<&str>,
307 ) -> Result<Vec<TaskMessage>> {
308 if self.get(session_id, task_id).await?.is_none() {
311 return Ok(Vec::new());
312 }
313 let tid = task_id.to_string();
314 let after = after_id.map(|s| s.to_string());
315 let limit = limit.map(|l| l as i64);
316 let snapshots: Vec<String> = self
317 .db
318 .with_conn(|conn| {
319 let after_seq: Option<i64> = match &after {
321 Some(id) => conn
322 .query_row(
323 "SELECT seq FROM local_task_messages WHERE id = ?1",
324 [id],
325 |row| row.get(0),
326 )
327 .optional()?,
328 None => None,
329 };
330 let mut sql =
331 String::from("SELECT snapshot FROM local_task_messages WHERE task_id = ?1");
332 if after_seq.is_some() {
333 sql.push_str(" AND seq > ?2");
334 }
335 sql.push_str(" ORDER BY seq ASC");
336 if limit.is_some() {
337 sql.push_str(if after_seq.is_some() {
338 " LIMIT ?3"
339 } else {
340 " LIMIT ?2"
341 });
342 }
343 let mut stmt = conn.prepare(&sql)?;
344 let rows = match (after_seq, limit) {
345 (Some(seq), Some(lim)) => stmt
346 .query_map(rusqlite::params![tid, seq, lim], |row| row.get(0))?
347 .collect::<rusqlite::Result<Vec<String>>>()?,
348 (Some(seq), None) => stmt
349 .query_map(rusqlite::params![tid, seq], |row| row.get(0))?
350 .collect::<rusqlite::Result<Vec<String>>>()?,
351 (None, Some(lim)) => stmt
352 .query_map(rusqlite::params![tid, lim], |row| row.get(0))?
353 .collect::<rusqlite::Result<Vec<String>>>()?,
354 (None, None) => stmt
355 .query_map(rusqlite::params![tid], |row| row.get(0))?
356 .collect::<rusqlite::Result<Vec<String>>>()?,
357 };
358 Ok(rows)
359 })
360 .map_err(AgentLoopError::from)?;
361 snapshots
362 .into_iter()
363 .map(|json| {
364 serde_json::from_str(&json).map_err(|e| AgentLoopError::from(LocalError::from(e)))
365 })
366 .collect()
367 }
368}