use crate::helpers::{
init_tracing, EmptySource, ErrorProcessor, ErrorSource, PassthroughProcessor, SkipProcessor,
SkipSource, StreamErrorSource, TestSource,
};
use cortex_ai::Flow;
use flume::bounded;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tracing::info;
#[cfg(test)]
mod flow_tests {
use cortex_ai::{flow::types::SourceOutput, FlowComponent, FlowError, FlowFuture, Source};
use flume::Receiver;
use super::*;
use crate::helpers::{run_flow_with_timeout, TestCondition};
struct MultiSource {
rx: Receiver<Result<String, FlowError>>,
feedback: flume::Sender<Result<String, FlowError>>,
}
impl FlowComponent for MultiSource {
type Input = ();
type Output = String;
type Error = FlowError;
}
impl Source for MultiSource {
fn stream(&self) -> FlowFuture<'_, SourceOutput<Self::Output, Self::Error>, Self::Error> {
let rx = self.rx.clone();
let feedback = self.feedback.clone();
Box::pin(async move {
Ok(SourceOutput {
receiver: rx,
feedback,
})
})
}
}
#[tokio::test]
async fn it_should_error_when_source_not_set() {
init_tracing();
info!("Starting source not set test");
let flow = Flow::<String, FlowError, String>::new();
let (_, shutdown_rx) = tokio::sync::broadcast::channel(1);
let result = flow.run_stream(shutdown_rx).await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Flow error: No source configured"
);
}
#[tokio::test]
async fn it_should_handle_source_stream_error() {
init_tracing();
info!("Starting source stream error test");
let flow = Flow::<String, FlowError, String>::new().source(StreamErrorSource);
let (_, shutdown_rx) = tokio::sync::broadcast::channel(1);
let result = flow.run_stream(shutdown_rx).await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Source error: Stream initialization error"
);
}
#[tokio::test]
async fn it_should_handle_processor_error() {
init_tracing();
info!("Starting processor error test");
let (feedback_tx, _) = bounded::<Result<String, FlowError>>(1);
let flow = Flow::<String, FlowError, String>::new()
.source(TestSource {
data: "test_input".to_string(),
feedback: feedback_tx,
})
.process(ErrorProcessor);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Process error: Processing failed"
);
}
#[tokio::test]
async fn it_should_handle_empty_source() {
init_tracing();
info!("Starting empty source test");
let flow = Flow::<String, FlowError, String>::new().source(EmptySource);
let (_, shutdown_rx) = tokio::sync::broadcast::channel(1);
let result = flow.run_stream(shutdown_rx).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 0);
}
#[tokio::test]
async fn it_should_create_flow_using_default() {
init_tracing();
info!("Starting flow using default test");
let flow = Flow::<String, FlowError, String>::default();
let (_, shutdown_rx) = tokio::sync::broadcast::channel(1);
let result = flow.run_stream(shutdown_rx).await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Flow error: No source configured"
);
}
#[tokio::test]
async fn it_should_send_error_feedback_for_source_errors() {
init_tracing();
info!("Starting send error feedback for source errors test");
let (feedback_tx, feedback_rx) = bounded::<Result<String, FlowError>>(1);
let feedback_results = Arc::new(Mutex::new(Vec::<Result<String, FlowError>>::new()));
let feedback_results_clone = feedback_results.clone();
tokio::spawn(async move {
while let Ok(result) = feedback_rx.recv_async().await {
let mut results = feedback_results_clone.lock().unwrap();
results.push(result);
}
});
let flow: Flow<String, FlowError, String> = Flow::new()
.source(ErrorSource {
feedback: feedback_tx,
})
.process(PassthroughProcessor);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Source error: Source error"
);
tokio::time::sleep(Duration::from_millis(50)).await;
let feedback_results = feedback_results.lock().unwrap();
assert_eq!(feedback_results.len(), 1);
assert!(matches!(
&feedback_results[0],
Err(e) if e.to_string() == "Source error: Source error"
));
drop(feedback_results);
}
#[tokio::test]
async fn it_should_handle_skipped_items() {
init_tracing();
info!("Starting handle skipped items test");
let (feedback_tx, feedback_rx) = bounded::<Result<String, FlowError>>(1);
let feedback_results = Arc::new(Mutex::new(Vec::<Result<String, FlowError>>::new()));
let feedback_results_clone = feedback_results.clone();
tokio::spawn(async move {
while let Ok(result) = feedback_rx.recv_async().await {
let mut results = feedback_results_clone.lock().unwrap();
results.push(result);
}
});
let flow: Flow<String, FlowError, String> = Flow::new()
.source(SkipSource {
feedback: feedback_tx,
})
.process(SkipProcessor);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_ok());
let results = result.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0], "skipped");
tokio::time::sleep(Duration::from_millis(50)).await;
let feedback_results = feedback_results.lock().unwrap();
assert_eq!(feedback_results.len(), 1);
assert!(matches!(
&feedback_results[0],
Ok(msg) if msg == "skipped"
));
drop(feedback_results);
}
#[tokio::test]
async fn it_should_set_source() {
init_tracing();
info!("Starting set source test");
let (feedback_tx, _) = bounded::<Result<String, FlowError>>(1);
let flow = Flow::<String, FlowError, String>::new();
let flow = flow.source(TestSource {
data: "test".to_string(),
feedback: feedback_tx,
});
let (_, shutdown_rx) = tokio::sync::broadcast::channel(1);
let result = flow.run_stream(shutdown_rx).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn it_should_handle_source_item_error() {
init_tracing();
info!("Starting handle source item error test");
let (feedback_tx, feedback_rx) = bounded::<Result<String, FlowError>>(1);
let feedback_results = Arc::new(Mutex::new(Vec::<Result<String, FlowError>>::new()));
let feedback_results_clone = feedback_results.clone();
tokio::spawn(async move {
while let Ok(result) = feedback_rx.recv_async().await {
let mut results = feedback_results_clone.lock().unwrap();
results.push(result);
}
});
let flow: Flow<String, FlowError, String> = Flow::new()
.source(ErrorSource {
feedback: feedback_tx,
})
.process(PassthroughProcessor);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_err());
let feedback_results = feedback_results.lock().unwrap();
assert_eq!(feedback_results.len(), 1);
assert!(matches!(
&feedback_results[0],
Err(e) if e.to_string() == "Source error: Source error"
));
drop(feedback_results);
}
#[tokio::test]
async fn it_should_return_empty_vec_when_no_items_processed() {
init_tracing();
info!("Starting return empty vec when no items processed test");
let flow: Flow<String, FlowError, String> = Flow::new().source(EmptySource);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[tokio::test]
async fn it_should_add_processor_to_stages() {
init_tracing();
info!("Starting add processor to stages test");
let (feedback_tx, _) = bounded::<Result<String, FlowError>>(1);
let flow = Flow::<String, FlowError, String>::new()
.source(TestSource {
data: "test".to_string(),
feedback: feedback_tx,
})
.process(PassthroughProcessor);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_ok());
let results = result.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0], "test"); }
#[tokio::test]
async fn it_should_preserve_source_after_setting() {
init_tracing();
info!("Starting preserve source after setting test");
let (feedback_tx, _) = bounded::<Result<String, FlowError>>(1);
let source_data = "test".to_string();
let flow = Flow::<String, FlowError, String>::new().source(TestSource {
data: source_data.clone(),
feedback: feedback_tx,
});
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_ok());
let results = result.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0], source_data); }
#[tokio::test]
async fn it_should_handle_source_item_with_feedback() {
init_tracing();
info!("Starting handle source item with feedback test");
let (feedback_tx, feedback_rx) = bounded::<Result<String, FlowError>>(1);
let feedback_results = Arc::new(Mutex::new(Vec::<Result<String, FlowError>>::new()));
let feedback_results_clone = feedback_results.clone();
tokio::spawn(async move {
while let Ok(result) = feedback_rx.recv_async().await {
let mut results = feedback_results_clone.lock().unwrap();
results.push(result);
}
});
let test_data = "test_data".to_string();
let flow: Flow<String, FlowError, String> = Flow::new()
.source(TestSource {
data: test_data.clone(),
feedback: feedback_tx, })
.process(PassthroughProcessor);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_ok());
let results = result.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0], test_data);
tokio::time::sleep(Duration::from_millis(50)).await;
let feedback_results = feedback_results.lock().unwrap();
assert_eq!(feedback_results.len(), 1);
assert!(matches!(
&feedback_results[0],
Ok(data) if data == &test_data
));
drop(feedback_results);
}
#[tokio::test]
async fn it_should_process_multiple_items() {
init_tracing();
info!("Starting process multiple items test");
let (tx, rx) = bounded(2);
let (feedback_tx, _) = bounded(2);
tx.send(Ok("item1".to_string())).unwrap();
tx.send(Ok("item2".to_string())).unwrap();
drop(tx);
let flow: Flow<String, FlowError, String> = Flow::new()
.source(MultiSource {
rx,
feedback: feedback_tx,
})
.process(PassthroughProcessor);
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_ok());
let results = result.unwrap();
assert_eq!(results.len(), 2); assert_eq!(results, vec!["item1".to_string(), "item2".to_string()]);
}
#[tokio::test]
async fn it_should_show_tracing_metrics() {
let subscriber = tracing_subscriber::fmt()
.with_env_filter("cortex_ai=info,test=info")
.with_thread_ids(true)
.with_thread_names(true)
.with_file(true)
.with_line_number(true)
.with_target(true)
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL) .try_init();
if subscriber.is_err() {
println!("Warning: tracing already initialized");
}
let (feedback_tx, _) = bounded::<Result<String, FlowError>>(1);
let test_condition = TestCondition; let flow = Flow::new()
.source(TestSource {
data: "test_data".to_string(),
feedback: feedback_tx,
})
.process(PassthroughProcessor)
.when(test_condition) .process(PassthroughProcessor)
.otherwise()
.process(PassthroughProcessor)
.end();
let result = run_flow_with_timeout(flow, Duration::from_millis(100)).await;
assert!(result.is_ok());
}
}