pub mod config;
pub mod intern;
pub mod metrics;
pub mod parse;
pub mod pre_route;
pub mod types;
pub use config::{BatchProcessingConfig, ParseErrorAction, PreRouteFilterConfig};
pub use intern::FieldInterner;
pub use types::{MessageMetadata, ParsedMessage, PreRouteResult, RawMessage};
#[cfg(feature = "transport")]
#[derive(Debug, thiserror::Error)]
pub enum EngineError {
#[error("transport error: {0}")]
Transport(#[from] crate::TransportError),
#[error("sink error: {0}")]
Sink(String),
#[error("shutdown")]
Shutdown,
}
use std::sync::Arc;
use rayon::prelude::*;
use super::pool::AdaptiveWorkerPool;
use super::stats::PipelineStats;
use self::pre_route::{PreRouteOutcome, apply_filters, extract_routing_field, filters_from_config};
use self::types::PayloadFormat;
use super::config::WorkerPoolConfig;
pub struct BatchEngine {
config: BatchProcessingConfig,
pool: Arc<AdaptiveWorkerPool>,
stats: Arc<PipelineStats>,
interner: Arc<FieldInterner>,
filters: Vec<pre_route::PreRouteFilter>,
#[cfg(feature = "memory")]
memory_guard: Option<Arc<crate::memory::MemoryGuard>>,
}
impl BatchEngine {
#[must_use]
pub fn new(config: BatchProcessingConfig) -> Self {
let pool = Arc::new(AdaptiveWorkerPool::new(WorkerPoolConfig::default()));
Self::with_pool(pool, config)
}
#[must_use]
pub fn with_pool(pool: Arc<AdaptiveWorkerPool>, config: BatchProcessingConfig) -> Self {
let known_refs: Vec<&str> = config.known_fields.iter().map(String::as_str).collect();
let interner = Arc::new(FieldInterner::with_known_fields(&known_refs));
let filters = filters_from_config(&config.pre_route_filters);
Self {
config,
pool,
stats: Arc::new(PipelineStats::new()),
interner,
filters,
#[cfg(feature = "memory")]
memory_guard: None,
}
}
pub fn from_cascade(key: &str) -> Result<Self, crate::config::ConfigError> {
let config = BatchProcessingConfig::from_cascade(key)?;
Ok(Self::new(config))
}
#[must_use]
pub fn stats(&self) -> &Arc<PipelineStats> {
&self.stats
}
#[must_use]
pub fn pool(&self) -> &Arc<AdaptiveWorkerPool> {
&self.pool
}
#[must_use]
pub fn config(&self) -> &BatchProcessingConfig {
&self.config
}
pub fn auto_wire(
&mut self,
metrics_manager: &crate::metrics::MetricsManager,
#[cfg(feature = "memory")] memory_guard: Option<&Arc<crate::memory::MemoryGuard>>,
) {
metrics::register(metrics_manager, &self.config);
#[cfg(feature = "memory")]
if let Some(guard) = memory_guard {
self.memory_guard = Some(Arc::clone(guard));
}
}
pub fn process_mid_tier<O, E, F>(
&self,
messages: &[RawMessage],
transform: F,
) -> Vec<Result<O, E>>
where
O: Send,
E: Send + From<String>,
F: Fn(&mut ParsedMessage) -> Result<O, E> + Sync,
{
if messages.is_empty() {
return Vec::new();
}
let chunk_size = if self.config.max_chunk_size == 0 {
messages.len()
} else {
self.config.max_chunk_size
};
let has_routing = self.config.routing_field.is_some();
let mut all_results = Vec::with_capacity(messages.len());
for chunk in messages.chunks(chunk_size) {
self.stats.add_received(chunk.len() as u64);
let chunk_bytes: u64 = chunk.iter().map(|m| m.payload.len() as u64).sum();
self.stats.add_bytes_received(chunk_bytes);
let mut parsed_msgs: Vec<Result<ParsedMessage, String>> =
Vec::with_capacity(chunk.len());
for msg in chunk {
if has_routing {
let field_name = self.config.routing_field.as_ref().expect("checked above");
let extraction = extract_routing_field(&msg.payload, field_name);
let outcome = apply_filters(&extraction, &self.filters);
match outcome {
PreRouteOutcome::Continue => {}
PreRouteOutcome::Filtered => {
self.stats.incr_filtered();
continue; }
PreRouteOutcome::Dlq(reason) => {
self.stats.incr_dlq();
self.stats.incr_errors();
parsed_msgs.push(Err(reason));
continue;
}
}
}
let format = match msg.metadata.format {
PayloadFormat::Auto => PayloadFormat::detect(&msg.payload),
other => other,
};
match parse::parse_payload(&msg.payload, format) {
Ok(value) => {
let extracted = self.interner.extract_known(&value);
parsed_msgs.push(Ok(ParsedMessage::Parsed {
value,
raw: msg.payload.clone(),
format,
key: msg.key.clone(),
headers: msg.headers.clone(),
metadata: msg.metadata.clone(),
extracted,
}));
}
Err(e) => {
self.stats.incr_errors();
match self.config.parse_error_action {
ParseErrorAction::Dlq => {
self.stats.incr_dlq();
parsed_msgs.push(Err(format!("parse error: {e}")));
}
ParseErrorAction::Skip => {
}
ParseErrorAction::FailBatch => {
parsed_msgs.push(Err(format!("parse error (fail_batch): {e}")));
let results: Vec<Result<O, E>> = parsed_msgs
.into_iter()
.map(|r| match r {
Ok(_) => Err(E::from(
"batch failed due to parse error".to_string(),
)),
Err(reason) => Err(E::from(reason)),
})
.collect();
all_results.extend(results);
return all_results;
}
}
}
}
}
let mut indexed: Vec<(usize, Result<ParsedMessage, String>)> =
parsed_msgs.into_iter().enumerate().collect();
let mut chunk_results: Vec<(usize, Result<O, E>)> = Vec::with_capacity(indexed.len());
let mut to_transform: Vec<(usize, ParsedMessage)> = Vec::with_capacity(indexed.len());
for (idx, item) in indexed.drain(..) {
match item {
Ok(pm) => to_transform.push((idx, pm)),
Err(reason) => chunk_results.push((idx, Err(E::from(reason)))),
}
}
let transformed: Vec<(usize, Result<O, E>)> = self.pool.install(|| {
to_transform
.into_par_iter()
.map(|(idx, mut pm)| {
let result = transform(&mut pm);
(idx, result)
})
.collect()
});
chunk_results.extend(transformed);
chunk_results.sort_by_key(|(idx, _)| *idx);
let ok_count = chunk_results.iter().filter(|(_, r)| r.is_ok()).count();
self.stats.add_processed(ok_count as u64);
all_results.extend(chunk_results.into_iter().map(|(_, r)| r));
self.check_memory_pressure();
}
all_results
}
pub fn process_raw<O, E, F>(&self, messages: &[RawMessage], transform: F) -> Vec<Result<O, E>>
where
O: Send,
E: Send + From<String>,
F: Fn(&RawMessage) -> Result<O, E> + Sync,
{
if messages.is_empty() {
return Vec::new();
}
let chunk_size = if self.config.max_chunk_size == 0 {
messages.len()
} else {
self.config.max_chunk_size
};
let has_routing = self.config.routing_field.is_some();
let mut all_results = Vec::with_capacity(messages.len());
for chunk in messages.chunks(chunk_size) {
self.stats.add_received(chunk.len() as u64);
let chunk_bytes: u64 = chunk.iter().map(|m| m.payload.len() as u64).sum();
self.stats.add_bytes_received(chunk_bytes);
let to_process: Vec<&RawMessage> = if has_routing {
let field_name = self.config.routing_field.as_ref().expect("checked above");
let mut passed = Vec::with_capacity(chunk.len());
for msg in chunk {
let extraction = extract_routing_field(&msg.payload, field_name);
let outcome = apply_filters(&extraction, &self.filters);
match outcome {
PreRouteOutcome::Continue => passed.push(msg),
PreRouteOutcome::Filtered => {
self.stats.incr_filtered();
}
PreRouteOutcome::Dlq(reason) => {
self.stats.incr_dlq();
self.stats.incr_errors();
all_results.push(Err(E::from(reason)));
}
}
}
passed
} else {
chunk.iter().collect()
};
let results = self.pool.process_batch(&to_process, |msg| transform(msg));
let ok_count = results.iter().filter(|r| r.is_ok()).count();
self.stats.add_processed(ok_count as u64);
all_results.extend(results);
self.check_memory_pressure();
}
all_results
}
#[cfg(feature = "transport")]
pub async fn run<R, O, E, Transform, Sink>(
&self,
receiver: &R,
shutdown: tokio_util::sync::CancellationToken,
transform: Transform,
mut sink: Sink,
) -> Result<(), EngineError>
where
R: crate::transport::TransportReceiver,
O: Send + 'static,
E: Send + From<String> + std::fmt::Display + 'static,
Transform: Fn(&mut ParsedMessage) -> Result<O, E> + Sync,
Sink: FnMut(Vec<Result<O, E>>) -> Result<(), EngineError>,
{
tracing::info!(
chunk_size = self.config.max_chunk_size,
routing_field = ?self.config.routing_field,
"BatchEngine starting"
);
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => {
tracing::info!("BatchEngine shutting down");
return Ok(());
}
recv_result = receiver.recv(self.config.max_chunk_size) => {
let messages = recv_result.map_err(EngineError::Transport)?;
if messages.is_empty() {
continue;
}
let tokens: Vec<R::Token> = messages.iter()
.map(|m| m.token.clone())
.collect();
let raw: Vec<RawMessage> = messages.into_iter()
.map(RawMessage::from)
.collect();
let results = self.process_mid_tier(&raw, &transform);
if let Err(e) = sink(results) {
tracing::error!(error = %e, "Sink failed, skipping commit");
continue;
}
if let Err(e) = receiver.commit(&tokens).await {
tracing::error!(error = %e, "Commit failed");
}
}
}
}
}
#[cfg(feature = "transport")]
pub async fn run_raw<R, O, E, Transform, Sink>(
&self,
receiver: &R,
shutdown: tokio_util::sync::CancellationToken,
transform: Transform,
mut sink: Sink,
) -> Result<(), EngineError>
where
R: crate::transport::TransportReceiver,
O: Send + 'static,
E: Send + From<String> + std::fmt::Display + 'static,
Transform: Fn(&RawMessage) -> Result<O, E> + Sync,
Sink: FnMut(Vec<Result<O, E>>) -> Result<(), EngineError>,
{
tracing::info!(
chunk_size = self.config.max_chunk_size,
"BatchEngine (raw) starting"
);
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => {
tracing::info!("BatchEngine (raw) shutting down");
return Ok(());
}
recv_result = receiver.recv(self.config.max_chunk_size) => {
let messages = recv_result.map_err(EngineError::Transport)?;
if messages.is_empty() {
continue;
}
let tokens: Vec<R::Token> = messages.iter()
.map(|m| m.token.clone())
.collect();
let raw: Vec<RawMessage> = messages.into_iter()
.map(RawMessage::from)
.collect();
let results = self.process_raw(&raw, &transform);
if let Err(e) = sink(results) {
tracing::error!(error = %e, "Sink failed (raw), skipping commit");
continue;
}
if let Err(e) = receiver.commit(&tokens).await {
tracing::error!(error = %e, "Commit failed (raw)");
}
}
}
}
}
#[cfg(feature = "transport")]
pub async fn run_async<R, O, E, Transform, Sink, SinkFut, Ticker, TickerFut>(
&self,
receiver: &R,
shutdown: tokio_util::sync::CancellationToken,
transform: Transform,
mut sink: Sink,
ticker: Option<(std::time::Duration, Ticker)>,
) -> Result<(), EngineError>
where
R: crate::transport::TransportReceiver,
O: Send + 'static,
E: Send + From<String> + std::fmt::Display + 'static,
Transform: Fn(&mut ParsedMessage) -> Result<O, E> + Sync,
Sink: FnMut(Vec<Result<O, E>>, Vec<R::Token>) -> SinkFut,
SinkFut: std::future::Future<Output = Result<(), EngineError>>,
Ticker: FnMut() -> TickerFut,
TickerFut: std::future::Future<Output = Result<(), EngineError>>,
{
tracing::info!(
chunk_size = self.config.max_chunk_size,
routing_field = ?self.config.routing_field,
ticker = ticker.is_some(),
"BatchEngine (async) starting"
);
let mut tick_interval = ticker.as_ref().map(|(d, _)| tokio::time::interval(*d));
let mut ticker_fn = ticker.map(|(_, f)| f);
if let Some(ref mut interval) = tick_interval {
interval.tick().await;
}
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => {
tracing::info!("BatchEngine (async) shutting down");
return Ok(());
}
_ = async {
match tick_interval.as_mut() {
Some(interval) => interval.tick().await,
None => std::future::pending().await,
}
} => {
if let Some(ref mut f) = ticker_fn
&& let Err(e) = f().await
{
tracing::error!(error = %e, "Ticker failed");
}
}
recv_result = receiver.recv(self.config.max_chunk_size) => {
let messages = recv_result.map_err(EngineError::Transport)?;
if messages.is_empty() {
continue;
}
let tokens: Vec<R::Token> = messages.iter()
.map(|m| m.token.clone())
.collect();
let raw: Vec<RawMessage> = messages.into_iter()
.map(RawMessage::from)
.collect();
let results = self.process_mid_tier(&raw, &transform);
if let Err(e) = sink(results, tokens).await {
tracing::error!(error = %e, "Sink failed (async)");
}
}
}
}
}
#[cfg(feature = "transport")]
pub async fn run_raw_async<R, O, E, Transform, Sink, SinkFut, Ticker, TickerFut>(
&self,
receiver: &R,
shutdown: tokio_util::sync::CancellationToken,
transform: Transform,
mut sink: Sink,
ticker: Option<(std::time::Duration, Ticker)>,
) -> Result<(), EngineError>
where
R: crate::transport::TransportReceiver,
O: Send + 'static,
E: Send + From<String> + std::fmt::Display + 'static,
Transform: Fn(&RawMessage) -> Result<O, E> + Sync,
Sink: FnMut(Vec<Result<O, E>>, Vec<R::Token>) -> SinkFut,
SinkFut: std::future::Future<Output = Result<(), EngineError>>,
Ticker: FnMut() -> TickerFut,
TickerFut: std::future::Future<Output = Result<(), EngineError>>,
{
tracing::info!(
chunk_size = self.config.max_chunk_size,
ticker = ticker.is_some(),
"BatchEngine (raw async) starting"
);
let mut tick_interval = ticker.as_ref().map(|(d, _)| tokio::time::interval(*d));
let mut ticker_fn = ticker.map(|(_, f)| f);
if let Some(ref mut interval) = tick_interval {
interval.tick().await;
}
loop {
tokio::select! {
biased;
() = shutdown.cancelled() => {
tracing::info!("BatchEngine (raw async) shutting down");
return Ok(());
}
_ = async {
match tick_interval.as_mut() {
Some(interval) => interval.tick().await,
None => std::future::pending().await,
}
} => {
if let Some(ref mut f) = ticker_fn
&& let Err(e) = f().await
{
tracing::error!(error = %e, "Ticker (raw) failed");
}
}
recv_result = receiver.recv(self.config.max_chunk_size) => {
let messages = recv_result.map_err(EngineError::Transport)?;
if messages.is_empty() {
continue;
}
let tokens: Vec<R::Token> = messages.iter()
.map(|m| m.token.clone())
.collect();
let raw: Vec<RawMessage> = messages.into_iter()
.map(RawMessage::from)
.collect();
let results = self.process_raw(&raw, &transform);
if let Err(e) = sink(results, tokens).await {
tracing::error!(error = %e, "Sink failed (raw async)");
}
}
}
}
}
#[allow(clippy::unused_self)]
fn check_memory_pressure(&self) {
#[cfg(feature = "memory")]
if let Some(guard) = &self.memory_guard
&& guard.under_pressure()
{
tracing::warn!(
pause_ms = self.config.memory_pressure_pause_ms,
"BatchEngine: memory pressure detected, pausing between chunks"
);
std::thread::sleep(std::time::Duration::from_millis(
self.config.memory_pressure_pause_ms,
));
}
}
}
impl std::fmt::Debug for BatchEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("BatchEngine");
s.field("config", &self.config)
.field("pool_max_threads", &self.pool.max_threads())
.field("stats", &self.stats.snapshot())
.field("interner_len", &self.interner.len())
.field("filters", &self.filters);
#[cfg(feature = "memory")]
s.field("memory_guard", &self.memory_guard.is_some());
s.finish()
}
}
#[cfg(test)]
mod engine_tests {
use super::*;
use bytes::Bytes;
fn make_json_messages(n: usize) -> Vec<RawMessage> {
(0..n)
.map(|i| RawMessage {
payload: Bytes::from(format!(r#"{{"_table":"events","id":{i}}}"#)),
key: None,
headers: vec![],
metadata: MessageMetadata {
timestamp_ms: None,
format: types::PayloadFormat::Json,
commit_token: None,
},
})
.collect()
}
fn default_engine() -> BatchEngine {
BatchEngine::new(BatchProcessingConfig::default())
}
#[test]
fn process_mid_tier_basic() {
let engine = default_engine();
let msgs = make_json_messages(100);
let results: Vec<Result<String, String>> = engine.process_mid_tier(&msgs, |pm| {
Ok(pm
.field("_table")
.and_then(|v| sonic_rs::JsonValueTrait::as_str(v))
.unwrap_or("unknown")
.to_string())
});
assert_eq!(results.len(), 100);
assert!(results.iter().all(|r| r.is_ok()));
assert_eq!(results[0].as_ref().unwrap(), "events");
}
#[test]
fn process_mid_tier_parse_error() {
let engine = default_engine();
let mut msgs = make_json_messages(2);
msgs.insert(
1,
RawMessage {
payload: Bytes::from_static(b"not json {{{"),
key: None,
headers: vec![],
metadata: MessageMetadata {
timestamp_ms: None,
format: types::PayloadFormat::Json,
commit_token: None,
},
},
);
let results: Vec<Result<String, String>> =
engine.process_mid_tier(&msgs, |pm| Ok(pm.raw_payload().len().to_string()));
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_err());
assert!(results[1].as_ref().unwrap_err().contains("parse error"));
assert!(results[2].is_ok());
}
#[test]
fn process_mid_tier_empty_batch() {
let engine = default_engine();
let results: Vec<Result<(), String>> = engine.process_mid_tier(&[], |_| Ok(()));
assert!(results.is_empty());
}
#[test]
fn process_mid_tier_respects_chunk_size() {
let config = BatchProcessingConfig {
max_chunk_size: 50,
..Default::default()
};
let engine = BatchEngine::new(config);
let msgs = make_json_messages(120);
let results: Vec<Result<usize, String>> =
engine.process_mid_tier(&msgs, |pm| Ok(pm.raw_payload().len()));
assert_eq!(results.len(), 120);
assert!(results.iter().all(|r| r.is_ok()));
let snap = engine.stats().snapshot();
assert_eq!(snap.received, 120);
}
#[test]
fn stats_updated_after_processing() {
let engine = default_engine();
let msgs = make_json_messages(10);
let _results: Vec<Result<(), String>> = engine.process_mid_tier(&msgs, |_| Ok(()));
let snap = engine.stats().snapshot();
assert_eq!(snap.received, 10);
assert_eq!(snap.processed, 10);
assert_eq!(snap.errors, 0);
assert_eq!(snap.filtered, 0);
}
#[test]
fn process_raw_passthrough() {
let engine = default_engine();
let msgs = make_json_messages(50);
let results: Vec<Result<usize, String>> =
engine.process_raw(&msgs, |msg| Ok(msg.payload.len()));
assert_eq!(results.len(), 50);
assert!(results.iter().all(|r| r.is_ok()));
assert!(results[0].as_ref().unwrap() > &0);
let snap = engine.stats().snapshot();
assert_eq!(snap.received, 50);
assert_eq!(snap.processed, 50);
}
#[test]
fn process_mid_tier_with_pre_route() {
let config = BatchProcessingConfig {
routing_field: Some("_table".to_string()),
pre_route_filters: vec![config::PreRouteFilterConfig::DlqFieldValue {
field: "_table".to_string(),
value: "poison".to_string(),
}],
..Default::default()
};
let engine = BatchEngine::new(config);
let mut msgs = make_json_messages(3);
msgs[1] = RawMessage {
payload: Bytes::from(r#"{"_table":"poison","id":999}"#),
key: None,
headers: vec![],
metadata: MessageMetadata {
timestamp_ms: None,
format: types::PayloadFormat::Json,
commit_token: None,
},
};
let results: Vec<Result<String, String>> = engine.process_mid_tier(&msgs, |pm| {
Ok(pm
.field("_table")
.and_then(|v| sonic_rs::JsonValueTrait::as_str(v))
.unwrap_or("?")
.to_string())
});
assert_eq!(results.len(), 3);
assert!(results[0].is_ok());
assert!(results[1].is_err());
assert!(results[1].as_ref().unwrap_err().contains("DLQ"));
assert!(results[2].is_ok());
let snap = engine.stats().snapshot();
assert_eq!(snap.dlq, 1);
assert_eq!(snap.errors, 1);
}
#[test]
fn process_mid_tier_filtered_not_in_results() {
let config = BatchProcessingConfig {
routing_field: Some("_table".to_string()),
pre_route_filters: vec![config::PreRouteFilterConfig::DropFieldMissing {
field: "_table".to_string(),
}],
..Default::default()
};
let engine = BatchEngine::new(config);
let mut msgs = make_json_messages(3);
msgs[1] = RawMessage {
payload: Bytes::from(r#"{"host":"web1"}"#),
key: None,
headers: vec![],
metadata: MessageMetadata {
timestamp_ms: None,
format: types::PayloadFormat::Json,
commit_token: None,
},
};
let results: Vec<Result<String, String>> =
engine.process_mid_tier(&msgs, |_pm| Ok("ok".to_string()));
assert_eq!(results.len(), 2);
assert!(results.iter().all(|r| r.is_ok()));
let snap = engine.stats().snapshot();
assert_eq!(snap.filtered, 1);
assert_eq!(snap.received, 3);
}
#[test]
fn from_cascade_creates_engine() {
let engine = BatchEngine::from_cascade("batch_processing").unwrap();
assert_eq!(engine.config().max_chunk_size, 10_000);
}
#[test]
fn accessors_return_expected_types() {
let engine = default_engine();
let _stats = engine.stats();
let _pool = engine.pool();
let _config = engine.config();
assert_eq!(engine.stats().snapshot().received, 0);
}
#[test]
fn auto_wire_does_not_panic() {
let mut engine = default_engine();
let mgr = crate::metrics::MetricsManager::new_for_test("test_auto_wire");
engine.auto_wire(
&mgr,
#[cfg(feature = "memory")]
None,
);
let msgs = make_json_messages(5);
let results: Vec<Result<(), String>> = engine.process_mid_tier(&msgs, |_| Ok(()));
assert_eq!(results.len(), 5);
}
#[test]
fn debug_impl_works() {
let engine = default_engine();
let debug = format!("{engine:?}");
assert!(debug.contains("BatchEngine"));
assert!(debug.contains("config"));
}
#[cfg(feature = "transport-memory")]
mod async_engine_tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
fn json_payload(table: &str, id: usize) -> Vec<u8> {
format!(r#"{{"_table":"{table}","id":{id}}}"#).into_bytes()
}
#[tokio::test]
async fn run_async_processes_and_passes_tokens_to_sink() {
let config = crate::transport::memory::MemoryConfig {
recv_timeout_ms: 50,
..Default::default()
};
let transport = crate::transport::memory::MemoryTransport::new(&config);
for i in 0..5 {
transport
.inject(None, json_payload("events", i))
.await
.unwrap();
}
let engine = default_engine();
let shutdown = tokio_util::sync::CancellationToken::new();
let shutdown_clone = shutdown.clone();
let sink_count = Arc::new(AtomicU64::new(0));
let token_count = Arc::new(AtomicU64::new(0));
let sink_count_clone = Arc::clone(&sink_count);
let token_count_clone = Arc::clone(&token_count);
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
shutdown_clone.cancel();
});
let result = engine
.run_async(
&transport,
shutdown,
|pm: &mut ParsedMessage| -> Result<String, String> {
Ok(pm
.field("_table")
.and_then(|v| sonic_rs::JsonValueTrait::as_str(v))
.unwrap_or("?")
.to_string())
},
|results, tokens| {
let sc = Arc::clone(&sink_count_clone);
let tc = Arc::clone(&token_count_clone);
async move {
sc.fetch_add(results.len() as u64, Ordering::Relaxed);
tc.fetch_add(tokens.len() as u64, Ordering::Relaxed);
Ok(())
}
},
None::<(
std::time::Duration,
fn() -> std::future::Ready<Result<(), EngineError>>,
)>,
)
.await;
assert!(result.is_ok());
assert_eq!(sink_count.load(Ordering::Relaxed), 5);
assert_eq!(token_count.load(Ordering::Relaxed), 5);
}
#[tokio::test]
async fn run_async_ticker_fires() {
let config = crate::transport::memory::MemoryConfig {
recv_timeout_ms: 50,
..Default::default()
};
let transport = crate::transport::memory::MemoryTransport::new(&config);
let engine = default_engine();
let shutdown = tokio_util::sync::CancellationToken::new();
let shutdown_clone = shutdown.clone();
let tick_count = Arc::new(AtomicU64::new(0));
let tick_count_clone = Arc::clone(&tick_count);
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(350)).await;
shutdown_clone.cancel();
});
let result = engine
.run_async(
&transport,
shutdown,
|_pm: &mut ParsedMessage| -> Result<(), String> { Ok(()) },
|_results, _tokens| async { Ok(()) },
Some((std::time::Duration::from_millis(100), move || {
let tc = Arc::clone(&tick_count_clone);
async move {
tc.fetch_add(1, Ordering::Relaxed);
Ok(())
}
})),
)
.await;
assert!(result.is_ok());
let ticks = tick_count.load(Ordering::Relaxed);
assert!(ticks >= 2, "Expected at least 2 ticks, got {ticks}");
}
#[tokio::test]
async fn run_raw_async_processes_without_parse() {
let config = crate::transport::memory::MemoryConfig {
recv_timeout_ms: 50,
..Default::default()
};
let transport = crate::transport::memory::MemoryTransport::new(&config);
for i in 0..3 {
transport
.inject(None, json_payload("logs", i))
.await
.unwrap();
}
let engine = default_engine();
let shutdown = tokio_util::sync::CancellationToken::new();
let shutdown_clone = shutdown.clone();
let total_bytes = Arc::new(AtomicU64::new(0));
let total_bytes_clone = Arc::clone(&total_bytes);
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
shutdown_clone.cancel();
});
let result = engine
.run_raw_async(
&transport,
shutdown,
|msg: &RawMessage| -> Result<usize, String> { Ok(msg.payload.len()) },
|results, _tokens| {
let tb = Arc::clone(&total_bytes_clone);
async move {
for len in results.iter().flatten() {
tb.fetch_add(*len as u64, Ordering::Relaxed);
}
Ok(())
}
},
None::<(
std::time::Duration,
fn() -> std::future::Ready<Result<(), EngineError>>,
)>,
)
.await;
assert!(result.is_ok());
assert!(total_bytes.load(Ordering::Relaxed) > 0);
}
#[tokio::test]
async fn run_async_sink_error_does_not_crash() {
let config = crate::transport::memory::MemoryConfig {
recv_timeout_ms: 50,
..Default::default()
};
let transport = crate::transport::memory::MemoryTransport::new(&config);
transport
.inject(None, json_payload("events", 0))
.await
.unwrap();
let engine = default_engine();
let shutdown = tokio_util::sync::CancellationToken::new();
let shutdown_clone = shutdown.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
shutdown_clone.cancel();
});
let result = engine
.run_async(
&transport,
shutdown,
|_pm: &mut ParsedMessage| -> Result<(), String> { Ok(()) },
|_results, _tokens| async { Err(EngineError::Sink("test sink error".into())) },
None::<(
std::time::Duration,
fn() -> std::future::Ready<Result<(), EngineError>>,
)>,
)
.await;
assert!(result.is_ok());
}
}
}