1use std::sync::Arc;
2
3use crate::body::Body;
4use crate::exchange::Exchange;
5use crate::message::Message;
6
7pub type SplitExpression = Arc<dyn Fn(&Exchange) -> Vec<Exchange> + Send + Sync>;
9
10#[derive(Clone, Default)]
12pub enum AggregationStrategy {
13 #[default]
15 LastWins,
16 CollectAll,
18 Original,
20 Custom(Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>),
22}
23
24pub struct SplitterConfig {
26 pub expression: SplitExpression,
28 pub aggregation: AggregationStrategy,
30 pub parallel: bool,
32 pub parallel_limit: Option<usize>,
34 pub stop_on_exception: bool,
40}
41
42impl SplitterConfig {
43 pub fn new(expression: SplitExpression) -> Self {
45 Self {
46 expression,
47 aggregation: AggregationStrategy::default(),
48 parallel: false,
49 parallel_limit: None,
50 stop_on_exception: true,
51 }
52 }
53
54 pub fn aggregation(mut self, strategy: AggregationStrategy) -> Self {
56 self.aggregation = strategy;
57 self
58 }
59
60 pub fn parallel(mut self, parallel: bool) -> Self {
62 self.parallel = parallel;
63 self
64 }
65
66 pub fn parallel_limit(mut self, limit: usize) -> Self {
68 self.parallel_limit = Some(limit);
69 self
70 }
71
72 pub fn stop_on_exception(mut self, stop: bool) -> Self {
77 self.stop_on_exception = stop;
78 self
79 }
80}
81
82fn fragment_exchange(parent: &Exchange, body: Body) -> Exchange {
107 let mut msg = Message::new(body);
108 msg.headers = parent.input.headers.clone();
109 let mut ex = Exchange::new(msg);
110 ex.properties = parent.properties.clone();
111 ex.pattern = parent.pattern;
112 ex.otel_context = parent.otel_context.clone();
114 ex
115}
116
117pub fn split_body_lines() -> SplitExpression {
120 Arc::new(|exchange: &Exchange| {
121 let text = match &exchange.input.body {
122 Body::Text(s) => s.as_str(),
123 _ => return Vec::new(),
124 };
125 text.lines()
126 .map(|line| fragment_exchange(exchange, Body::Text(line.to_string())))
127 .collect()
128 })
129}
130
131pub fn split_body_json_array() -> SplitExpression {
134 Arc::new(|exchange: &Exchange| {
135 let arr = match &exchange.input.body {
136 Body::Json(serde_json::Value::Array(arr)) => arr,
137 _ => return Vec::new(),
138 };
139 arr.iter()
140 .map(|val| fragment_exchange(exchange, Body::Json(val.clone())))
141 .collect()
142 })
143}
144
145pub fn split_body<F>(f: F) -> SplitExpression
147where
148 F: Fn(&Body) -> Vec<Body> + Send + Sync + 'static,
149{
150 Arc::new(move |exchange: &Exchange| {
151 f(&exchange.input.body)
152 .into_iter()
153 .map(|body| fragment_exchange(exchange, body))
154 .collect()
155 })
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::value::Value;
162
163 #[test]
164 fn test_split_body_lines() {
165 let mut ex = Exchange::new(Message::new("a\nb\nc"));
166 ex.input.set_header("source", Value::String("test".into()));
167 ex.set_property("trace", Value::Bool(true));
168
169 let fragments = split_body_lines()(&ex);
170 assert_eq!(fragments.len(), 3);
171 assert_eq!(fragments[0].input.body.as_text(), Some("a"));
172 assert_eq!(fragments[1].input.body.as_text(), Some("b"));
173 assert_eq!(fragments[2].input.body.as_text(), Some("c"));
174
175 for frag in &fragments {
177 assert_eq!(
178 frag.input.header("source"),
179 Some(&Value::String("test".into()))
180 );
181 assert_eq!(frag.property("trace"), Some(&Value::Bool(true)));
182 }
183 }
184
185 #[test]
186 fn test_split_body_lines_empty() {
187 let ex = Exchange::new(Message::default()); let fragments = split_body_lines()(&ex);
189 assert!(fragments.is_empty());
190 }
191
192 #[test]
193 fn test_split_body_json_array() {
194 let arr = serde_json::json!([1, 2, 3]);
195 let ex = Exchange::new(Message::new(arr));
196
197 let fragments = split_body_json_array()(&ex);
198 assert_eq!(fragments.len(), 3);
199 assert!(matches!(&fragments[0].input.body, Body::Json(v) if *v == serde_json::json!(1)));
200 assert!(matches!(&fragments[1].input.body, Body::Json(v) if *v == serde_json::json!(2)));
201 assert!(matches!(&fragments[2].input.body, Body::Json(v) if *v == serde_json::json!(3)));
202 }
203
204 #[test]
205 fn test_split_body_json_array_not_array() {
206 let obj = serde_json::json!({"not": "array"});
207 let ex = Exchange::new(Message::new(obj));
208
209 let fragments = split_body_json_array()(&ex);
210 assert!(fragments.is_empty());
211 }
212
213 #[test]
214 fn test_split_body_custom() {
215 let splitter = split_body(|body: &Body| match body {
216 Body::Text(s) => s
217 .split(',')
218 .map(|part| Body::Text(part.trim().to_string()))
219 .collect(),
220 _ => Vec::new(),
221 });
222
223 let mut ex = Exchange::new(Message::new("x, y, z"));
224 ex.set_property("id", Value::from(42));
225
226 let fragments = splitter(&ex);
227 assert_eq!(fragments.len(), 3);
228 assert_eq!(fragments[0].input.body.as_text(), Some("x"));
229 assert_eq!(fragments[1].input.body.as_text(), Some("y"));
230 assert_eq!(fragments[2].input.body.as_text(), Some("z"));
231
232 for frag in &fragments {
234 assert_eq!(frag.property("id"), Some(&Value::from(42)));
235 }
236 }
237
238 #[test]
239 fn test_splitter_config_defaults() {
240 let config = SplitterConfig::new(split_body_lines());
241 assert!(matches!(config.aggregation, AggregationStrategy::LastWins));
242 assert!(!config.parallel);
243 assert!(config.parallel_limit.is_none());
244 assert!(config.stop_on_exception);
245 }
246
247 #[test]
248 fn test_splitter_config_builder() {
249 let config = SplitterConfig::new(split_body_lines())
250 .aggregation(AggregationStrategy::CollectAll)
251 .parallel(true)
252 .parallel_limit(4)
253 .stop_on_exception(false);
254
255 assert!(matches!(
256 config.aggregation,
257 AggregationStrategy::CollectAll
258 ));
259 assert!(config.parallel);
260 assert_eq!(config.parallel_limit, Some(4));
261 assert!(!config.stop_on_exception);
262 }
263
264 #[test]
265 fn test_fragment_exchange_inherits_otel_context() {
266 use opentelemetry::Context;
267 use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
268
269 let mut parent = Exchange::new(Message::new("test"));
271 let trace_id = TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123]);
272 let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 1, 200]);
273 let span_context = SpanContext::new(
274 trace_id,
275 span_id,
276 TraceFlags::SAMPLED,
277 true,
278 Default::default(),
279 );
280 let expected_trace_id = span_context.trace_id();
281 parent.otel_context = Context::current().with_remote_span_context(span_context);
282
283 let fragments = split_body_lines()(&parent);
285 assert!(!fragments.is_empty(), "Should have at least one fragment");
286
287 for fragment in &fragments {
289 let span = fragment.otel_context.span();
290 let frag_span_ctx = span.span_context();
291 assert!(
292 frag_span_ctx.is_valid(),
293 "Fragment should have valid span context"
294 );
295 assert_eq!(
296 frag_span_ctx.trace_id(),
297 expected_trace_id,
298 "Fragment should have same trace ID as parent"
299 );
300 }
301 }
302
303 #[test]
304 fn test_all_fragments_share_same_trace_context() {
305 use opentelemetry::Context;
306 use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
307
308 let mut parent = Exchange::new(Message::new("line1\nline2\nline3"));
310 let trace_id =
311 TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x3B, 0x9A, 0xCA, 0x09]);
312 let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 0, 111]);
313 let span_context = SpanContext::new(
314 trace_id,
315 span_id,
316 TraceFlags::SAMPLED,
317 true,
318 Default::default(),
319 );
320 parent.otel_context = Context::current().with_remote_span_context(span_context);
321
322 let fragments = split_body_lines()(&parent);
323 assert_eq!(fragments.len(), 3);
324
325 let trace_ids: Vec<_> = fragments
327 .iter()
328 .map(|f| {
329 let span = f.otel_context.span();
330 span.span_context().trace_id()
331 })
332 .collect();
333
334 assert!(
335 trace_ids.iter().all(|&id| id == trace_id),
336 "All fragments should have the same trace ID"
337 );
338 }
339}