use std::collections::HashSet;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use futures::future::BoxFuture;
use tokio::sync::{mpsc, watch, Notify};
use crate::clock::BaseClock;
use crate::error::{PipecatError, Result};
use crate::frames::{ControlFrame, ErrorFrameData, Frame, FrameDirection, FrameHandler, FrameInner, FrameKind, FrameProcessor, FrameProcessorSetup, StartFrameData, SystemFrame};
use crate::observer::BaseObserver;
use super::pipeline::Pipeline;
type AsyncCb0 = Arc<dyn Fn() -> BoxFuture<'static, ()> + Send + Sync>;
type AsyncCbFrame = Arc<dyn Fn(Frame) -> BoxFuture<'static, ()> + Send + Sync>;
type AsyncCbFinish = Arc<dyn Fn(Frame, FinishReason) -> BoxFuture<'static, ()> + Send + Sync>;
type AsyncCbError = Arc<dyn Fn(ErrorFrameData) -> BoxFuture<'static, ()> + Send + Sync>;
#[derive(Clone)]
pub struct PipelineParams {
pub allow_interruptions: bool,
pub enable_metrics: bool,
pub enable_usage_metrics: bool,
pub report_only_initial_ttfb: bool,
pub enable_heartbeats: bool,
pub heartbeat_seconds: f64,
pub idle_timeout: Option<Duration>,
pub cancel_on_idle_timeout: bool,
pub idle_timeout_frames: HashSet<FrameKind>,
}
impl Default for PipelineParams {
fn default() -> Self {
let mut idle_timeout_frames = HashSet::new();
idle_timeout_frames.insert(FrameKind::BotSpeaking);
idle_timeout_frames.insert(FrameKind::UserSpeaking);
Self {
allow_interruptions: false,
enable_metrics: false,
enable_usage_metrics: false,
report_only_initial_ttfb: false,
enable_heartbeats: false,
heartbeat_seconds: 1.0,
idle_timeout: None,
cancel_on_idle_timeout: true,
idle_timeout_frames,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FinishReason {
End,
Stop,
Cancel(Option<String>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum PipelineLifecycle {
NotStarted,
Running,
Finished(FinishReason),
}
pub(crate) struct TaskState {
pub(crate) lifecycle_tx: watch::Sender<PipelineLifecycle>,
pub(crate) on_pipeline_started: std::sync::Mutex<Vec<AsyncCbFrame>>,
pub(crate) on_pipeline_finished: std::sync::Mutex<Vec<AsyncCbFinish>>,
pub(crate) on_pipeline_error: std::sync::Mutex<Vec<AsyncCbError>>,
pub(crate) on_frame_reached_upstream: std::sync::Mutex<Vec<AsyncCbFrame>>,
pub(crate) on_frame_reached_downstream: std::sync::Mutex<Vec<AsyncCbFrame>>,
pub(crate) on_idle_timeout: std::sync::Mutex<Vec<AsyncCb0>>,
pub(crate) upstream_filter: std::sync::Mutex<HashSet<FrameKind>>,
pub(crate) downstream_filter: std::sync::Mutex<HashSet<FrameKind>>,
pub(crate) idle_notify: Arc<Notify>,
pub(crate) idle_timeout_frames: HashSet<FrameKind>,
pub(crate) cancel_on_idle_timeout: bool,
pub(crate) cancelled: AtomicBool,
}
impl TaskState {
fn new(params: &PipelineParams) -> (Arc<Self>, watch::Receiver<PipelineLifecycle>) {
let (lifecycle_tx, lifecycle_rx) =
watch::channel(PipelineLifecycle::NotStarted);
let state = Arc::new(Self {
lifecycle_tx,
on_pipeline_started: std::sync::Mutex::new(Vec::new()),
on_pipeline_finished: std::sync::Mutex::new(Vec::new()),
on_pipeline_error: std::sync::Mutex::new(Vec::new()),
on_frame_reached_upstream: std::sync::Mutex::new(Vec::new()),
on_frame_reached_downstream: std::sync::Mutex::new(Vec::new()),
on_idle_timeout: std::sync::Mutex::new(Vec::new()),
upstream_filter: std::sync::Mutex::new(HashSet::new()),
downstream_filter: std::sync::Mutex::new(HashSet::new()),
idle_notify: Arc::new(Notify::new()),
idle_timeout_frames: params.idle_timeout_frames.clone(),
cancel_on_idle_timeout: params.cancel_on_idle_timeout,
cancelled: AtomicBool::new(false),
});
(state, lifecycle_rx)
}
}
pub(crate) struct TaskSourceHandler {
state: Arc<TaskState>,
}
#[async_trait]
impl FrameHandler for TaskSourceHandler {
async fn on_process_frame(
&self,
processor: &FrameProcessor,
frame: Frame,
direction: FrameDirection,
) -> Result<()> {
match direction {
FrameDirection::Downstream => {
processor.push_frame(frame, FrameDirection::Downstream).await
}
FrameDirection::Upstream => {
self.handle_upstream_escape(processor, frame).await
}
}
}
}
impl TaskSourceHandler {
async fn handle_upstream_escape(
&self,
processor: &FrameProcessor,
frame: Frame,
) -> Result<()> {
match &frame.inner {
FrameInner::System(SystemFrame::EndTask { .. }) => {
processor
.push_frame(Frame::end(), FrameDirection::Downstream)
.await?;
}
FrameInner::System(SystemFrame::CancelTask { .. }) => {
processor
.push_frame(Frame::cancel(), FrameDirection::Downstream)
.await?;
}
FrameInner::System(SystemFrame::StopTask) => {
processor
.push_frame(Frame::stop(), FrameDirection::Downstream)
.await?;
}
FrameInner::System(SystemFrame::InterruptionTask) => {
processor.broadcast_interruption().await?;
}
FrameInner::System(SystemFrame::Error(ref data)) if data.fatal => {
let data_clone = data.clone();
fire_error(&self.state, data_clone).await;
processor
.push_frame(Frame::cancel(), FrameDirection::Downstream)
.await?;
}
FrameInner::System(SystemFrame::Error(ref data)) => {
let data_clone = data.clone();
fire_error(&self.state, data_clone).await;
processor.push_frame(frame, FrameDirection::Downstream).await?;
}
_ => {
if self.state.idle_timeout_frames.contains(&frame.kind()) {
self.state.idle_notify.notify_one();
}
let matches = self
.state
.upstream_filter
.lock()
.unwrap()
.contains(&frame.kind());
if matches {
fire_frame_cbs(&self.state.on_frame_reached_upstream, frame).await;
}
}
}
Ok(())
}
}
pub(crate) struct TaskSinkHandler {
state: Arc<TaskState>,
}
#[async_trait]
impl FrameHandler for TaskSinkHandler {
async fn on_process_frame(
&self,
processor: &FrameProcessor,
frame: Frame,
direction: FrameDirection,
) -> Result<()> {
match direction {
FrameDirection::Upstream => {
processor.push_frame(frame, FrameDirection::Upstream).await
}
FrameDirection::Downstream => {
self.handle_downstream_escape(processor, frame).await
}
}
}
}
impl TaskSinkHandler {
async fn handle_downstream_escape(&self, processor: &FrameProcessor, frame: Frame) -> Result<()> {
match &frame.inner {
FrameInner::System(SystemFrame::Start(_)) => {
let _ = self.state.lifecycle_tx.send(PipelineLifecycle::Running);
fire_frame_cbs(&self.state.on_pipeline_started, frame.clone()).await;
processor.push_frame(frame, FrameDirection::Downstream).await?;
}
FrameInner::Control(ControlFrame::End { .. }) => {
fire_frame_finish_cbs(&self.state.on_pipeline_finished, frame.clone(), FinishReason::End).await;
let _ = self.state.lifecycle_tx.send(PipelineLifecycle::Finished(FinishReason::End));
processor.push_frame(frame, FrameDirection::Downstream).await?;
}
FrameInner::System(SystemFrame::Stop { .. }) => {
fire_frame_finish_cbs(&self.state.on_pipeline_finished, frame.clone(), FinishReason::Stop).await;
let _ = self.state.lifecycle_tx.send(PipelineLifecycle::Finished(FinishReason::Stop));
processor.push_frame(frame, FrameDirection::Downstream).await?;
}
FrameInner::System(SystemFrame::Cancel { .. }) => {
fire_frame_finish_cbs(&self.state.on_pipeline_finished, frame.clone(), FinishReason::Cancel(None)).await;
let _ = self.state.lifecycle_tx.send(PipelineLifecycle::Finished(FinishReason::Cancel(None)));
processor.push_frame(frame, FrameDirection::Downstream).await?;
}
FrameInner::System(SystemFrame::Error(ref data)) => {
let data_clone = data.clone();
fire_error(&self.state, data_clone).await;
processor.push_frame(frame, FrameDirection::Downstream).await?;
}
FrameInner::System(
SystemFrame::EndTask { .. }
| SystemFrame::CancelTask { .. }
| SystemFrame::StopTask
) => {
log::warn!(
"TaskSink: task-control frame {} reached downstream boundary, ignoring",
frame.name()
);
}
_ => {
if self.state.idle_timeout_frames.contains(&frame.kind()) {
self.state.idle_notify.notify_one();
}
let matches = self
.state
.downstream_filter
.lock()
.unwrap()
.contains(&frame.kind());
if matches {
fire_frame_cbs(&self.state.on_frame_reached_downstream, frame.clone()).await;
}
processor.push_frame(frame, FrameDirection::Downstream).await?;
}
}
Ok(())
}
}
async fn fire_frame_cbs(
handlers: &std::sync::Mutex<Vec<AsyncCbFrame>>,
frame: Frame,
) {
if handlers.lock().unwrap().is_empty() {
return;
}
let cbs: Vec<AsyncCbFrame> = handlers.lock().unwrap().clone();
for cb in &cbs {
cb(frame.clone()).await;
}
}
async fn fire_frame_finish_cbs(
handlers: &std::sync::Mutex<Vec<AsyncCbFinish>>,
frame: Frame,
reason: FinishReason,
) {
if handlers.lock().unwrap().is_empty() {
return;
}
let cbs: Vec<AsyncCbFinish> = handlers.lock().unwrap().clone();
for cb in &cbs {
cb(frame.clone(), reason.clone()).await;
}
}
async fn fire_error(state: &TaskState, data: ErrorFrameData) {
if state.on_pipeline_error.lock().unwrap().is_empty() {
return;
}
let cbs: Vec<AsyncCbError> = state.on_pipeline_error.lock().unwrap().clone();
for cb in &cbs {
cb(data.clone()).await;
}
}
pub struct PipelineTask {
pipeline: FrameProcessor,
params: PipelineParams,
state: Arc<TaskState>,
push_tx: mpsc::Sender<(Frame, FrameDirection)>,
push_rx: std::sync::Mutex<Option<mpsc::Receiver<(Frame, FrameDirection)>>>,
lifecycle_rx: std::sync::Mutex<watch::Receiver<PipelineLifecycle>>,
}
impl PipelineTask {
pub fn new(processors: Vec<FrameProcessor>, params: PipelineParams) -> Self {
let (state, lifecycle_rx) = TaskState::new(¶ms);
let task_source = FrameProcessor::new(
"PipelineTaskSource",
Box::new(TaskSourceHandler { state: state.clone() }),
true, );
let task_sink = FrameProcessor::new(
"PipelineTaskSink",
Box::new(TaskSinkHandler { state: state.clone() }),
true, );
let all: Vec<FrameProcessor> = std::iter::once(task_source)
.chain(processors)
.chain(std::iter::once(task_sink))
.collect();
let pipeline = Pipeline::new(all);
let (push_tx, push_rx) = mpsc::channel(64);
Self {
pipeline,
params,
state,
push_tx,
push_rx: std::sync::Mutex::new(Some(push_rx)),
lifecycle_rx: std::sync::Mutex::new(lifecycle_rx),
}
}
pub fn push_sender(&self) -> mpsc::Sender<(Frame, FrameDirection)> {
self.push_tx.clone()
}
pub async fn push_frame(&self, frame: Frame, direction: FrameDirection) -> Result<()> {
self.push_tx
.send((frame, direction))
.await
.map_err(|_| PipecatError::pipeline("Push channel closed"))
}
pub fn lifecycle_receiver(&self) -> watch::Receiver<PipelineLifecycle> {
self.lifecycle_rx.lock().unwrap().clone()
}
pub fn add_on_pipeline_started<F>(&self, f: F)
where
F: Fn(Frame) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
self.state
.on_pipeline_started
.lock()
.unwrap()
.push(Arc::new(f));
}
pub fn add_on_pipeline_finished<F>(&self, f: F)
where
F: Fn(Frame, FinishReason) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
self.state
.on_pipeline_finished
.lock()
.unwrap()
.push(Arc::new(f));
}
pub fn add_on_pipeline_error<F>(&self, f: F)
where
F: Fn(ErrorFrameData) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
self.state
.on_pipeline_error
.lock()
.unwrap()
.push(Arc::new(f));
}
pub fn add_on_frame_reached_upstream<F>(&self, f: F)
where
F: Fn(Frame) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
self.state
.on_frame_reached_upstream
.lock()
.unwrap()
.push(Arc::new(f));
}
pub fn add_on_frame_reached_downstream<F>(&self, f: F)
where
F: Fn(Frame) -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
self.state
.on_frame_reached_downstream
.lock()
.unwrap()
.push(Arc::new(f));
}
pub fn add_on_idle_timeout<F>(&self, f: F)
where
F: Fn() -> BoxFuture<'static, ()> + Send + Sync + 'static,
{
self.state
.on_idle_timeout
.lock()
.unwrap()
.push(Arc::new(f));
}
pub fn set_upstream_filter(&self, kinds: HashSet<FrameKind>) {
*self.state.upstream_filter.lock().unwrap() = kinds;
}
pub fn set_downstream_filter(&self, kinds: HashSet<FrameKind>) {
*self.state.downstream_filter.lock().unwrap() = kinds;
}
pub async fn run(
&self,
clock: Arc<dyn BaseClock>,
observer: Option<Arc<dyn BaseObserver>>,
) -> Result<()> {
let push_rx = self
.push_rx
.lock()
.unwrap()
.take()
.ok_or_else(|| PipecatError::pipeline("PipelineTask::run() called more than once"))?;
self.pipeline
.setup(FrameProcessorSetup { clock, observer })
.await?;
let start_frame = Frame::start(StartFrameData {
allow_interruptions: self.params.allow_interruptions,
enable_metrics: self.params.enable_metrics,
enable_usage_metrics: self.params.enable_usage_metrics,
report_only_initial_ttfb: self.params.report_only_initial_ttfb,
..Default::default()
});
self.pipeline
.queue_frame(start_frame, FrameDirection::Downstream, None)
.await?;
let pipeline_for_push = self.pipeline.clone();
let push_task = tokio::spawn(async move {
let mut rx = push_rx;
while let Some((frame, direction)) = rx.recv().await {
let _ = pipeline_for_push
.queue_frame(frame, direction, None)
.await;
}
});
let heartbeat_task = if self.params.enable_heartbeats {
let pipeline_hb = self.pipeline.clone();
let period = Duration::from_secs_f64(self.params.heartbeat_seconds);
Some(tokio::spawn(async move {
loop {
tokio::time::sleep(period).await;
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
let _ = pipeline_hb
.queue_frame(
Frame::heartbeat(ts),
FrameDirection::Downstream,
None,
)
.await;
}
}))
} else {
None
};
let idle_task = if let Some(timeout) = self.params.idle_timeout {
let state = self.state.clone();
let pipeline_idle = self.pipeline.clone();
Some(tokio::spawn(async move {
let idle_notify = state.idle_notify.clone();
loop {
let timed_out = tokio::time::timeout(
timeout,
idle_notify.notified(),
)
.await
.is_err();
if timed_out {
let cbs: Vec<AsyncCb0> =
state.on_idle_timeout.lock().unwrap().clone();
for cb in &cbs {
cb().await;
}
if state.cancel_on_idle_timeout {
log::info!("PipelineTask: idle timeout — cancelling pipeline");
let _ = pipeline_idle
.queue_frame(
Frame::cancel(),
FrameDirection::Downstream,
None,
)
.await;
break;
}
}
}
}))
} else {
None
};
let mut lifecycle_rx = self.lifecycle_rx.lock().unwrap().clone();
loop {
lifecycle_rx
.changed()
.await
.map_err(|_| PipecatError::pipeline("Lifecycle channel closed unexpectedly"))?;
if matches!(
*lifecycle_rx.borrow(),
PipelineLifecycle::Finished(_)
) {
break;
}
}
push_task.abort();
if let Some(t) = heartbeat_task { t.abort(); }
if let Some(t) = idle_task { t.abort(); }
self.pipeline.cleanup().await?;
Ok(())
}
}