Skip to main content

agent_sdk_store_sqlite/
agent_pool.rs

1use std::{
2    collections::{BTreeMap, BTreeSet, VecDeque},
3    path::{Path, PathBuf},
4    sync::{Arc, Mutex},
5};
6
7use agent_sdk_core::{
8    AgentError, AgentErrorKind, AgentPoolId, AgentPoolMember, AgentPoolSnapshot, AgentPoolStore,
9    AgentPoolStoreConfig, AgentPoolStoreCursor, AgentPoolStoreRecord, AgentPoolStoreRecordPayload,
10    AgentPoolStoreStream, AgentPoolStoredMessage, AgentPoolStoredWake, CompiledEventFilter,
11    IdempotencyKey, MessageId, MessageReceipt, RetryClassification, RunId, RunMessage, TopicId,
12    WakeCondition, WakeConditionId, WakeRegistration,
13};
14use rusqlite::{Connection, params};
15
16/// SQLite-backed implementation of `AgentPoolStore`.
17///
18/// Two independent `SqliteAgentPoolStore` values opened against the same
19/// database file share pool membership, messages, wake registrations, dedupe,
20/// rehydration, and watch cursors. The adapter stores core pool records as JSON
21/// and replays them into snapshots; it does not own workflow scheduling or a
22/// second event bus.
23#[derive(Clone)]
24pub struct SqliteAgentPoolStore {
25    path: PathBuf,
26    connection: Arc<Mutex<Connection>>,
27}
28
29impl SqliteAgentPoolStore {
30    /// Opens or creates a SQLite-backed agent-pool store.
31    pub fn open(path: impl AsRef<Path>) -> Result<Self, AgentError> {
32        let path = path.as_ref().to_path_buf();
33        let connection = Connection::open(&path).map_err(sqlite_error)?;
34        connection
35            .execute_batch(
36                "PRAGMA journal_mode = WAL;
37                 CREATE TABLE IF NOT EXISTS agent_pool_records (
38                     pool_id TEXT NOT NULL,
39                     seq INTEGER NOT NULL,
40                     kind TEXT NOT NULL,
41                     payload_json TEXT NOT NULL,
42                     PRIMARY KEY (pool_id, seq)
43                 );
44                 CREATE TABLE IF NOT EXISTS agent_pool_event_seq (
45                     pool_id TEXT PRIMARY KEY,
46                     seq INTEGER NOT NULL
47                 );",
48            )
49            .map_err(sqlite_error)?;
50        Ok(Self {
51            path,
52            connection: Arc::new(Mutex::new(connection)),
53        })
54    }
55
56    /// Returns the backing database path.
57    pub fn path(&self) -> &Path {
58        &self.path
59    }
60
61    fn append_record(
62        &self,
63        pool_id: &AgentPoolId,
64        payload: AgentPoolStoreRecordPayload,
65    ) -> Result<AgentPoolStoreCursor, AgentError> {
66        let mut connection = self.connection()?;
67        let transaction = connection.transaction().map_err(sqlite_error)?;
68        let next_seq = transaction
69            .query_row(
70                "SELECT COALESCE(MAX(seq), 0) + 1 FROM agent_pool_records WHERE pool_id = ?1",
71                params![pool_id.as_str()],
72                |row| row.get::<_, i64>(0),
73            )
74            .map_err(sqlite_error)?;
75        let payload_json = serde_json::to_string(&payload).map_err(serde_error)?;
76        transaction
77            .execute(
78                "INSERT INTO agent_pool_records (pool_id, seq, kind, payload_json)
79                 VALUES (?1, ?2, ?3, ?4)",
80                params![
81                    pool_id.as_str(),
82                    next_seq,
83                    payload_kind(&payload),
84                    payload_json
85                ],
86            )
87            .map_err(sqlite_error)?;
88        transaction.commit().map_err(sqlite_error)?;
89        Ok(AgentPoolStoreCursor::new(next_seq as u64))
90    }
91
92    fn records_after(
93        &self,
94        pool_id: &AgentPoolId,
95        cursor: Option<AgentPoolStoreCursor>,
96    ) -> Result<Vec<AgentPoolStoreRecord>, AgentError> {
97        let start_after = cursor.map(|cursor| cursor.sequence).unwrap_or(0);
98        let connection = self.connection()?;
99        let mut statement = connection
100            .prepare(
101                "SELECT seq, payload_json
102                 FROM agent_pool_records
103                 WHERE pool_id = ?1 AND seq > ?2
104                 ORDER BY seq ASC",
105            )
106            .map_err(sqlite_error)?;
107        let rows = statement
108            .query_map(params![pool_id.as_str(), start_after as i64], |row| {
109                let seq: i64 = row.get(0)?;
110                let payload_json: String = row.get(1)?;
111                Ok((seq, payload_json))
112            })
113            .map_err(sqlite_error)?;
114
115        let mut records = Vec::new();
116        for row in rows {
117            let (seq, payload_json) = row.map_err(sqlite_error)?;
118            let payload = serde_json::from_str::<AgentPoolStoreRecordPayload>(&payload_json)
119                .map_err(serde_error)?;
120            records.push(AgentPoolStoreRecord {
121                pool_id: pool_id.clone(),
122                cursor: AgentPoolStoreCursor::new(seq as u64),
123                payload,
124            });
125        }
126        Ok(records)
127    }
128
129    fn replay(&self, pool_id: &AgentPoolId) -> Result<PoolReplay, AgentError> {
130        let mut replay = PoolReplay::default();
131        for record in self.records_after(pool_id, Some(AgentPoolStoreCursor::start()))? {
132            replay.cursor = Some(record.cursor.clone());
133            replay.apply(record.payload)?;
134        }
135        Ok(replay)
136    }
137
138    fn connection(&self) -> Result<std::sync::MutexGuard<'_, Connection>, AgentError> {
139        self.connection
140            .lock()
141            .map_err(|_| AgentError::contract_violation("sqlite agent pool store lock poisoned"))
142    }
143}
144
145impl AgentPoolStore for SqliteAgentPoolStore {
146    fn open_pool(
147        &self,
148        pool_id: AgentPoolId,
149        config: AgentPoolStoreConfig,
150    ) -> Result<AgentPoolSnapshot, AgentError> {
151        let replay = self.replay(&pool_id)?;
152        if let Some(existing) = replay.config.as_ref() {
153            if existing != &config {
154                return Err(AgentError::new(
155                    AgentErrorKind::InvalidStateTransition,
156                    RetryClassification::RepairNeeded,
157                    "sqlite agent pool store config conflicts with existing pool",
158                ));
159            }
160        } else {
161            self.append_record(&pool_id, AgentPoolStoreRecordPayload::PoolOpened { config })?;
162        }
163        self.snapshot(&pool_id)
164    }
165
166    fn snapshot(&self, pool_id: &AgentPoolId) -> Result<AgentPoolSnapshot, AgentError> {
167        self.replay(pool_id)?.snapshot(pool_id.clone())
168    }
169
170    fn record_pool_created(
171        &self,
172        pool_id: &AgentPoolId,
173    ) -> Result<AgentPoolStoreCursor, AgentError> {
174        self.append_record(pool_id, AgentPoolStoreRecordPayload::PoolCreated)
175    }
176
177    fn join_member(
178        &self,
179        pool_id: &AgentPoolId,
180        member: AgentPoolMember,
181    ) -> Result<AgentPoolStoreCursor, AgentError> {
182        self.snapshot(pool_id)?;
183        self.append_record(
184            pool_id,
185            AgentPoolStoreRecordPayload::MemberJoined { member },
186        )
187    }
188
189    fn leave_member(
190        &self,
191        pool_id: &AgentPoolId,
192        run_id: &RunId,
193    ) -> Result<(AgentPoolMember, AgentPoolStoreCursor), AgentError> {
194        let replay = self.replay(pool_id)?;
195        let member = replay.members.get(run_id).cloned().ok_or_else(|| {
196            AgentError::new(
197                AgentErrorKind::InvalidStateTransition,
198                RetryClassification::NotRetryable,
199                "run is not a member of this agent pool",
200            )
201        })?;
202        let cursor = self.append_record(
203            pool_id,
204            AgentPoolStoreRecordPayload::MemberLeft {
205                member: member.clone(),
206            },
207        )?;
208        Ok((member, cursor))
209    }
210
211    fn message_receipt(
212        &self,
213        pool_id: &AgentPoolId,
214        idempotency_key: &IdempotencyKey,
215    ) -> Result<Option<MessageReceipt>, AgentError> {
216        Ok(self
217            .replay(pool_id)?
218            .message_dedupe
219            .get(idempotency_key)
220            .cloned())
221    }
222
223    fn record_message(
224        &self,
225        pool_id: &AgentPoolId,
226        message: RunMessage,
227        receipt: MessageReceipt,
228    ) -> Result<AgentPoolStoreCursor, AgentError> {
229        self.snapshot(pool_id)?;
230        self.append_record(
231            pool_id,
232            AgentPoolStoreRecordPayload::RunMessage {
233                stored: AgentPoolStoredMessage { message, receipt },
234            },
235        )
236    }
237
238    fn wake_registration(
239        &self,
240        pool_id: &AgentPoolId,
241        idempotency_key: &IdempotencyKey,
242    ) -> Result<Option<WakeRegistration>, AgentError> {
243        Ok(self
244            .replay(pool_id)?
245            .wake_dedupe
246            .get(idempotency_key)
247            .cloned())
248    }
249
250    fn wake(
251        &self,
252        pool_id: &AgentPoolId,
253        condition_id: &WakeConditionId,
254    ) -> Result<Option<AgentPoolStoredWake>, AgentError> {
255        Ok(self.replay(pool_id)?.wakes.get(condition_id).cloned())
256    }
257
258    fn record_wake(
259        &self,
260        pool_id: &AgentPoolId,
261        condition: WakeCondition,
262        compiled_filter: CompiledEventFilter,
263        registration: WakeRegistration,
264    ) -> Result<AgentPoolStoreCursor, AgentError> {
265        self.snapshot(pool_id)?;
266        self.append_record(
267            pool_id,
268            AgentPoolStoreRecordPayload::Wake {
269                stored: AgentPoolStoredWake {
270                    condition,
271                    compiled_filter,
272                    registration,
273                },
274            },
275        )
276    }
277
278    fn watch(
279        &self,
280        pool_id: &AgentPoolId,
281        cursor: Option<AgentPoolStoreCursor>,
282    ) -> Result<AgentPoolStoreStream, AgentError> {
283        Ok(AgentPoolStoreStream::new(VecDeque::from(
284            self.records_after(pool_id, cursor)?,
285        )))
286    }
287
288    fn next_event_sequence(&self, pool_id: &AgentPoolId) -> Result<u64, AgentError> {
289        let mut connection = self.connection()?;
290        let transaction = connection.transaction().map_err(sqlite_error)?;
291        transaction
292            .execute(
293                "INSERT OR IGNORE INTO agent_pool_event_seq (pool_id, seq) VALUES (?1, 0)",
294                params![pool_id.as_str()],
295            )
296            .map_err(sqlite_error)?;
297        transaction
298            .execute(
299                "UPDATE agent_pool_event_seq SET seq = seq + 1 WHERE pool_id = ?1",
300                params![pool_id.as_str()],
301            )
302            .map_err(sqlite_error)?;
303        let seq = transaction
304            .query_row(
305                "SELECT seq FROM agent_pool_event_seq WHERE pool_id = ?1",
306                params![pool_id.as_str()],
307                |row| row.get::<_, i64>(0),
308            )
309            .map_err(sqlite_error)?;
310        transaction.commit().map_err(sqlite_error)?;
311        Ok(seq as u64)
312    }
313}
314
315#[derive(Default)]
316struct PoolReplay {
317    config: Option<AgentPoolStoreConfig>,
318    created: bool,
319    members: BTreeMap<RunId, AgentPoolMember>,
320    messages: BTreeMap<MessageId, AgentPoolStoredMessage>,
321    message_dedupe: BTreeMap<IdempotencyKey, MessageReceipt>,
322    wakes: BTreeMap<WakeConditionId, AgentPoolStoredWake>,
323    wake_dedupe: BTreeMap<IdempotencyKey, WakeRegistration>,
324    cursor: Option<AgentPoolStoreCursor>,
325}
326
327impl PoolReplay {
328    fn apply(&mut self, payload: AgentPoolStoreRecordPayload) -> Result<(), AgentError> {
329        match payload {
330            AgentPoolStoreRecordPayload::PoolOpened { config } => {
331                if self
332                    .config
333                    .as_ref()
334                    .is_some_and(|existing| existing != &config)
335                {
336                    return Err(AgentError::new(
337                        AgentErrorKind::InvalidStateTransition,
338                        RetryClassification::RepairNeeded,
339                        "sqlite agent pool store contains conflicting pool open records",
340                    ));
341                }
342                self.config = Some(config);
343            }
344            AgentPoolStoreRecordPayload::PoolCreated => {
345                self.created = true;
346            }
347            AgentPoolStoreRecordPayload::MemberJoined { member } => {
348                self.members.insert(member.run_id.clone(), member);
349            }
350            AgentPoolStoreRecordPayload::MemberLeft { member } => {
351                self.members.remove(&member.run_id);
352            }
353            AgentPoolStoreRecordPayload::RunMessage { stored } => {
354                self.message_dedupe.insert(
355                    stored.message.idempotency_key.clone(),
356                    stored.receipt.clone(),
357                );
358                self.messages
359                    .insert(stored.message.message_id.clone(), stored);
360            }
361            AgentPoolStoreRecordPayload::Wake { stored } => {
362                self.wake_dedupe.insert(
363                    stored.condition.idempotency_key.clone(),
364                    stored.registration.clone(),
365                );
366                self.wakes
367                    .insert(stored.condition.condition_id.clone(), stored);
368            }
369        }
370        Ok(())
371    }
372
373    fn snapshot(self, pool_id: AgentPoolId) -> Result<AgentPoolSnapshot, AgentError> {
374        let config = self.config.ok_or_else(|| {
375            AgentError::new(
376                AgentErrorKind::HostConfigurationNeeded,
377                RetryClassification::HostConfigurationNeeded,
378                "sqlite agent pool store has not opened this pool",
379            )
380        })?;
381        let topics = topics_from_members(self.members.values());
382        Ok(AgentPoolSnapshot {
383            pool_id,
384            created: self.created,
385            members: self.members.into_values().collect(),
386            topics,
387            message_policy: config.message_policy,
388            wake_policy: config.wake_policy,
389            policy_refs: config.policy_refs,
390            messages: self.messages.into_values().collect(),
391            wakes: self.wakes.into_values().collect(),
392            cursor: self.cursor,
393        })
394    }
395}
396
397fn topics_from_members<'a>(members: impl IntoIterator<Item = &'a AgentPoolMember>) -> Vec<TopicId> {
398    let mut topics = BTreeSet::new();
399    for member in members {
400        topics.extend(member.topics.iter().cloned());
401    }
402    topics.into_iter().collect()
403}
404
405fn payload_kind(payload: &AgentPoolStoreRecordPayload) -> &'static str {
406    match payload {
407        AgentPoolStoreRecordPayload::PoolOpened { .. } => "pool_opened",
408        AgentPoolStoreRecordPayload::PoolCreated => "pool_created",
409        AgentPoolStoreRecordPayload::MemberJoined { .. } => "member_joined",
410        AgentPoolStoreRecordPayload::MemberLeft { .. } => "member_left",
411        AgentPoolStoreRecordPayload::RunMessage { .. } => "run_message",
412        AgentPoolStoreRecordPayload::Wake { .. } => "wake",
413    }
414}
415
416fn sqlite_error(error: rusqlite::Error) -> AgentError {
417    AgentError::new(
418        AgentErrorKind::InvalidStateTransition,
419        RetryClassification::RepairNeeded,
420        format!("sqlite agent pool store failure: {error}"),
421    )
422}
423
424fn serde_error(error: serde_json::Error) -> AgentError {
425    AgentError::new(
426        AgentErrorKind::InvalidStateTransition,
427        RetryClassification::RepairNeeded,
428        format!("sqlite agent pool store serialization failure: {error}"),
429    )
430}