Skip to main content

agent_sdk_store_postgres/
tool_execution.rs

1use agent_sdk_core::{
2    AgentError, EffectId, IdempotencyKey, JournalCursor, RunId, ToolCallId, ToolExecutionStore,
3    ToolExecutionStoreCursor, ToolExecutionStoreRecord,
4};
5use serde_json::Value;
6
7use crate::{
8    PostgresStoreClient,
9    util::{decode_row, json_value},
10};
11
12#[derive(Clone)]
13pub struct PostgresToolExecutionStore {
14    client: PostgresStoreClient,
15}
16
17impl PostgresToolExecutionStore {
18    pub fn new(client: PostgresStoreClient) -> Self {
19        Self { client }
20    }
21}
22
23impl ToolExecutionStore for PostgresToolExecutionStore {
24    fn put_tool_execution_record(
25        &self,
26        record: ToolExecutionStoreRecord,
27    ) -> Result<ToolExecutionStoreCursor, AgentError> {
28        self.client.execute(
29            format!("insert into {} (store_scope, run_id, tool_call_id, journal_seq, idempotency_key, effect_id, record_json) values ($1, $2, $3, $4, $5, $6, $7) on conflict (store_scope, run_id, tool_call_id, journal_seq) do update set idempotency_key = excluded.idempotency_key, effect_id = excluded.effect_id, record_json = excluded.record_json", self.client.table("agent_sdk_tool_execution")),
30            vec![
31                self.client.scope(),
32                Value::String(record.run_id.as_str().to_string()),
33                Value::String(record.tool_call_id.as_str().to_string()),
34                Value::from(record.journal_seq),
35                record.idempotency_key.as_ref().map(|key| Value::String(key.as_str().to_string())).unwrap_or(Value::Null),
36                record.effect_id.as_ref().map(|effect_id| Value::String(effect_id.as_str().to_string())).unwrap_or(Value::Null),
37                json_value(&record)?,
38            ],
39        )?;
40        Ok(ToolExecutionStoreCursor::new(record.journal_seq))
41    }
42
43    fn records_for_run(&self, run_id: &RunId) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
44        let response = self.client.execute(
45            format!("select record_json from {} where store_scope = $1 and run_id = $2 order by journal_seq asc", self.client.table("agent_sdk_tool_execution")),
46            vec![self.client.scope(), Value::String(run_id.as_str().to_string())],
47        )?;
48        response
49            .rows
50            .into_iter()
51            .map(|row| decode_row(row, "record_json"))
52            .collect()
53    }
54
55    fn records_for_effect_id(
56        &self,
57        effect_id: &EffectId,
58    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
59        let response = self.client.execute(
60            format!("select record_json from {} where store_scope = $1 and effect_id = $2 order by run_id asc, journal_seq asc", self.client.table("agent_sdk_tool_execution")),
61            vec![self.client.scope(), Value::String(effect_id.as_str().to_string())],
62        )?;
63        response
64            .rows
65            .into_iter()
66            .map(|row| decode_row(row, "record_json"))
67            .collect()
68    }
69
70    fn record_for_tool_call(
71        &self,
72        run_id: &RunId,
73        tool_call_id: &ToolCallId,
74    ) -> Result<Option<ToolExecutionStoreRecord>, AgentError> {
75        let response = self.client.execute(
76            format!("select record_json from {} where store_scope = $1 and run_id = $2 and tool_call_id = $3 order by journal_seq desc limit 1", self.client.table("agent_sdk_tool_execution")),
77            vec![self.client.scope(), Value::String(run_id.as_str().to_string()), Value::String(tool_call_id.as_str().to_string())],
78        )?;
79        response
80            .rows
81            .into_iter()
82            .next()
83            .map(|row| decode_row(row, "record_json"))
84            .transpose()
85    }
86
87    fn records_for_idempotency_key(
88        &self,
89        idempotency_key: &IdempotencyKey,
90    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
91        let response = self.client.execute(
92            format!("select record_json from {} where store_scope = $1 and idempotency_key = $2 order by run_id asc, journal_seq asc", self.client.table("agent_sdk_tool_execution")),
93            vec![self.client.scope(), Value::String(idempotency_key.as_str().to_string())],
94        )?;
95        response
96            .rows
97            .into_iter()
98            .map(|row| decode_row(row, "record_json"))
99            .collect()
100    }
101
102    fn records_after_journal_seq(
103        &self,
104        run_id: &RunId,
105        journal_seq: u64,
106    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
107        let response = self.client.execute(
108            format!("select record_json from {} where store_scope = $1 and run_id = $2 and journal_seq > $3 order by journal_seq asc", self.client.table("agent_sdk_tool_execution")),
109            vec![self.client.scope(), Value::String(run_id.as_str().to_string()), Value::from(journal_seq)],
110        )?;
111        response
112            .rows
113            .into_iter()
114            .map(|row| decode_row(row, "record_json"))
115            .collect()
116    }
117
118    fn records_in_journal_cursor_range(
119        &self,
120        run_id: &RunId,
121        after: Option<&JournalCursor>,
122        through: Option<&JournalCursor>,
123    ) -> Result<Vec<ToolExecutionStoreRecord>, AgentError> {
124        let after_seq = after
125            .map(|cursor| {
126                ToolExecutionStoreRecord::journal_sequence_for_cursor(cursor).ok_or_else(|| {
127                    AgentError::contract_violation(
128                        "tool execution cursor range uses an unsupported journal cursor",
129                    )
130                })
131            })
132            .transpose()?;
133        let through_seq = through
134            .map(|cursor| {
135                ToolExecutionStoreRecord::journal_sequence_for_cursor(cursor).ok_or_else(|| {
136                    AgentError::contract_violation(
137                        "tool execution cursor range uses an unsupported journal cursor",
138                    )
139                })
140            })
141            .transpose()?;
142        let response = self.client.execute(
143            format!("select record_json from {} where store_scope = $1 and run_id = $2 and ($3 is null or journal_seq > $3) and ($4 is null or journal_seq <= $4) order by journal_seq asc", self.client.table("agent_sdk_tool_execution")),
144            vec![
145                self.client.scope(),
146                Value::String(run_id.as_str().to_string()),
147                after_seq.map(Value::from).unwrap_or(Value::Null),
148                through_seq.map(Value::from).unwrap_or(Value::Null),
149            ],
150        )?;
151        response
152            .rows
153            .into_iter()
154            .map(|row| decode_row(row, "record_json"))
155            .collect()
156    }
157}