use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use serde_json::Value;
use thiserror::Error;
use tokio::sync::broadcast;
use tokio::task::JoinHandle;
use uuid::Uuid;
use crate::provider::{ChatStreamEvent, ModelName, ProviderId, ToolChoice};
use crate::runtime::agent::{AgentRuntime, RunOutput};
use crate::runtime::error::RuntimeError;
use crate::runtime::event::AgentEvent;
use crate::runtime::run::{RunId, RunRequest};
#[derive(Debug, Clone)]
pub struct EmitRequest {
pub provider: ProviderId,
pub model: ModelName,
pub input: String,
pub session_id: Option<Uuid>,
pub client_request_id: Option<String>,
pub metadata: Value,
}
impl EmitRequest {
#[must_use]
pub fn new(provider: ProviderId, model: ModelName, input: impl Into<String>) -> Self {
Self {
provider,
model,
input: input.into(),
session_id: None,
client_request_id: None,
metadata: Value::Null,
}
}
#[must_use]
pub fn with_session_id(mut self, session_id: Uuid) -> Self {
self.session_id = Some(session_id);
self
}
#[must_use]
pub fn with_client_request_id(mut self, id: impl Into<String>) -> Self {
self.client_request_id = Some(id.into());
self
}
#[must_use]
pub fn with_metadata(mut self, metadata: Value) -> Self {
self.metadata = metadata;
self
}
#[must_use]
pub fn into_run_request(self) -> RunRequest {
RunRequest {
session_id: self.session_id,
run_id: None,
provider: self.provider,
model: self.model,
input: self.input,
metadata: self.metadata,
tool_choice: ToolChoice::Auto,
client_request_id: self.client_request_id,
}
}
}
#[derive(Debug, Error)]
pub enum InvocationError {
#[error(transparent)]
Runtime(#[from] RuntimeError),
#[error("invocation task failed: {message}")]
TaskFailed {
message: String,
},
#[error("invalid invocation request: {message}")]
InvalidRequest {
message: String,
},
}
#[derive(Debug, Clone)]
pub enum InvocationEvent {
Agent(AgentEvent),
Chat(ChatStreamEvent),
}
impl InvocationEvent {
#[must_use]
pub fn run_id(&self) -> Option<RunId> {
match self {
InvocationEvent::Agent(e) => Some(e.run_id()),
InvocationEvent::Chat(_) => None,
}
}
#[must_use]
pub fn as_agent(&self) -> Option<&AgentEvent> {
match self {
InvocationEvent::Agent(e) => Some(e),
InvocationEvent::Chat(_) => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EventKind {
Any,
RunStarted,
ContextBuilt,
ModelStarted,
TextDelta,
ToolCallStarted,
ToolCallDelta,
ToolCallCompleted,
ToolExecutionStarted,
ToolExecutionFinished,
AssistantMessageCommitted,
ToolMessageCommitted,
UsageRecorded,
RunCompleted,
RunFailed,
RunCancelled,
DoomLoopDetected,
CompactionCircuitOpened,
ChatStarted,
ChatTextDelta,
ChatToolCallStarted,
ChatToolCallArgumentsDelta,
ChatToolCallCompleted,
ChatFinished,
}
impl EventKind {
#[must_use]
#[allow(clippy::too_many_lines)]
pub fn matches(self, event: &InvocationEvent) -> bool {
match self {
Self::Any => true,
Self::RunStarted => matches!(event, InvocationEvent::Agent(AgentEvent::RunStarted(_))),
Self::ContextBuilt => {
matches!(event, InvocationEvent::Agent(AgentEvent::ContextBuilt(_)))
}
Self::ModelStarted => {
matches!(event, InvocationEvent::Agent(AgentEvent::ModelStarted(_)))
}
Self::TextDelta => matches!(event, InvocationEvent::Agent(AgentEvent::TextDelta(_))),
Self::ToolCallStarted => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::ToolCallStarted(_))
)
}
Self::ToolCallDelta => {
matches!(event, InvocationEvent::Agent(AgentEvent::ToolCallDelta(_)))
}
Self::ToolCallCompleted => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::ToolCallCompleted(_))
)
}
Self::ToolExecutionStarted => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::ToolExecutionStarted(_))
)
}
Self::ToolExecutionFinished => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::ToolExecutionFinished(_))
)
}
Self::AssistantMessageCommitted => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::AssistantMessageCommitted(_))
)
}
Self::ToolMessageCommitted => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::ToolMessageCommitted(_))
)
}
Self::UsageRecorded => {
matches!(event, InvocationEvent::Agent(AgentEvent::UsageRecorded(_)))
}
Self::RunCompleted => {
matches!(event, InvocationEvent::Agent(AgentEvent::RunCompleted(_)))
}
Self::RunFailed => matches!(event, InvocationEvent::Agent(AgentEvent::RunFailed(_))),
Self::RunCancelled => {
matches!(event, InvocationEvent::Agent(AgentEvent::RunCancelled(_)))
}
Self::DoomLoopDetected => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::DoomLoopDetected(_))
)
}
Self::CompactionCircuitOpened => {
matches!(
event,
InvocationEvent::Agent(AgentEvent::CompactionCircuitOpened(_))
)
}
Self::ChatStarted => matches!(
event,
InvocationEvent::Chat(ChatStreamEvent::Started { .. })
),
Self::ChatTextDelta => {
matches!(
event,
InvocationEvent::Chat(ChatStreamEvent::TextDelta { .. })
)
}
Self::ChatToolCallStarted => {
matches!(
event,
InvocationEvent::Chat(ChatStreamEvent::ToolCallStarted { .. })
)
}
Self::ChatToolCallArgumentsDelta => {
matches!(
event,
InvocationEvent::Chat(ChatStreamEvent::ToolCallArgumentsDelta { .. })
)
}
Self::ChatToolCallCompleted => {
matches!(
event,
InvocationEvent::Chat(ChatStreamEvent::ToolCallCompleted { .. })
)
}
Self::ChatFinished => {
matches!(
event,
InvocationEvent::Chat(ChatStreamEvent::Finished { .. })
)
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SessionContext {
pub session_id: Option<Uuid>,
pub run_id: Option<RunId>,
}
impl SessionContext {
#[must_use]
pub fn from_agent_event(event: &AgentEvent) -> Self {
let session_id = match event {
AgentEvent::RunStarted(e) => Some(e.session_id),
_ => None,
};
Self {
session_id,
run_id: Some(event.run_id()),
}
}
}
#[derive(Debug)]
struct ControlInner {
cancelled: AtomicBool,
timeout: Mutex<Option<Duration>>,
concurrency_limit: Mutex<Option<usize>>,
}
#[derive(Debug, Clone)]
pub struct Control {
inner: Arc<ControlInner>,
}
impl Default for Control {
fn default() -> Self {
Self::new()
}
}
impl Control {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(ControlInner {
cancelled: AtomicBool::new(false),
timeout: Mutex::new(None),
concurrency_limit: Mutex::new(None),
}),
}
}
pub fn cancel(&self) {
self.inner.cancelled.store(true, Ordering::Release);
}
#[must_use]
pub fn is_cancelled(&self) -> bool {
self.inner.cancelled.load(Ordering::Acquire)
}
pub fn set_timeout(&self, timeout: Duration) {
*lock_or_recover(&self.inner.timeout) = Some(timeout);
}
#[must_use]
pub fn timeout(&self) -> Option<Duration> {
*lock_or_recover(&self.inner.timeout)
}
pub fn set_concurrency_limit(&self, limit: usize) {
*lock_or_recover(&self.inner.concurrency_limit) = Some(limit);
}
#[must_use]
pub fn concurrency_limit(&self) -> Option<usize> {
*lock_or_recover(&self.inner.concurrency_limit)
}
}
fn lock_or_recover<T>(lock: &Mutex<T>) -> std::sync::MutexGuard<'_, T> {
lock.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[derive(Debug)]
pub struct InvocationHandle {
task: JoinHandle<()>,
}
impl InvocationHandle {
pub fn abort(&self) {
self.task.abort();
}
#[must_use]
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
}
impl Drop for InvocationHandle {
fn drop(&mut self) {
self.task.abort();
}
}
#[derive(Clone)]
pub struct RuntimeInvocation {
runtime: Arc<AgentRuntime>,
}
impl RuntimeInvocation {
#[must_use]
pub fn new(runtime: Arc<AgentRuntime>) -> Self {
Self { runtime }
}
#[must_use]
pub fn runtime(&self) -> &AgentRuntime {
&self.runtime
}
pub async fn emit<F, Fut>(&self, f: F) -> Result<RunOutput, InvocationError>
where
F: FnOnce(SessionContext, Control) -> Fut + Send,
Fut: Future<Output = EmitRequest> + Send,
{
let control = Control::new();
let ctx = SessionContext::default();
if control.is_cancelled() {
return Err(InvocationError::TaskFailed {
message: "cancelled".into(),
});
}
let request = f(ctx, control.clone()).await;
if control.is_cancelled() {
return Err(InvocationError::TaskFailed {
message: "cancelled".into(),
});
}
let run_request = request.into_run_request();
let output = self.runtime.run(run_request).await?;
Ok(output)
}
#[allow(clippy::unused_async)]
pub async fn on<F, Fut>(
&self,
kind: EventKind,
f: F,
) -> Result<InvocationHandle, InvocationError>
where
F: Fn(InvocationEvent, SessionContext, Control) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let receiver = self.runtime.subscribe();
let control = Control::new();
Ok(spawn_listener(receiver, kind, control, f))
}
}
fn spawn_listener<F, Fut>(
mut receiver: broadcast::Receiver<AgentEvent>,
kind: EventKind,
control: Control,
handler: F,
) -> InvocationHandle
where
F: Fn(InvocationEvent, SessionContext, Control) -> Fut + Send + Sync + 'static,
Fut: Future<Output = ()> + Send + 'static,
{
let handler = Arc::new(handler);
let task = tokio::spawn(async move {
loop {
match receiver.recv().await {
Ok(event) => {
if control.is_cancelled() {
break;
}
let ctx = SessionContext::from_agent_event(&event);
let inv = InvocationEvent::Agent(event);
if !kind.matches(&inv) {
continue;
}
let h = Arc::clone(&handler);
let c = control.clone();
tokio::spawn(async move {
h(inv, ctx, c).await;
});
}
Err(broadcast::error::RecvError::Closed) => break,
Err(broadcast::error::RecvError::Lagged(_)) => {}
}
}
});
InvocationHandle { task }
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::provider::{ChatStreamEvent, FinishReason, ModelName, ProviderId, ToolChoice};
use crate::runtime::event::{RunCompleted, RunStarted, TextDelta};
use chrono::Utc;
use serde_json::json;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::broadcast;
use uuid::Uuid;
#[test]
fn emit_request_converts_to_run_request() {
let sid = Uuid::new_v4();
let req = EmitRequest::new(ProviderId::new("p"), ModelName::new("m"), "hi")
.with_session_id(sid)
.with_client_request_id("cid")
.with_metadata(json!({"k": "v"}));
let run = req.into_run_request();
assert_eq!(run.provider, ProviderId::new("p"));
assert_eq!(run.model, ModelName::new("m"));
assert_eq!(run.input, "hi");
assert_eq!(run.session_id, Some(sid));
assert_eq!(run.client_request_id.as_deref(), Some("cid"));
assert_eq!(run.metadata, json!({"k": "v"}));
assert!(matches!(run.tool_choice, ToolChoice::Auto));
assert!(run.run_id.is_none());
}
#[test]
fn emit_request_default_metadata_is_null() {
let req = EmitRequest::new(ProviderId::new("p"), ModelName::new("m"), "hi");
assert!(req.metadata.is_null());
let run = req.clone().into_run_request();
assert!(run.metadata.is_null());
}
#[test]
fn event_kind_matches_agent_text_delta() {
let ev = InvocationEvent::Agent(AgentEvent::TextDelta(TextDelta {
run_id: RunId::new(),
delta: "x".into(),
timestamp: Utc::now(),
}));
assert!(EventKind::TextDelta.matches(&ev));
assert!(EventKind::Any.matches(&ev));
assert!(!EventKind::RunCompleted.matches(&ev));
assert!(!EventKind::ChatTextDelta.matches(&ev));
}
#[test]
fn event_kind_matches_agent_run_completed() {
let ev = InvocationEvent::Agent(AgentEvent::RunCompleted(RunCompleted {
run_id: RunId::new(),
finish_reason: FinishReason::Stop,
iterations: 1,
timestamp: Utc::now(),
}));
assert!(EventKind::RunCompleted.matches(&ev));
assert!(EventKind::Any.matches(&ev));
assert!(!EventKind::TextDelta.matches(&ev));
}
#[test]
fn event_kind_matches_chat_text_delta() {
let ev = InvocationEvent::Chat(ChatStreamEvent::TextDelta { delta: "x".into() });
assert!(EventKind::ChatTextDelta.matches(&ev));
assert!(EventKind::Any.matches(&ev));
assert!(!EventKind::TextDelta.matches(&ev));
}
#[test]
fn control_cancel_sets_flag_and_shares_state() {
let c = Control::new();
assert!(!c.is_cancelled());
c.cancel();
assert!(c.is_cancelled());
let cloned = c.clone();
assert!(cloned.is_cancelled(), "clone must share cancel state");
}
#[tokio::test]
async fn on_only_handles_matching_events() {
let (tx, rx) = broadcast::channel::<AgentEvent>(16);
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let handler = move |_, _, _| {
let c = c.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
}
};
let handle = spawn_listener(rx, EventKind::TextDelta, Control::new(), handler);
let _ = tx.send(AgentEvent::TextDelta(TextDelta {
run_id: RunId::new(),
delta: "a".into(),
timestamp: Utc::now(),
}));
let _ = tx.send(AgentEvent::RunStarted(RunStarted {
run_id: RunId::new(),
session_id: Uuid::new_v4(),
provider: ProviderId::new("p"),
model: ModelName::new("m"),
timestamp: Utc::now(),
}));
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"only the matching TextDelta event should be handled"
);
handle.abort();
}
#[tokio::test]
async fn invocation_handle_abort_stops_listener() {
let (tx, rx) = broadcast::channel::<AgentEvent>(16);
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let handler = move |_, _, _| {
let c = c.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
}
};
let handle = spawn_listener(rx, EventKind::Any, Control::new(), handler);
let _ = tx.send(AgentEvent::RunStarted(RunStarted {
run_id: RunId::new(),
session_id: Uuid::new_v4(),
provider: ProviderId::new("p"),
model: ModelName::new("m"),
timestamp: Utc::now(),
}));
tokio::time::sleep(Duration::from_millis(100)).await;
let before = counter.load(Ordering::SeqCst);
assert!(before >= 1, "first event should be handled");
handle.abort();
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(handle.is_finished(), "listener should finish after abort");
let _ = tx.send(AgentEvent::RunStarted(RunStarted {
run_id: RunId::new(),
session_id: Uuid::new_v4(),
provider: ProviderId::new("p"),
model: ModelName::new("m"),
timestamp: Utc::now(),
}));
tokio::time::sleep(Duration::from_millis(100)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
before,
"no events should be handled after abort"
);
}
}