use super::process::ProcessWakeDelivery;
use crate::{PluginMessage, TurnCause, TurnInput};
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SessionCommand {
RefreshToolCatalog { reason: String },
ResetSession { reason: String },
}
impl SessionCommand {
pub fn kind(&self) -> &'static str {
match self {
Self::RefreshToolCatalog { .. } => "refresh_tool_catalog",
Self::ResetSession { .. } => "reset_session",
}
}
pub fn source_key(&self, idempotency_key: impl AsRef<str>) -> String {
format!("command:{}:{}", self.kind(), idempotency_key.as_ref())
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SessionCommandReceipt {
pub session_id: String,
pub batch_id: String,
pub source_key: String,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DeliveryPolicy {
EarliestSafeBoundary,
AfterCurrentTurnCommit,
}
impl DeliveryPolicy {
pub fn as_str(self) -> &'static str {
match self {
Self::EarliestSafeBoundary => "earliest_safe_boundary",
Self::AfterCurrentTurnCommit => "after_current_turn_commit",
}
}
pub fn from_wire_str(value: &str) -> Option<Self> {
match value {
"earliest_safe_boundary" => Some(Self::EarliestSafeBoundary),
"after_current_turn_commit" => Some(Self::AfterCurrentTurnCommit),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SlotPolicy {
Join,
Exclusive,
}
impl SlotPolicy {
pub fn as_str(self) -> &'static str {
match self {
Self::Join => "join",
Self::Exclusive => "exclusive",
}
}
pub fn from_wire_str(value: &str) -> Option<Self> {
match value {
"join" => Some(Self::Join),
"exclusive" => Some(Self::Exclusive),
_ => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MergeKey {
Never,
PayloadDefault,
Group(String),
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum QueuedWorkPayload {
ProcessWake { wake: Box<ProcessWakeDelivery> },
SessionCommand { command: Box<SessionCommand> },
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum QueuedWorkClass {
SessionCommand,
TurnWork,
}
impl QueuedWorkPayload {
pub fn process_wake(wake: ProcessWakeDelivery) -> Self {
Self::ProcessWake {
wake: Box::new(wake),
}
}
pub fn session_command(command: SessionCommand) -> Self {
Self::SessionCommand {
command: Box::new(command),
}
}
pub fn work_class(&self) -> QueuedWorkClass {
match self {
Self::SessionCommand { .. } => QueuedWorkClass::SessionCommand,
Self::ProcessWake { .. } => QueuedWorkClass::TurnWork,
}
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct QueuedWorkItem {
pub item_id: String,
pub payload: QueuedWorkPayload,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct QueuedWorkBatch {
pub batch_id: String,
pub session_id: String,
pub enqueue_seq: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source_key: Option<String>,
pub delivery_policy: DeliveryPolicy,
pub slot_policy: SlotPolicy,
pub merge_key: MergeKey,
pub available_at_ms: u64,
pub enqueued_at_ms: u64,
pub items: Vec<QueuedWorkItem>,
}
impl QueuedWorkBatch {
pub fn work_class(&self) -> Option<QueuedWorkClass> {
work_class_for_payloads(self.items.iter().map(|item| &item.payload))
}
pub fn is_session_command_work(&self) -> bool {
self.work_class() == Some(QueuedWorkClass::SessionCommand)
}
pub fn is_turn_work(&self) -> bool {
self.work_class() == Some(QueuedWorkClass::TurnWork)
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct QueuedWorkBatchDraft {
pub session_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub source_key: Option<String>,
pub delivery_policy: DeliveryPolicy,
pub slot_policy: SlotPolicy,
pub merge_key: MergeKey,
pub available_at_ms: u64,
pub payloads: Vec<QueuedWorkPayload>,
}
impl QueuedWorkBatchDraft {
pub fn new(
session_id: impl Into<String>,
delivery_policy: DeliveryPolicy,
slot_policy: SlotPolicy,
payloads: impl Into<Vec<QueuedWorkPayload>>,
) -> Self {
Self {
session_id: session_id.into(),
source_key: None,
delivery_policy,
slot_policy,
merge_key: MergeKey::Never,
available_at_ms: 0,
payloads: payloads.into(),
}
}
pub fn with_source_key(mut self, source_key: impl Into<String>) -> Self {
self.source_key = Some(source_key.into());
self
}
pub fn with_available_at_ms(mut self, available_at_ms: u64) -> Self {
self.available_at_ms = available_at_ms;
self
}
pub fn with_merge_key(mut self, merge_key: MergeKey) -> Self {
self.merge_key = merge_key;
self
}
pub fn work_class(&self) -> Option<QueuedWorkClass> {
work_class_for_payloads(self.payloads.iter())
}
}
fn work_class_for_payloads<'a>(
payloads: impl IntoIterator<Item = &'a QueuedWorkPayload>,
) -> Option<QueuedWorkClass> {
let mut payloads = payloads.into_iter();
let first = payloads.next()?.work_class();
payloads
.all(|payload| payload.work_class() == first)
.then_some(first)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum QueuedWorkClaimBoundary {
ActiveTurnCheckpoint,
Idle,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct QueuedWorkCompletion {
pub session_id: String,
pub claim_id: String,
pub lease_token: String,
pub batch_ids: Vec<String>,
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct QueuedWorkClaim {
pub session_id: String,
pub claim_id: String,
pub owner: crate::LeaseOwnerIdentity,
pub lease_token: String,
pub fencing_token: u64,
pub claimed_at_epoch_ms: u64,
pub expires_at_epoch_ms: u64,
pub batches: Vec<QueuedWorkBatch>,
}
impl QueuedWorkClaim {
pub fn completion(&self) -> QueuedWorkCompletion {
QueuedWorkCompletion {
session_id: self.session_id.clone(),
claim_id: self.claim_id.clone(),
lease_token: self.lease_token.clone(),
batch_ids: self
.batches
.iter()
.map(|batch| batch.batch_id.clone())
.collect(),
}
}
pub fn is_empty(&self) -> bool {
self.batches.iter().all(|batch| batch.items.is_empty())
}
pub fn materialize_for_checkpoint(&self) -> QueuedCheckpointWork {
let messages = Vec::new();
let transient_messages = Vec::new();
let mut turn_causes = Vec::new();
for batch in &self.batches {
for item in &batch.items {
match &item.payload {
QueuedWorkPayload::ProcessWake { wake } => {
turn_causes.push(crate::process_wake_turn_cause(wake));
}
QueuedWorkPayload::SessionCommand { .. } => {}
}
}
}
QueuedCheckpointWork {
messages,
transient_messages,
turn_causes,
}
}
pub async fn materialize_for_checkpoint_with_attachments(
&self,
_attachment_store: &dyn crate::AttachmentStore,
) -> Result<QueuedCheckpointWork, String> {
let messages = Vec::new();
let transient_messages = Vec::new();
let mut turn_causes = Vec::new();
for batch in &self.batches {
for item in &batch.items {
match &item.payload {
QueuedWorkPayload::ProcessWake { wake } => {
turn_causes.push(crate::process_wake_turn_cause(wake));
}
QueuedWorkPayload::SessionCommand { .. } => {}
}
}
}
Ok(QueuedCheckpointWork {
messages,
transient_messages,
turn_causes,
})
}
pub fn exclusive_session_command(&self) -> Option<(&QueuedWorkBatch, &SessionCommand)> {
if self.batches.len() != 1 {
return None;
}
let batch = self.batches.first()?;
if batch.slot_policy != SlotPolicy::Exclusive || batch.items.len() != 1 {
return None;
}
let item = batch.items.first()?;
match &item.payload {
QueuedWorkPayload::SessionCommand { command } => Some((batch, command.as_ref())),
_ => None,
}
}
pub fn materialize_for_turn(&self) -> QueuedTurnWork {
let checkpoint = self.materialize_for_checkpoint();
QueuedTurnWork {
input: TurnInput::empty(),
messages: checkpoint.messages,
turn_causes: checkpoint.turn_causes,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct QueuedCheckpointWork {
pub messages: Vec<PluginMessage>,
pub transient_messages: Vec<PluginMessage>,
pub turn_causes: Vec<TurnCause>,
}
#[derive(Clone, Debug)]
pub struct QueuedTurnWork {
pub input: TurnInput,
pub messages: Vec<PluginMessage>,
pub turn_causes: Vec<TurnCause>,
}
pub fn process_wake_batch_draft(wake: ProcessWakeDelivery) -> QueuedWorkBatchDraft {
let source_key = format!("process:{}:event:{}:wake", wake.process_id, wake.sequence);
QueuedWorkBatchDraft::new(
wake.target_session_id.clone(),
DeliveryPolicy::EarliestSafeBoundary,
SlotPolicy::Exclusive,
vec![QueuedWorkPayload::process_wake(wake)],
)
.with_source_key(source_key)
}