use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum AgUiEvent {
#[serde(rename = "RUN_STARTED")]
RunStarted {
run_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
thread_id: Option<String>,
timestamp: u64,
},
#[serde(rename = "STEP_STARTED")]
StepStarted {
run_id: String,
step_name: String,
timestamp: u64,
},
#[serde(rename = "STEP_FINISHED")]
StepFinished {
run_id: String,
step_name: String,
timestamp: u64,
},
#[serde(rename = "RUN_FINISHED")]
RunFinished {
run_id: String,
timestamp: u64,
},
#[serde(rename = "RUN_ERROR")]
RunError {
run_id: String,
code: String,
message: String,
timestamp: u64,
},
#[serde(rename = "TEXT_MESSAGE_START")]
TextMessageStart {
run_id: String,
message_id: String,
role: String,
timestamp: u64,
},
#[serde(rename = "TEXT_MESSAGE_CONTENT")]
TextMessageContent {
run_id: String,
message_id: String,
delta: String,
timestamp: u64,
},
#[serde(rename = "TEXT_MESSAGE_END")]
TextMessageEnd {
run_id: String,
message_id: String,
timestamp: u64,
},
#[serde(rename = "TOOL_CALL_START")]
ToolCallStart {
run_id: String,
tool_call_id: String,
tool_name: String,
timestamp: u64,
},
#[serde(rename = "TOOL_CALL_ARGS")]
ToolCallArgs {
run_id: String,
tool_call_id: String,
delta: String,
timestamp: u64,
},
#[serde(rename = "TOOL_CALL_END")]
ToolCallEnd {
run_id: String,
tool_call_id: String,
timestamp: u64,
},
#[serde(rename = "TOOL_CALL_RESULT")]
ToolCallResult {
run_id: String,
tool_call_id: String,
result: Value,
timestamp: u64,
},
#[serde(rename = "STATE_SNAPSHOT")]
StateSnapshot {
run_id: String,
snapshot: Value,
timestamp: u64,
},
#[serde(rename = "STATE_DELTA")]
StateDelta {
run_id: String,
delta: Value,
timestamp: u64,
},
#[serde(rename = "MESSAGES_SNAPSHOT")]
MessagesSnapshot {
run_id: String,
messages: Value,
timestamp: u64,
},
#[serde(rename = "RAW")]
Raw {
run_id: String,
source: String,
payload: Value,
timestamp: u64,
},
#[serde(rename = "CUSTOM")]
Custom {
run_id: String,
name: String,
payload: Value,
timestamp: u64,
},
#[serde(other)]
Unknown,
}
impl AgUiEvent {
#[must_use]
pub const fn kind(&self) -> AgUiEventKind {
match self {
Self::RunStarted { .. } => AgUiEventKind::RunStarted,
Self::StepStarted { .. } => AgUiEventKind::StepStarted,
Self::StepFinished { .. } => AgUiEventKind::StepFinished,
Self::RunFinished { .. } => AgUiEventKind::RunFinished,
Self::RunError { .. } => AgUiEventKind::RunError,
Self::TextMessageStart { .. } => AgUiEventKind::TextMessageStart,
Self::TextMessageContent { .. } => AgUiEventKind::TextMessageContent,
Self::TextMessageEnd { .. } => AgUiEventKind::TextMessageEnd,
Self::ToolCallStart { .. } => AgUiEventKind::ToolCallStart,
Self::ToolCallArgs { .. } => AgUiEventKind::ToolCallArgs,
Self::ToolCallEnd { .. } => AgUiEventKind::ToolCallEnd,
Self::ToolCallResult { .. } => AgUiEventKind::ToolCallResult,
Self::StateSnapshot { .. } => AgUiEventKind::StateSnapshot,
Self::StateDelta { .. } => AgUiEventKind::StateDelta,
Self::MessagesSnapshot { .. } => AgUiEventKind::MessagesSnapshot,
Self::Raw { .. } => AgUiEventKind::Raw,
Self::Custom { .. } => AgUiEventKind::Custom,
Self::Unknown => AgUiEventKind::Unknown,
}
}
#[must_use]
pub fn run_id(&self) -> &str {
match self {
Self::RunStarted { run_id, .. }
| Self::StepStarted { run_id, .. }
| Self::StepFinished { run_id, .. }
| Self::RunFinished { run_id, .. }
| Self::RunError { run_id, .. }
| Self::TextMessageStart { run_id, .. }
| Self::TextMessageContent { run_id, .. }
| Self::TextMessageEnd { run_id, .. }
| Self::ToolCallStart { run_id, .. }
| Self::ToolCallArgs { run_id, .. }
| Self::ToolCallEnd { run_id, .. }
| Self::ToolCallResult { run_id, .. }
| Self::StateSnapshot { run_id, .. }
| Self::StateDelta { run_id, .. }
| Self::MessagesSnapshot { run_id, .. }
| Self::Raw { run_id, .. }
| Self::Custom { run_id, .. } => run_id,
Self::Unknown => "",
}
}
#[must_use]
pub fn run_started(run_id: impl Into<String>, thread_id: Option<&str>) -> Self {
Self::RunStarted {
run_id: run_id.into(),
thread_id: thread_id.map(String::from),
timestamp: now_ms(),
}
}
#[must_use]
pub fn step_started(run_id: impl Into<String>, step_name: impl Into<String>) -> Self {
Self::StepStarted {
run_id: run_id.into(),
step_name: step_name.into(),
timestamp: now_ms(),
}
}
#[must_use]
pub fn step_finished(run_id: impl Into<String>, step_name: impl Into<String>) -> Self {
Self::StepFinished {
run_id: run_id.into(),
step_name: step_name.into(),
timestamp: now_ms(),
}
}
#[must_use]
pub fn run_finished(run_id: impl Into<String>) -> Self {
Self::RunFinished {
run_id: run_id.into(),
timestamp: now_ms(),
}
}
#[must_use]
pub fn run_error(
run_id: impl Into<String>,
code: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self::RunError {
run_id: run_id.into(),
code: code.into(),
message: message.into(),
timestamp: now_ms(),
}
}
#[must_use]
pub fn tool_call_start(
run_id: impl Into<String>,
tool_call_id: impl Into<String>,
tool_name: impl Into<String>,
) -> Self {
Self::ToolCallStart {
run_id: run_id.into(),
tool_call_id: tool_call_id.into(),
tool_name: tool_name.into(),
timestamp: now_ms(),
}
}
#[must_use]
pub fn tool_call_result(
run_id: impl Into<String>,
tool_call_id: impl Into<String>,
result: Value,
) -> Self {
Self::ToolCallResult {
run_id: run_id.into(),
tool_call_id: tool_call_id.into(),
result,
timestamp: now_ms(),
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum AgUiEventKind {
RunStarted,
StepStarted,
StepFinished,
RunFinished,
RunError,
TextMessageStart,
TextMessageContent,
TextMessageEnd,
ToolCallStart,
ToolCallArgs,
ToolCallEnd,
ToolCallResult,
StateSnapshot,
StateDelta,
MessagesSnapshot,
Raw,
Custom,
Unknown,
}
impl AgUiEventKind {
pub const ALL: [Self; 18] = [
Self::RunStarted,
Self::StepStarted,
Self::StepFinished,
Self::RunFinished,
Self::RunError,
Self::TextMessageStart,
Self::TextMessageContent,
Self::TextMessageEnd,
Self::ToolCallStart,
Self::ToolCallArgs,
Self::ToolCallEnd,
Self::ToolCallResult,
Self::StateSnapshot,
Self::StateDelta,
Self::MessagesSnapshot,
Self::Raw,
Self::Custom,
Self::Unknown,
];
}
#[derive(Debug, Clone)]
pub struct AgUiEventFilter {
allowed: HashSet<AgUiEventKind>,
}
impl AgUiEventFilter {
#[must_use]
pub fn allow_all() -> Self {
Self {
allowed: AgUiEventKind::ALL.iter().copied().collect(),
}
}
#[must_use]
pub fn deny_all() -> Self {
Self {
allowed: HashSet::new(),
}
}
#[must_use]
pub fn only<I: IntoIterator<Item = AgUiEventKind>>(kinds: I) -> Self {
Self {
allowed: kinds.into_iter().collect(),
}
}
#[must_use]
pub fn with(mut self, kind: AgUiEventKind) -> Self {
self.allowed.insert(kind);
self
}
#[must_use]
pub fn without(mut self, kind: AgUiEventKind) -> Self {
self.allowed.remove(&kind);
self
}
#[must_use]
pub fn allows(&self, kind: AgUiEventKind) -> bool {
self.allowed.contains(&kind)
}
}
impl Default for AgUiEventFilter {
fn default() -> Self {
Self::allow_all()
}
}
#[async_trait]
pub trait AgUiEmitter: Send + Sync {
fn filter(&self) -> &AgUiEventFilter;
async fn emit(&self, event: &AgUiEvent);
}
#[derive(Debug, Clone, Default)]
pub struct NoopEmitter {
filter: AgUiEventFilter,
}
impl NoopEmitter {
#[must_use]
pub fn new(filter: AgUiEventFilter) -> Self {
Self { filter }
}
}
#[async_trait]
impl AgUiEmitter for NoopEmitter {
fn filter(&self) -> &AgUiEventFilter {
&self.filter
}
async fn emit(&self, _event: &AgUiEvent) {}
}
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
.unwrap_or(0)
}