use super::lifecycle::{RunStatus, TerminationReason};
use super::message::Message;
use super::suspension::{ToolCallResume, ToolCallResumeMode};
use super::tool::ToolDescriptor;
use crate::state::PersistedState;
use crate::thread::Thread;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
mod error;
pub mod message_append;
pub use error::StorageError;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum RunRequestOrigin {
#[default]
User,
Mcp,
A2A,
Internal,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RunRequestSnapshot {
#[serde(default = "default_run_origin")]
pub origin: RunRequestOrigin,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub sender_id: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub input_message_ids: Vec<String>,
#[serde(default, skip_serializing_if = "is_zero_u64")]
pub input_message_count: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_extras: Option<Value>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub decisions: Vec<RunResumeDecision>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub frontend_tools: Vec<ToolDescriptor>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parent_thread_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub transport_request_id: Option<String>,
}
fn default_run_origin() -> RunRequestOrigin {
RunRequestOrigin::User
}
fn is_zero_u64(value: &u64) -> bool {
*value == 0
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RunResumeDecision {
pub call_id: String,
pub resume: ToolCallResume,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct MessageSeqRange {
pub from_seq: u64,
pub to_seq: u64,
}
impl MessageSeqRange {
#[must_use]
pub fn new(from_seq: u64, to_seq: u64) -> Option<Self> {
(from_seq > 0 && from_seq <= to_seq).then_some(Self { from_seq, to_seq })
}
#[must_use]
pub fn len(self) -> u64 {
self.to_seq - self.from_seq + 1
}
#[must_use]
pub fn is_empty(self) -> bool {
self.from_seq > self.to_seq
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct RunMessageInput {
pub thread_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub range: Option<MessageSeqRange>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub trigger_message_ids: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub selected_message_ids: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub context_policy: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub compacted_snapshot_id: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct RunMessageOutput {
pub thread_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub range: Option<MessageSeqRange>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub message_ids: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WaitingReason {
ToolPermission,
UserInput,
BackgroundTasks,
ExternalEvent,
RateLimit,
ManualPause,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RunWaitingTicket {
pub ticket_id: String,
pub tool_call_id: String,
pub tool_name: String,
#[serde(default, skip_serializing_if = "Value::is_null")]
pub arguments: Value,
#[serde(default)]
pub resume_mode: ToolCallResumeMode,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(default)]
pub updated_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RunWaitingState {
pub reason: WaitingReason,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub ticket_ids: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tickets: Vec<RunWaitingTicket>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub since_dispatch_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RunOutcome {
pub termination_reason: TerminationReason,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub final_output: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error_payload: Option<Value>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RunRecord {
pub run_id: String,
pub thread_id: String,
pub agent_id: String,
pub parent_run_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub resolution_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub activation: Option<super::run::RunActivationSnapshot>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request: Option<RunRequestSnapshot>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input: Option<RunMessageInput>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output: Option<RunMessageOutput>,
pub status: RunStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub termination_reason: Option<TerminationReason>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub final_output: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error_payload: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dispatch_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub transport_request_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub waiting: Option<RunWaitingState>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub outcome: Option<RunOutcome>,
pub created_at: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub started_at: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub finished_at: Option<u64>,
pub updated_at: u64,
pub steps: usize,
pub input_tokens: u64,
pub output_tokens: u64,
pub state: Option<PersistedState>,
}
impl RunRecord {
pub fn validate_for_persist(&self) -> Result<(), StorageError> {
require_non_empty("run_id", &self.run_id)?;
require_non_empty("thread_id", &self.thread_id)?;
require_non_empty("agent_id", &self.agent_id)?;
if let Some(activation) = &self.activation {
activation.validate()?;
if activation.intent.thread_id != self.thread_id {
return Err(StorageError::Validation(format!(
"run activation thread_id '{}' must match run thread_id '{}'",
activation.intent.thread_id, self.thread_id
)));
}
}
if let Some(input) = &self.input {
validate_run_message_input(&self.thread_id, input)?;
}
if let Some(output) = &self.output {
validate_run_message_output(&self.thread_id, output)?;
}
match self.status {
RunStatus::Created | RunStatus::Running => {
if self.waiting.is_some() {
return Err(StorageError::Validation(format!(
"{:?} run '{}' must not carry waiting state",
self.status, self.run_id
)));
}
if self.outcome.is_some() {
return Err(StorageError::Validation(format!(
"{:?} run '{}' must not carry terminal outcome",
self.status, self.run_id
)));
}
if self.finished_at.is_some() {
return Err(StorageError::Validation(format!(
"{:?} run '{}' must not carry finished_at",
self.status, self.run_id
)));
}
}
RunStatus::Waiting => {
if self.waiting.is_none() {
return Err(StorageError::Validation(format!(
"waiting run '{}' must carry waiting state",
self.run_id
)));
}
if self.outcome.is_some() {
return Err(StorageError::Validation(format!(
"waiting run '{}' must not carry terminal outcome",
self.run_id
)));
}
if self.finished_at.is_some() {
return Err(StorageError::Validation(format!(
"waiting run '{}' must not carry finished_at",
self.run_id
)));
}
}
RunStatus::Done => {
if self.waiting.is_some() {
return Err(StorageError::Validation(format!(
"done run '{}' must not carry waiting state",
self.run_id
)));
}
if self.finished_at.is_none() {
return Err(StorageError::Validation(format!(
"done run '{}' must carry finished_at",
self.run_id
)));
}
if let Some(outcome) = &self.outcome {
if self
.termination_reason
.as_ref()
.is_some_and(|reason| reason != &outcome.termination_reason)
{
return Err(StorageError::Validation(format!(
"done run '{}' termination_reason must match outcome.termination_reason",
self.run_id
)));
}
if self
.final_output
.as_ref()
.is_some_and(|output| Some(output) != outcome.final_output.as_ref())
{
return Err(StorageError::Validation(format!(
"done run '{}' final_output must match outcome.final_output",
self.run_id
)));
}
if self
.error_payload
.as_ref()
.is_some_and(|payload| Some(payload) != outcome.error_payload.as_ref())
{
return Err(StorageError::Validation(format!(
"done run '{}' error_payload must match outcome.error_payload",
self.run_id
)));
}
}
}
}
Ok(())
}
#[must_use]
pub fn waiting_reason(&self) -> Option<WaitingReason> {
if self.status != RunStatus::Waiting {
return None;
}
self.waiting.as_ref().map(|waiting| waiting.reason)
}
#[must_use]
pub fn is_resumable_waiting(&self) -> bool {
self.waiting_reason().is_some()
}
#[must_use]
pub fn is_background_task_waiting(&self) -> bool {
self.waiting_reason() == Some(WaitingReason::BackgroundTasks)
}
}
fn require_non_empty(field: &str, value: &str) -> Result<(), StorageError> {
if value.trim().is_empty() {
return Err(StorageError::Validation(format!(
"{field} must not be empty"
)));
}
Ok(())
}
fn validate_seq_range(field: &str, range: MessageSeqRange) -> Result<(), StorageError> {
if range.from_seq == 0 || range.from_seq > range.to_seq {
return Err(StorageError::Validation(format!(
"{field} range must be non-empty and 1-based"
)));
}
Ok(())
}
fn validate_run_message_input(
run_thread_id: &str,
input: &RunMessageInput,
) -> Result<(), StorageError> {
if input.thread_id != run_thread_id {
return Err(StorageError::Validation(format!(
"run input thread_id '{}' must match run thread_id '{}'",
input.thread_id, run_thread_id
)));
}
if let Some(range) = input.range {
validate_seq_range("run input", range)?;
}
Ok(())
}
fn validate_run_message_output(
run_thread_id: &str,
output: &RunMessageOutput,
) -> Result<(), StorageError> {
if output.thread_id != run_thread_id {
return Err(StorageError::Validation(format!(
"run output thread_id '{}' must match run thread_id '{}'",
output.thread_id, run_thread_id
)));
}
if let Some(range) = output.range {
validate_seq_range("run output", range)?;
if range.len() as usize != output.message_ids.len() {
return Err(StorageError::Validation(format!(
"run output message_ids length {} must match range length {}",
output.message_ids.len(),
range.len()
)));
}
}
Ok(())
}
#[derive(Debug, Clone, Default)]
pub struct CheckpointSnapshot {
pub messages: Vec<Message>,
pub message_version: u64,
pub latest_run: Option<RunRecord>,
pub thread_state: Option<PersistedState>,
}
#[async_trait]
pub trait RuntimeCheckpointStore: Send + Sync {
async fn load_thread(&self, thread_id: &str) -> Result<Option<Thread>, StorageError>;
async fn load_messages(&self, thread_id: &str) -> Result<Option<Vec<Message>>, StorageError>;
async fn load_committed_messages(
&self,
thread_id: &str,
) -> Result<Option<Vec<Message>>, StorageError>;
async fn load_run(&self, run_id: &str) -> Result<Option<RunRecord>, StorageError>;
async fn latest_run(&self, thread_id: &str) -> Result<Option<RunRecord>, StorageError>;
async fn load_thread_state(
&self,
thread_id: &str,
) -> Result<Option<crate::state::PersistedState>, StorageError> {
let _ = thread_id;
Ok(None)
}
async fn load_checkpoint(
&self,
thread_id: &str,
) -> Result<Option<CheckpointSnapshot>, StorageError> {
let committed = self.load_committed_messages(thread_id).await?;
let latest_run = self.latest_run(thread_id).await?;
if committed.is_none() && latest_run.is_none() {
return Ok(None);
}
let raw = committed.unwrap_or_default();
let message_version = raw.len() as u64;
let messages = super::message::effective_committed_view(raw, thread_id);
let thread_state = self.load_thread_state(thread_id).await?;
Ok(Some(CheckpointSnapshot {
messages,
message_version,
latest_run,
thread_state,
}))
}
}
#[cfg(test)]
mod tests;