use std::{mem, num::NonZeroUsize, time::Duration};
use tracing::trace;
use crate::{
client_request::{ClientRequest, RequestBatch},
output::TimeoutRequest,
RequestPayload,
};
#[derive(Debug)]
pub(crate) struct RequestBatcher<P> {
next_batch: Vec<ClientRequest<P>>,
timeout_duration: Duration,
max_size: Option<NonZeroUsize>,
}
impl<P: RequestPayload> RequestBatcher<P> {
pub(super) fn new(timeout: Duration, max_size: Option<NonZeroUsize>) -> Self {
Self {
next_batch: Vec::new(),
timeout_duration: timeout,
max_size,
}
}
pub(super) fn batch(
&mut self,
request: ClientRequest<P>,
) -> (Option<RequestBatch<P>>, TimeoutRequest) {
self.next_batch.push(request);
match self.max_size {
Some(max_size) if self.next_batch.len() >= max_size.get() => {
let batch = mem::take(&mut self.next_batch);
trace!("Reached the maximum size of batched client requests.");
(
Some(RequestBatch::new(batch.into_boxed_slice())),
TimeoutRequest::new_stop_batch_req(),
)
}
_ => (
None,
TimeoutRequest::new_start_batch_req(self.timeout_duration),
),
}
}
pub(crate) fn timeout(&mut self) -> (Option<RequestBatch<P>>, TimeoutRequest) {
(
if self.next_batch.is_empty() {
None
} else {
let batch = mem::take(&mut self.next_batch);
Some(RequestBatch::new(batch.into_boxed_slice()))
},
TimeoutRequest::new_stop_batch_req(),
)
}
}