1use std::collections::BTreeMap;
2
3use agent_sdk_core::{
4 AgentError, AgentPoolId, AgentPoolMember, AgentPoolSnapshot, AgentPoolStore,
5 AgentPoolStoreConfig, AgentPoolStoreCursor, AgentPoolStoreRecord, AgentPoolStoreRecordPayload,
6 AgentPoolStoreStream, AgentPoolStoredMessage, AgentPoolStoredWake, IdempotencyKey,
7 MessageReceipt, RunId, RunMessage, WakeCondition, WakeConditionId, WakeRegistration,
8 agent_pool::AgentPoolStoredWake as CoreStoredWake, event::CompiledEventFilter,
9};
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12
13use crate::{client::SupabaseClient, transport::supabase_error};
14
15#[derive(Clone)]
16pub struct SupabaseAgentPoolStore {
18 client: SupabaseClient,
19}
20
21#[derive(Clone, Debug, Deserialize, Serialize)]
22struct PoolState {
23 config: AgentPoolStoreConfig,
24 created: bool,
25 members: BTreeMap<RunId, AgentPoolMember>,
26 messages: BTreeMap<String, AgentPoolStoredMessage>,
27 message_dedupe: BTreeMap<IdempotencyKey, MessageReceipt>,
28 wakes: BTreeMap<WakeConditionId, AgentPoolStoredWake>,
29 wake_dedupe: BTreeMap<IdempotencyKey, WakeRegistration>,
30 records: Vec<AgentPoolStoreRecord>,
31 next_event_counter: u64,
32}
33
34impl PoolState {
35 fn new(config: AgentPoolStoreConfig) -> Self {
36 Self {
37 config,
38 created: false,
39 members: BTreeMap::new(),
40 messages: BTreeMap::new(),
41 message_dedupe: BTreeMap::new(),
42 wakes: BTreeMap::new(),
43 wake_dedupe: BTreeMap::new(),
44 records: Vec::new(),
45 next_event_counter: 0,
46 }
47 }
48
49 fn snapshot(&self, pool_id: AgentPoolId) -> AgentPoolSnapshot {
50 AgentPoolSnapshot {
51 pool_id,
52 created: self.created,
53 topics: self
54 .members
55 .values()
56 .flat_map(|member| member.topics.clone())
57 .collect(),
58 members: self.members.values().cloned().collect(),
59 message_policy: self.config.message_policy.clone(),
60 wake_policy: self.config.wake_policy.clone(),
61 policy_refs: self.config.policy_refs.clone(),
62 messages: self.messages.values().cloned().collect(),
63 wakes: self.wakes.values().cloned().collect(),
64 cursor: self.records.last().map(|record| record.cursor.clone()),
65 }
66 }
67}
68
69impl SupabaseAgentPoolStore {
70 pub fn new(client: SupabaseClient) -> Self {
71 Self { client }
72 }
73
74 fn load_state(&self, pool_id: &AgentPoolId) -> Result<Option<PoolState>, AgentError> {
75 let query = format!(
76 "store_scope=eq.{}&pool_id=eq.{}&select=state&limit=1",
77 self.client.config().store_scope(),
78 pool_id.as_str()
79 );
80 let response = self.client.select("agent_sdk_agent_pools", &query)?;
81 if !(200..300).contains(&response.status) {
82 return Err(supabase_error(format!(
83 "supabase agent pool read failed with status {}",
84 response.status
85 )));
86 }
87 let rows = serde_json::from_slice::<Vec<serde_json::Value>>(&response.body)
88 .map_err(|error| supabase_error(error.to_string()))?;
89 rows.into_iter()
90 .next()
91 .map(|row| {
92 serde_json::from_value(row["state"].clone()).map_err(|error| {
93 AgentError::contract_violation(format!(
94 "supabase agent pool state decode failed: {error}"
95 ))
96 })
97 })
98 .transpose()
99 }
100
101 fn save_state(&self, pool_id: &AgentPoolId, state: &PoolState) -> Result<(), AgentError> {
102 let response = self.client.rpc(
103 "agent_sdk_upsert_agent_pool_state",
104 &json!({
105 "p_store_scope": self.client.config().store_scope(),
106 "p_pool_id": pool_id.as_str(),
107 "p_state": state,
108 }),
109 )?;
110 if !(200..300).contains(&response.status) {
111 return Err(supabase_error(format!(
112 "supabase agent pool save failed with status {}",
113 response.status
114 )));
115 }
116 Ok(())
117 }
118
119 fn with_state<T>(
120 &self,
121 pool_id: &AgentPoolId,
122 f: impl FnOnce(&mut PoolState) -> Result<T, AgentError>,
123 ) -> Result<T, AgentError> {
124 let mut state = self
125 .load_state(pool_id)?
126 .ok_or_else(|| AgentError::host_configuration_needed("agent pool is not open"))?;
127 let output = f(&mut state)?;
128 self.save_state(pool_id, &state)?;
129 Ok(output)
130 }
131
132 fn append_record(
133 pool_id: &AgentPoolId,
134 state: &mut PoolState,
135 payload: AgentPoolStoreRecordPayload,
136 ) -> AgentPoolStoreCursor {
137 let cursor = AgentPoolStoreCursor::new(state.records.len() as u64 + 1);
138 state.records.push(AgentPoolStoreRecord {
139 pool_id: pool_id.clone(),
140 cursor: cursor.clone(),
141 payload,
142 });
143 cursor
144 }
145}
146
147impl AgentPoolStore for SupabaseAgentPoolStore {
148 fn open_pool(
149 &self,
150 pool_id: AgentPoolId,
151 config: AgentPoolStoreConfig,
152 ) -> Result<AgentPoolSnapshot, AgentError> {
153 let state = if let Some(existing) = self.load_state(&pool_id)? {
154 if existing.config != config {
155 return Err(AgentError::contract_violation(
156 "agent pool store config conflicts with existing pool",
157 ));
158 }
159 existing
160 } else {
161 let mut state = PoolState::new(config.clone());
162 Self::append_record(
163 &pool_id,
164 &mut state,
165 AgentPoolStoreRecordPayload::PoolOpened { config },
166 );
167 state
168 };
169 let snapshot = state.snapshot(pool_id.clone());
170 self.save_state(&pool_id, &state)?;
171 Ok(snapshot)
172 }
173
174 fn snapshot(&self, pool_id: &AgentPoolId) -> Result<AgentPoolSnapshot, AgentError> {
175 self.load_state(pool_id)?
176 .map(|state| state.snapshot(pool_id.clone()))
177 .ok_or_else(|| AgentError::host_configuration_needed("agent pool is not open"))
178 }
179
180 fn record_pool_created(
181 &self,
182 pool_id: &AgentPoolId,
183 ) -> Result<AgentPoolStoreCursor, AgentError> {
184 self.with_state(pool_id, |state| {
185 state.created = true;
186 Ok(Self::append_record(
187 pool_id,
188 state,
189 AgentPoolStoreRecordPayload::PoolCreated,
190 ))
191 })
192 }
193
194 fn join_member(
195 &self,
196 pool_id: &AgentPoolId,
197 member: AgentPoolMember,
198 ) -> Result<AgentPoolStoreCursor, AgentError> {
199 self.with_state(pool_id, |state| {
200 state.members.insert(member.run_id.clone(), member.clone());
201 Ok(Self::append_record(
202 pool_id,
203 state,
204 AgentPoolStoreRecordPayload::MemberJoined { member },
205 ))
206 })
207 }
208
209 fn leave_member(
210 &self,
211 pool_id: &AgentPoolId,
212 run_id: &RunId,
213 ) -> Result<(AgentPoolMember, AgentPoolStoreCursor), AgentError> {
214 self.with_state(pool_id, |state| {
215 let member = state.members.remove(run_id).ok_or_else(|| {
216 AgentError::contract_violation("run is not a member of this agent pool")
217 })?;
218 let cursor = Self::append_record(
219 pool_id,
220 state,
221 AgentPoolStoreRecordPayload::MemberLeft {
222 member: member.clone(),
223 },
224 );
225 Ok((member, cursor))
226 })
227 }
228
229 fn message_receipt(
230 &self,
231 pool_id: &AgentPoolId,
232 idempotency_key: &IdempotencyKey,
233 ) -> Result<Option<MessageReceipt>, AgentError> {
234 Ok(self
235 .load_state(pool_id)?
236 .and_then(|state| state.message_dedupe.get(idempotency_key).cloned()))
237 }
238
239 fn record_message(
240 &self,
241 pool_id: &AgentPoolId,
242 message: RunMessage,
243 receipt: MessageReceipt,
244 ) -> Result<AgentPoolStoreCursor, AgentError> {
245 self.with_state(pool_id, |state| {
246 let stored = AgentPoolStoredMessage {
247 message: message.clone(),
248 receipt: receipt.clone(),
249 };
250 state
251 .messages
252 .insert(message.message_id.as_str().to_string(), stored.clone());
253 state
254 .message_dedupe
255 .insert(message.idempotency_key.clone(), receipt);
256 Ok(Self::append_record(
257 pool_id,
258 state,
259 AgentPoolStoreRecordPayload::RunMessage { stored },
260 ))
261 })
262 }
263
264 fn wake_registration(
265 &self,
266 pool_id: &AgentPoolId,
267 idempotency_key: &IdempotencyKey,
268 ) -> Result<Option<WakeRegistration>, AgentError> {
269 Ok(self
270 .load_state(pool_id)?
271 .and_then(|state| state.wake_dedupe.get(idempotency_key).cloned()))
272 }
273
274 fn wake(
275 &self,
276 pool_id: &AgentPoolId,
277 condition_id: &WakeConditionId,
278 ) -> Result<Option<CoreStoredWake>, AgentError> {
279 Ok(self
280 .load_state(pool_id)?
281 .and_then(|state| state.wakes.get(condition_id).cloned()))
282 }
283
284 fn record_wake(
285 &self,
286 pool_id: &AgentPoolId,
287 condition: WakeCondition,
288 compiled_filter: CompiledEventFilter,
289 registration: WakeRegistration,
290 ) -> Result<AgentPoolStoreCursor, AgentError> {
291 self.with_state(pool_id, |state| {
292 let stored = AgentPoolStoredWake {
293 condition: condition.clone(),
294 compiled_filter,
295 registration: registration.clone(),
296 };
297 state
298 .wakes
299 .insert(condition.condition_id.clone(), stored.clone());
300 state
301 .wake_dedupe
302 .insert(condition.idempotency_key.clone(), registration);
303 Ok(Self::append_record(
304 pool_id,
305 state,
306 AgentPoolStoreRecordPayload::Wake { stored },
307 ))
308 })
309 }
310
311 fn watch(
312 &self,
313 pool_id: &AgentPoolId,
314 cursor: Option<AgentPoolStoreCursor>,
315 ) -> Result<AgentPoolStoreStream, AgentError> {
316 let after = cursor.map(|cursor| cursor.sequence).unwrap_or(0);
317 let records = self
318 .load_state(pool_id)?
319 .map(|state| {
320 state
321 .records
322 .into_iter()
323 .filter(|record| record.cursor.sequence > after)
324 .collect::<Vec<_>>()
325 })
326 .unwrap_or_default();
327 Ok(AgentPoolStoreStream::new(records))
328 }
329
330 fn next_event_sequence(&self, pool_id: &AgentPoolId) -> Result<u64, AgentError> {
331 let response = self.client.rpc(
332 "agent_sdk_next_agent_pool_event_sequence",
333 &json!({
334 "p_store_scope": self.client.config().store_scope(),
335 "p_pool_id": pool_id.as_str(),
336 }),
337 )?;
338 if !(200..300).contains(&response.status) {
339 return Err(supabase_error(format!(
340 "supabase agent pool sequence allocation failed with status {}",
341 response.status
342 )));
343 }
344 let value = serde_json::from_slice::<serde_json::Value>(&response.body)
345 .map_err(|error| supabase_error(error.to_string()))?;
346 parse_sequence(value)
347 .ok_or_else(|| supabase_error("supabase agent pool sequence response missing value"))
348 }
349}
350
351fn parse_sequence(value: serde_json::Value) -> Option<u64> {
352 value
353 .as_u64()
354 .or_else(|| value.as_array()?.first()?.as_u64())
355 .or_else(|| value.as_array()?.first()?.get("next_sequence")?.as_u64())
356 .or_else(|| value.get("next_sequence")?.as_u64())
357}