use std::sync::Arc;
use blazen_events::{AnyEvent, InputResponseEvent};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use crate::error::WorkflowError;
use crate::session_ref::SessionRefRegistry;
use crate::snapshot::WorkflowSnapshot;
#[derive(Debug)]
pub struct WorkflowResult {
pub event: Box<dyn AnyEvent>,
pub session_refs: Arc<SessionRefRegistry>,
}
pub(crate) enum WorkflowControl {
Pause,
Resume,
Snapshot {
reply: oneshot::Sender<Result<WorkflowSnapshot, WorkflowError>>,
},
Abort,
InputResponse(InputResponseEvent),
}
pub struct WorkflowHandler {
result_rx: Option<oneshot::Receiver<Result<Box<dyn AnyEvent>, WorkflowError>>>,
stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
control_tx: mpsc::UnboundedSender<WorkflowControl>,
event_loop_handle: Option<JoinHandle<()>>,
session_refs: Arc<SessionRefRegistry>,
#[cfg(feature = "telemetry")]
history_rx: Option<mpsc::UnboundedReceiver<blazen_telemetry::HistoryEvent>>,
}
impl WorkflowHandler {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
result_rx: oneshot::Receiver<Result<Box<dyn AnyEvent>, WorkflowError>>,
stream_tx: broadcast::Sender<Box<dyn AnyEvent>>,
control_tx: mpsc::UnboundedSender<WorkflowControl>,
event_loop_handle: JoinHandle<()>,
session_refs: Arc<SessionRefRegistry>,
#[cfg(feature = "telemetry")] history_rx: Option<
mpsc::UnboundedReceiver<blazen_telemetry::HistoryEvent>,
>,
) -> Self {
Self {
result_rx: Some(result_rx),
stream_tx,
control_tx,
event_loop_handle: Some(event_loop_handle),
session_refs,
#[cfg(feature = "telemetry")]
history_rx,
}
}
#[must_use]
pub fn session_refs(&self) -> Arc<SessionRefRegistry> {
Arc::clone(&self.session_refs)
}
pub async fn result(mut self) -> Result<WorkflowResult, WorkflowError> {
let rx = self
.result_rx
.take()
.expect("result() called after result was already consumed");
let event = rx.await.unwrap_or(Err(WorkflowError::ChannelClosed))?;
if let Some(handle) = self.event_loop_handle.take() {
let _ = handle.await;
}
let session_refs = Arc::clone(&self.session_refs);
Ok(WorkflowResult {
event,
session_refs,
})
}
pub fn stream_events(
&self,
) -> impl tokio_stream::Stream<Item = Box<dyn AnyEvent>> + Send + Unpin + use<> {
let rx = self.stream_tx.subscribe();
BroadcastStream::new(rx).filter_map(std::result::Result::ok)
}
pub fn pause(&self) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::Pause)
.map_err(|_| WorkflowError::ChannelClosed)
}
pub fn resume_in_place(&self) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::Resume)
.map_err(|_| WorkflowError::ChannelClosed)
}
pub async fn snapshot(&self) -> Result<WorkflowSnapshot, WorkflowError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.control_tx
.send(WorkflowControl::Snapshot { reply: reply_tx })
.map_err(|_| WorkflowError::ChannelClosed)?;
reply_rx.await.unwrap_or(Err(WorkflowError::ChannelClosed))
}
pub fn respond_to_input(&self, response: InputResponseEvent) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::InputResponse(response))
.map_err(|_| WorkflowError::ChannelClosed)
}
pub fn abort(&self) -> Result<(), WorkflowError> {
self.control_tx
.send(WorkflowControl::Abort)
.map_err(|_| WorkflowError::ChannelClosed)
}
#[cfg(feature = "telemetry")]
pub fn collect_history(
&mut self,
run_id: uuid::Uuid,
workflow_name: String,
) -> Option<blazen_telemetry::WorkflowHistory> {
let mut rx = self.history_rx.take()?;
let mut history = blazen_telemetry::WorkflowHistory::new(run_id, workflow_name);
while let Ok(mut event) = rx.try_recv() {
event.sequence = history.events.len() as u64;
history.events.push(event);
}
Some(history)
}
}
impl Drop for WorkflowHandler {
fn drop(&mut self) {
let _ = self.control_tx.send(WorkflowControl::Abort);
}
}