Skip to main content

lash_core/runtime/
turn_queue.rs

1use super::process::ProcessWakeDelivery;
2use crate::{PluginMessage, TurnCause, TurnInput};
3
4pub const QUEUED_WORK_CLAIM_TTL_MS: u64 = 15 * 60 * 1000;
5
6#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
7#[serde(tag = "kind", rename_all = "snake_case")]
8pub enum SessionCommand {
9    RefreshToolSurface {
10        reason: String,
11        #[serde(default, skip_serializing_if = "Option::is_none")]
12        expected_generation: Option<u64>,
13    },
14    EmitHostEvent {
15        resource_type: String,
16        alias: String,
17        event: String,
18        #[serde(default)]
19        payload: serde_json::Value,
20    },
21    ResetSession {
22        reason: String,
23    },
24}
25
26impl SessionCommand {
27    pub fn kind(&self) -> &'static str {
28        match self {
29            Self::RefreshToolSurface { .. } => "refresh_tool_surface",
30            Self::EmitHostEvent { .. } => "emit_host_event",
31            Self::ResetSession { .. } => "reset_session",
32        }
33    }
34
35    pub fn source_key(&self, idempotency_key: impl AsRef<str>) -> String {
36        format!("command:{}:{}", self.kind(), idempotency_key.as_ref())
37    }
38}
39
40#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
41pub struct SessionCommandReceipt {
42    pub session_id: String,
43    pub batch_id: String,
44    pub source_key: String,
45}
46
47#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
48#[serde(rename_all = "snake_case")]
49pub enum DeliveryPolicy {
50    EarliestSafeBoundary,
51    AfterCurrentTurnCommit,
52}
53
54impl DeliveryPolicy {
55    pub fn as_str(self) -> &'static str {
56        match self {
57            Self::EarliestSafeBoundary => "earliest_safe_boundary",
58            Self::AfterCurrentTurnCommit => "after_current_turn_commit",
59        }
60    }
61
62    pub fn from_wire_str(value: &str) -> Option<Self> {
63        match value {
64            "earliest_safe_boundary" => Some(Self::EarliestSafeBoundary),
65            "after_current_turn_commit" => Some(Self::AfterCurrentTurnCommit),
66            _ => None,
67        }
68    }
69}
70
71#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
72#[serde(rename_all = "snake_case")]
73pub enum SlotPolicy {
74    Join,
75    Exclusive,
76}
77
78impl SlotPolicy {
79    pub fn as_str(self) -> &'static str {
80        match self {
81            Self::Join => "join",
82            Self::Exclusive => "exclusive",
83        }
84    }
85
86    pub fn from_wire_str(value: &str) -> Option<Self> {
87        match value {
88            "join" => Some(Self::Join),
89            "exclusive" => Some(Self::Exclusive),
90            _ => None,
91        }
92    }
93}
94
95#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
96#[serde(rename_all = "snake_case")]
97pub enum MergeKey {
98    Never,
99    PayloadDefault,
100    Group(String),
101}
102
103#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
104#[serde(tag = "type", rename_all = "snake_case")]
105pub enum QueuedWorkPayload {
106    TurnInput {
107        input: Box<TurnInput>,
108    },
109    ProcessWake {
110        wake: Box<ProcessWakeDelivery>,
111    },
112    HostEvent {
113        name: String,
114        #[serde(default)]
115        payload: serde_json::Value,
116    },
117    Timer {
118        name: String,
119        #[serde(default)]
120        payload: serde_json::Value,
121    },
122    SessionCommand {
123        command: Box<SessionCommand>,
124    },
125}
126
127impl QueuedWorkPayload {
128    pub fn turn_input(input: TurnInput) -> Self {
129        Self::TurnInput {
130            input: Box::new(input),
131        }
132    }
133
134    pub fn process_wake(wake: ProcessWakeDelivery) -> Self {
135        Self::ProcessWake {
136            wake: Box::new(wake),
137        }
138    }
139
140    pub fn session_command(command: SessionCommand) -> Self {
141        Self::SessionCommand {
142            command: Box::new(command),
143        }
144    }
145}
146
147#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
148pub struct QueuedWorkItem {
149    pub item_id: String,
150    pub payload: QueuedWorkPayload,
151}
152
153#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
154pub struct QueuedWorkBatch {
155    pub batch_id: String,
156    pub session_id: String,
157    pub enqueue_seq: u64,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub source_key: Option<String>,
160    pub delivery_policy: DeliveryPolicy,
161    pub slot_policy: SlotPolicy,
162    pub merge_key: MergeKey,
163    pub available_at_ms: u64,
164    pub enqueued_at_ms: u64,
165    pub items: Vec<QueuedWorkItem>,
166}
167
168#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
169pub struct QueuedWorkBatchDraft {
170    pub session_id: String,
171    #[serde(default, skip_serializing_if = "Option::is_none")]
172    pub source_key: Option<String>,
173    pub delivery_policy: DeliveryPolicy,
174    pub slot_policy: SlotPolicy,
175    pub merge_key: MergeKey,
176    pub available_at_ms: u64,
177    pub payloads: Vec<QueuedWorkPayload>,
178}
179
180impl QueuedWorkBatchDraft {
181    pub fn new(
182        session_id: impl Into<String>,
183        delivery_policy: DeliveryPolicy,
184        slot_policy: SlotPolicy,
185        payloads: impl Into<Vec<QueuedWorkPayload>>,
186    ) -> Self {
187        Self {
188            session_id: session_id.into(),
189            source_key: None,
190            delivery_policy,
191            slot_policy,
192            merge_key: MergeKey::Never,
193            available_at_ms: 0,
194            payloads: payloads.into(),
195        }
196    }
197
198    pub fn with_source_key(mut self, source_key: impl Into<String>) -> Self {
199        self.source_key = Some(source_key.into());
200        self
201    }
202
203    pub fn with_available_at_ms(mut self, available_at_ms: u64) -> Self {
204        self.available_at_ms = available_at_ms;
205        self
206    }
207
208    pub fn with_merge_key(mut self, merge_key: MergeKey) -> Self {
209        self.merge_key = merge_key;
210        self
211    }
212}
213
214#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
215#[serde(rename_all = "snake_case")]
216pub enum QueuedWorkClaimBoundary {
217    ActiveTurnCheckpoint,
218    Idle,
219}
220
221#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
222pub struct QueuedWorkCompletion {
223    pub session_id: String,
224    pub claim_id: String,
225    pub lease_token: String,
226    pub batch_ids: Vec<String>,
227}
228
229#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
230pub struct QueuedWorkClaim {
231    pub session_id: String,
232    pub claim_id: String,
233    pub owner_id: String,
234    pub lease_token: String,
235    pub fencing_token: u64,
236    pub claimed_at_epoch_ms: u64,
237    pub expires_at_epoch_ms: u64,
238    pub batches: Vec<QueuedWorkBatch>,
239}
240
241impl QueuedWorkClaim {
242    pub fn completion(&self) -> QueuedWorkCompletion {
243        QueuedWorkCompletion {
244            session_id: self.session_id.clone(),
245            claim_id: self.claim_id.clone(),
246            lease_token: self.lease_token.clone(),
247            batch_ids: self
248                .batches
249                .iter()
250                .map(|batch| batch.batch_id.clone())
251                .collect(),
252        }
253    }
254
255    pub fn is_empty(&self) -> bool {
256        self.batches.iter().all(|batch| batch.items.is_empty())
257    }
258
259    pub fn materialize_for_checkpoint(&self) -> QueuedCheckpointWork {
260        let messages = Vec::new();
261        let mut transient_messages = Vec::new();
262        let mut turn_causes = Vec::new();
263        for batch in &self.batches {
264            for item in &batch.items {
265                match &item.payload {
266                    QueuedWorkPayload::TurnInput { input } => {
267                        if let Some(message) = plugin_message_from_turn_input(input) {
268                            transient_messages.push(message);
269                        }
270                    }
271                    QueuedWorkPayload::ProcessWake { wake } => {
272                        turn_causes.push(crate::process_wake_turn_cause(wake));
273                    }
274                    QueuedWorkPayload::HostEvent { name, payload }
275                    | QueuedWorkPayload::Timer { name, payload } => {
276                        turn_causes.push(host_event_cause(
277                            &item.item_id,
278                            name,
279                            payload,
280                            matches!(&item.payload, QueuedWorkPayload::Timer { .. }),
281                        ));
282                    }
283                    QueuedWorkPayload::SessionCommand { .. } => {}
284                }
285            }
286        }
287        QueuedCheckpointWork {
288            messages,
289            transient_messages,
290            turn_causes,
291        }
292    }
293
294    pub fn materialize_for_checkpoint_with_attachments(
295        &self,
296        attachment_store: &dyn crate::AttachmentStore,
297    ) -> Result<QueuedCheckpointWork, String> {
298        let messages = Vec::new();
299        let mut transient_messages = Vec::new();
300        let mut turn_causes = Vec::new();
301        for batch in &self.batches {
302            for item in &batch.items {
303                match &item.payload {
304                    QueuedWorkPayload::TurnInput { input } => {
305                        if let Some(message) = plugin_message_from_turn_input_with_attachments(
306                            input,
307                            attachment_store,
308                        )? {
309                            transient_messages.push(message);
310                        }
311                    }
312                    QueuedWorkPayload::ProcessWake { wake } => {
313                        turn_causes.push(crate::process_wake_turn_cause(wake));
314                    }
315                    QueuedWorkPayload::HostEvent { name, payload }
316                    | QueuedWorkPayload::Timer { name, payload } => {
317                        turn_causes.push(host_event_cause(
318                            &item.item_id,
319                            name,
320                            payload,
321                            matches!(&item.payload, QueuedWorkPayload::Timer { .. }),
322                        ));
323                    }
324                    QueuedWorkPayload::SessionCommand { .. } => {}
325                }
326            }
327        }
328        Ok(QueuedCheckpointWork {
329            messages,
330            transient_messages,
331            turn_causes,
332        })
333    }
334
335    pub fn accepted_turn_inputs(&self) -> Vec<crate::AcceptedInjectedTurnInput> {
336        let mut accepted = Vec::new();
337        for batch in &self.batches {
338            let id = batch.source_key.as_deref().map(|source| {
339                source
340                    .strip_prefix("host:")
341                    .or_else(|| source.strip_prefix("injection:"))
342                    .unwrap_or(source)
343                    .to_string()
344            });
345            for item in &batch.items {
346                if let QueuedWorkPayload::TurnInput { input } = &item.payload
347                    && let Some(message) = plugin_message_from_turn_input(input)
348                {
349                    accepted.push(crate::AcceptedInjectedTurnInput {
350                        id: id.clone(),
351                        message,
352                    });
353                }
354            }
355        }
356        accepted
357    }
358
359    pub fn exclusive_session_command(&self) -> Option<(&QueuedWorkBatch, &SessionCommand)> {
360        if self.batches.len() != 1 {
361            return None;
362        }
363        let batch = self.batches.first()?;
364        if batch.slot_policy != SlotPolicy::Exclusive || batch.items.len() != 1 {
365            return None;
366        }
367        let item = batch.items.first()?;
368        match &item.payload {
369            QueuedWorkPayload::SessionCommand { command } => Some((batch, command.as_ref())),
370            _ => None,
371        }
372    }
373
374    pub fn materialize_for_turn(&self) -> QueuedTurnWork {
375        let checkpoint = self.materialize_for_checkpoint();
376        let mut input_items = Vec::new();
377        let mut image_blobs = std::collections::HashMap::new();
378        let mut protocol_turn_options = None;
379        let mut trace_turn_id = None;
380        for batch in &self.batches {
381            for item in &batch.items {
382                if let QueuedWorkPayload::TurnInput { input } = &item.payload {
383                    input_items.extend(input.items.clone());
384                    image_blobs.extend(input.image_blobs.clone());
385                    if protocol_turn_options.is_none() {
386                        protocol_turn_options = input.protocol_turn_options.clone();
387                    }
388                    if trace_turn_id.is_none() {
389                        trace_turn_id = input.trace_turn_id.clone();
390                    }
391                }
392            }
393        }
394        QueuedTurnWork {
395            input: TurnInput {
396                items: input_items,
397                image_blobs,
398                protocol_turn_options,
399                trace_turn_id,
400                protocol_extension: None,
401                turn_context: crate::TurnContext::default(),
402            },
403            messages: checkpoint.messages,
404            turn_causes: checkpoint.turn_causes,
405        }
406    }
407}
408
409#[derive(Clone, Debug, Default)]
410pub struct QueuedCheckpointWork {
411    pub messages: Vec<PluginMessage>,
412    pub transient_messages: Vec<PluginMessage>,
413    pub turn_causes: Vec<TurnCause>,
414}
415
416#[derive(Clone, Debug)]
417pub struct QueuedTurnWork {
418    pub input: TurnInput,
419    pub messages: Vec<PluginMessage>,
420    pub turn_causes: Vec<TurnCause>,
421}
422
423pub fn process_wake_batch_draft(wake: ProcessWakeDelivery) -> QueuedWorkBatchDraft {
424    let source_key = format!("process:{}:event:{}:wake", wake.process_id, wake.sequence);
425    QueuedWorkBatchDraft::new(
426        wake.target_session_id.clone(),
427        DeliveryPolicy::EarliestSafeBoundary,
428        SlotPolicy::Exclusive,
429        vec![QueuedWorkPayload::process_wake(wake)],
430    )
431    .with_source_key(source_key)
432}
433
434fn plugin_message_from_turn_input(input: &TurnInput) -> Option<PluginMessage> {
435    let mut text = Vec::new();
436    let mut images = Vec::new();
437    for item in &input.items {
438        match item {
439            crate::InputItem::Text { text: item_text } if !item_text.is_empty() => {
440                text.push(item_text.clone());
441            }
442            crate::InputItem::Text { .. } => {}
443            crate::InputItem::ImageRef { id } => {
444                if let Some(bytes) = input.image_blobs.get(id).cloned() {
445                    images.push(bytes);
446                }
447            }
448        }
449    }
450    if text.is_empty() && images.is_empty() {
451        return None;
452    }
453    Some(PluginMessage {
454        role: crate::MessageRole::User,
455        content: text.join("\n"),
456        origin: None,
457        parts: Vec::new(),
458        images,
459    })
460}
461
462fn plugin_message_from_turn_input_with_attachments(
463    input: &TurnInput,
464    attachment_store: &dyn crate::AttachmentStore,
465) -> Result<Option<PluginMessage>, String> {
466    let normalized =
467        super::io::normalize_input_items(&input.items, &input.image_blobs, attachment_store)?;
468    let has_image = normalized
469        .iter()
470        .any(|item| matches!(item, super::NormalizedItem::Image(_)));
471    if !has_image {
472        return Ok(plugin_message_from_turn_input(input));
473    }
474
475    let mut content = Vec::new();
476    let mut parts = Vec::new();
477    for item in normalized {
478        match item {
479            super::NormalizedItem::Text(text) if !text.is_empty() => {
480                let part_id = format!("queued.p{}", parts.len());
481                content.push(text.clone());
482                parts.push(crate::Part {
483                    id: part_id,
484                    kind: crate::PartKind::Text,
485                    content: text,
486                    attachment: None,
487                    tool_call_id: None,
488                    tool_name: None,
489                    tool_replay: None,
490                    prune_state: crate::PruneState::Intact,
491                    reasoning_meta: None,
492                    response_meta: None,
493                });
494            }
495            super::NormalizedItem::Text(_) => {}
496            super::NormalizedItem::Image(reference) => {
497                let part_id = format!("queued.p{}", parts.len());
498                parts.push(crate::Part {
499                    id: part_id,
500                    kind: crate::PartKind::Image,
501                    content: String::new(),
502                    attachment: Some(crate::session_model::message::PartAttachment { reference }),
503                    tool_call_id: None,
504                    tool_name: None,
505                    tool_replay: None,
506                    prune_state: crate::PruneState::Intact,
507                    reasoning_meta: None,
508                    response_meta: None,
509                });
510            }
511        }
512    }
513    if parts.is_empty() {
514        return Ok(None);
515    }
516    Ok(Some(PluginMessage {
517        role: crate::MessageRole::User,
518        content: content.join("\n"),
519        origin: None,
520        parts,
521        images: Vec::new(),
522    }))
523}
524
525fn host_event_cause(
526    item_id: &str,
527    name: &str,
528    payload: &serde_json::Value,
529    timer: bool,
530) -> TurnCause {
531    let event_type = if timer { "timer" } else { "host_event" };
532    TurnCause {
533        id: item_id.to_string(),
534        event_type: name.to_string(),
535        origin: crate::MessageOrigin::Plugin {
536            plugin_id: event_type.to_string(),
537            transient: false,
538        },
539        text: if payload.is_null() {
540            name.to_string()
541        } else {
542            format!("{name}\n{payload}")
543        },
544    }
545}