1#![deny(missing_docs)]
19
20#[cfg(doctest)]
21use doc_comment::doctest;
22#[cfg(doctest)]
23doctest!("../README.md");
24
25mod batch;
26mod batch_queue;
27mod batcher;
28mod error;
29mod policies;
30mod processor;
31mod timeout;
32mod worker;
33
34pub use batcher::Batcher;
35pub use error::BatchError;
36pub use policies::{BatchingPolicy, Limits, OnFull};
37pub use processor::Processor;
38
39#[cfg(test)]
40mod tests {
41 use std::time::Duration;
42
43 use tokio::join;
44 use tracing::{Instrument, span};
45
46 use crate::{Batcher, BatchingPolicy, Limits, Processor};
47
48 #[derive(Debug, Clone)]
49 pub struct SimpleBatchProcessor(pub Duration);
50
51 impl Processor for SimpleBatchProcessor {
52 type Key = String;
53 type Input = String;
54 type Output = String;
55 type Error = String;
56 type Resources = ();
57
58 async fn acquire_resources(&self, _key: String) -> Result<(), String> {
59 Ok(())
60 }
61
62 async fn process(
63 &self,
64 key: String,
65 inputs: impl Iterator<Item = String> + Send,
66 _resources: (),
67 ) -> Result<Vec<String>, String> {
68 tokio::time::sleep(self.0).await;
69 Ok(inputs.map(|s| s + " processed for " + &key).collect())
70 }
71 }
72
73 #[tokio::test]
74 #[ignore = "flaky"]
75 async fn test_tracing() {
76 use tracing::Level;
77 use tracing_capture::{CaptureLayer, SharedStorage};
78 use tracing_subscriber::layer::SubscriberExt;
79
80 let subscriber = tracing_subscriber::fmt()
81 .pretty()
82 .with_max_level(Level::INFO)
83 .finish();
84 let storage = SharedStorage::default();
86 let subscriber = subscriber.with(CaptureLayer::new(&storage));
87
88 let _guard = tracing::subscriber::set_default(subscriber);
90
91 let batcher = Batcher::builder()
92 .name("test_tracing")
93 .processor(SimpleBatchProcessor(Duration::ZERO))
94 .limits(Limits::default().with_max_batch_size(2))
95 .batching_policy(BatchingPolicy::Size)
96 .build();
97
98 let h1 = {
99 tokio_test::task::spawn({
100 let span = span!(Level::INFO, "test_handler_span1");
101
102 batcher
103 .add("A".to_string(), "1".to_string())
104 .instrument(span)
105 })
106 };
107 let h2 = {
108 tokio_test::task::spawn({
109 let span = span!(Level::INFO, "test_handler_span2");
110
111 batcher
112 .add("A".to_string(), "2".to_string())
113 .instrument(span)
114 })
115 };
116
117 let (o1, o2) = join!(h1, h2);
118
119 assert!(o1.is_ok());
120 assert!(o2.is_ok());
121
122 let worker = batcher.worker_handle();
123 worker.shut_down().await;
124 tokio::time::timeout(Duration::from_secs(1), worker.wait_for_shutdown())
125 .await
126 .expect("Worker should shut down");
127
128 let storage = storage.lock();
129
130 let process_spans: Vec<_> = storage
131 .all_spans()
132 .filter(|span| span.metadata().name().contains("process batch"))
133 .collect();
134 assert_eq!(
135 process_spans.len(),
136 1,
137 "should be a single span for processing the batch"
138 );
139
140 let process_span = process_spans.first().unwrap();
141
142 assert_eq!(
143 process_span["batch.size"], 2u64,
144 "batch.size shouldn't be emitted as a string",
145 );
146
147 assert_eq!(
148 process_span.follows_from().len(),
149 2,
150 "should follow from both handler spans"
151 );
152
153 let link_back_spans: Vec<_> = storage
154 .all_spans()
155 .filter(|span| span.metadata().name().contains("batch finished"))
156 .collect();
157 assert_eq!(
158 link_back_spans.len(),
159 2,
160 "should be two spans for linking back to the process span"
161 );
162
163 for span in link_back_spans {
164 assert_eq!(
165 span.follows_from().len(),
166 1,
167 "link back spans should follow from the process span"
168 );
169 }
170
171 assert_eq!(storage.all_spans().len(), 6, "should be 6 spans in total");
172 }
173}