use std::sync::Arc;
use std::time::Duration;
use futures::future::BoxFuture;
use thiserror::Error;
use tokio::task::JoinHandle;
use tracing::*;
use self::{
aggregator::Aggregator,
batch::{BatchBuilder, FlushResult, ResultHandle},
};
use crate::{
client::{
error::Error as ClientError,
partition::{Compression, PartitionClient},
producer::aggregator::TryPush,
},
record::Record,
};
pub mod aggregator;
mod batch;
pub(crate) mod broadcast;
#[derive(Debug, Error, Clone)]
pub enum Error {
#[error("Aggregator error: {0}")]
Aggregator(Arc<dyn std::error::Error + Send + Sync>),
#[error("Client error: {0}")]
Client(#[from] Arc<ClientError>),
#[error("Flush error: {0}")]
FlushError(String),
#[error("Input too large for aggregator")]
TooLarge,
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug)]
pub struct BatchProducerBuilder {
client: Arc<dyn ProducerClient>,
linger: Duration,
compression: Compression,
}
impl BatchProducerBuilder {
pub fn new(client: Arc<PartitionClient>) -> Self {
Self::new_with_client(client)
}
pub fn new_with_client(client: Arc<dyn ProducerClient>) -> Self {
Self {
client,
linger: Duration::from_millis(5),
compression: Compression::default(),
}
}
pub fn with_linger(self, linger: Duration) -> Self {
Self { linger, ..self }
}
pub fn with_compression(self, compression: Compression) -> Self {
Self {
compression,
..self
}
}
pub fn build<A>(self, aggregator: A) -> BatchProducer<A>
where
A: aggregator::Aggregator,
{
BatchProducer {
linger: self.linger,
inner: Arc::new(parking_lot::Mutex::new(ProducerInner::new(
aggregator,
self.client,
self.compression,
))),
}
}
}
pub trait ProducerClient: std::fmt::Debug + Send + Sync {
fn produce(
&self,
records: Vec<Record>,
compression: Compression,
) -> BoxFuture<'_, Result<Vec<i64>, ClientError>>;
}
impl ProducerClient for PartitionClient {
fn produce(
&self,
records: Vec<Record>,
compression: Compression,
) -> BoxFuture<'_, Result<Vec<i64>, ClientError>> {
Box::pin(self.produce(records, compression))
}
}
#[derive(Debug)]
struct ProducerInner<A>
where
A: aggregator::Aggregator,
{
batch_builder: Option<BatchBuilder<A>>,
flush_clock: usize,
has_linger_waiter: bool,
compression: Compression,
client: Arc<dyn ProducerClient>,
pending_flushes: Vec<JoinHandle<()>>,
}
impl<A> Drop for ProducerInner<A>
where
A: Aggregator,
{
fn drop(&mut self) {
self.pending_flushes.drain(..).for_each(|f| f.abort());
}
}
impl<A> ProducerInner<A>
where
A: aggregator::Aggregator,
{
fn new(aggregator: A, client: Arc<dyn ProducerClient>, compression: Compression) -> Self {
Self {
batch_builder: Some(BatchBuilder::new(aggregator)),
flush_clock: 0,
has_linger_waiter: false,
client,
compression,
pending_flushes: Vec::new(),
}
}
fn try_push(&mut self, data: A::Input) -> Result<CallerRole<A>, Error> {
let handle = match self.batch_builder.as_mut().unwrap().try_push(data)? {
TryPush::Aggregated(handle) => handle,
TryPush::NoCapacity(data) => {
debug!(client=?self.client, "insufficient capacity in aggregator - flushing");
self.flush(None)?;
match self.batch_builder.as_mut().unwrap().try_push(data)? {
TryPush::Aggregated(handle) => handle,
TryPush::NoCapacity(_) => {
error!(client=?self.client, "record too large for aggregator");
return Err(Error::TooLarge);
}
}
}
};
if self.has_linger_waiter {
return Ok(CallerRole::JustWait(handle));
}
self.has_linger_waiter = true;
Ok(CallerRole::Linger {
handle,
flush_token: self.flush_clock,
})
}
fn flush(&mut self, flusher_token: Option<usize>) -> Result<()> {
if let Some(token) = flusher_token {
if token != self.flush_clock {
debug!(client=?self.client, "spurious batch flush call");
return Ok(());
}
}
debug!(client=?self.client, "flushing batch");
let batch = self.batch_builder.take().expect("no batch to flush");
let (new_builder, flush_task, maybe_err) =
match batch.background_flush(Arc::clone(&self.client), self.compression) {
FlushResult::Ok(b, flush_task) => (b, flush_task, None),
FlushResult::Error(b, e) => {
error!(client=?self.client, error=%e, "failed to write record batch");
(b, None, Some(e))
}
};
self.batch_builder = Some(new_builder);
self.pending_flushes.retain_mut(|t| !t.is_finished());
if let Some(t) = flush_task {
self.pending_flushes.push(t);
}
self.flush_clock = self.flush_clock.wrapping_add(1);
self.has_linger_waiter = false;
match maybe_err {
Some(e) => Err(e),
None => Ok(()),
}
}
}
enum CallerRole<A>
where
A: Aggregator,
{
JustWait(ResultHandle<A>),
Linger {
handle: ResultHandle<A>,
flush_token: usize,
},
}
#[derive(Debug)]
pub struct BatchProducer<A>
where
A: aggregator::Aggregator,
{
linger: Duration,
inner: Arc<parking_lot::Mutex<ProducerInner<A>>>,
}
impl<A> BatchProducer<A>
where
A: aggregator::Aggregator,
{
pub async fn produce(
&self,
data: A::Input,
) -> Result<<A as aggregator::AggregatorStatus>::Status> {
let role = {
let mut inner = self.inner.lock();
inner.try_push(data)?
};
match role {
CallerRole::JustWait(mut handle) => {
let status = handle.wait().await?;
handle.result(status)
}
CallerRole::Linger {
mut handle,
flush_token,
} => {
let linger: JoinHandle<Result<(), Error>> = tokio::spawn({
let linger = self.linger;
let inner = Arc::clone(&self.inner);
async move {
tokio::time::sleep(linger).await;
inner.lock().flush(Some(flush_token))?;
Ok(())
}
});
tokio::select! {
res = linger => res.expect("linger panic")?,
r = handle.wait() => return handle.result(r?),
}
let status = handle.wait().await?;
handle.result(status)
}
}
}
pub async fn flush(&self) -> Result<()> {
let outstanding = {
let mut inner = self.inner.lock();
debug!("Manual flush");
inner.flush(None)?;
std::mem::take(&mut inner.pending_flushes)
};
for t in outstanding.into_iter() {
if !t.is_finished() {
t.await.expect("flush task panic");
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::aggregator::{Aggregator, RecordAggregatorStatusDeaggregator, StatusDeaggregator};
use super::*;
use crate::client::error::RequestContext;
use crate::{
client::producer::aggregator::RecordAggregator, protocol::error::Error as ProtocolError,
};
use chrono::{TimeZone, Utc};
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::{FutureExt, StreamExt, TryStreamExt, pin_mut};
#[derive(Debug)]
struct MockClient {
error: Option<ProtocolError>,
panic: Option<String>,
delay: Duration,
batch_sizes: parking_lot::Mutex<Vec<usize>>,
}
impl ProducerClient for MockClient {
fn produce(
&self,
records: Vec<Record>,
_compression: Compression,
) -> BoxFuture<'_, Result<Vec<i64>, ClientError>> {
Box::pin(async move {
tokio::time::sleep(self.delay).await;
if let Some(e) = self.error {
return Err(ClientError::ServerError {
protocol_error: e,
error_message: None,
request: RequestContext::Partition("foo".into(), 1),
response: None,
is_virtual: false,
});
}
if let Some(p) = self.panic.as_ref() {
panic!("{}", p);
}
let mut batch_sizes = self.batch_sizes.lock();
let offset_base = batch_sizes.iter().sum::<usize>();
let offsets = (0..records.len())
.map(|x| (x + offset_base) as i64)
.collect();
batch_sizes.push(records.len());
Ok(offsets)
})
}
}
fn record() -> Record {
Record {
key: Some(vec![0; 4]),
value: Some(vec![0; 6]),
headers: Default::default(),
timestamp: Utc.timestamp_millis_opt(320).unwrap(),
}
}
#[tokio::test]
async fn test_producer() {
let record = record();
let linger = Duration::from_millis(100);
for delay in [Duration::from_secs(0), Duration::from_millis(1)] {
let client = Arc::new(MockClient {
error: None,
panic: None,
delay,
batch_sizes: Default::default(),
});
let aggregator = RecordAggregator::new(record.approximate_size() * 2);
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let mut futures = FuturesOrdered::new();
futures.push_back(producer.produce(record.clone()));
futures.push_back(producer.produce(record.clone()));
futures.push_back(producer.produce(record.clone()));
let assert_ok = |a: Result<Option<Result<_, _>>, _>, expected: i64| {
let offset = a
.expect("no timeout")
.expect("Some future left")
.expect("no producer error");
assert_eq!(offset, expected);
};
assert_ok(
tokio::time::timeout(Duration::from_millis(10), futures.next()).await,
0,
);
assert_ok(
tokio::time::timeout(Duration::from_millis(10), futures.next()).await,
1,
);
tokio::time::timeout(Duration::from_millis(10), futures.next())
.await
.expect_err("timeout");
assert_eq!(client.batch_sizes.lock().as_slice(), &[2]);
assert_ok(tokio::time::timeout(linger * 2, futures.next()).await, 2);
assert_eq!(client.batch_sizes.lock().as_slice(), &[2, 1]);
}
}
#[tokio::test]
async fn test_manual_flush() {
let record = record();
let linger = Duration::from_secs(3600);
let client = Arc::new(MockClient {
error: None,
panic: None,
delay: Duration::from_millis(1),
batch_sizes: Default::default(),
});
let aggregator = RecordAggregator::new(usize::MAX);
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let a = producer.produce(record.clone()).fuse();
pin_mut!(a);
let b = producer.produce(record).fuse();
pin_mut!(b);
futures::select! {
_ = a => panic!("a finished!"),
_ = b => panic!("b finished!"),
_ = tokio::time::sleep(Duration::from_millis(100)).fuse() => {}
};
producer.flush().await.unwrap();
let offset_a = tokio::time::timeout(Duration::from_millis(10), a)
.await
.unwrap()
.unwrap();
let offset_b = tokio::time::timeout(Duration::from_millis(10), b)
.await
.unwrap()
.unwrap();
assert!(((offset_a == 0) && (offset_b == 1)) || ((offset_a == 1) && (offset_b == 0)));
}
#[tokio::test]
async fn test_producer_empty_aggregator_with_linger() {
let record = record();
let linger = Duration::from_millis(2);
let client = Arc::new(MockClient {
error: None,
panic: None,
delay: Duration::from_millis(0),
batch_sizes: Default::default(),
});
struct EmptyAgg {}
impl Aggregator for EmptyAgg {
type Input = Record;
type Tag = ();
type StatusDeaggregator = EmptyDeagg;
fn try_push(
&mut self,
_record: Self::Input,
) -> Result<TryPush<Self::Input, Self::Tag>, aggregator::Error> {
Ok(TryPush::Aggregated(()))
}
fn flush(
&mut self,
) -> Result<(Vec<Record>, Self::StatusDeaggregator), aggregator::Error> {
Ok((vec![], EmptyDeagg {}))
}
}
#[derive(Debug)]
struct EmptyDeagg {}
impl StatusDeaggregator for EmptyDeagg {
type Status = ();
type Tag = ();
fn deaggregate(
&self,
_input: &[i64],
_tag: Self::Tag,
) -> Result<Self::Status, aggregator::Error> {
Ok(())
}
}
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(EmptyAgg {});
let mut futures: FuturesUnordered<_> = (0..10)
.map(|_| async {
producer.produce(record.clone()).await.unwrap();
})
.collect();
while futures.next().await.is_some() {}
}
#[tokio::test]
async fn test_producer_client_error() {
let record = record();
let linger = Duration::from_millis(5);
let client = Arc::new(MockClient {
error: Some(ProtocolError::NetworkException),
panic: None,
delay: Duration::from_millis(1),
batch_sizes: Default::default(),
});
let aggregator = RecordAggregator::new(record.approximate_size() * 2);
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let mut futures = FuturesUnordered::new();
futures.push(producer.produce(record.clone()));
futures.push(producer.produce(record.clone()));
futures.next().await.unwrap().unwrap_err();
futures.next().await.unwrap().unwrap_err();
}
#[tokio::test]
async fn test_producer_aggregator_error_push() {
let record = record();
let linger = Duration::from_millis(5);
let client = Arc::new(MockClient {
error: None,
panic: None,
delay: Duration::from_millis(1),
batch_sizes: Default::default(),
});
let aggregator = MockAggregator {
inner: RecordAggregator::new(record.approximate_size() * 2),
push_errors: vec!["test".to_owned().into()],
flush_errors: vec![],
deagg_errors: vec![],
};
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let mut futures = FuturesUnordered::new();
futures.push(producer.produce(record.clone()));
futures.push(producer.produce(record.clone()));
futures.push(producer.produce(record.clone()));
futures.next().await.unwrap().unwrap_err();
futures.next().await.unwrap().unwrap();
}
#[tokio::test]
async fn test_producer_aggregator_error_flush() {
let record = record();
let linger = Duration::from_millis(5);
let client = Arc::new(MockClient {
error: None,
panic: None,
delay: Duration::from_millis(1),
batch_sizes: Default::default(),
});
let aggregator = MockAggregator {
inner: RecordAggregator::new(record.approximate_size() * 2),
push_errors: vec![],
flush_errors: vec!["test".to_owned().into()],
deagg_errors: vec![],
};
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let mut futures = FuturesUnordered::new();
futures.push(producer.produce(record.clone()));
futures.push(producer.produce(record.clone()));
futures.next().await.unwrap().unwrap_err();
futures.next().await.unwrap().unwrap_err();
}
#[tokio::test]
async fn test_producer_aggregator_error_deagg() {
let record = record();
let linger = Duration::from_millis(5);
let client = Arc::new(MockClient {
error: None,
panic: None,
delay: Duration::from_millis(1),
batch_sizes: Default::default(),
});
let aggregator = MockAggregator {
inner: RecordAggregator::new(record.approximate_size() * 2),
push_errors: vec![],
flush_errors: vec![],
deagg_errors: vec![vec![Some("test".to_owned().into()), None]],
};
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let futures = FuturesUnordered::new();
futures.push(producer.produce(record.clone()));
futures.push(producer.produce(record.clone()));
let mut results = futures.map_err(|e| e.to_string()).collect::<Vec<_>>().await;
results.sort();
assert_eq!(
results,
vec![Ok(1), Err("Aggregator error: test".to_owned())],
);
}
#[tokio::test]
async fn test_producer_aggregator_cancel() {
let record = record();
let linger = Duration::from_micros(100);
let client = Arc::new(MockClient {
error: None,
panic: None,
delay: Duration::from_millis(10),
batch_sizes: Default::default(),
});
let aggregator = RecordAggregator::new(record.approximate_size() * 2);
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let a = producer.produce(record.clone()).fuse();
let b = producer.produce(record).fuse();
pin_mut!(b);
{
pin_mut!(a);
futures::select_biased! {
_ = &mut a => panic!("a should not have flushed"),
_ = &mut b => panic!("b should not have flushed"),
_ = tokio::time::sleep(Duration::from_millis(1)).fuse() => {},
}
}
tokio::time::timeout(Duration::from_secs(1), b)
.await
.unwrap()
.unwrap();
assert_eq!(client.batch_sizes.lock().as_slice(), &[2]);
}
#[tokio::test]
async fn test_producer_aggregator_panic() {
let record = record();
let linger = Duration::from_millis(100);
let client = Arc::new(MockClient {
error: None,
panic: Some("test panic".into()),
delay: Duration::from_millis(0),
batch_sizes: Default::default(),
});
let aggregator = RecordAggregator::new(record.approximate_size() * 2);
let producer = BatchProducerBuilder::new_with_client(Arc::<MockClient>::clone(&client))
.with_linger(linger)
.build(aggregator);
let a = producer.produce(record.clone());
let b = producer.produce(record);
let (a, b) = futures::future::join(a, b).await;
assert!(matches!(&a, Err(Error::FlushError(_))));
assert!(matches!(&b, Err(Error::FlushError(_))));
}
#[derive(Debug)]
struct MockAggregator {
inner: RecordAggregator,
push_errors: Vec<aggregator::Error>,
flush_errors: Vec<aggregator::Error>,
deagg_errors: Vec<Vec<Option<aggregator::Error>>>,
}
impl Aggregator for MockAggregator {
type Input = Record;
type Tag = usize;
type StatusDeaggregator = MockDeaggregator;
fn try_push(
&mut self,
record: Self::Input,
) -> Result<TryPush<Self::Input, Self::Tag>, aggregator::Error> {
if !self.push_errors.is_empty() {
return Err(self.push_errors.remove(0));
}
Ok(self.inner.try_push(record).unwrap())
}
fn flush(&mut self) -> Result<(Vec<Record>, Self::StatusDeaggregator), aggregator::Error> {
if !self.flush_errors.is_empty() {
return Err(self.flush_errors.remove(0));
}
let deagg_errors = if self.deagg_errors.is_empty() {
vec![]
} else {
self.deagg_errors.remove(0)
};
let (records, deagg) = self.inner.flush().unwrap();
Ok((
records,
MockDeaggregator {
inner: deagg,
errors: std::sync::Mutex::new(deagg_errors),
},
))
}
}
#[derive(Debug)]
struct MockDeaggregator {
inner: RecordAggregatorStatusDeaggregator,
errors: std::sync::Mutex<Vec<Option<aggregator::Error>>>,
}
impl StatusDeaggregator for MockDeaggregator {
type Status = i64;
type Tag = usize;
fn deaggregate(
&self,
input: &[i64],
tag: Self::Tag,
) -> Result<Self::Status, aggregator::Error> {
let mut errors = self.errors.lock().unwrap();
if let Some(e) = errors.get_mut(tag) {
if let Some(e) = std::mem::take(e) {
return Err(e);
}
}
Ok(self.inner.deaggregate(input, tag).unwrap())
}
}
}