Skip to main content

everruns_local/
task_registry.rs

1// SQLite-backed SessionTaskRegistry.
2//
3// Persists tasks and their message channel so a freshly-spawned process can
4// reopen the database file and read / continue / inspect tasks
5// (restart-survivability). Lifecycle invariants are NOT reimplemented here:
6// every update routes through `everruns_core::session_task::apply_task_update`,
7// matching the postgres / in-memory backends.
8//
9// Storage shape: one `local_tasks` row per task. The full `SessionTask` is
10// stored as a JSON snapshot for faithful round-tripping; `session_id`, `kind`,
11// and `state` are also stored as plain columns so `SessionTaskFilter` queries
12// hit an index instead of deserializing every row. Messages live in
13// `local_task_messages`, ordered by an autoincrement `seq` to give a stable
14// oldest-first order and a cheap `after_id` cursor.
15
16use async_trait::async_trait;
17use chrono::Utc;
18use everruns_core::error::{AgentLoopError, Result};
19use everruns_core::session_task::{
20    CreateSessionTask, NewTaskMessage, SessionTask, SessionTaskFilter, SessionTaskRegistry,
21    SessionTaskState, SessionTaskUpdate, TaskMessage, TaskMessageDirection, apply_task_update,
22    generate_task_message_id, new_session_task,
23};
24use everruns_core::typed_id::SessionId;
25use rusqlite::OptionalExtension;
26
27use crate::db::SqliteDb;
28use crate::error::LocalError;
29
30/// SQLite-backed task registry for local embedded hosts.
31#[derive(Clone)]
32pub struct LocalSessionTaskRegistry {
33    db: SqliteDb,
34}
35
36impl LocalSessionTaskRegistry {
37    /// Open (and migrate) a registry over the given database handle.
38    pub fn new(db: SqliteDb) -> Result<Self> {
39        db.with_conn(|conn| {
40            conn.execute_batch(
41                "CREATE TABLE IF NOT EXISTS local_tasks (
42                    id          TEXT PRIMARY KEY,
43                    session_id  TEXT NOT NULL,
44                    kind        TEXT NOT NULL,
45                    state       TEXT NOT NULL,
46                    snapshot    TEXT NOT NULL
47                 );
48                 CREATE INDEX IF NOT EXISTS idx_local_tasks_session
49                    ON local_tasks(session_id);
50                 CREATE TABLE IF NOT EXISTS local_task_messages (
51                    seq         INTEGER PRIMARY KEY AUTOINCREMENT,
52                    id          TEXT NOT NULL UNIQUE,
53                    task_id     TEXT NOT NULL,
54                    snapshot    TEXT NOT NULL,
55                    FOREIGN KEY(task_id) REFERENCES local_tasks(id)
56                 );
57                 CREATE INDEX IF NOT EXISTS idx_local_task_messages_task
58                    ON local_task_messages(task_id, seq);",
59            )
60        })
61        .map_err(AgentLoopError::from)?;
62        Ok(Self { db })
63    }
64
65    fn load_task(&self, task_id: &str) -> Result<Option<SessionTask>> {
66        let snapshot: Option<String> = self
67            .db
68            .with_conn(|conn| {
69                conn.query_row(
70                    "SELECT snapshot FROM local_tasks WHERE id = ?1",
71                    [task_id],
72                    |row| row.get(0),
73                )
74                .optional()
75            })
76            .map_err(AgentLoopError::from)?;
77        match snapshot {
78            Some(json) => Ok(Some(
79                serde_json::from_str(&json)
80                    .map_err(|e| AgentLoopError::from(LocalError::from(e)))?,
81            )),
82            None => Ok(None),
83        }
84    }
85
86    fn store_task(&self, task: &SessionTask) -> Result<()> {
87        let snapshot =
88            serde_json::to_string(task).map_err(|e| AgentLoopError::from(LocalError::from(e)))?;
89        let id = task.id.clone();
90        let session_id = task.session_id.to_string();
91        let kind = task.kind.clone();
92        let state = task.state.to_string();
93        self.db
94            .with_conn(|conn| {
95                conn.execute(
96                    "INSERT INTO local_tasks (id, session_id, kind, state, snapshot)
97                     VALUES (?1, ?2, ?3, ?4, ?5)
98                     ON CONFLICT(id) DO UPDATE SET
99                        session_id = excluded.session_id,
100                        kind = excluded.kind,
101                        state = excluded.state,
102                        snapshot = excluded.snapshot",
103                    rusqlite::params![id, session_id, kind, state, snapshot],
104                )
105            })
106            .map_err(AgentLoopError::from)?;
107        Ok(())
108    }
109}
110
111#[async_trait]
112impl SessionTaskRegistry for LocalSessionTaskRegistry {
113    async fn create(&self, input: CreateSessionTask) -> Result<SessionTask> {
114        // Idempotent on a caller-supplied id, but only within the same session.
115        // Reusing an id across sessions is rejected, matching the canonical
116        // DB-backed registry, so a caller cannot alias another session's task.
117        if let Some(id) = &input.id
118            && let Some(existing) = self.load_task(id)?
119        {
120            if existing.session_id == input.session_id {
121                return Ok(existing);
122            }
123            return Err(AgentLoopError::store(format!(
124                "task id {id} already exists under a different session"
125            )));
126        }
127        let task = new_session_task(input, Utc::now());
128        self.store_task(&task)?;
129        Ok(task)
130    }
131
132    async fn update(
133        &self,
134        session_id: SessionId,
135        task_id: &str,
136        update: SessionTaskUpdate,
137    ) -> Result<Option<SessionTask>> {
138        let Some(mut task) = self.load_task(task_id)? else {
139            return Ok(None);
140        };
141        // Session-scoped: ignore updates targeting a task in another session.
142        if task.session_id != session_id {
143            return Ok(None);
144        }
145        apply_task_update(&mut task, update, Utc::now());
146        self.store_task(&task)?;
147        Ok(Some(task))
148    }
149
150    async fn get(&self, session_id: SessionId, task_id: &str) -> Result<Option<SessionTask>> {
151        // Session-scoped: a task id from another session is not visible here.
152        Ok(self
153            .load_task(task_id)?
154            .filter(|task| task.session_id == session_id))
155    }
156
157    async fn list(
158        &self,
159        session_id: SessionId,
160        filter: Option<&SessionTaskFilter>,
161    ) -> Result<Vec<SessionTask>> {
162        let session = session_id.to_string();
163        let kind = filter.and_then(|f| f.kind.clone());
164        let state = filter.and_then(|f| f.state.map(|s| s.to_string()));
165        let snapshots: Vec<String> = self
166            .db
167            .with_conn(|conn| {
168                // Build the query with optional kind/state predicates. Bind
169                // params positionally to keep the prepared statement simple.
170                let mut sql =
171                    String::from("SELECT snapshot FROM local_tasks WHERE session_id = ?1");
172                if kind.is_some() {
173                    sql.push_str(" AND kind = ?2");
174                }
175                if state.is_some() {
176                    // ?3 if kind present, else ?2 — rusqlite positional binding
177                    // tolerates gaps, so always use ?3 and bind kind as NULL
178                    // when absent is not possible; instead branch explicitly.
179                    sql.push_str(if kind.is_some() {
180                        " AND state = ?3"
181                    } else {
182                        " AND state = ?2"
183                    });
184                }
185                sql.push_str(" ORDER BY rowid ASC");
186
187                let mut stmt = conn.prepare(&sql)?;
188                let rows = match (&kind, &state) {
189                    (Some(k), Some(s)) => stmt
190                        .query_map(rusqlite::params![session, k, s], |row| row.get(0))?
191                        .collect::<rusqlite::Result<Vec<String>>>()?,
192                    (Some(k), None) => stmt
193                        .query_map(rusqlite::params![session, k], |row| row.get(0))?
194                        .collect::<rusqlite::Result<Vec<String>>>()?,
195                    (None, Some(s)) => stmt
196                        .query_map(rusqlite::params![session, s], |row| row.get(0))?
197                        .collect::<rusqlite::Result<Vec<String>>>()?,
198                    (None, None) => stmt
199                        .query_map(rusqlite::params![session], |row| row.get(0))?
200                        .collect::<rusqlite::Result<Vec<String>>>()?,
201                };
202                Ok(rows)
203            })
204            .map_err(AgentLoopError::from)?;
205        snapshots
206            .into_iter()
207            .map(|json| {
208                serde_json::from_str(&json).map_err(|e| AgentLoopError::from(LocalError::from(e)))
209            })
210            .collect()
211    }
212
213    async fn request_cancel(
214        &self,
215        session_id: SessionId,
216        task_id: &str,
217    ) -> Result<Option<SessionTask>> {
218        let Some(mut task) = self.load_task(task_id)? else {
219            return Ok(None);
220        };
221        // Session-scoped: do not record cancel intent on another session's task.
222        if task.session_id != session_id {
223            return Ok(None);
224        }
225        // Cooperative cancel: record intent, do not change state. Idempotent.
226        task.cancel_requested_at.get_or_insert_with(Utc::now);
227        task.updated_at = Utc::now();
228        self.store_task(&task)?;
229        Ok(Some(task))
230    }
231
232    async fn record_message(
233        &self,
234        session_id: SessionId,
235        task_id: &str,
236        message: NewTaskMessage,
237    ) -> Result<TaskMessage> {
238        // Session-scoped: a message may only be appended to a task that belongs
239        // to the calling session (mirrors the DB-backed registry).
240        let mut task = self
241            .get(session_id, task_id)
242            .await?
243            .ok_or_else(|| AgentLoopError::tool(format!("no task {task_id}")))?;
244        // Stale-attempt fence: reject writes from a superseded executor so the
245        // thread cannot grow under a zombie. Mirrors the postgres backend.
246        if let Some(expected) = message.expected_attempt
247            && expected != task.attempt
248        {
249            return Err(AgentLoopError::tool(format!(
250                "stale attempt for task {task_id}: expected {expected}, current {}",
251                task.attempt
252            )));
253        }
254
255        let record = TaskMessage {
256            id: generate_task_message_id(),
257            task_id: task_id.to_string(),
258            direction: message.direction,
259            content: message.content,
260            in_reply_to: message.in_reply_to.clone(),
261            created_at: Utc::now(),
262        };
263        let snapshot = serde_json::to_string(&record)
264            .map_err(|e| AgentLoopError::from(LocalError::from(e)))?;
265        let id = record.id.clone();
266        let tid = task_id.to_string();
267        self.db
268            .with_conn(|conn| {
269                conn.execute(
270                    "INSERT INTO local_task_messages (id, task_id, snapshot)
271                     VALUES (?1, ?2, ?3)",
272                    rusqlite::params![id, tid, snapshot],
273                )
274            })
275            .map_err(AgentLoopError::from)?;
276
277        // An inbound answer (in_reply_to set) clears a matching pending input
278        // request and returns the task to running. Only inbound messages resume
279        // the task, matching the DB-backed registry; outbound messages never do.
280        if message.direction == TaskMessageDirection::Inbound
281            && let Some(reply_id) = &message.in_reply_to
282            && task
283                .input_request
284                .as_ref()
285                .is_some_and(|req| &req.id == reply_id)
286        {
287            apply_task_update(
288                &mut task,
289                SessionTaskUpdate {
290                    state: Some(SessionTaskState::Running),
291                    ..Default::default()
292                },
293                Utc::now(),
294            );
295            self.store_task(&task)?;
296        }
297
298        Ok(record)
299    }
300
301    async fn list_messages(
302        &self,
303        session_id: SessionId,
304        task_id: &str,
305        limit: Option<u32>,
306        after_id: Option<&str>,
307    ) -> Result<Vec<TaskMessage>> {
308        // Session-scoped: do not leak another session's message history even
309        // when the task id is known. Missing/foreign task -> empty list.
310        if self.get(session_id, task_id).await?.is_none() {
311            return Ok(Vec::new());
312        }
313        let tid = task_id.to_string();
314        let after = after_id.map(|s| s.to_string());
315        let limit = limit.map(|l| l as i64);
316        let snapshots: Vec<String> = self
317            .db
318            .with_conn(|conn| {
319                // Resolve the exclusive cursor seq, if any.
320                let after_seq: Option<i64> = match &after {
321                    Some(id) => conn
322                        .query_row(
323                            "SELECT seq FROM local_task_messages WHERE id = ?1",
324                            [id],
325                            |row| row.get(0),
326                        )
327                        .optional()?,
328                    None => None,
329                };
330                let mut sql =
331                    String::from("SELECT snapshot FROM local_task_messages WHERE task_id = ?1");
332                if after_seq.is_some() {
333                    sql.push_str(" AND seq > ?2");
334                }
335                sql.push_str(" ORDER BY seq ASC");
336                if limit.is_some() {
337                    sql.push_str(if after_seq.is_some() {
338                        " LIMIT ?3"
339                    } else {
340                        " LIMIT ?2"
341                    });
342                }
343                let mut stmt = conn.prepare(&sql)?;
344                let rows = match (after_seq, limit) {
345                    (Some(seq), Some(lim)) => stmt
346                        .query_map(rusqlite::params![tid, seq, lim], |row| row.get(0))?
347                        .collect::<rusqlite::Result<Vec<String>>>()?,
348                    (Some(seq), None) => stmt
349                        .query_map(rusqlite::params![tid, seq], |row| row.get(0))?
350                        .collect::<rusqlite::Result<Vec<String>>>()?,
351                    (None, Some(lim)) => stmt
352                        .query_map(rusqlite::params![tid, lim], |row| row.get(0))?
353                        .collect::<rusqlite::Result<Vec<String>>>()?,
354                    (None, None) => stmt
355                        .query_map(rusqlite::params![tid], |row| row.get(0))?
356                        .collect::<rusqlite::Result<Vec<String>>>()?,
357                };
358                Ok(rows)
359            })
360            .map_err(AgentLoopError::from)?;
361        snapshots
362            .into_iter()
363            .map(|json| {
364                serde_json::from_str(&json).map_err(|e| AgentLoopError::from(LocalError::from(e)))
365            })
366            .collect()
367    }
368}