use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::Result;
use crate::request::RequestId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct StepId(pub Uuid);
impl std::fmt::Display for StepId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0.to_string()[..8])
}
}
impl From<Uuid> for StepId {
fn from(uuid: Uuid) -> Self {
StepId(uuid)
}
}
impl std::ops::Deref for StepId {
type Target = Uuid;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StepKind {
ModelCall,
ToolCall,
}
impl StepKind {
pub fn as_str(&self) -> &'static str {
match self {
StepKind::ModelCall => "model_call",
StepKind::ToolCall => "tool_call",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"model_call" => Some(StepKind::ModelCall),
"tool_call" => Some(StepKind::ToolCall),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StepState {
Pending,
Processing,
Completed,
Failed,
Canceled,
}
impl StepState {
pub fn as_str(&self) -> &'static str {
match self {
StepState::Pending => "pending",
StepState::Processing => "processing",
StepState::Completed => "completed",
StepState::Failed => "failed",
StepState::Canceled => "canceled",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"pending" => Some(StepState::Pending),
"processing" => Some(StepState::Processing),
"completed" => Some(StepState::Completed),
"failed" => Some(StepState::Failed),
"canceled" => Some(StepState::Canceled),
_ => None,
}
}
pub fn is_terminal(&self) -> bool {
matches!(
self,
StepState::Completed | StepState::Failed | StepState::Canceled
)
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ResponseStep {
pub id: StepId,
pub request_id: Option<RequestId>,
pub prev_step_id: Option<StepId>,
pub parent_step_id: Option<StepId>,
pub step_kind: StepKind,
pub step_sequence: i64,
pub request_payload: serde_json::Value,
pub response_payload: Option<serde_json::Value>,
pub state: StepState,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub failed_at: Option<DateTime<Utc>>,
pub canceled_at: Option<DateTime<Utc>>,
pub retry_attempt: i32,
pub error: Option<serde_json::Value>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct CreateStepInput {
pub id: Option<Uuid>,
pub request_id: Option<RequestId>,
pub prev_step_id: Option<StepId>,
pub parent_step_id: Option<StepId>,
pub step_kind: StepKind,
pub step_sequence: i64,
pub request_payload: serde_json::Value,
}
#[async_trait]
pub trait ResponseStepStore: Send + Sync {
async fn create_step(&self, input: CreateStepInput) -> Result<StepId>;
async fn get_step(&self, id: StepId) -> Result<Option<ResponseStep>>;
async fn get_step_by_request(&self, request_id: RequestId) -> Result<Option<ResponseStep>>;
async fn list_chain(&self, head_step_id: StepId) -> Result<Vec<ResponseStep>>;
async fn mark_step_processing(&self, id: StepId) -> Result<()>;
async fn complete_step(&self, id: StepId, response: serde_json::Value) -> Result<()>;
async fn fail_step(&self, id: StepId, error: serde_json::Value) -> Result<()>;
async fn cancel_step(&self, id: StepId) -> Result<()>;
async fn requeue_step_for_retry(&self, id: StepId) -> Result<()>;
}