1use std::{
4 pin::Pin,
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use futures::{Stream, StreamExt};
10use s2_common::{caps::RECORD_BATCH_MAX, read_extent::CountOrBytes};
11use tokio::time::Instant;
12
13use crate::types::{
14 AppendInput, AppendRecord, AppendRecordBatch, FencingToken, MeteredBytes, ValidationError,
15};
16
17const RECORD_BATCH_MIN: CountOrBytes = CountOrBytes { count: 1, bytes: 8 };
18
19#[derive(Debug, Clone)]
20pub struct BatchingConfig {
22 linger: Duration,
23 max_batch_bytes: usize,
24 max_batch_records: usize,
25}
26
27impl Default for BatchingConfig {
28 fn default() -> Self {
29 Self {
30 linger: Duration::from_millis(5),
31 max_batch_bytes: RECORD_BATCH_MAX.bytes,
32 max_batch_records: RECORD_BATCH_MAX.count,
33 }
34 }
35}
36
37impl BatchingConfig {
38 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn with_linger(self, linger: Duration) -> Self {
47 Self { linger, ..self }
48 }
49
50 pub fn with_max_batch_bytes(self, max_batch_bytes: usize) -> Result<Self, ValidationError> {
56 if max_batch_bytes < RECORD_BATCH_MIN.bytes {
57 return Err(ValidationError(format!(
58 "max_batch_bytes ({max_batch_bytes}) must be at least {}",
59 RECORD_BATCH_MIN.bytes
60 )));
61 }
62 if max_batch_bytes > RECORD_BATCH_MAX.bytes {
63 return Err(ValidationError(format!(
64 "max_batch_bytes ({max_batch_bytes}) must not exceed {}",
65 RECORD_BATCH_MAX.bytes
66 )));
67 }
68 Ok(Self {
69 max_batch_bytes,
70 ..self
71 })
72 }
73
74 pub fn with_max_batch_records(self, max_batch_records: usize) -> Result<Self, ValidationError> {
80 if max_batch_records < RECORD_BATCH_MIN.count {
81 return Err(ValidationError(format!(
82 "max_batch_records ({max_batch_records}) must be at least {}",
83 RECORD_BATCH_MIN.count
84 )));
85 }
86 if max_batch_records > RECORD_BATCH_MAX.count {
87 return Err(ValidationError(format!(
88 "max_batch_records ({max_batch_records}) must not exceed {}",
89 RECORD_BATCH_MAX.count
90 )));
91 }
92 Ok(Self {
93 max_batch_records,
94 ..self
95 })
96 }
97}
98
99pub struct AppendInputs {
101 pub(crate) batches: AppendRecordBatches,
102 pub(crate) fencing_token: Option<FencingToken>,
103 pub(crate) match_seq_num: Option<u64>,
104}
105
106impl AppendInputs {
107 pub fn new(
109 records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
110 config: BatchingConfig,
111 ) -> Self {
112 Self {
113 batches: AppendRecordBatches::new(records, config),
114 fencing_token: None,
115 match_seq_num: None,
116 }
117 }
118
119 pub fn with_fencing_token(self, fencing_token: FencingToken) -> Self {
121 Self {
122 fencing_token: Some(fencing_token),
123 ..self
124 }
125 }
126
127 pub fn with_match_seq_num(self, seq_num: u64) -> Self {
130 Self {
131 match_seq_num: Some(seq_num),
132 ..self
133 }
134 }
135}
136
137impl Stream for AppendInputs {
138 type Item = Result<AppendInput, ValidationError>;
139
140 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141 match self.batches.poll_next_unpin(cx) {
142 Poll::Ready(Some(Ok(batch))) => {
143 let match_seq_num = self.match_seq_num;
144 if let Some(seq_num) = self.match_seq_num.as_mut() {
145 *seq_num += batch.len() as u64;
146 }
147 Poll::Ready(Some(Ok(AppendInput {
148 records: batch,
149 match_seq_num,
150 fencing_token: self.fencing_token.clone(),
151 })))
152 }
153 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
154 Poll::Ready(None) => Poll::Ready(None),
155 Poll::Pending => Poll::Pending,
156 }
157 }
158}
159
160pub struct AppendRecordBatches {
162 inner: Pin<Box<dyn Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send>>,
163}
164
165impl AppendRecordBatches {
166 pub fn new(
168 records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
169 config: BatchingConfig,
170 ) -> Self {
171 Self {
172 inner: Box::pin(append_record_batches(records, config)),
173 }
174 }
175}
176
177impl Stream for AppendRecordBatches {
178 type Item = Result<AppendRecordBatch, ValidationError>;
179
180 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181 self.inner.as_mut().poll_next(cx)
182 }
183}
184
185fn is_batch_full(config: &BatchingConfig, count: usize, bytes: usize) -> bool {
186 count >= config.max_batch_records || bytes >= config.max_batch_bytes
187}
188
189fn would_overflow_batch(
190 config: &BatchingConfig,
191 count: usize,
192 bytes: usize,
193 record: &AppendRecord,
194) -> bool {
195 count + 1 > config.max_batch_records || bytes + record.metered_bytes() > config.max_batch_bytes
196}
197
198fn append_record_batches(
199 mut records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
200 config: BatchingConfig,
201) -> impl Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send + 'static {
202 async_stream::try_stream! {
203 let mut batch = AppendRecordBatch::with_capacity(config.max_batch_records);
204 let mut overflowed_record: Option<AppendRecord> = None;
205
206 let linger_deadline = tokio::time::sleep(config.linger);
207 tokio::pin!(linger_deadline);
208
209 'outer: loop {
210 let first_record = match overflowed_record.take() {
211 Some(record) => record,
212 None => match records.next().await {
213 Some(item) => item.into(),
214 None => break,
215 },
216 };
217
218 let record_bytes = first_record.metered_bytes();
219 if record_bytes > config.max_batch_bytes {
220 Err(ValidationError(format!(
221 "record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({})",
222 config.max_batch_bytes
223 )))?;
224 }
225 batch.push(first_record);
226
227 while !is_batch_full(&config, batch.len(), batch.metered_bytes())
228 && overflowed_record.is_none()
229 {
230 if batch.len() == 1 {
231 linger_deadline
232 .as_mut()
233 .reset(Instant::now() + config.linger);
234 }
235
236 tokio::select! {
237 next_record = records.next() => {
238 match next_record {
239 Some(record) => {
240 let record: AppendRecord = record.into();
241 if would_overflow_batch(&config, batch.len(), batch.metered_bytes(), &record) {
242 overflowed_record = Some(record);
243 } else {
244 batch.push(record);
245 }
246 }
247 None => {
248 yield std::mem::replace(&mut batch, AppendRecordBatch::with_capacity(config.max_batch_records));
249 break 'outer;
250 }
251 }
252 },
253 _ = &mut linger_deadline, if !batch.is_empty() => {
254 break;
255 }
256 };
257 }
258
259 yield std::mem::replace(
260 &mut batch,
261 AppendRecordBatch::with_capacity(config.max_batch_records),
262 );
263 }
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use assert_matches::assert_matches;
270 use futures::TryStreamExt;
271
272 use super::*;
273
274 #[tokio::test]
275 async fn batches_should_be_empty_when_record_stream_is_empty() {
276 let batches: Vec<_> = AppendRecordBatches::new(
277 futures::stream::iter::<Vec<AppendRecord>>(vec![]),
278 BatchingConfig::default(),
279 )
280 .collect()
281 .await;
282 assert_eq!(batches.len(), 0);
283 }
284
285 #[tokio::test]
286 async fn batches_respect_count_limit() -> Result<(), ValidationError> {
287 let records: Vec<_> = (0..10)
288 .map(|i| AppendRecord::new(format!("record{i}")))
289 .collect::<Result<_, _>>()?;
290 let config = BatchingConfig::default().with_max_batch_records(3)?;
291 let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
292 .try_collect()
293 .await?;
294
295 assert_eq!(batches.len(), 4);
296 assert_eq!(batches[0].len(), 3);
297 assert_eq!(batches[1].len(), 3);
298 assert_eq!(batches[2].len(), 3);
299 assert_eq!(batches[3].len(), 1);
300
301 Ok(())
302 }
303
304 #[tokio::test]
305 async fn batches_respect_bytes_limit() -> Result<(), ValidationError> {
306 let records: Vec<_> = (0..10)
307 .map(|i| AppendRecord::new(format!("record{i}")))
308 .collect::<Result<_, _>>()?;
309 let single_record_bytes = records[0].metered_bytes();
310 let max_batch_bytes = single_record_bytes * 3;
311
312 let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
313 let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
314 .try_collect()
315 .await?;
316
317 assert_eq!(batches.len(), 4);
318 assert_eq!(batches[0].metered_bytes(), max_batch_bytes);
319 assert_eq!(batches[1].metered_bytes(), max_batch_bytes);
320 assert_eq!(batches[2].metered_bytes(), max_batch_bytes);
321 assert_eq!(batches[3].metered_bytes(), single_record_bytes);
322
323 Ok(())
324 }
325
326 #[tokio::test]
327 async fn batching_should_error_when_it_sees_oversized_record() -> Result<(), ValidationError> {
328 let record = AppendRecord::new("hello-world")?;
329 let record_bytes = record.metered_bytes();
330 let max_batch_bytes = 10;
331
332 let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
333 let results: Vec<_> = AppendRecordBatches::new(futures::stream::iter(vec![record]), config)
334 .collect()
335 .await;
336
337 assert_eq!(results.len(), 1);
338 assert_matches!(&results[0], Err(err) => {
339 assert_eq!(
340 err.to_string(),
341 format!("record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({max_batch_bytes})")
342 );
343 });
344
345 Ok(())
346 }
347}