1use std::pin::Pin;
2use std::sync::Arc;
3
4use futures::Stream;
5
6use crate::body::Body;
7use crate::error::CamelError;
8use crate::exchange::Exchange;
9use crate::message::Message;
10
11pub type SplitExpression = Arc<dyn Fn(&Exchange) -> Vec<Exchange> + Send + Sync>;
13
14pub type StreamingSplitExpression = Arc<
21 dyn Fn(Exchange) -> Pin<Box<dyn Stream<Item = Result<Exchange, CamelError>> + Send>>
22 + Send
23 + Sync,
24>;
25
26#[derive(Clone, Default)]
28pub enum AggregationStrategy {
29 #[default]
31 LastWins,
32 CollectAll,
34 Original,
36 Custom(Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>),
38}
39
40#[derive(
42 Clone,
43 Debug,
44 Default,
45 PartialEq,
46 Eq,
47 serde::Serialize,
48 serde::Deserialize,
49 schemars::JsonSchema,
50 ts_rs::TS,
51)]
52#[serde(rename_all = "snake_case")]
53#[ts(rename_all = "snake_case")]
54pub enum StreamSplitFormat {
55 #[default]
57 Auto,
58 Ndjson,
60 Lines,
62 Chunks,
64 Zip,
66}
67
68#[derive(
73 Clone,
74 Debug,
75 PartialEq,
76 Eq,
77 serde::Serialize,
78 serde::Deserialize,
79 schemars::JsonSchema,
80 ts_rs::TS,
81)]
82#[serde(rename_all = "snake_case")]
83#[ts(rename_all = "snake_case")]
84pub struct StreamSplitConfig {
85 pub format: StreamSplitFormat,
87 pub max_record_bytes: usize,
89 pub batch_size: usize,
91 pub chunk_size: Option<usize>,
93 pub include_origin: bool,
95}
96
97impl Default for StreamSplitConfig {
98 fn default() -> Self {
99 Self {
100 format: StreamSplitFormat::Auto,
101 max_record_bytes: 1024 * 1024,
102 batch_size: 1,
103 chunk_size: None,
104 include_origin: true,
105 }
106 }
107}
108
109impl StreamSplitConfig {
110 pub fn validate(&self) -> Result<(), CamelError> {
121 if self.batch_size == 0 {
122 return Err(CamelError::Config(
123 "stream split batch_size must be > 0".into(),
124 ));
125 }
126 if self.max_record_bytes == 0 {
127 return Err(CamelError::Config(
128 "stream split max_record_bytes must be > 0".into(),
129 ));
130 }
131 if self.format == StreamSplitFormat::Chunks && self.chunk_size.is_none() {
132 return Err(CamelError::Config(
133 "stream split format=Chunks requires chunk_size".into(),
134 ));
135 }
136 if self.format == StreamSplitFormat::Zip && self.chunk_size.is_some() {
139 return Err(CamelError::Config(
140 "stream split format=Zip does not support chunk_size".into(),
141 ));
142 }
143 if let Some(cs) = self.chunk_size
144 && cs == 0
145 {
146 return Err(CamelError::Config(
147 "stream split chunk_size must be > 0".into(),
148 ));
149 }
150 if self.format == StreamSplitFormat::Chunks
151 && let Some(cs) = self.chunk_size
152 && cs > self.max_record_bytes
153 {
154 return Err(CamelError::Config(
155 "stream split chunk_size must be <= max_record_bytes".into(),
156 ));
157 }
158 Ok(())
159 }
160}
161
162pub struct SplitterConfig {
164 pub expression: SplitExpression,
166 pub aggregation: AggregationStrategy,
168 pub parallel: bool,
170 pub parallel_limit: Option<usize>,
172 pub stop_on_exception: bool,
178}
179
180impl SplitterConfig {
181 pub fn new(expression: SplitExpression) -> Self {
183 Self {
184 expression,
185 aggregation: AggregationStrategy::default(),
186 parallel: false,
187 parallel_limit: None,
188 stop_on_exception: true,
189 }
190 }
191
192 pub fn aggregation(mut self, strategy: AggregationStrategy) -> Self {
194 self.aggregation = strategy;
195 self
196 }
197
198 pub fn parallel(mut self, parallel: bool) -> Self {
200 self.parallel = parallel;
201 self
202 }
203
204 pub fn parallel_limit(mut self, limit: usize) -> Self {
206 self.parallel_limit = Some(limit);
207 self
208 }
209
210 pub fn stop_on_exception(mut self, stop: bool) -> Self {
215 self.stop_on_exception = stop;
216 self
217 }
218
219 pub fn validate(&self) -> Result<(), CamelError> {
224 if self.parallel && self.parallel_limit == Some(0) {
225 return Err(CamelError::Config(
226 "splitter parallel_limit must be > 0".to_string(),
227 ));
228 }
229 Ok(())
230 }
231}
232
233pub fn fragment_exchange(parent: &Exchange, body: Body) -> Exchange {
258 let mut msg = Message::new(body);
259 msg.headers = parent.input.headers.clone();
260 let mut ex = Exchange::new(msg);
261 ex.properties = parent.properties.clone();
262 ex.pattern = parent.pattern;
263 ex.otel_context = parent.otel_context.clone();
265 ex
266}
267
268pub fn split_body_lines() -> SplitExpression {
271 Arc::new(|exchange: &Exchange| {
272 let text = match &exchange.input.body {
273 Body::Text(s) => s.as_str(),
274 _ => return Vec::new(),
275 };
276 text.lines()
277 .map(|line| fragment_exchange(exchange, Body::Text(line.to_string())))
278 .collect()
279 })
280}
281
282pub fn split_body_json_array() -> SplitExpression {
285 Arc::new(|exchange: &Exchange| {
286 let arr = match &exchange.input.body {
287 Body::Json(serde_json::Value::Array(arr)) => arr,
288 _ => return Vec::new(),
289 };
290 arr.iter()
291 .map(|val| fragment_exchange(exchange, Body::Json(val.clone())))
292 .collect()
293 })
294}
295
296pub fn split_body<F>(f: F) -> SplitExpression
298where
299 F: Fn(&Body) -> Vec<Body> + Send + Sync + 'static,
300{
301 Arc::new(move |exchange: &Exchange| {
302 f(&exchange.input.body)
303 .into_iter()
304 .map(|body| fragment_exchange(exchange, body))
305 .collect()
306 })
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::value::Value;
313
314 #[test]
315 fn test_split_body_lines() {
316 let mut ex = Exchange::new(Message::new("a\nb\nc"));
317 ex.input.set_header("source", Value::String("test".into()));
318 ex.set_property("trace", Value::Bool(true));
319
320 let fragments = split_body_lines()(&ex);
321 assert_eq!(fragments.len(), 3);
322 assert_eq!(fragments[0].input.body.as_text(), Some("a"));
323 assert_eq!(fragments[1].input.body.as_text(), Some("b"));
324 assert_eq!(fragments[2].input.body.as_text(), Some("c"));
325
326 for frag in &fragments {
328 assert_eq!(
329 frag.input.header("source"),
330 Some(&Value::String("test".into()))
331 );
332 assert_eq!(frag.property("trace"), Some(&Value::Bool(true)));
333 }
334 }
335
336 #[test]
337 fn test_split_body_lines_empty() {
338 let ex = Exchange::new(Message::default()); let fragments = split_body_lines()(&ex);
340 assert!(fragments.is_empty());
341 }
342
343 #[test]
344 fn test_split_body_json_array() {
345 let arr = serde_json::json!([1, 2, 3]);
346 let ex = Exchange::new(Message::new(arr));
347
348 let fragments = split_body_json_array()(&ex);
349 assert_eq!(fragments.len(), 3);
350 assert!(matches!(&fragments[0].input.body, Body::Json(v) if *v == serde_json::json!(1)));
351 assert!(matches!(&fragments[1].input.body, Body::Json(v) if *v == serde_json::json!(2)));
352 assert!(matches!(&fragments[2].input.body, Body::Json(v) if *v == serde_json::json!(3)));
353 }
354
355 #[test]
356 fn test_split_body_json_array_not_array() {
357 let obj = serde_json::json!({"not": "array"});
358 let ex = Exchange::new(Message::new(obj));
359
360 let fragments = split_body_json_array()(&ex);
361 assert!(fragments.is_empty());
362 }
363
364 #[test]
365 fn test_split_body_custom() {
366 let splitter = split_body(|body: &Body| match body {
367 Body::Text(s) => s
368 .split(',')
369 .map(|part| Body::Text(part.trim().to_string()))
370 .collect(),
371 _ => Vec::new(),
372 });
373
374 let mut ex = Exchange::new(Message::new("x, y, z"));
375 ex.set_property("id", Value::from(42));
376
377 let fragments = splitter(&ex);
378 assert_eq!(fragments.len(), 3);
379 assert_eq!(fragments[0].input.body.as_text(), Some("x"));
380 assert_eq!(fragments[1].input.body.as_text(), Some("y"));
381 assert_eq!(fragments[2].input.body.as_text(), Some("z"));
382
383 for frag in &fragments {
385 assert_eq!(frag.property("id"), Some(&Value::from(42)));
386 }
387 }
388
389 #[test]
390 fn test_splitter_config_defaults() {
391 let config = SplitterConfig::new(split_body_lines());
392 assert!(matches!(config.aggregation, AggregationStrategy::LastWins));
393 assert!(!config.parallel);
394 assert!(config.parallel_limit.is_none());
395 assert!(config.stop_on_exception);
396 }
397
398 #[test]
399 fn test_splitter_config_builder() {
400 let config = SplitterConfig::new(split_body_lines())
401 .aggregation(AggregationStrategy::CollectAll)
402 .parallel(true)
403 .parallel_limit(4)
404 .stop_on_exception(false);
405
406 assert!(matches!(
407 config.aggregation,
408 AggregationStrategy::CollectAll
409 ));
410 assert!(config.parallel);
411 assert_eq!(config.parallel_limit, Some(4));
412 assert!(!config.stop_on_exception);
413 }
414
415 #[test]
416 fn test_fragment_exchange_inherits_otel_context() {
417 use opentelemetry::Context;
418 use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
419
420 let mut parent = Exchange::new(Message::new("test"));
422 let trace_id = TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 123]);
423 let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 1, 200]);
424 let span_context = SpanContext::new(
425 trace_id,
426 span_id,
427 TraceFlags::SAMPLED,
428 true,
429 Default::default(),
430 );
431 let expected_trace_id = span_context.trace_id();
432 parent.otel_context = Context::current().with_remote_span_context(span_context);
433
434 let fragments = split_body_lines()(&parent);
436 assert!(!fragments.is_empty(), "Should have at least one fragment");
437
438 for fragment in &fragments {
440 let span = fragment.otel_context.span();
441 let frag_span_ctx = span.span_context();
442 assert!(
443 frag_span_ctx.is_valid(),
444 "Fragment should have valid span context"
445 );
446 assert_eq!(
447 frag_span_ctx.trace_id(),
448 expected_trace_id,
449 "Fragment should have same trace ID as parent"
450 );
451 }
452 }
453
454 #[test]
455 fn test_stream_split_config_defaults_valid() {
456 let config = StreamSplitConfig::default();
457 assert!(config.validate().is_ok());
458 }
459
460 #[test]
461 fn test_stream_split_config_batch_size_zero_rejected() {
462 let config = StreamSplitConfig {
463 batch_size: 0,
464 ..Default::default()
465 };
466 let err = config.validate().unwrap_err();
467 assert!(err.to_string().contains("batch_size"));
468 }
469
470 #[test]
471 fn test_stream_split_config_max_record_bytes_zero_rejected() {
472 let config = StreamSplitConfig {
473 max_record_bytes: 0,
474 ..Default::default()
475 };
476 let err = config.validate().unwrap_err();
477 assert!(err.to_string().contains("max_record_bytes"));
478 }
479
480 #[test]
481 fn test_stream_split_config_chunks_requires_chunk_size() {
482 let config = StreamSplitConfig {
483 format: StreamSplitFormat::Chunks,
484 chunk_size: None,
485 ..Default::default()
486 };
487 let err = config.validate().unwrap_err();
488 assert!(err.to_string().contains("Chunks requires chunk_size"));
489 }
490
491 #[test]
492 fn test_stream_split_config_chunk_size_zero_rejected() {
493 let config = StreamSplitConfig {
494 format: StreamSplitFormat::Chunks,
495 chunk_size: Some(0),
496 ..Default::default()
497 };
498 let err = config.validate().unwrap_err();
499 assert!(err.to_string().contains("chunk_size must be > 0"));
500 }
501
502 #[test]
503 fn test_stream_split_config_chunk_size_exceeds_max_record_bytes() {
504 let config = StreamSplitConfig {
505 format: StreamSplitFormat::Chunks,
506 chunk_size: Some(2000),
507 max_record_bytes: 1000,
508 ..Default::default()
509 };
510 let err = config.validate().unwrap_err();
511 assert!(
512 err.to_string()
513 .contains("chunk_size must be <= max_record_bytes")
514 );
515 }
516
517 #[test]
518 fn test_stream_split_config_zip_rejects_chunk_size() {
519 let config = StreamSplitConfig {
520 format: StreamSplitFormat::Zip,
521 chunk_size: Some(1024),
522 ..Default::default()
523 };
524 let err = config.validate().unwrap_err();
525 assert!(err.to_string().contains("Zip does not support chunk_size"));
526 }
527
528 #[test]
529 fn test_all_fragments_share_same_trace_context() {
530 use opentelemetry::Context;
531 use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
532
533 let mut parent = Exchange::new(Message::new("line1\nline2\nline3"));
535 let trace_id =
536 TraceId::from_bytes([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x3B, 0x9A, 0xCA, 0x09]);
537 let span_id = SpanId::from_bytes([0, 0, 0, 0, 0, 0, 0, 111]);
538 let span_context = SpanContext::new(
539 trace_id,
540 span_id,
541 TraceFlags::SAMPLED,
542 true,
543 Default::default(),
544 );
545 parent.otel_context = Context::current().with_remote_span_context(span_context);
546
547 let fragments = split_body_lines()(&parent);
548 assert_eq!(fragments.len(), 3);
549
550 let trace_ids: Vec<_> = fragments
552 .iter()
553 .map(|f| {
554 let span = f.otel_context.span();
555 span.span_context().trace_id()
556 })
557 .collect();
558
559 assert!(
560 trace_ids.iter().all(|&id| id == trace_id),
561 "All fragments should have the same trace ID"
562 );
563 }
564}