use std::sync::Arc;
use crate::body::Body;
use crate::exchange::Exchange;
use crate::message::Message;
pub type SplitExpression = Arc<dyn Fn(&Exchange) -> Vec<Exchange> + Send + Sync>;
#[derive(Clone, Default)]
pub enum AggregationStrategy {
#[default]
LastWins,
CollectAll,
Original,
Custom(Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>),
}
pub struct SplitterConfig {
pub expression: SplitExpression,
pub aggregation: AggregationStrategy,
pub parallel: bool,
pub parallel_limit: Option<usize>,
pub stop_on_exception: bool,
}
impl SplitterConfig {
pub fn new(expression: SplitExpression) -> Self {
Self {
expression,
aggregation: AggregationStrategy::default(),
parallel: false,
parallel_limit: None,
stop_on_exception: true,
}
}
pub fn aggregation(mut self, strategy: AggregationStrategy) -> Self {
self.aggregation = strategy;
self
}
pub fn parallel(mut self, parallel: bool) -> Self {
self.parallel = parallel;
self
}
pub fn parallel_limit(mut self, limit: usize) -> Self {
self.parallel_limit = Some(limit);
self
}
pub fn stop_on_exception(mut self, stop: bool) -> Self {
self.stop_on_exception = stop;
self
}
}
fn fragment_exchange(parent: &Exchange, body: Body) -> Exchange {
let mut msg = Message::new(body);
msg.headers = parent.input.headers.clone();
let mut ex = Exchange::new(msg);
ex.properties = parent.properties.clone();
ex.pattern = parent.pattern;
ex.otel_context = parent.otel_context.clone();
ex
}
pub fn split_body_lines() -> SplitExpression {
Arc::new(|exchange: &Exchange| {
let text = match &exchange.input.body {
Body::Text(s) => s.as_str(),
_ => return Vec::new(),
};
text.lines()
.map(|line| fragment_exchange(exchange, Body::Text(line.to_string())))
.collect()
})
}
pub fn split_body_json_array() -> SplitExpression {
Arc::new(|exchange: &Exchange| {
let arr = match &exchange.input.body {
Body::Json(serde_json::Value::Array(arr)) => arr,
_ => return Vec::new(),
};
arr.iter()
.map(|val| fragment_exchange(exchange, Body::Json(val.clone())))
.collect()
})
}
pub fn split_body<F>(f: F) -> SplitExpression
where
F: Fn(&Body) -> Vec<Body> + Send + Sync + 'static,
{
Arc::new(move |exchange: &Exchange| {
f(&exchange.input.body)
.into_iter()
.map(|body| fragment_exchange(exchange, body))
.collect()
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::value::Value;
#[test]
fn test_split_body_lines() {
let mut ex = Exchange::new(Message::new("a\nb\nc"));
ex.input.set_header("source", Value::String("test".into()));
ex.set_property("trace", Value::Bool(true));
let fragments = split_body_lines()(&ex);
assert_eq!(fragments.len(), 3);
assert_eq!(fragments[0].input.body.as_text(), Some("a"));
assert_eq!(fragments[1].input.body.as_text(), Some("b"));
assert_eq!(fragments[2].input.body.as_text(), Some("c"));
for frag in &fragments {
assert_eq!(
frag.input.header("source"),
Some(&Value::String("test".into()))
);
assert_eq!(frag.property("trace"), Some(&Value::Bool(true)));
}
}
#[test]
fn test_split_body_lines_empty() {
let ex = Exchange::new(Message::default()); let fragments = split_body_lines()(&ex);
assert!(fragments.is_empty());
}
#[test]
fn test_split_body_json_array() {
let arr = serde_json::json!([1, 2, 3]);
let ex = Exchange::new(Message::new(arr));
let fragments = split_body_json_array()(&ex);
assert_eq!(fragments.len(), 3);
assert!(matches!(&fragments[0].input.body, Body::Json(v) if *v == serde_json::json!(1)));
assert!(matches!(&fragments[1].input.body, Body::Json(v) if *v == serde_json::json!(2)));
assert!(matches!(&fragments[2].input.body, Body::Json(v) if *v == serde_json::json!(3)));
}
#[test]
fn test_split_body_json_array_not_array() {
let obj = serde_json::json!({"not": "array"});
let ex = Exchange::new(Message::new(obj));
let fragments = split_body_json_array()(&ex);
assert!(fragments.is_empty());
}
#[test]
fn test_split_body_custom() {
let splitter = split_body(|body: &Body| match body {
Body::Text(s) => s
.split(',')
.map(|part| Body::Text(part.trim().to_string()))
.collect(),
_ => Vec::new(),
});
let mut ex = Exchange::new(Message::new("x, y, z"));
ex.set_property("id", Value::from(42));
let fragments = splitter(&ex);
assert_eq!(fragments.len(), 3);
assert_eq!(fragments[0].input.body.as_text(), Some("x"));
assert_eq!(fragments[1].input.body.as_text(), Some("y"));
assert_eq!(fragments[2].input.body.as_text(), Some("z"));
for frag in &fragments {
assert_eq!(frag.property("id"), Some(&Value::from(42)));
}
}
#[test]
fn test_splitter_config_defaults() {
let config = SplitterConfig::new(split_body_lines());
assert!(matches!(config.aggregation, AggregationStrategy::LastWins));
assert!(!config.parallel);
assert!(config.parallel_limit.is_none());
assert!(config.stop_on_exception);
}
#[test]
fn test_splitter_config_builder() {
let config = SplitterConfig::new(split_body_lines())
.aggregation(AggregationStrategy::CollectAll)
.parallel(true)
.parallel_limit(4)
.stop_on_exception(false);
assert!(matches!(
config.aggregation,
AggregationStrategy::CollectAll
));
assert!(config.parallel);
assert_eq!(config.parallel_limit, Some(4));
assert!(!config.stop_on_exception);
}
#[test]
fn test_fragment_exchange_inherits_otel_context() {
use opentelemetry::Context;
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
let mut parent = Exchange::new(Message::new("test"));
let trace_id = TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123]);
let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 1, 200]);
let span_context = SpanContext::new(
trace_id,
span_id,
TraceFlags::SAMPLED,
true,
Default::default(),
);
let expected_trace_id = span_context.trace_id();
parent.otel_context = Context::current().with_remote_span_context(span_context);
let fragments = split_body_lines()(&parent);
assert!(!fragments.is_empty(), "Should have at least one fragment");
for fragment in &fragments {
let span = fragment.otel_context.span();
let frag_span_ctx = span.span_context();
assert!(
frag_span_ctx.is_valid(),
"Fragment should have valid span context"
);
assert_eq!(
frag_span_ctx.trace_id(),
expected_trace_id,
"Fragment should have same trace ID as parent"
);
}
}
#[test]
fn test_all_fragments_share_same_trace_context() {
use opentelemetry::Context;
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
let mut parent = Exchange::new(Message::new("line1\nline2\nline3"));
let trace_id =
TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x3B, 0x9A, 0xCA, 0x09]);
let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 0, 111]);
let span_context = SpanContext::new(
trace_id,
span_id,
TraceFlags::SAMPLED,
true,
Default::default(),
);
parent.otel_context = Context::current().with_remote_span_context(span_context);
let fragments = split_body_lines()(&parent);
assert_eq!(fragments.len(), 3);
let trace_ids: Vec<_> = fragments
.iter()
.map(|f| {
let span = f.otel_context.span();
span.span_context().trace_id()
})
.collect();
assert!(
trace_ids.iter().all(|&id| id == trace_id),
"All fragments should have the same trace ID"
);
}
}