Skip to main content

agent_sdk_store_sqlite/
tool_execution.rs

1use std::path::{Path, PathBuf};
2
3use agent_sdk_core::{
4    AgentError, EffectId, IdempotencyKey, JournalCursor, RunId, ToolCallId, ToolExecutionStore,
5    ToolExecutionStoreCursor, ToolExecutionStoreRecord,
6};
7use rusqlite::{OptionalExtension, params};
8
9use crate::util::{decode, encode, open, sqlite_error};
10
11const SCHEMA: &str = "
12CREATE TABLE IF NOT EXISTS tool_execution_records (
13    run_id TEXT NOT NULL,
14    tool_call_id TEXT NOT NULL,
15    journal_seq INTEGER NOT NULL,
16    idempotency_key TEXT,
17    effect_id TEXT,
18    record_json TEXT NOT NULL,
19    PRIMARY KEY (run_id, tool_call_id, journal_seq)
20);
21CREATE INDEX IF NOT EXISTS idx_tool_execution_idempotency
22ON tool_execution_records(idempotency_key);
23CREATE INDEX IF NOT EXISTS idx_tool_execution_effect
24ON tool_execution_records(effect_id);
25CREATE INDEX IF NOT EXISTS idx_tool_execution_run_seq
26ON tool_execution_records(run_id, journal_seq);
27";
28
29#[derive(Clone, Debug)]
30/// SQLite-backed rebuildable tool-execution projection store.
31pub struct SqliteToolExecutionStore {
32    path: PathBuf,
33}
34
35impl SqliteToolExecutionStore {
36    /// Opens or creates a SQLite tool-execution projection store.
37    pub fn open(path: impl AsRef<Path>) -> Result<Self, AgentError> {
38        crate::util::init(path.as_ref(), SCHEMA)?;
39        Ok(Self {
40            path: path.as_ref().to_path_buf(),
41        })
42    }
43}
44
45impl ToolExecutionStore for SqliteToolExecutionStore {
46    fn put_tool_execution_record(
47        &self,
48        record: ToolExecutionStoreRecord,
49    ) -> Result<ToolExecutionStoreCursor, AgentError> {
50        let connection = open(&self.path)?;
51        connection
52            .execute(
53                "INSERT OR REPLACE INTO tool_execution_records
54                 (run_id, tool_call_id, journal_seq, idempotency_key, effect_id, record_json)
55                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
56                params![
57                    record.run_id.as_str(),
58                    record.tool_call_id.as_str(),
59                    record.journal_seq as i64,
60                    record.idempotency_key.as_ref().map(|key| key.as_str()),
61                    record
62                        .effect_id
63                        .as_ref()
64                        .map(|effect_id| effect_id.as_str()),
65                    encode(&record)?,
66                ],
67            )
68            .map_err(sqlite_error)?;
69        let sequence = connection
70            .query_row(
71                "SELECT COUNT(*) FROM tool_execution_records WHERE run_id = ?1",
72                params![record.run_id.as_str()],
73                |row| row.get::<_, i64>(0),
74            )
75            .map_err(sqlite_error)?;
76        Ok(ToolExecutionStoreCursor::new(sequence as u64))
77    }
78
79    fn records_for_run(&self, run_id: &RunId) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
80        let connection = open(&self.path)?;
81        let mut statement = connection
82            .prepare(
83                "SELECT record_json FROM tool_execution_records
84                 WHERE run_id = ?1 ORDER BY journal_seq ASC",
85            )
86            .map_err(sqlite_error)?;
87        let rows = statement
88            .query_map(params![run_id.as_str()], |row| row.get::<_, String>(0))
89            .map_err(sqlite_error)?;
90        let mut records = Vec::new();
91        for row in rows {
92            records.push(decode(&row.map_err(sqlite_error)?)?);
93        }
94        Ok(records)
95    }
96
97    fn record_for_tool_call(
98        &self,
99        run_id: &RunId,
100        tool_call_id: &ToolCallId,
101    ) -> Result<Option<ToolExecutionStoreRecord>, AgentError> {
102        let connection = open(&self.path)?;
103        let row = connection
104            .query_row(
105                "SELECT record_json FROM tool_execution_records
106                 WHERE run_id = ?1 AND tool_call_id = ?2
107                 ORDER BY journal_seq DESC LIMIT 1",
108                params![run_id.as_str(), tool_call_id.as_str()],
109                |row| row.get::<_, String>(0),
110            )
111            .optional()
112            .map_err(sqlite_error)?;
113        row.map(|json| decode(&json)).transpose()
114    }
115
116    fn records_for_idempotency_key(
117        &self,
118        idempotency_key: &IdempotencyKey,
119    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
120        let connection = open(&self.path)?;
121        let mut statement = connection
122            .prepare(
123                "SELECT record_json FROM tool_execution_records
124                 WHERE idempotency_key = ?1 ORDER BY run_id ASC, journal_seq ASC",
125            )
126            .map_err(sqlite_error)?;
127        let rows = statement
128            .query_map(params![idempotency_key.as_str()], |row| {
129                row.get::<_, String>(0)
130            })
131            .map_err(sqlite_error)?;
132        let mut records = Vec::new();
133        for row in rows {
134            records.push(decode(&row.map_err(sqlite_error)?)?);
135        }
136        Ok(records)
137    }
138
139    fn records_for_effect_id(
140        &self,
141        effect_id: &EffectId,
142    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
143        let connection = open(&self.path)?;
144        let mut statement = connection
145            .prepare(
146                "SELECT record_json FROM tool_execution_records
147                 WHERE effect_id = ?1 ORDER BY run_id ASC, journal_seq ASC",
148            )
149            .map_err(sqlite_error)?;
150        let rows = statement
151            .query_map(params![effect_id.as_str()], |row| row.get::<_, String>(0))
152            .map_err(sqlite_error)?;
153        let mut records = Vec::new();
154        for row in rows {
155            records.push(decode(&row.map_err(sqlite_error)?)?);
156        }
157        Ok(records)
158    }
159
160    fn records_after_journal_seq(
161        &self,
162        run_id: &RunId,
163        journal_seq: u64,
164    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
165        let connection = open(&self.path)?;
166        let mut statement = connection
167            .prepare(
168                "SELECT record_json FROM tool_execution_records
169                 WHERE run_id = ?1 AND journal_seq > ?2 ORDER BY journal_seq ASC",
170            )
171            .map_err(sqlite_error)?;
172        let rows = statement
173            .query_map(params![run_id.as_str(), journal_seq as i64], |row| {
174                row.get::<_, String>(0)
175            })
176            .map_err(sqlite_error)?;
177        let mut records = Vec::new();
178        for row in rows {
179            records.push(decode(&row.map_err(sqlite_error)?)?);
180        }
181        Ok(records)
182    }
183
184    fn records_in_journal_cursor_range(
185        &self,
186        run_id: &RunId,
187        after: Option<&JournalCursor>,
188        through: Option<&JournalCursor>,
189    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
190        let after_seq = after
191            .map(|cursor| {
192                ToolExecutionStoreRecord::journal_sequence_for_cursor(cursor).ok_or_else(|| {
193                    AgentError::contract_violation(
194                        "tool execution cursor range uses an unsupported journal cursor",
195                    )
196                })
197            })
198            .transpose()?;
199        let through_seq = through
200            .map(|cursor| {
201                ToolExecutionStoreRecord::journal_sequence_for_cursor(cursor).ok_or_else(|| {
202                    AgentError::contract_violation(
203                        "tool execution cursor range uses an unsupported journal cursor",
204                    )
205                })
206            })
207            .transpose()?;
208        let connection = open(&self.path)?;
209        let mut statement = connection
210            .prepare(
211                "SELECT record_json FROM tool_execution_records
212                 WHERE run_id = ?1
213                   AND (?2 IS NULL OR journal_seq > ?2)
214                   AND (?3 IS NULL OR journal_seq <= ?3)
215                 ORDER BY journal_seq ASC",
216            )
217            .map_err(sqlite_error)?;
218        let rows = statement
219            .query_map(
220                params![
221                    run_id.as_str(),
222                    after_seq.map(|seq| seq as i64),
223                    through_seq.map(|seq| seq as i64),
224                ],
225                |row| row.get::<_, String>(0),
226            )
227            .map_err(sqlite_error)?;
228        let mut records = Vec::new();
229        for row in rows {
230            records.push(decode(&row.map_err(sqlite_error)?)?);
231        }
232        Ok(records)
233    }
234}