use std::collections::VecDeque;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use futures::Stream;
use futures::future::{BoxFuture, Fuse, FusedFuture, FutureExt};
use tracing::{debug, trace, warn};
use crate::{
client::{
error::{Error, ProtocolError, Result},
partition::PartitionClient,
},
record::RecordAndOffset,
};
use super::partition::OffsetAt;
#[derive(Debug, Clone, Copy)]
pub enum StartOffset {
Earliest,
Latest,
At(i64),
}
#[derive(Debug)]
pub struct StreamConsumerBuilder {
client: Arc<dyn FetchClient>,
start_offset: StartOffset,
max_wait_ms: i32,
min_batch_size: i32,
max_batch_size: i32,
}
impl StreamConsumerBuilder {
pub fn new(client: Arc<PartitionClient>, start_offset: StartOffset) -> Self {
Self::new_with_client(client, start_offset)
}
fn new_with_client(client: Arc<dyn FetchClient>, start_offset: StartOffset) -> Self {
Self {
client,
start_offset,
max_wait_ms: 500,
min_batch_size: 1,
max_batch_size: 52428800,
}
}
pub fn with_min_batch_size(self, min_batch_size: i32) -> Self {
Self {
min_batch_size,
..self
}
}
pub fn with_max_batch_size(self, max_batch_size: i32) -> Self {
Self {
max_batch_size,
..self
}
}
pub fn with_max_wait_ms(self, max_wait_ms: i32) -> Self {
Self {
max_wait_ms,
..self
}
}
pub fn build(self) -> StreamConsumer {
StreamConsumer {
client: self.client,
max_wait_ms: self.max_wait_ms,
min_batch_size: self.min_batch_size,
max_batch_size: self.max_batch_size,
next_offset: None,
next_backoff: None,
start_offset: self.start_offset,
terminated: false,
last_high_watermark: -1,
buffer: Default::default(),
fetch_fut: Fuse::terminated(),
}
}
}
struct FetchResultOk {
records_and_offsets: Vec<RecordAndOffset>,
watermark: i64,
used_offset: i64,
}
type FetchResult = Result<FetchResultOk>;
trait FetchClient: std::fmt::Debug + Send + Sync {
fn fetch_records(
&self,
offset: i64,
bytes: Range<i32>,
max_wait_ms: i32,
) -> BoxFuture<'_, Result<(Vec<RecordAndOffset>, i64)>>;
fn get_offset(&self, at: OffsetAt) -> BoxFuture<'_, Result<i64>>;
}
impl FetchClient for PartitionClient {
fn fetch_records(
&self,
offset: i64,
bytes: Range<i32>,
max_wait_ms: i32,
) -> BoxFuture<'_, Result<(Vec<RecordAndOffset>, i64)>> {
Box::pin(self.fetch_records(offset, bytes, max_wait_ms))
}
fn get_offset(&self, at: OffsetAt) -> BoxFuture<'_, Result<i64>> {
Box::pin(self.get_offset(at))
}
}
pub struct StreamConsumer {
client: Arc<dyn FetchClient>,
min_batch_size: i32,
max_batch_size: i32,
max_wait_ms: i32,
start_offset: StartOffset,
next_offset: Option<i64>,
next_backoff: Option<Duration>,
terminated: bool,
last_high_watermark: i64,
buffer: VecDeque<RecordAndOffset>,
fetch_fut: Fuse<BoxFuture<'static, FetchResult>>,
}
impl Stream for StreamConsumer {
type Item = Result<(RecordAndOffset, i64)>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.terminated {
return Poll::Ready(None);
}
if let Some(x) = self.buffer.pop_front() {
return Poll::Ready(Some(Ok((x, self.last_high_watermark))));
}
if self.fetch_fut.is_terminated() {
let next_offset = self.next_offset;
let start_offset = self.start_offset;
let bytes = (self.min_batch_size)..(self.max_batch_size);
let max_wait_ms = self.max_wait_ms;
let next_backoff = std::mem::take(&mut self.next_backoff);
let client = Arc::clone(&self.client);
trace!(?start_offset, ?next_offset, "Fetching records at offset");
self.fetch_fut = FutureExt::fuse(Box::pin(async move {
if let Some(backoff) = next_backoff {
tokio::time::sleep(backoff).await;
}
let offset = match next_offset {
Some(x) => x,
None => match start_offset {
StartOffset::Earliest => {
let offset = client.get_offset(OffsetAt::Earliest).await?;
debug!(offset, "resolved `earliest` offset");
offset
}
StartOffset::Latest => {
let offset = client.get_offset(OffsetAt::Latest).await?;
debug!(offset, "resolved `latest` offset");
offset
}
StartOffset::At(x) => x,
},
};
let (records_and_offsets, watermark) =
client.fetch_records(offset, bytes, max_wait_ms).await?;
Ok(FetchResultOk {
records_and_offsets,
watermark,
used_offset: offset,
})
}));
}
let data: FetchResult = futures::ready!(self.fetch_fut.poll_unpin(cx));
match (data, self.start_offset) {
(Ok(inner), _) => {
let FetchResultOk {
mut records_and_offsets,
watermark,
used_offset,
} = inner;
trace!(
high_watermark = watermark,
n_records = records_and_offsets.len(),
"Received records and a high watermark",
);
self.next_offset = Some(used_offset);
records_and_offsets.sort_by_key(|x| x.offset);
self.last_high_watermark = watermark;
if let Some(x) = records_and_offsets.last() {
self.next_offset = Some(x.offset + 1);
self.buffer.extend(records_and_offsets)
}
continue;
}
(
Err(Error::ServerError {
protocol_error: ProtocolError::OffsetOutOfRange,
..
}),
StartOffset::Earliest | StartOffset::Latest,
) => {
self.next_offset = None;
let backoff_secs = 1;
warn!(
start_offset=?self.start_offset,
backoff_secs,
"Records are gone between ListOffsets and Fetch, backoff a bit",
);
self.next_backoff = Some(Duration::from_secs(backoff_secs));
continue;
}
(Err(e), _) => {
self.terminated = true;
return Poll::Ready(Some(Err(e)));
}
}
}
}
}
impl std::fmt::Debug for StreamConsumer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamConsumer")
.field("client", &self.client)
.field("min_batch_size", &self.min_batch_size)
.field("max_batch_size", &self.max_batch_size)
.field("max_wait_ms", &self.max_wait_ms)
.field("next_offset", &self.next_offset)
.field("terminated", &self.terminated)
.field("last_high_watermark", &self.last_high_watermark)
.field("buffer", &self.buffer)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use assert_matches::assert_matches;
use chrono::{TimeZone, Utc};
use futures::{StreamExt, pin_mut};
use tokio::sync::{Mutex, mpsc};
use crate::{
client::error::{Error, ProtocolError, RequestContext},
record::Record,
};
use super::*;
#[derive(Debug)]
struct MockFetch {
inner: Arc<Mutex<MockFetchInner>>,
}
#[derive(Debug)]
struct MockFetchInner {
batch_sizes: Vec<usize>,
stream: mpsc::Receiver<Record>,
next_err: Option<Error>,
buffer: Vec<Record>,
range: (i64, i64),
}
impl MockFetch {
fn new(stream: mpsc::Receiver<Record>, next_err: Option<Error>, range: (i64, i64)) -> Self {
Self {
inner: Arc::new(Mutex::new(MockFetchInner {
batch_sizes: vec![],
stream,
buffer: Default::default(),
next_err,
range,
})),
}
}
async fn batch_sizes(&self) -> Vec<usize> {
self.inner.lock().await.batch_sizes.clone()
}
}
impl FetchClient for MockFetch {
fn fetch_records(
&self,
start_offset: i64,
bytes: Range<i32>,
max_wait_ms: i32,
) -> BoxFuture<'_, Result<(Vec<RecordAndOffset>, i64)>> {
let inner = Arc::clone(&self.inner);
Box::pin(async move {
if let Some(err) = inner.lock().await.next_err.take() {
return Err(err);
}
println!("MockFetch::fetch_records");
let mut inner = inner.lock().await;
println!("MockFetch::fetch_records locked");
let mut buffer = vec![];
let mut buffered = 0;
while let Ok(x) = inner.stream.try_recv() {
inner.buffer.push(x)
}
for (record_offset, record) in
inner.buffer.iter().enumerate().skip(start_offset as usize)
{
let size = record.approximate_size();
if size + buffered > bytes.end as usize {
assert_ne!(buffered, 0, "record too large");
break;
}
buffer.push(RecordAndOffset {
record: record.clone(),
offset: record_offset as i64,
});
buffered += size;
}
println!("Waiting up to {} ms for more data", max_wait_ms);
let timeout = tokio::time::sleep(Duration::from_millis(max_wait_ms as u64)).fuse();
pin_mut!(timeout);
while buffered < bytes.start as usize && !timeout.is_terminated() {
futures::select! {
maybe_record = inner.stream.recv().fuse() => match maybe_record {
Some(record) => {
println!("Received a new record");
let size = record.approximate_size();
let record_offset = inner.buffer.len() as i64;
inner.buffer.push(record.clone());
if record_offset < start_offset {
continue
}
if size + buffered > bytes.end as usize {
assert_ne!(buffered, 0, "record too large");
break;
}
buffer.push(RecordAndOffset {
record,
offset: record_offset,
});
buffered += size
}
None => break,
},
_ = timeout => {
println!("Timeout receiving records");
break
},
}
}
inner.batch_sizes.push(buffer.len());
Ok((buffer, inner.buffer.len() as i64 - 1))
})
}
fn get_offset(&self, at: OffsetAt) -> BoxFuture<'_, Result<i64>> {
let inner = Arc::clone(&self.inner);
Box::pin(async move {
match at {
OffsetAt::Earliest => Ok(inner.lock().await.range.0),
OffsetAt::Latest => Ok(inner.lock().await.range.1),
OffsetAt::Timestamp(_) => {
unreachable!("timestamp based offset is tested in e2e test")
}
}
})
}
}
#[tokio::test]
async fn test_consumer() {
let record = Record {
key: Some(vec![0; 4]),
value: Some(vec![0; 6]),
headers: Default::default(),
timestamp: Utc.timestamp_millis_opt(1337).unwrap(),
};
let (sender, receiver) = mpsc::channel(10);
let consumer = Arc::new(MockFetch::new(receiver, None, (0, 1_000)));
let mut stream = StreamConsumerBuilder::new_with_client(
Arc::<MockFetch>::clone(&consumer),
StartOffset::At(2),
)
.with_max_wait_ms(10)
.build();
assert_stream_pending(&mut stream).await;
sender.send(record.clone()).await.unwrap();
sender.send(record.clone()).await.unwrap();
assert_stream_pending(&mut stream).await;
sender.send(record.clone()).await.unwrap();
let unwrap = |e: Result<Option<Result<_, _>>, _>| e.unwrap().unwrap().unwrap();
let (record_and_offset, high_watermark) =
unwrap(tokio::time::timeout(Duration::from_micros(10), stream.next()).await);
assert_eq!(record_and_offset.offset, 2);
assert_eq!(high_watermark, 2);
sender.send(record.clone()).await.unwrap();
sender.send(record.clone()).await.unwrap();
sender.send(record.clone()).await.unwrap();
let (record_and_offset, high_watermark) =
unwrap(tokio::time::timeout(Duration::from_micros(1), stream.next()).await);
assert_eq!(record_and_offset.offset, 3);
assert_eq!(high_watermark, 5);
let (record_and_offset, high_watermark) =
tokio::time::timeout(Duration::from_millis(1), stream.next())
.await
.unwrap()
.unwrap()
.unwrap();
assert_eq!(record_and_offset.offset, 4);
assert_eq!(high_watermark, 5);
let (record_and_offset, high_watermark) =
tokio::time::timeout(Duration::from_millis(1), stream.next())
.await
.unwrap()
.unwrap()
.unwrap();
assert_eq!(record_and_offset.offset, 5);
assert_eq!(high_watermark, 5);
let received = consumer.batch_sizes().await;
assert_eq!(&received, &[1, 3]);
}
#[tokio::test]
async fn test_consumer_timeout() {
let record = Record {
key: Some(vec![0; 4]),
value: Some(vec![0; 6]),
headers: Default::default(),
timestamp: Utc.timestamp_millis_opt(1337).unwrap(),
};
let (sender, receiver) = mpsc::channel(10);
let consumer = Arc::new(MockFetch::new(receiver, None, (0, 1_000)));
assert!(consumer.batch_sizes().await.is_empty());
let mut stream = StreamConsumerBuilder::new_with_client(
Arc::<MockFetch>::clone(&consumer),
StartOffset::At(0),
)
.with_min_batch_size((record.approximate_size() * 2) as i32)
.with_max_batch_size((record.approximate_size() * 3) as i32)
.with_max_wait_ms(5)
.build();
assert_stream_pending(&mut stream).await;
let received = tokio::select! {
_ = stream.next() => panic!("stream returned!"),
x = consumer.batch_sizes() => x,
};
assert!(!received.is_empty());
assert!(received.iter().all(|x| *x == 0));
sender.send(record.clone()).await.unwrap();
tokio::time::timeout(Duration::from_millis(10), stream.next())
.await
.unwrap()
.unwrap()
.unwrap();
sender.send(record.clone()).await.unwrap();
sender.send(record.clone()).await.unwrap();
tokio::time::timeout(Duration::from_micros(1), stream.next())
.await
.unwrap()
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_micros(1), stream.next())
.await
.unwrap()
.unwrap()
.unwrap();
}
#[tokio::test]
async fn test_consumer_terminate() {
let e = Error::ServerError {
protocol_error: ProtocolError::OffsetOutOfRange,
error_message: None,
request: RequestContext::Partition("foo".into(), 1),
response: None,
is_virtual: true,
};
let (_sender, receiver) = mpsc::channel(10);
let consumer = Arc::new(MockFetch::new(receiver, Some(e), (0, 1_000)));
let mut stream =
StreamConsumerBuilder::new_with_client(consumer, StartOffset::At(0)).build();
let error = stream.next().await.expect("stream not empty").unwrap_err();
assert_matches!(
error,
Error::ServerError {
protocol_error: ProtocolError::OffsetOutOfRange,
..
}
);
assert!(stream.next().await.is_none());
}
#[tokio::test]
async fn test_consumer_earliest() {
let record = Record {
key: Some(vec![0; 4]),
value: Some(vec![0; 6]),
headers: Default::default(),
timestamp: Utc.timestamp_millis_opt(1337).unwrap(),
};
let e = Error::ServerError {
protocol_error: ProtocolError::OffsetOutOfRange,
error_message: None,
request: RequestContext::Partition("foo".into(), 1),
response: None,
is_virtual: true,
};
let (sender, receiver) = mpsc::channel(10);
let consumer = Arc::new(MockFetch::new(receiver, Some(e), (2, 1_000)));
let mut stream = StreamConsumerBuilder::new_with_client(
Arc::<MockFetch>::clone(&consumer),
StartOffset::Earliest,
)
.with_max_wait_ms(10)
.build();
assert_stream_pending(&mut stream).await;
sender.send(record.clone()).await.unwrap();
sender.send(record.clone()).await.unwrap();
assert_stream_pending(&mut stream).await;
sender.send(record.clone()).await.unwrap();
let unwrap = |e: Result<Option<Result<_, _>>, _>| e.unwrap().unwrap().unwrap();
let (record_and_offset, high_watermark) =
unwrap(tokio::time::timeout(Duration::from_secs(2), stream.next()).await);
assert_eq!(record_and_offset.offset, 2);
assert_eq!(high_watermark, 2);
}
#[tokio::test]
async fn test_consumer_latest() {
let record = Record {
key: Some(vec![0; 4]),
value: Some(vec![0; 6]),
headers: Default::default(),
timestamp: Utc.timestamp_millis_opt(1337).unwrap(),
};
let e = Error::ServerError {
protocol_error: ProtocolError::OffsetOutOfRange,
error_message: None,
request: RequestContext::Partition("foo".into(), 1),
response: None,
is_virtual: true,
};
let (sender, receiver) = mpsc::channel(10);
let consumer = Arc::new(MockFetch::new(receiver, Some(e), (0, 2)));
let mut stream = StreamConsumerBuilder::new_with_client(
Arc::<MockFetch>::clone(&consumer),
StartOffset::Latest,
)
.with_max_wait_ms(10)
.build();
assert_stream_pending(&mut stream).await;
sender.send(record.clone()).await.unwrap();
sender.send(record.clone()).await.unwrap();
assert_stream_pending(&mut stream).await;
sender.send(record.clone()).await.unwrap();
let unwrap = |e: Result<Option<Result<_, _>>, _>| e.unwrap().unwrap().unwrap();
let (record_and_offset, high_watermark) =
unwrap(tokio::time::timeout(Duration::from_secs(2), stream.next()).await);
assert_eq!(record_and_offset.offset, 2);
assert_eq!(high_watermark, 2);
}
async fn assert_stream_pending<S>(stream: &mut S)
where
S: Stream + Send + Unpin,
S::Item: std::fmt::Debug,
{
tokio::select! {
e = stream.next() => panic!("stream is not pending, yielded: {e:?}"),
_ = tokio::time::sleep(Duration::from_millis(1)) => {},
};
}
}