1use std::sync::Arc;
2
3use crate::body::Body;
4use crate::error::CamelError;
5use crate::exchange::Exchange;
6use crate::message::Message;
7
8pub type SplitExpression = Arc<dyn Fn(&Exchange) -> Vec<Exchange> + Send + Sync>;
10
11#[derive(Clone, Default)]
13pub enum AggregationStrategy {
14 #[default]
16 LastWins,
17 CollectAll,
19 Original,
21 Custom(Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>),
23}
24
25pub struct SplitterConfig {
27 pub expression: SplitExpression,
29 pub aggregation: AggregationStrategy,
31 pub parallel: bool,
33 pub parallel_limit: Option<usize>,
35 pub stop_on_exception: bool,
41}
42
43impl SplitterConfig {
44 pub fn new(expression: SplitExpression) -> Self {
46 Self {
47 expression,
48 aggregation: AggregationStrategy::default(),
49 parallel: false,
50 parallel_limit: None,
51 stop_on_exception: true,
52 }
53 }
54
55 pub fn aggregation(mut self, strategy: AggregationStrategy) -> Self {
57 self.aggregation = strategy;
58 self
59 }
60
61 pub fn parallel(mut self, parallel: bool) -> Self {
63 self.parallel = parallel;
64 self
65 }
66
67 pub fn parallel_limit(mut self, limit: usize) -> Self {
69 self.parallel_limit = Some(limit);
70 self
71 }
72
73 pub fn stop_on_exception(mut self, stop: bool) -> Self {
78 self.stop_on_exception = stop;
79 self
80 }
81
82 pub fn validate(&self) -> Result<(), CamelError> {
87 if self.parallel && self.parallel_limit == Some(0) {
88 return Err(CamelError::Config(
89 "splitter parallel_limit must be > 0".to_string(),
90 ));
91 }
92 Ok(())
93 }
94}
95
96fn fragment_exchange(parent: &Exchange, body: Body) -> Exchange {
121 let mut msg = Message::new(body);
122 msg.headers = parent.input.headers.clone();
123 let mut ex = Exchange::new(msg);
124 ex.properties = parent.properties.clone();
125 ex.pattern = parent.pattern;
126 ex.otel_context = parent.otel_context.clone();
128 ex
129}
130
131pub fn split_body_lines() -> SplitExpression {
134 Arc::new(|exchange: &Exchange| {
135 let text = match &exchange.input.body {
136 Body::Text(s) => s.as_str(),
137 _ => return Vec::new(),
138 };
139 text.lines()
140 .map(|line| fragment_exchange(exchange, Body::Text(line.to_string())))
141 .collect()
142 })
143}
144
145pub fn split_body_json_array() -> SplitExpression {
148 Arc::new(|exchange: &Exchange| {
149 let arr = match &exchange.input.body {
150 Body::Json(serde_json::Value::Array(arr)) => arr,
151 _ => return Vec::new(),
152 };
153 arr.iter()
154 .map(|val| fragment_exchange(exchange, Body::Json(val.clone())))
155 .collect()
156 })
157}
158
159pub fn split_body<F>(f: F) -> SplitExpression
161where
162 F: Fn(&Body) -> Vec<Body> + Send + Sync + 'static,
163{
164 Arc::new(move |exchange: &Exchange| {
165 f(&exchange.input.body)
166 .into_iter()
167 .map(|body| fragment_exchange(exchange, body))
168 .collect()
169 })
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::value::Value;
176
177 #[test]
178 fn test_split_body_lines() {
179 let mut ex = Exchange::new(Message::new("a\nb\nc"));
180 ex.input.set_header("source", Value::String("test".into()));
181 ex.set_property("trace", Value::Bool(true));
182
183 let fragments = split_body_lines()(&ex);
184 assert_eq!(fragments.len(), 3);
185 assert_eq!(fragments[0].input.body.as_text(), Some("a"));
186 assert_eq!(fragments[1].input.body.as_text(), Some("b"));
187 assert_eq!(fragments[2].input.body.as_text(), Some("c"));
188
189 for frag in &fragments {
191 assert_eq!(
192 frag.input.header("source"),
193 Some(&Value::String("test".into()))
194 );
195 assert_eq!(frag.property("trace"), Some(&Value::Bool(true)));
196 }
197 }
198
199 #[test]
200 fn test_split_body_lines_empty() {
201 let ex = Exchange::new(Message::default()); let fragments = split_body_lines()(&ex);
203 assert!(fragments.is_empty());
204 }
205
206 #[test]
207 fn test_split_body_json_array() {
208 let arr = serde_json::json!([1, 2, 3]);
209 let ex = Exchange::new(Message::new(arr));
210
211 let fragments = split_body_json_array()(&ex);
212 assert_eq!(fragments.len(), 3);
213 assert!(matches!(&fragments[0].input.body, Body::Json(v) if *v == serde_json::json!(1)));
214 assert!(matches!(&fragments[1].input.body, Body::Json(v) if *v == serde_json::json!(2)));
215 assert!(matches!(&fragments[2].input.body, Body::Json(v) if *v == serde_json::json!(3)));
216 }
217
218 #[test]
219 fn test_split_body_json_array_not_array() {
220 let obj = serde_json::json!({"not": "array"});
221 let ex = Exchange::new(Message::new(obj));
222
223 let fragments = split_body_json_array()(&ex);
224 assert!(fragments.is_empty());
225 }
226
227 #[test]
228 fn test_split_body_custom() {
229 let splitter = split_body(|body: &Body| match body {
230 Body::Text(s) => s
231 .split(',')
232 .map(|part| Body::Text(part.trim().to_string()))
233 .collect(),
234 _ => Vec::new(),
235 });
236
237 let mut ex = Exchange::new(Message::new("x, y, z"));
238 ex.set_property("id", Value::from(42));
239
240 let fragments = splitter(&ex);
241 assert_eq!(fragments.len(), 3);
242 assert_eq!(fragments[0].input.body.as_text(), Some("x"));
243 assert_eq!(fragments[1].input.body.as_text(), Some("y"));
244 assert_eq!(fragments[2].input.body.as_text(), Some("z"));
245
246 for frag in &fragments {
248 assert_eq!(frag.property("id"), Some(&Value::from(42)));
249 }
250 }
251
252 #[test]
253 fn test_splitter_config_defaults() {
254 let config = SplitterConfig::new(split_body_lines());
255 assert!(matches!(config.aggregation, AggregationStrategy::LastWins));
256 assert!(!config.parallel);
257 assert!(config.parallel_limit.is_none());
258 assert!(config.stop_on_exception);
259 }
260
261 #[test]
262 fn test_splitter_config_builder() {
263 let config = SplitterConfig::new(split_body_lines())
264 .aggregation(AggregationStrategy::CollectAll)
265 .parallel(true)
266 .parallel_limit(4)
267 .stop_on_exception(false);
268
269 assert!(matches!(
270 config.aggregation,
271 AggregationStrategy::CollectAll
272 ));
273 assert!(config.parallel);
274 assert_eq!(config.parallel_limit, Some(4));
275 assert!(!config.stop_on_exception);
276 }
277
278 #[test]
279 fn test_fragment_exchange_inherits_otel_context() {
280 use opentelemetry::Context;
281 use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
282
283 let mut parent = Exchange::new(Message::new("test"));
285 let trace_id = TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123]);
286 let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 1, 200]);
287 let span_context = SpanContext::new(
288 trace_id,
289 span_id,
290 TraceFlags::SAMPLED,
291 true,
292 Default::default(),
293 );
294 let expected_trace_id = span_context.trace_id();
295 parent.otel_context = Context::current().with_remote_span_context(span_context);
296
297 let fragments = split_body_lines()(&parent);
299 assert!(!fragments.is_empty(), "Should have at least one fragment");
300
301 for fragment in &fragments {
303 let span = fragment.otel_context.span();
304 let frag_span_ctx = span.span_context();
305 assert!(
306 frag_span_ctx.is_valid(),
307 "Fragment should have valid span context"
308 );
309 assert_eq!(
310 frag_span_ctx.trace_id(),
311 expected_trace_id,
312 "Fragment should have same trace ID as parent"
313 );
314 }
315 }
316
317 #[test]
318 fn test_all_fragments_share_same_trace_context() {
319 use opentelemetry::Context;
320 use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
321
322 let mut parent = Exchange::new(Message::new("line1\nline2\nline3"));
324 let trace_id =
325 TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x3B, 0x9A, 0xCA, 0x09]);
326 let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 0, 111]);
327 let span_context = SpanContext::new(
328 trace_id,
329 span_id,
330 TraceFlags::SAMPLED,
331 true,
332 Default::default(),
333 );
334 parent.otel_context = Context::current().with_remote_span_context(span_context);
335
336 let fragments = split_body_lines()(&parent);
337 assert_eq!(fragments.len(), 3);
338
339 let trace_ids: Vec<_> = fragments
341 .iter()
342 .map(|f| {
343 let span = f.otel_context.span();
344 span.span_context().trace_id()
345 })
346 .collect();
347
348 assert!(
349 trace_ids.iter().all(|&id| id == trace_id),
350 "All fragments should have the same trace ID"
351 );
352 }
353}