1use std::{
17 error::Error,
18 io,
19 net::SocketAddr,
20 pin::Pin,
21 task::{Context, Poll},
22 time::Duration,
23};
24
25use bytes::{Buf, Bytes, BytesMut};
26use corro_api_types::{ChangeId, QueryEvent, TypedNotifyEvent, TypedQueryEvent};
27use futures::{ready, Future, Stream};
28use pin_project_lite::pin_project;
29use serde::de::DeserializeOwned;
30use tokio::time::{sleep, Sleep};
31use tokio_util::{
32 codec::{Decoder, FramedRead, LinesCodecError},
33 io::StreamReader,
34};
35use tracing::error;
36use uuid::Uuid;
37
38pin_project! {
39 pub struct IoBodyStream {
43 #[pin]
44 body: reqwest::Body
45 }
46}
47
48impl Stream for IoBodyStream {
49 type Item = io::Result<Bytes>;
50
51 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52 use http_body::Body;
53 let this = self.project();
54 let res = ready!(this.body.poll_frame(cx));
55 match res {
56 Some(Ok(b)) => Poll::Ready(Some(
57 b.into_data()
58 .map_err(|_| io::Error::other("not a data frame")),
59 )),
60 Some(Err(e)) => {
61 let io_err = match e
62 .source()
63 .and_then(|source| source.downcast_ref::<io::Error>())
64 {
65 Some(io_err) => io::Error::from(io_err.kind()),
66 None => io::Error::other(e),
67 };
68 Poll::Ready(Some(Err(io_err)))
69 }
70 None => Poll::Ready(None),
71 }
72 }
73}
74
75type IoBodyStreamReader = StreamReader<IoBodyStream, Bytes>;
76type FramedBody = FramedRead<IoBodyStreamReader, LinesBytesCodec>;
77type ResponseFuture =
78 Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>> + Unpin + Send + Sync>;
79
80pub struct SubscriptionStream<T> {
91 id: Uuid,
92 hash: Option<String>,
93 client: reqwest::Client,
94 api_addr: SocketAddr,
95 observed_eoq: bool,
96 last_change_id: Option<ChangeId>,
97 stream: Option<FramedBody>,
98 backoff: Option<Pin<Box<Sleep>>>,
99 backoff_count: u32,
100 response: Option<ResponseFuture>,
101 _deser: std::marker::PhantomData<T>,
102}
103
104#[derive(Debug, thiserror::Error)]
106pub enum SubscriptionError {
107 #[error(transparent)]
109 Io(#[from] io::Error),
110 #[error(transparent)]
112 Http(#[from] http::Error),
113 #[error(transparent)]
115 Deserialize(#[from] serde_json::Error),
116 #[error("missed a change (expected: {expected}, got: {got}), inconsistent state")]
119 MissedChange { expected: ChangeId, got: ChangeId },
120 #[error("max line length exceeded")]
122 MaxLineLengthExceeded,
123 #[error("initial query never finished")]
126 UnfinishedQuery,
127 #[error("max retry attempts exceeded")]
129 MaxRetryAttempts,
130}
131
132impl<T> SubscriptionStream<T>
133where
134 T: DeserializeOwned + Unpin,
135{
136 pub fn new(
137 id: Uuid,
138 hash: Option<String>,
139 client: reqwest::Client,
140 api_addr: SocketAddr,
141 body: reqwest::Body,
142 change_id: Option<ChangeId>,
143 ) -> Self {
144 Self {
145 id,
146 hash,
147 client,
148 api_addr,
149 observed_eoq: change_id.is_some(),
150 last_change_id: change_id,
151 stream: Some(FramedRead::new(
152 StreamReader::new(IoBodyStream { body }),
153 LinesBytesCodec::default(),
154 )),
155 backoff: None,
156 backoff_count: 0,
157 response: None,
158 _deser: Default::default(),
159 }
160 }
161
162 pub fn id(&self) -> Uuid {
168 self.id
169 }
170
171 pub fn hash(&self) -> Option<&str> {
174 self.hash.as_deref()
175 }
176
177 pub fn api_addr(&self) -> SocketAddr {
179 self.api_addr
180 }
181
182 fn poll_stream(
183 mut self: Pin<&mut Self>,
184 cx: &mut Context<'_>,
185 ) -> Poll<Option<Result<TypedQueryEvent<T>, SubscriptionError>>> {
186 let stream = loop {
187 match self.stream.as_mut() {
188 None => match ready!(self.as_mut().poll_request(cx)) {
189 Ok(stream) => {
190 self.stream = Some(stream);
191 }
192 Err(e) => return Poll::Ready(Some(Err(e))),
193 },
194 Some(stream) => {
195 break stream;
196 }
197 }
198 };
199
200 let res = ready!(Pin::new(stream).poll_next(cx));
201 match res {
202 Some(Ok(b)) => match serde_json::from_slice(&b) {
203 Ok(evt) => {
204 if let TypedQueryEvent::EndOfQuery { change_id, .. } = &evt {
205 self.handle_eoq(*change_id);
206 }
207
208 if let TypedQueryEvent::Change(_, _, _, change_id) = &evt {
209 if let Err(e) = self.handle_change(*change_id) {
210 return Poll::Ready(Some(Err(e)));
211 }
212 }
213
214 Poll::Ready(Some(Ok(evt)))
215 }
216 Err(deser_err) => {
217 if let Ok(evt) = serde_json::from_slice::<QueryEvent>(&b) {
219 if let TypedQueryEvent::EndOfQuery { change_id, .. } = &evt {
220 self.handle_eoq(*change_id);
221 }
222
223 if let TypedQueryEvent::Change(_, _, _, change_id) = &evt {
224 if let Err(e) = self.handle_change(*change_id) {
225 return Poll::Ready(Some(Err(e)));
226 }
227 }
228 }
229
230 Poll::Ready(Some(Err(deser_err.into())))
232 }
233 },
234 Some(Err(e)) => match e {
235 LinesCodecError::MaxLineLengthExceeded => {
236 Poll::Ready(Some(Err(SubscriptionError::MaxLineLengthExceeded)))
237 }
238 LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
239 },
240 None => Poll::Ready(None),
241 }
242 }
243
244 fn handle_eoq(&mut self, change_id: Option<ChangeId>) {
245 self.observed_eoq = true;
246 self.last_change_id = change_id;
247 }
248
249 fn handle_change(&mut self, change_id: ChangeId) -> Result<(), SubscriptionError> {
250 match self.last_change_id {
251 Some(id) if id + 1 != change_id => {
252 return Err(SubscriptionError::MissedChange {
253 expected: id + 1,
254 got: change_id,
255 })
256 }
257 _ => (),
258 }
259
260 self.last_change_id = Some(change_id);
261
262 Ok(())
263 }
264
265 fn poll_request(
266 mut self: Pin<&mut Self>,
267 cx: &mut Context<'_>,
268 ) -> Poll<Result<FramedBody, SubscriptionError>> {
269 loop {
270 if let Some(res_fut) = self.response.as_mut() {
271 let res = ready!(Pin::new(res_fut).poll(cx));
273
274 self.response = None;
276
277 return match res {
278 Ok(res) => Poll::Ready(Ok(FramedRead::new(
279 StreamReader::new(IoBodyStream { body: res.into() }),
280 LinesBytesCodec::default(),
281 ))),
282 Err(e) => {
283 let io_err = match e
284 .source()
285 .and_then(|source| source.downcast_ref::<io::Error>())
286 {
287 Some(io_err) => io::Error::from(io_err.kind()),
288 None => io::Error::other(e),
289 };
290 Poll::Ready(Err(io_err.into()))
291 }
292 };
293 } else if self.observed_eoq {
294 let response = self
295 .client
296 .get(format!(
297 "http://{}/v1/subscriptions/{}?from={}",
298 self.api_addr,
299 self.id,
300 self.last_change_id.unwrap_or_default()
301 ))
302 .header(http::header::ACCEPT, "application/json")
303 .send();
304
305 self.response = Some(Box::new(response));
306 } else {
308 return Poll::Ready(Err(SubscriptionError::UnfinishedQuery));
309 }
310 }
311 }
312}
313
314impl<T> Stream for SubscriptionStream<T>
315where
316 T: DeserializeOwned + Unpin,
317{
318 type Item = Result<TypedQueryEvent<T>, SubscriptionError>;
319
320 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
321 if let Some(backoff) = self.backoff.as_mut() {
323 ready!(backoff.as_mut().poll(cx));
324 self.backoff = None;
325 }
326
327 let io_err = match ready!(self.as_mut().poll_stream(cx)) {
328 Some(Err(SubscriptionError::Io(io_err))) => io_err,
329 other => {
330 self.backoff_count = 0;
331 return Poll::Ready(other);
332 }
333 };
334
335 self.stream = None;
337
338 if self.backoff_count >= 10 {
339 return Poll::Ready(Some(Err(SubscriptionError::MaxRetryAttempts)));
340 }
341
342 error!("encountered a stream IO error: {io_err}, retrying in a bit");
343
344 let mut backoff = Box::pin(sleep(Duration::from_secs(1)));
345
346 _ = backoff.as_mut().poll(cx);
348
349 self.backoff = Some(backoff);
351
352 self.backoff_count += 1;
353
354 Poll::Pending
355 }
356}
357
358pub struct UpdatesStream<T> {
363 id: Uuid,
364 stream: FramedBody,
365 _deser: std::marker::PhantomData<T>,
366}
367
368#[derive(Debug, thiserror::Error)]
370pub enum UpdatesError {
371 #[error(transparent)]
373 Io(#[from] io::Error),
374 #[error(transparent)]
376 Deserialize(#[from] serde_json::Error),
377 #[error("max line length exceeded")]
379 MaxLineLengthExceeded,
380}
381
382impl<T> UpdatesStream<T>
383where
384 T: DeserializeOwned + Unpin,
385{
386 pub fn new(id: Uuid, body: reqwest::Body) -> Self {
388 Self {
389 id,
390 stream: FramedRead::new(
391 StreamReader::new(IoBodyStream { body }),
392 LinesBytesCodec::default(),
393 ),
394 _deser: Default::default(),
395 }
396 }
397
398 pub fn id(&self) -> Uuid {
400 self.id
401 }
402}
403
404impl<T> Stream for UpdatesStream<T>
405where
406 T: DeserializeOwned + Unpin,
407{
408 type Item = Result<TypedNotifyEvent<T>, UpdatesError>;
409
410 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
411 let res = ready!(Pin::new(&mut self.stream).poll_next(cx));
412 match res {
413 Some(Ok(b)) => match serde_json::from_slice(&b) {
414 Ok(evt) => Poll::Ready(Some(Ok(evt))),
415 Err(e) => Poll::Ready(Some(Err(e.into()))),
416 },
417 Some(Err(e)) => match e {
418 LinesCodecError::MaxLineLengthExceeded => {
419 Poll::Ready(Some(Err(UpdatesError::MaxLineLengthExceeded)))
420 }
421 LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
422 },
423 None => Poll::Ready(None),
424 }
425 }
426}
427
428pub struct QueryStream<T> {
432 stream: FramedBody,
433 _deser: std::marker::PhantomData<T>,
434}
435
436#[derive(Debug, thiserror::Error)]
438pub enum QueryError {
439 #[error(transparent)]
441 Io(#[from] io::Error),
442 #[error(transparent)]
444 Deserialize(#[from] serde_json::Error),
445 #[error("max line length exceeded")]
447 MaxLineLengthExceeded,
448}
449
450impl<T> QueryStream<T>
451where
452 T: DeserializeOwned + Unpin,
453{
454 pub fn new(body: reqwest::Body) -> Self {
456 Self {
457 stream: FramedRead::new(
458 StreamReader::new(IoBodyStream { body }),
459 LinesBytesCodec::default(),
460 ),
461 _deser: Default::default(),
462 }
463 }
464}
465
466impl<T> Stream for QueryStream<T>
467where
468 T: DeserializeOwned + Unpin,
469{
470 type Item = Result<TypedQueryEvent<T>, QueryError>;
471
472 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
473 match ready!(Pin::new(&mut self.stream).poll_next(cx)) {
474 Some(Ok(b)) => match serde_json::from_slice(&b) {
475 Ok(evt) => Poll::Ready(Some(Ok(evt))),
476 Err(e) => Poll::Ready(Some(Err(e.into()))),
477 },
478 Some(Err(e)) => match e {
479 LinesCodecError::MaxLineLengthExceeded => {
480 Poll::Ready(Some(Err(QueryError::MaxLineLengthExceeded)))
481 }
482 LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
483 },
484 None => Poll::Ready(None),
485 }
486 }
487}
488
489pub struct LinesBytesCodec {
492 next_index: usize,
499
500 max_length: usize,
503
504 is_discarding: bool,
507}
508
509impl Default for LinesBytesCodec {
510 fn default() -> Self {
519 LinesBytesCodec {
520 next_index: 0,
521 max_length: usize::MAX,
522 is_discarding: false,
523 }
524 }
525}
526
527impl Decoder for LinesBytesCodec {
528 type Item = BytesMut;
529 type Error = LinesCodecError;
530
531 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
532 loop {
533 let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
536
537 let newline_offset = buf[self.next_index..read_to]
538 .iter()
539 .position(|b| *b == b'\n');
540
541 match (self.is_discarding, newline_offset) {
542 (true, Some(offset)) => {
543 buf.advance(offset + self.next_index + 1);
547 self.is_discarding = false;
548 self.next_index = 0;
549 }
550 (true, None) => {
551 buf.advance(read_to);
555 self.next_index = 0;
556 if buf.is_empty() {
557 return Ok(None);
558 }
559 }
560 (false, Some(offset)) => {
561 let newline_index = offset + self.next_index;
563 self.next_index = 0;
564 let mut line = buf.split_to(newline_index + 1);
565 line.truncate(line.len() - 1);
566 without_carriage_return(&mut line);
567 return Ok(Some(line));
568 }
569 (false, None) if buf.len() > self.max_length => {
570 self.is_discarding = true;
574 return Err(LinesCodecError::MaxLineLengthExceeded);
575 }
576 (false, None) => {
577 self.next_index = read_to;
580 return Ok(None);
581 }
582 }
583 }
584 }
585
586 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
587 Ok(match self.decode(buf)? {
588 Some(frame) => Some(frame),
589 None => {
590 if buf.is_empty() || buf == &b"\r"[..] {
592 None
593 } else {
594 let mut line = buf.split_to(buf.len());
595 line.truncate(line.len() - 1);
596 without_carriage_return(&mut line);
597 self.next_index = 0;
598 Some(line)
599 }
600 }
601 })
602 }
603}
604
605fn without_carriage_return(s: &mut BytesMut) {
606 if let Some(&b'\r') = s.last() {
607 s.truncate(s.len() - 1);
608 }
609}