use std::{
error::Error,
io,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use bytes::{Buf, Bytes, BytesMut};
use corro_api_types::{ChangeId, QueryEvent, TypedNotifyEvent, TypedQueryEvent};
use futures::{ready, Future, Stream};
use pin_project_lite::pin_project;
use serde::de::DeserializeOwned;
use tokio::time::{sleep, Sleep};
use tokio_util::{
codec::{Decoder, FramedRead, LinesCodecError},
io::StreamReader,
};
use tracing::error;
use uuid::Uuid;
pin_project! {
pub struct IoBodyStream {
#[pin]
body: reqwest::Body
}
}
impl Stream for IoBodyStream {
type Item = io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use http_body::Body;
let this = self.project();
let res = ready!(this.body.poll_frame(cx));
match res {
Some(Ok(b)) => Poll::Ready(Some(
b.into_data()
.map_err(|_| io::Error::other("not a data frame")),
)),
Some(Err(e)) => {
let io_err = match e
.source()
.and_then(|source| source.downcast_ref::<io::Error>())
{
Some(io_err) => io::Error::from(io_err.kind()),
None => io::Error::other(e),
};
Poll::Ready(Some(Err(io_err)))
}
None => Poll::Ready(None),
}
}
}
type IoBodyStreamReader = StreamReader<IoBodyStream, Bytes>;
type FramedBody = FramedRead<IoBodyStreamReader, LinesBytesCodec>;
type ResponseFuture =
Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>> + Unpin + Send + Sync>;
pub struct SubscriptionStream<T> {
id: Uuid,
hash: Option<String>,
client: reqwest::Client,
api_addr: SocketAddr,
observed_eoq: bool,
last_change_id: Option<ChangeId>,
stream: Option<FramedBody>,
backoff: Option<Pin<Box<Sleep>>>,
backoff_count: u32,
response: Option<ResponseFuture>,
_deser: std::marker::PhantomData<T>,
}
#[derive(Debug, thiserror::Error)]
pub enum SubscriptionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Http(#[from] http::Error),
#[error(transparent)]
Deserialize(#[from] serde_json::Error),
#[error("missed a change (expected: {expected}, got: {got}), inconsistent state")]
MissedChange { expected: ChangeId, got: ChangeId },
#[error("max line length exceeded")]
MaxLineLengthExceeded,
#[error("initial query never finished")]
UnfinishedQuery,
#[error("max retry attempts exceeded")]
MaxRetryAttempts,
}
impl<T> SubscriptionStream<T>
where
T: DeserializeOwned + Unpin,
{
pub fn new(
id: Uuid,
hash: Option<String>,
client: reqwest::Client,
api_addr: SocketAddr,
body: reqwest::Body,
change_id: Option<ChangeId>,
) -> Self {
Self {
id,
hash,
client,
api_addr,
observed_eoq: change_id.is_some(),
last_change_id: change_id,
stream: Some(FramedRead::new(
StreamReader::new(IoBodyStream { body }),
LinesBytesCodec::default(),
)),
backoff: None,
backoff_count: 0,
response: None,
_deser: Default::default(),
}
}
pub fn id(&self) -> Uuid {
self.id
}
pub fn hash(&self) -> Option<&str> {
self.hash.as_deref()
}
pub fn api_addr(&self) -> SocketAddr {
self.api_addr
}
fn poll_stream(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<TypedQueryEvent<T>, SubscriptionError>>> {
let stream = loop {
match self.stream.as_mut() {
None => match ready!(self.as_mut().poll_request(cx)) {
Ok(stream) => {
self.stream = Some(stream);
}
Err(e) => return Poll::Ready(Some(Err(e))),
},
Some(stream) => {
break stream;
}
}
};
let res = ready!(Pin::new(stream).poll_next(cx));
match res {
Some(Ok(b)) => match serde_json::from_slice(&b) {
Ok(evt) => {
if let TypedQueryEvent::EndOfQuery { change_id, .. } = &evt {
self.handle_eoq(*change_id);
}
if let TypedQueryEvent::Change(_, _, _, change_id) = &evt {
if let Err(e) = self.handle_change(*change_id) {
return Poll::Ready(Some(Err(e)));
}
}
Poll::Ready(Some(Ok(evt)))
}
Err(deser_err) => {
if let Ok(evt) = serde_json::from_slice::<QueryEvent>(&b) {
if let TypedQueryEvent::EndOfQuery { change_id, .. } = &evt {
self.handle_eoq(*change_id);
}
if let TypedQueryEvent::Change(_, _, _, change_id) = &evt {
if let Err(e) = self.handle_change(*change_id) {
return Poll::Ready(Some(Err(e)));
}
}
}
Poll::Ready(Some(Err(deser_err.into())))
}
},
Some(Err(e)) => match e {
LinesCodecError::MaxLineLengthExceeded => {
Poll::Ready(Some(Err(SubscriptionError::MaxLineLengthExceeded)))
}
LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
},
None => Poll::Ready(None),
}
}
fn handle_eoq(&mut self, change_id: Option<ChangeId>) {
self.observed_eoq = true;
self.last_change_id = change_id;
}
fn handle_change(&mut self, change_id: ChangeId) -> Result<(), SubscriptionError> {
match self.last_change_id {
Some(id) if id + 1 != change_id => {
return Err(SubscriptionError::MissedChange {
expected: id + 1,
got: change_id,
})
}
_ => (),
}
self.last_change_id = Some(change_id);
Ok(())
}
fn poll_request(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<FramedBody, SubscriptionError>> {
loop {
if let Some(res_fut) = self.response.as_mut() {
let res = ready!(Pin::new(res_fut).poll(cx));
self.response = None;
return match res {
Ok(res) => Poll::Ready(Ok(FramedRead::new(
StreamReader::new(IoBodyStream { body: res.into() }),
LinesBytesCodec::default(),
))),
Err(e) => {
let io_err = match e
.source()
.and_then(|source| source.downcast_ref::<io::Error>())
{
Some(io_err) => io::Error::from(io_err.kind()),
None => io::Error::other(e),
};
Poll::Ready(Err(io_err.into()))
}
};
} else if self.observed_eoq {
let response = self
.client
.get(format!(
"http://{}/v1/subscriptions/{}?from={}",
self.api_addr,
self.id,
self.last_change_id.unwrap_or_default()
))
.header(http::header::ACCEPT, "application/json")
.send();
self.response = Some(Box::new(response));
} else {
return Poll::Ready(Err(SubscriptionError::UnfinishedQuery));
}
}
}
}
impl<T> Stream for SubscriptionStream<T>
where
T: DeserializeOwned + Unpin,
{
type Item = Result<TypedQueryEvent<T>, SubscriptionError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(backoff) = self.backoff.as_mut() {
ready!(backoff.as_mut().poll(cx));
self.backoff = None;
}
let io_err = match ready!(self.as_mut().poll_stream(cx)) {
Some(Err(SubscriptionError::Io(io_err))) => io_err,
other => {
self.backoff_count = 0;
return Poll::Ready(other);
}
};
self.stream = None;
if self.backoff_count >= 10 {
return Poll::Ready(Some(Err(SubscriptionError::MaxRetryAttempts)));
}
error!("encountered a stream IO error: {io_err}, retrying in a bit");
let mut backoff = Box::pin(sleep(Duration::from_secs(1)));
_ = backoff.as_mut().poll(cx);
self.backoff = Some(backoff);
self.backoff_count += 1;
Poll::Pending
}
}
pub struct UpdatesStream<T> {
id: Uuid,
stream: FramedBody,
_deser: std::marker::PhantomData<T>,
}
#[derive(Debug, thiserror::Error)]
pub enum UpdatesError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Deserialize(#[from] serde_json::Error),
#[error("max line length exceeded")]
MaxLineLengthExceeded,
}
impl<T> UpdatesStream<T>
where
T: DeserializeOwned + Unpin,
{
pub fn new(id: Uuid, body: reqwest::Body) -> Self {
Self {
id,
stream: FramedRead::new(
StreamReader::new(IoBodyStream { body }),
LinesBytesCodec::default(),
),
_deser: Default::default(),
}
}
pub fn id(&self) -> Uuid {
self.id
}
}
impl<T> Stream for UpdatesStream<T>
where
T: DeserializeOwned + Unpin,
{
type Item = Result<TypedNotifyEvent<T>, UpdatesError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let res = ready!(Pin::new(&mut self.stream).poll_next(cx));
match res {
Some(Ok(b)) => match serde_json::from_slice(&b) {
Ok(evt) => Poll::Ready(Some(Ok(evt))),
Err(e) => Poll::Ready(Some(Err(e.into()))),
},
Some(Err(e)) => match e {
LinesCodecError::MaxLineLengthExceeded => {
Poll::Ready(Some(Err(UpdatesError::MaxLineLengthExceeded)))
}
LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
},
None => Poll::Ready(None),
}
}
}
pub struct QueryStream<T> {
stream: FramedBody,
_deser: std::marker::PhantomData<T>,
}
#[derive(Debug, thiserror::Error)]
pub enum QueryError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Deserialize(#[from] serde_json::Error),
#[error("max line length exceeded")]
MaxLineLengthExceeded,
}
impl<T> QueryStream<T>
where
T: DeserializeOwned + Unpin,
{
pub fn new(body: reqwest::Body) -> Self {
Self {
stream: FramedRead::new(
StreamReader::new(IoBodyStream { body }),
LinesBytesCodec::default(),
),
_deser: Default::default(),
}
}
}
impl<T> Stream for QueryStream<T>
where
T: DeserializeOwned + Unpin,
{
type Item = Result<TypedQueryEvent<T>, QueryError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(Pin::new(&mut self.stream).poll_next(cx)) {
Some(Ok(b)) => match serde_json::from_slice(&b) {
Ok(evt) => Poll::Ready(Some(Ok(evt))),
Err(e) => Poll::Ready(Some(Err(e.into()))),
},
Some(Err(e)) => match e {
LinesCodecError::MaxLineLengthExceeded => {
Poll::Ready(Some(Err(QueryError::MaxLineLengthExceeded)))
}
LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
},
None => Poll::Ready(None),
}
}
}
pub struct LinesBytesCodec {
next_index: usize,
max_length: usize,
is_discarding: bool,
}
impl Default for LinesBytesCodec {
fn default() -> Self {
LinesBytesCodec {
next_index: 0,
max_length: usize::MAX,
is_discarding: false,
}
}
}
impl Decoder for LinesBytesCodec {
type Item = BytesMut;
type Error = LinesCodecError;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
loop {
let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
let newline_offset = buf[self.next_index..read_to]
.iter()
.position(|b| *b == b'\n');
match (self.is_discarding, newline_offset) {
(true, Some(offset)) => {
buf.advance(offset + self.next_index + 1);
self.is_discarding = false;
self.next_index = 0;
}
(true, None) => {
buf.advance(read_to);
self.next_index = 0;
if buf.is_empty() {
return Ok(None);
}
}
(false, Some(offset)) => {
let newline_index = offset + self.next_index;
self.next_index = 0;
let mut line = buf.split_to(newline_index + 1);
line.truncate(line.len() - 1);
without_carriage_return(&mut line);
return Ok(Some(line));
}
(false, None) if buf.len() > self.max_length => {
self.is_discarding = true;
return Err(LinesCodecError::MaxLineLengthExceeded);
}
(false, None) => {
self.next_index = read_to;
return Ok(None);
}
}
}
}
fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
Ok(match self.decode(buf)? {
Some(frame) => Some(frame),
None => {
if buf.is_empty() || buf == &b"\r"[..] {
None
} else {
let mut line = buf.split_to(buf.len());
line.truncate(line.len() - 1);
without_carriage_return(&mut line);
self.next_index = 0;
Some(line)
}
}
})
}
}
fn without_carriage_return(s: &mut BytesMut) {
if let Some(&b'\r') = s.last() {
s.truncate(s.len() - 1);
}
}