1use super::process::ProcessWakeDelivery;
2use crate::{PluginMessage, TurnCause, TurnInput};
3
4pub const QUEUED_WORK_CLAIM_TTL_MS: u64 = 30 * 1000;
5
6#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
7#[serde(tag = "kind", rename_all = "snake_case")]
8pub enum SessionCommand {
9 RefreshToolCatalog { reason: String },
14 ResetSession { reason: String },
15}
16
17impl SessionCommand {
18 pub fn kind(&self) -> &'static str {
19 match self {
20 Self::RefreshToolCatalog { .. } => "refresh_tool_catalog",
21 Self::ResetSession { .. } => "reset_session",
22 }
23 }
24
25 pub fn source_key(&self, idempotency_key: impl AsRef<str>) -> String {
26 format!("command:{}:{}", self.kind(), idempotency_key.as_ref())
27 }
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub struct SessionCommandReceipt {
32 pub session_id: String,
33 pub batch_id: String,
34 pub source_key: String,
35}
36
37#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum DeliveryPolicy {
40 EarliestSafeBoundary,
41 AfterCurrentTurnCommit,
42}
43
44impl DeliveryPolicy {
45 pub fn as_str(self) -> &'static str {
46 match self {
47 Self::EarliestSafeBoundary => "earliest_safe_boundary",
48 Self::AfterCurrentTurnCommit => "after_current_turn_commit",
49 }
50 }
51
52 pub fn from_wire_str(value: &str) -> Option<Self> {
53 match value {
54 "earliest_safe_boundary" => Some(Self::EarliestSafeBoundary),
55 "after_current_turn_commit" => Some(Self::AfterCurrentTurnCommit),
56 _ => None,
57 }
58 }
59}
60
61#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
62#[serde(rename_all = "snake_case")]
63pub enum SlotPolicy {
64 Join,
65 Exclusive,
66}
67
68impl SlotPolicy {
69 pub fn as_str(self) -> &'static str {
70 match self {
71 Self::Join => "join",
72 Self::Exclusive => "exclusive",
73 }
74 }
75
76 pub fn from_wire_str(value: &str) -> Option<Self> {
77 match value {
78 "join" => Some(Self::Join),
79 "exclusive" => Some(Self::Exclusive),
80 _ => None,
81 }
82 }
83}
84
85#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
86#[serde(rename_all = "snake_case")]
87pub enum MergeKey {
88 Never,
89 PayloadDefault,
90 Group(String),
91}
92
93#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
94#[serde(tag = "type", rename_all = "snake_case")]
95pub enum QueuedWorkPayload {
96 ProcessWake { wake: Box<ProcessWakeDelivery> },
97 SessionCommand { command: Box<SessionCommand> },
98}
99
100#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
101#[serde(rename_all = "snake_case")]
102pub enum QueuedWorkClass {
103 SessionCommand,
104 TurnWork,
105}
106
107impl QueuedWorkPayload {
108 pub fn process_wake(wake: ProcessWakeDelivery) -> Self {
109 Self::ProcessWake {
110 wake: Box::new(wake),
111 }
112 }
113
114 pub fn session_command(command: SessionCommand) -> Self {
115 Self::SessionCommand {
116 command: Box::new(command),
117 }
118 }
119
120 pub fn work_class(&self) -> QueuedWorkClass {
121 match self {
122 Self::SessionCommand { .. } => QueuedWorkClass::SessionCommand,
123 Self::ProcessWake { .. } => QueuedWorkClass::TurnWork,
124 }
125 }
126}
127
128#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
129pub struct QueuedWorkItem {
130 pub item_id: String,
131 pub payload: QueuedWorkPayload,
132}
133
134#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
135pub struct QueuedWorkBatch {
136 pub batch_id: String,
137 pub session_id: String,
138 pub enqueue_seq: u64,
139 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub source_key: Option<String>,
141 pub delivery_policy: DeliveryPolicy,
142 pub slot_policy: SlotPolicy,
143 pub merge_key: MergeKey,
144 pub available_at_ms: u64,
145 pub enqueued_at_ms: u64,
146 pub items: Vec<QueuedWorkItem>,
147}
148
149impl QueuedWorkBatch {
150 pub fn work_class(&self) -> Option<QueuedWorkClass> {
151 work_class_for_payloads(self.items.iter().map(|item| &item.payload))
152 }
153
154 pub fn is_session_command_work(&self) -> bool {
155 self.work_class() == Some(QueuedWorkClass::SessionCommand)
156 }
157
158 pub fn is_turn_work(&self) -> bool {
159 self.work_class() == Some(QueuedWorkClass::TurnWork)
160 }
161}
162
163#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
164pub struct QueuedWorkBatchDraft {
165 pub session_id: String,
166 #[serde(default, skip_serializing_if = "Option::is_none")]
167 pub source_key: Option<String>,
168 pub delivery_policy: DeliveryPolicy,
169 pub slot_policy: SlotPolicy,
170 pub merge_key: MergeKey,
171 pub available_at_ms: u64,
172 pub payloads: Vec<QueuedWorkPayload>,
173}
174
175impl QueuedWorkBatchDraft {
176 pub fn new(
177 session_id: impl Into<String>,
178 delivery_policy: DeliveryPolicy,
179 slot_policy: SlotPolicy,
180 payloads: impl Into<Vec<QueuedWorkPayload>>,
181 ) -> Self {
182 Self {
183 session_id: session_id.into(),
184 source_key: None,
185 delivery_policy,
186 slot_policy,
187 merge_key: MergeKey::Never,
188 available_at_ms: 0,
189 payloads: payloads.into(),
190 }
191 }
192
193 pub fn with_source_key(mut self, source_key: impl Into<String>) -> Self {
194 self.source_key = Some(source_key.into());
195 self
196 }
197
198 pub fn with_available_at_ms(mut self, available_at_ms: u64) -> Self {
199 self.available_at_ms = available_at_ms;
200 self
201 }
202
203 pub fn with_merge_key(mut self, merge_key: MergeKey) -> Self {
204 self.merge_key = merge_key;
205 self
206 }
207
208 pub fn work_class(&self) -> Option<QueuedWorkClass> {
209 work_class_for_payloads(self.payloads.iter())
210 }
211}
212
213fn work_class_for_payloads<'a>(
214 payloads: impl IntoIterator<Item = &'a QueuedWorkPayload>,
215) -> Option<QueuedWorkClass> {
216 let mut payloads = payloads.into_iter();
217 let first = payloads.next()?.work_class();
218 payloads
219 .all(|payload| payload.work_class() == first)
220 .then_some(first)
221}
222
223#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
224#[serde(rename_all = "snake_case")]
225pub enum QueuedWorkClaimBoundary {
226 ActiveTurnCheckpoint,
227 Idle,
228}
229
230#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
231pub struct QueuedWorkCompletion {
232 pub session_id: String,
233 pub claim_id: String,
234 pub lease_token: String,
235 pub batch_ids: Vec<String>,
236}
237
238#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
239pub struct QueuedWorkClaim {
240 pub session_id: String,
241 pub claim_id: String,
242 pub owner: crate::LeaseOwnerIdentity,
243 pub lease_token: String,
244 pub fencing_token: u64,
245 pub claimed_at_epoch_ms: u64,
246 pub expires_at_epoch_ms: u64,
247 pub batches: Vec<QueuedWorkBatch>,
248}
249
250impl QueuedWorkClaim {
251 pub fn completion(&self) -> QueuedWorkCompletion {
252 QueuedWorkCompletion {
253 session_id: self.session_id.clone(),
254 claim_id: self.claim_id.clone(),
255 lease_token: self.lease_token.clone(),
256 batch_ids: self
257 .batches
258 .iter()
259 .map(|batch| batch.batch_id.clone())
260 .collect(),
261 }
262 }
263
264 pub fn is_empty(&self) -> bool {
265 self.batches.iter().all(|batch| batch.items.is_empty())
266 }
267
268 pub fn materialize_for_checkpoint(&self) -> QueuedCheckpointWork {
269 let messages = Vec::new();
270 let transient_messages = Vec::new();
271 let mut turn_causes = Vec::new();
272 for batch in &self.batches {
273 for item in &batch.items {
274 match &item.payload {
275 QueuedWorkPayload::ProcessWake { wake } => {
276 turn_causes.push(crate::process_wake_turn_cause(wake));
277 }
278 QueuedWorkPayload::SessionCommand { .. } => {}
279 }
280 }
281 }
282 QueuedCheckpointWork {
283 messages,
284 transient_messages,
285 turn_causes,
286 }
287 }
288
289 pub async fn materialize_for_checkpoint_with_attachments(
290 &self,
291 _attachment_store: &dyn crate::AttachmentStore,
292 ) -> Result<QueuedCheckpointWork, String> {
293 let messages = Vec::new();
294 let transient_messages = Vec::new();
295 let mut turn_causes = Vec::new();
296 for batch in &self.batches {
297 for item in &batch.items {
298 match &item.payload {
299 QueuedWorkPayload::ProcessWake { wake } => {
300 turn_causes.push(crate::process_wake_turn_cause(wake));
301 }
302 QueuedWorkPayload::SessionCommand { .. } => {}
303 }
304 }
305 }
306 Ok(QueuedCheckpointWork {
307 messages,
308 transient_messages,
309 turn_causes,
310 })
311 }
312
313 pub fn exclusive_session_command(&self) -> Option<(&QueuedWorkBatch, &SessionCommand)> {
314 if self.batches.len() != 1 {
315 return None;
316 }
317 let batch = self.batches.first()?;
318 if batch.slot_policy != SlotPolicy::Exclusive || batch.items.len() != 1 {
319 return None;
320 }
321 let item = batch.items.first()?;
322 match &item.payload {
323 QueuedWorkPayload::SessionCommand { command } => Some((batch, command.as_ref())),
324 _ => None,
325 }
326 }
327
328 pub fn materialize_for_turn(&self) -> QueuedTurnWork {
329 let checkpoint = self.materialize_for_checkpoint();
330 QueuedTurnWork {
331 input: TurnInput::empty(),
332 messages: checkpoint.messages,
333 turn_causes: checkpoint.turn_causes,
334 }
335 }
336}
337
338#[derive(Clone, Debug, Default)]
339pub struct QueuedCheckpointWork {
340 pub messages: Vec<PluginMessage>,
341 pub transient_messages: Vec<PluginMessage>,
342 pub turn_causes: Vec<TurnCause>,
343}
344
345#[derive(Clone, Debug)]
346pub struct QueuedTurnWork {
347 pub input: TurnInput,
348 pub messages: Vec<PluginMessage>,
349 pub turn_causes: Vec<TurnCause>,
350}
351
352pub fn process_wake_batch_draft(wake: ProcessWakeDelivery) -> QueuedWorkBatchDraft {
353 let source_key = format!("process:{}:event:{}:wake", wake.process_id, wake.sequence);
354 QueuedWorkBatchDraft::new(
355 wake.target_session_id.clone(),
356 DeliveryPolicy::EarliestSafeBoundary,
357 SlotPolicy::Exclusive,
358 vec![QueuedWorkPayload::process_wake(wake)],
359 )
360 .with_source_key(source_key)
361}