use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use std::time::Duration;
use crate::runtime::JoinHandle;
#[cfg(not(target_arch = "wasm32"))]
use crate::runtime::timeout;
use blazen_events::{AnyEvent, InputResponseEvent, UsageEvent};
use blazen_llm::types::TokenUsage;
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
use tokio_stream::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use crate::error::WorkflowError;
use crate::session_ref::SessionRefRegistry;
use crate::snapshot::WorkflowSnapshot;
#[derive(Debug, Default, Clone)]
pub(crate) struct UsageTotals {
pub usage: TokenUsage,
pub cost_usd: f64,
}
#[derive(Debug)]
pub struct WorkflowResult {
pub event: Box<dyn AnyEvent>,
pub session_refs: Arc<SessionRefRegistry>,
pub usage_total: TokenUsage,
pub cost_total_usd: f64,
}
pub(crate) enum WorkflowControl {
Pause,
Resume,
Snapshot {
reply: oneshot::Sender<Result<WorkflowSnapshot, WorkflowError>>,
},
Abort,
InputResponse(InputResponseEvent),
}
pub struct WorkflowHandler {
result_rx: Option<oneshot::Receiver<Result<Box<dyn AnyEvent>, WorkflowError>>>,
stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
control_tx: mpsc::UnboundedSender<WorkflowControl>,
event_loop_handle: Option<JoinHandle<()>>,
session_refs: Arc<SessionRefRegistry>,
usage_totals: Arc<Mutex<UsageTotals>>,
usage_accumulator_handle: Option<JoinHandle<()>>,
#[cfg(feature = "telemetry")]
history_rx: Option<mpsc::UnboundedReceiver<blazen_telemetry::HistoryEvent>>,
}
impl WorkflowHandler {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
result_rx: oneshot::Receiver<Result<Box<dyn AnyEvent>, WorkflowError>>,
stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
control_tx: mpsc::UnboundedSender<WorkflowControl>,
event_loop_handle: JoinHandle<()>,
session_refs: Arc<SessionRefRegistry>,
#[cfg(feature = "telemetry")] history_rx: Option<
mpsc::UnboundedReceiver<blazen_telemetry::HistoryEvent>,
>,
) -> Self {
let usage_totals = Arc::new(Mutex::new(UsageTotals::default()));
let mut accumulator_rx = stream_tx.subscribe();
let totals_for_task = Arc::clone(&usage_totals);
let accumulator_handle = crate::runtime::spawn(async move {
loop {
match accumulator_rx.recv().await {
Ok(boxed) => {
if let Some(usage) = boxed.as_any().downcast_ref::<UsageEvent>() {
let mut totals = totals_for_task.lock().await;
totals.usage.add(&TokenUsage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
reasoning_tokens: usage.reasoning_tokens,
cached_input_tokens: usage.cached_input_tokens,
audio_input_tokens: usage.audio_input_tokens,
audio_output_tokens: usage.audio_output_tokens,
});
if let Some(cost) = usage.cost_usd {
totals.cost_usd += cost;
}
}
}
Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(broadcast::error::RecvError::Closed) => break,
}
}
});
Self {
result_rx: Some(result_rx),
stream_tx,
control_tx,
event_loop_handle: Some(event_loop_handle),
session_refs,
usage_totals,
usage_accumulator_handle: Some(accumulator_handle),
#[cfg(feature = "telemetry")]
history_rx,
}
}
pub async fn usage_total(&self) -> TokenUsage {
let totals = self.usage_totals.lock().await;
totals.usage.clone()
}
pub async fn cost_total_usd(&self) -> f64 {
let totals = self.usage_totals.lock().await;
totals.cost_usd
}
#[must_use]
pub fn session_refs(&self) -> Arc<SessionRefRegistry> {
Arc::clone(&self.session_refs)
}
pub async fn result(mut self) -> Result<WorkflowResult, WorkflowError> {
let rx = self
.result_rx
.take()
.expect("result() called after result was already consumed");
let event = rx.await.unwrap_or(Err(WorkflowError::ChannelClosed))?;
#[cfg(not(target_arch = "wasm32"))]
{
if let Some(handle) = self.event_loop_handle.take() {
let _ = handle.await;
}
let (drained_stream_tx, _) = broadcast::channel::<Box<dyn AnyEvent>>(1);
let owned_sender = std::mem::replace(&mut self.stream_tx, drained_stream_tx);
drop(owned_sender);
if let Some(mut handle) = self.usage_accumulator_handle.take()
&& timeout(Duration::from_millis(50), &mut handle)
.await
.is_err()
{
handle.abort();
}
}
#[cfg(target_arch = "wasm32")]
{
self.event_loop_handle.take();
self.usage_accumulator_handle.take();
}
let totals = self.usage_totals.lock().await.clone();
let session_refs = Arc::clone(&self.session_refs);
Ok(WorkflowResult {
event,
session_refs,
usage_total: totals.usage,
cost_total_usd: totals.cost_usd,
})
}
pub fn stream_events(
&self,
) -> impl tokio_stream::Stream<Item = Box<dyn AnyEvent>> + Send + Unpin + use<> {
let rx = self.stream_tx.subscribe();
BroadcastStream::new(rx).filter_map(std::result::Result::ok)
}
pub fn pause(&self) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::Pause)
.map_err(|_| WorkflowError::ChannelClosed)
}
pub fn resume_in_place(&self) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::Resume)
.map_err(|_| WorkflowError::ChannelClosed)
}
pub async fn snapshot(&self) -> Result<WorkflowSnapshot, WorkflowError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.control_tx
.send(WorkflowControl::Snapshot { reply: reply_tx })
.map_err(|_| WorkflowError::ChannelClosed)?;
reply_rx.await.unwrap_or(Err(WorkflowError::ChannelClosed))
}
pub fn respond_to_input(&self, response: InputResponseEvent) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::InputResponse(response))
.map_err(|_| WorkflowError::ChannelClosed)
}
pub fn abort(&self) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::Abort)
.map_err(|_| WorkflowError::ChannelClosed)
}
#[cfg(feature = "telemetry")]
pub fn collect_history(
&mut self,
run_id: uuid::Uuid,
workflow_name: String,
) -> Option<blazen_telemetry::WorkflowHistory> {
let mut rx = self.history_rx.take()?;
let mut history = blazen_telemetry::WorkflowHistory::new(run_id, workflow_name);
while let Ok(mut event) = rx.try_recv() {
event.sequence = history.events.len() as u64;
history.events.push(event);
}
Some(history)
}
}
impl Drop for WorkflowHandler {
fn drop(&mut self) {
let _ = self.control_tx.send(WorkflowControl::Abort);
}
}