use cortex_ai::FlowError;
use cortex_ai::{
flow::types::SourceOutput, Condition, ConditionFuture, Flow, FlowComponent, FlowFuture,
Processor, Source,
};
use flume::bounded;
use std::error::Error;
use std::time::Duration;
use tokio::sync::broadcast;
use tracing_subscriber::EnvFilter;
#[derive(Debug, Clone)]
pub struct TestProcessor;
#[derive(Clone)]
pub struct TestCondition;
pub struct TestSource {
pub data: String,
pub feedback: flume::Sender<Result<String, FlowError>>,
}
#[derive(Clone)]
pub struct PassthroughProcessor;
#[derive(Clone)]
pub struct ErrorProcessor;
pub struct EmptySource;
pub struct StreamErrorSource;
#[derive(Clone)]
pub struct SkipProcessor;
pub struct ErrorSource {
pub feedback: flume::Sender<Result<String, FlowError>>,
}
pub struct SkipSource {
pub feedback: flume::Sender<Result<String, FlowError>>,
}
impl FlowComponent for TestProcessor {
type Input = String;
type Output = String;
type Error = FlowError;
}
impl Processor for TestProcessor {
fn process(&self, input: Self::Input) -> FlowFuture<'_, Self::Output, Self::Error> {
Box::pin(async move { Ok(format!("processed_{input}")) })
}
}
impl FlowComponent for TestCondition {
type Input = String;
type Output = String;
type Error = FlowError;
}
impl Condition for TestCondition {
fn evaluate(&self, input: Self::Input) -> ConditionFuture<'_, Self::Output, Self::Error> {
Box::pin(async move {
let condition_met = input.contains("test");
Ok((condition_met, Some(input)))
})
}
}
impl FlowComponent for TestSource {
type Input = ();
type Output = String;
type Error = FlowError;
}
impl Source for TestSource {
fn stream(&self) -> FlowFuture<'_, SourceOutput<Self::Output, Self::Error>, Self::Error> {
let data = self.data.clone();
let feedback = self.feedback.clone();
Box::pin(async move {
let (tx, rx) = bounded(1);
tx.send(Ok(data)).unwrap();
drop(tx);
Ok(SourceOutput {
receiver: rx,
feedback,
})
})
}
}
impl FlowComponent for PassthroughProcessor {
type Input = String;
type Output = String;
type Error = FlowError;
}
impl Processor for PassthroughProcessor {
fn process(&self, input: Self::Input) -> FlowFuture<'_, Self::Output, Self::Error> {
Box::pin(async move { Ok(input) })
}
}
impl FlowComponent for ErrorProcessor {
type Input = String;
type Output = String;
type Error = FlowError;
}
impl Processor for ErrorProcessor {
fn process(&self, _input: Self::Input) -> FlowFuture<'_, Self::Output, Self::Error> {
Box::pin(async move { Err(FlowError::Process("Processing failed".to_string())) })
}
}
impl FlowComponent for ErrorSource {
type Input = ();
type Output = String;
type Error = FlowError;
}
impl Source for ErrorSource {
fn stream(&self) -> FlowFuture<'_, SourceOutput<Self::Output, Self::Error>, Self::Error> {
let feedback = self.feedback.clone();
Box::pin(async move {
let (tx, rx) = bounded(1);
tx.send(Err(FlowError::Source("Source error".to_string())))
.unwrap();
drop(tx);
Ok(SourceOutput {
receiver: rx,
feedback,
})
})
}
}
impl FlowComponent for StreamErrorSource {
type Input = ();
type Output = String;
type Error = FlowError;
}
impl Source for StreamErrorSource {
fn stream(&self) -> FlowFuture<'_, SourceOutput<Self::Output, Self::Error>, Self::Error> {
Box::pin(async move { Err(FlowError::Source("Stream initialization error".to_string())) })
}
}
impl FlowComponent for SkipProcessor {
type Input = String;
type Output = String;
type Error = FlowError;
}
impl Processor for SkipProcessor {
fn process(&self, _input: Self::Input) -> FlowFuture<'_, Self::Output, Self::Error> {
Box::pin(async move { Ok("skipped".to_string()) })
}
}
impl FlowComponent for SkipSource {
type Input = ();
type Output = String;
type Error = FlowError;
}
impl Source for SkipSource {
fn stream(&self) -> FlowFuture<'_, SourceOutput<Self::Output, Self::Error>, Self::Error> {
let feedback = self.feedback.clone();
Box::pin(async move {
let (tx, rx) = bounded(1);
tx.send(Ok("to_be_skipped".to_string())).unwrap();
drop(tx);
Ok(SourceOutput {
receiver: rx,
feedback,
})
})
}
}
impl FlowComponent for EmptySource {
type Input = ();
type Output = String;
type Error = FlowError;
}
impl Source for EmptySource {
fn stream(&self) -> FlowFuture<'_, SourceOutput<Self::Output, Self::Error>, Self::Error> {
Box::pin(async move {
let (tx, rx) = bounded(1);
let (feedback_tx, feedback_rx) = bounded(1);
drop(tx);
tokio::spawn(async move {
while let Ok(result) = feedback_rx.recv_async().await {
match result {
Ok(processed_data) => println!("Processing succeeded: {processed_data}"),
Err(e) => println!("Processing failed: {e}"),
}
}
});
Ok(SourceOutput {
receiver: rx,
feedback: feedback_tx,
})
})
}
}
pub async fn run_flow_with_timeout<DataType, ErrorType, OutputType>(
flow: Flow<DataType, ErrorType, OutputType>,
timeout: Duration,
) -> Result<Vec<DataType>, ErrorType>
where
DataType: Clone + Send + Sync + 'static,
OutputType: Send + Sync + 'static,
ErrorType: Error + Send + Sync + Clone + 'static + From<FlowError>,
{
let (shutdown_tx, shutdown_rx) = broadcast::channel(1);
let handle = tokio::spawn(async move { flow.run_stream(shutdown_rx).await });
tokio::time::sleep(timeout).await;
let _ = shutdown_tx.send(());
handle.await.unwrap()
}
pub fn init_tracing() {
let subscriber = tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("cortex_ai=debug".parse().unwrap())
.add_directive("test=debug".parse().unwrap()),
)
.with_test_writer() .with_thread_ids(true) .with_file(true) .with_line_number(true) .with_target(false) .compact() .try_init();
if subscriber.is_err() {
println!("Warning: tracing already initialized");
}
}