1use super::process::ProcessWakeDelivery;
2use crate::{PluginMessage, TurnCause, TurnInput};
3
4#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
5#[serde(tag = "kind", rename_all = "snake_case")]
6pub enum SessionCommand {
7 RefreshToolCatalog { reason: String },
12 ResetSession { reason: String },
13}
14
15impl SessionCommand {
16 pub fn kind(&self) -> &'static str {
17 match self {
18 Self::RefreshToolCatalog { .. } => "refresh_tool_catalog",
19 Self::ResetSession { .. } => "reset_session",
20 }
21 }
22
23 pub fn source_key(&self, idempotency_key: impl AsRef<str>) -> String {
24 format!("command:{}:{}", self.kind(), idempotency_key.as_ref())
25 }
26}
27
28#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
29pub struct SessionCommandReceipt {
30 pub session_id: String,
31 pub batch_id: String,
32 pub source_key: String,
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
36#[serde(rename_all = "snake_case")]
37pub enum DeliveryPolicy {
38 EarliestSafeBoundary,
39 AfterCurrentTurnCommit,
40}
41
42impl DeliveryPolicy {
43 pub fn as_str(self) -> &'static str {
44 match self {
45 Self::EarliestSafeBoundary => "earliest_safe_boundary",
46 Self::AfterCurrentTurnCommit => "after_current_turn_commit",
47 }
48 }
49
50 pub fn from_wire_str(value: &str) -> Option<Self> {
51 match value {
52 "earliest_safe_boundary" => Some(Self::EarliestSafeBoundary),
53 "after_current_turn_commit" => Some(Self::AfterCurrentTurnCommit),
54 _ => None,
55 }
56 }
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum SlotPolicy {
62 Join,
63 Exclusive,
64}
65
66impl SlotPolicy {
67 pub fn as_str(self) -> &'static str {
68 match self {
69 Self::Join => "join",
70 Self::Exclusive => "exclusive",
71 }
72 }
73
74 pub fn from_wire_str(value: &str) -> Option<Self> {
75 match value {
76 "join" => Some(Self::Join),
77 "exclusive" => Some(Self::Exclusive),
78 _ => None,
79 }
80 }
81}
82
83#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
84#[serde(rename_all = "snake_case")]
85pub enum MergeKey {
86 Never,
87 PayloadDefault,
88 Group(String),
89}
90
91#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
92#[serde(tag = "type", rename_all = "snake_case")]
93pub enum QueuedWorkPayload {
94 ProcessWake { wake: Box<ProcessWakeDelivery> },
95 SessionCommand { command: Box<SessionCommand> },
96}
97
98#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum QueuedWorkClass {
101 SessionCommand,
102 TurnWork,
103}
104
105impl QueuedWorkPayload {
106 pub fn process_wake(wake: ProcessWakeDelivery) -> Self {
107 Self::ProcessWake {
108 wake: Box::new(wake),
109 }
110 }
111
112 pub fn session_command(command: SessionCommand) -> Self {
113 Self::SessionCommand {
114 command: Box::new(command),
115 }
116 }
117
118 pub fn work_class(&self) -> QueuedWorkClass {
119 match self {
120 Self::SessionCommand { .. } => QueuedWorkClass::SessionCommand,
121 Self::ProcessWake { .. } => QueuedWorkClass::TurnWork,
122 }
123 }
124}
125
126#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
127pub struct QueuedWorkItem {
128 pub item_id: String,
129 pub payload: QueuedWorkPayload,
130}
131
132#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
133pub struct QueuedWorkBatch {
134 pub batch_id: String,
135 pub session_id: String,
136 pub enqueue_seq: u64,
137 #[serde(default, skip_serializing_if = "Option::is_none")]
138 pub source_key: Option<String>,
139 pub delivery_policy: DeliveryPolicy,
140 pub slot_policy: SlotPolicy,
141 pub merge_key: MergeKey,
142 pub available_at_ms: u64,
143 pub enqueued_at_ms: u64,
144 pub items: Vec<QueuedWorkItem>,
145}
146
147impl QueuedWorkBatch {
148 pub fn work_class(&self) -> Option<QueuedWorkClass> {
149 work_class_for_payloads(self.items.iter().map(|item| &item.payload))
150 }
151
152 pub fn is_session_command_work(&self) -> bool {
153 self.work_class() == Some(QueuedWorkClass::SessionCommand)
154 }
155
156 pub fn is_turn_work(&self) -> bool {
157 self.work_class() == Some(QueuedWorkClass::TurnWork)
158 }
159}
160
161#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
162pub struct QueuedWorkBatchDraft {
163 pub session_id: String,
164 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub source_key: Option<String>,
166 pub delivery_policy: DeliveryPolicy,
167 pub slot_policy: SlotPolicy,
168 pub merge_key: MergeKey,
169 pub available_at_ms: u64,
170 pub payloads: Vec<QueuedWorkPayload>,
171}
172
173impl QueuedWorkBatchDraft {
174 pub fn new(
175 session_id: impl Into<String>,
176 delivery_policy: DeliveryPolicy,
177 slot_policy: SlotPolicy,
178 payloads: impl Into<Vec<QueuedWorkPayload>>,
179 ) -> Self {
180 Self {
181 session_id: session_id.into(),
182 source_key: None,
183 delivery_policy,
184 slot_policy,
185 merge_key: MergeKey::Never,
186 available_at_ms: 0,
187 payloads: payloads.into(),
188 }
189 }
190
191 pub fn with_source_key(mut self, source_key: impl Into<String>) -> Self {
192 self.source_key = Some(source_key.into());
193 self
194 }
195
196 pub fn with_available_at_ms(mut self, available_at_ms: u64) -> Self {
197 self.available_at_ms = available_at_ms;
198 self
199 }
200
201 pub fn with_merge_key(mut self, merge_key: MergeKey) -> Self {
202 self.merge_key = merge_key;
203 self
204 }
205
206 pub fn work_class(&self) -> Option<QueuedWorkClass> {
207 work_class_for_payloads(self.payloads.iter())
208 }
209}
210
211fn work_class_for_payloads<'a>(
212 payloads: impl IntoIterator<Item = &'a QueuedWorkPayload>,
213) -> Option<QueuedWorkClass> {
214 let mut payloads = payloads.into_iter();
215 let first = payloads.next()?.work_class();
216 payloads
217 .all(|payload| payload.work_class() == first)
218 .then_some(first)
219}
220
221#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
222#[serde(rename_all = "snake_case")]
223pub enum QueuedWorkClaimBoundary {
224 ActiveTurnCheckpoint,
225 Idle,
226}
227
228#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
229pub struct QueuedWorkCompletion {
230 pub session_id: String,
231 pub claim_id: String,
232 pub lease_token: String,
233 pub batch_ids: Vec<String>,
234}
235
236#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
237pub struct QueuedWorkClaim {
238 pub session_id: String,
239 pub claim_id: String,
240 pub owner: crate::LeaseOwnerIdentity,
241 pub lease_token: String,
242 pub fencing_token: u64,
243 pub claimed_at_epoch_ms: u64,
244 pub expires_at_epoch_ms: u64,
245 pub batches: Vec<QueuedWorkBatch>,
246}
247
248impl QueuedWorkClaim {
249 pub fn completion(&self) -> QueuedWorkCompletion {
250 QueuedWorkCompletion {
251 session_id: self.session_id.clone(),
252 claim_id: self.claim_id.clone(),
253 lease_token: self.lease_token.clone(),
254 batch_ids: self
255 .batches
256 .iter()
257 .map(|batch| batch.batch_id.clone())
258 .collect(),
259 }
260 }
261
262 pub fn is_empty(&self) -> bool {
263 self.batches.iter().all(|batch| batch.items.is_empty())
264 }
265
266 pub fn materialize_for_checkpoint(&self) -> QueuedCheckpointWork {
267 let messages = Vec::new();
268 let transient_messages = Vec::new();
269 let mut turn_causes = Vec::new();
270 for batch in &self.batches {
271 for item in &batch.items {
272 match &item.payload {
273 QueuedWorkPayload::ProcessWake { wake } => {
274 turn_causes.push(crate::process_wake_turn_cause(wake));
275 }
276 QueuedWorkPayload::SessionCommand { .. } => {}
277 }
278 }
279 }
280 QueuedCheckpointWork {
281 messages,
282 transient_messages,
283 turn_causes,
284 }
285 }
286
287 pub async fn materialize_for_checkpoint_with_attachments(
288 &self,
289 _attachment_store: &dyn crate::AttachmentStore,
290 ) -> Result<QueuedCheckpointWork, String> {
291 let messages = Vec::new();
292 let transient_messages = Vec::new();
293 let mut turn_causes = Vec::new();
294 for batch in &self.batches {
295 for item in &batch.items {
296 match &item.payload {
297 QueuedWorkPayload::ProcessWake { wake } => {
298 turn_causes.push(crate::process_wake_turn_cause(wake));
299 }
300 QueuedWorkPayload::SessionCommand { .. } => {}
301 }
302 }
303 }
304 Ok(QueuedCheckpointWork {
305 messages,
306 transient_messages,
307 turn_causes,
308 })
309 }
310
311 pub fn exclusive_session_command(&self) -> Option<(&QueuedWorkBatch, &SessionCommand)> {
312 if self.batches.len() != 1 {
313 return None;
314 }
315 let batch = self.batches.first()?;
316 if batch.slot_policy != SlotPolicy::Exclusive || batch.items.len() != 1 {
317 return None;
318 }
319 let item = batch.items.first()?;
320 match &item.payload {
321 QueuedWorkPayload::SessionCommand { command } => Some((batch, command.as_ref())),
322 _ => None,
323 }
324 }
325
326 pub fn materialize_for_turn(&self) -> QueuedTurnWork {
327 let checkpoint = self.materialize_for_checkpoint();
328 QueuedTurnWork {
329 input: TurnInput::empty(),
330 messages: checkpoint.messages,
331 turn_causes: checkpoint.turn_causes,
332 }
333 }
334}
335
336#[derive(Clone, Debug, Default)]
337pub struct QueuedCheckpointWork {
338 pub messages: Vec<PluginMessage>,
339 pub transient_messages: Vec<PluginMessage>,
340 pub turn_causes: Vec<TurnCause>,
341}
342
343#[derive(Clone, Debug)]
344pub struct QueuedTurnWork {
345 pub input: TurnInput,
346 pub messages: Vec<PluginMessage>,
347 pub turn_causes: Vec<TurnCause>,
348}
349
350pub fn process_wake_batch_draft(wake: ProcessWakeDelivery) -> QueuedWorkBatchDraft {
351 let source_key = format!("process:{}:event:{}:wake", wake.process_id, wake.sequence);
352 QueuedWorkBatchDraft::new(
353 wake.target_session_id.clone(),
354 DeliveryPolicy::EarliestSafeBoundary,
355 SlotPolicy::Exclusive,
356 vec![QueuedWorkPayload::process_wake(wake)],
357 )
358 .with_source_key(source_key)
359}