Skip to main content

corro_client/
sub.rs

1//! Streaming response types for the Corrosion HTTP API.
2//!
3//! Each long-running endpoint exposes a dedicated [`Stream`] implementation:
4//!
5//! * [`QueryStream`] — one-shot query results
6//!   (`POST /v1/queries`).
7//! * [`SubscriptionStream`] — resumable live subscription
8//!   (`POST /v1/subscriptions`); transparently reconnects with the last
9//!   observed [`ChangeId`] on transient I/O errors.
10//! * [`UpdatesStream`] — table-level update feed (`POST /v1/updates/{table}`).
11//!
12//! All three decode JSON-lines responses through the [`LinesBytesCodec`]
13//! defined here, which is a port of `tokio-util`'s `LinesCodec` adapted to
14//! emit [`bytes::BytesMut`] frames.
15
16use std::{
17    error::Error,
18    io,
19    net::SocketAddr,
20    pin::Pin,
21    task::{Context, Poll},
22    time::Duration,
23};
24
25use bytes::{Buf, Bytes, BytesMut};
26use corro_api_types::{ChangeId, QueryEvent, TypedNotifyEvent, TypedQueryEvent};
27use futures::{ready, Future, Stream};
28use pin_project_lite::pin_project;
29use serde::de::DeserializeOwned;
30use tokio::time::{sleep, Sleep};
31use tokio_util::{
32    codec::{Decoder, FramedRead, LinesCodecError},
33    io::StreamReader,
34};
35use tracing::error;
36use uuid::Uuid;
37
38pin_project! {
39    /// Adapter that exposes a [`reqwest::Body`] as a [`Stream`] of
40    /// [`io::Result<Bytes>`], mapping framing errors into [`io::Error`] so
41    /// the underlying [`StreamReader`] can consume them.
42    pub struct IoBodyStream {
43        #[pin]
44        body: reqwest::Body
45    }
46}
47
48impl Stream for IoBodyStream {
49    type Item = io::Result<Bytes>;
50
51    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52        use http_body::Body;
53        let this = self.project();
54        let res = ready!(this.body.poll_frame(cx));
55        match res {
56            Some(Ok(b)) => Poll::Ready(Some(
57                b.into_data()
58                    .map_err(|_| io::Error::other("not a data frame")),
59            )),
60            Some(Err(e)) => {
61                let io_err = match e
62                    .source()
63                    .and_then(|source| source.downcast_ref::<io::Error>())
64                {
65                    Some(io_err) => io::Error::from(io_err.kind()),
66                    None => io::Error::other(e),
67                };
68                Poll::Ready(Some(Err(io_err)))
69            }
70            None => Poll::Ready(None),
71        }
72    }
73}
74
75type IoBodyStreamReader = StreamReader<IoBodyStream, Bytes>;
76type FramedBody = FramedRead<IoBodyStreamReader, LinesBytesCodec>;
77type ResponseFuture =
78    Box<dyn Future<Output = Result<reqwest::Response, reqwest::Error>> + Unpin + Send + Sync>;
79
80/// Live subscription stream returned by
81/// [`crate::CorrosionApiClient::subscribe_typed`].
82///
83/// Yields [`TypedQueryEvent<T>`] frames produced by the agent. Once the
84/// initial result set is fully consumed (after the first
85/// [`TypedQueryEvent::EndOfQuery`]) the stream automatically reconnects
86/// when the underlying HTTP body terminates, using the last observed
87/// [`ChangeId`] to resume without gaps. Reconnects use a fixed-size
88/// linear backoff and give up after 10 consecutive failures with
89/// [`SubscriptionError::MaxRetryAttempts`].
90pub struct SubscriptionStream<T> {
91    id: Uuid,
92    hash: Option<String>,
93    client: reqwest::Client,
94    api_addr: SocketAddr,
95    observed_eoq: bool,
96    last_change_id: Option<ChangeId>,
97    stream: Option<FramedBody>,
98    backoff: Option<Pin<Box<Sleep>>>,
99    backoff_count: u32,
100    response: Option<ResponseFuture>,
101    _deser: std::marker::PhantomData<T>,
102}
103
104/// Errors yielded by [`SubscriptionStream`].
105#[derive(Debug, thiserror::Error)]
106pub enum SubscriptionError {
107    /// Underlying I/O error on the HTTP body.
108    #[error(transparent)]
109    Io(#[from] io::Error),
110    /// Generic HTTP error encountered while reconnecting.
111    #[error(transparent)]
112    Http(#[from] http::Error),
113    /// A frame could not be decoded as a [`TypedQueryEvent`].
114    #[error(transparent)]
115    Deserialize(#[from] serde_json::Error),
116    /// The agent skipped a [`ChangeId`], indicating the local view is no
117    /// longer consistent with the server.
118    #[error("missed a change (expected: {expected}, got: {got}), inconsistent state")]
119    MissedChange { expected: ChangeId, got: ChangeId },
120    /// A single JSON line exceeded the codec's maximum length.
121    #[error("max line length exceeded")]
122    MaxLineLengthExceeded,
123    /// The connection terminated before the initial query produced an
124    /// [`TypedQueryEvent::EndOfQuery`].
125    #[error("initial query never finished")]
126    UnfinishedQuery,
127    /// Error when maximum number of consecutive reconnect is exceeded.
128    #[error("max retry attempts exceeded")]
129    MaxRetryAttempts,
130}
131
132impl<T> SubscriptionStream<T>
133where
134    T: DeserializeOwned + Unpin,
135{
136    pub fn new(
137        id: Uuid,
138        hash: Option<String>,
139        client: reqwest::Client,
140        api_addr: SocketAddr,
141        body: reqwest::Body,
142        change_id: Option<ChangeId>,
143    ) -> Self {
144        Self {
145            id,
146            hash,
147            client,
148            api_addr,
149            observed_eoq: change_id.is_some(),
150            last_change_id: change_id,
151            stream: Some(FramedRead::new(
152                StreamReader::new(IoBodyStream { body }),
153                LinesBytesCodec::default(),
154            )),
155            backoff: None,
156            backoff_count: 0,
157            response: None,
158            _deser: Default::default(),
159        }
160    }
161
162    /// Server-assigned subscription identifier.
163    ///
164    /// Persist this id (along with the [`ChangeId`] from
165    /// [`TypedQueryEvent::Change`]) to resume the subscription later via
166    /// [`crate::CorrosionApiClient::subscription_typed`].
167    pub fn id(&self) -> Uuid {
168        self.id
169    }
170
171    /// Hash advertised by the server for this subscription's query.
172    ///
173    pub fn hash(&self) -> Option<&str> {
174        self.hash.as_deref()
175    }
176
177    /// API address the subscription was opened against.
178    pub fn api_addr(&self) -> SocketAddr {
179        self.api_addr
180    }
181
182    fn poll_stream(
183        mut self: Pin<&mut Self>,
184        cx: &mut Context<'_>,
185    ) -> Poll<Option<Result<TypedQueryEvent<T>, SubscriptionError>>> {
186        let stream = loop {
187            match self.stream.as_mut() {
188                None => match ready!(self.as_mut().poll_request(cx)) {
189                    Ok(stream) => {
190                        self.stream = Some(stream);
191                    }
192                    Err(e) => return Poll::Ready(Some(Err(e))),
193                },
194                Some(stream) => {
195                    break stream;
196                }
197            }
198        };
199
200        let res = ready!(Pin::new(stream).poll_next(cx));
201        match res {
202            Some(Ok(b)) => match serde_json::from_slice(&b) {
203                Ok(evt) => {
204                    if let TypedQueryEvent::EndOfQuery { change_id, .. } = &evt {
205                        self.handle_eoq(*change_id);
206                    }
207
208                    if let TypedQueryEvent::Change(_, _, _, change_id) = &evt {
209                        if let Err(e) = self.handle_change(*change_id) {
210                            return Poll::Ready(Some(Err(e)));
211                        }
212                    }
213
214                    Poll::Ready(Some(Ok(evt)))
215                }
216                Err(deser_err) => {
217                    // It failed to deserialize, try untyped variant to extract the metadata
218                    if let Ok(evt) = serde_json::from_slice::<QueryEvent>(&b) {
219                        if let TypedQueryEvent::EndOfQuery { change_id, .. } = &evt {
220                            self.handle_eoq(*change_id);
221                        }
222
223                        if let TypedQueryEvent::Change(_, _, _, change_id) = &evt {
224                            if let Err(e) = self.handle_change(*change_id) {
225                                return Poll::Ready(Some(Err(e)));
226                            }
227                        }
228                    }
229
230                    // But return the original error anyway (unless this is out-of-order event)
231                    Poll::Ready(Some(Err(deser_err.into())))
232                }
233            },
234            Some(Err(e)) => match e {
235                LinesCodecError::MaxLineLengthExceeded => {
236                    Poll::Ready(Some(Err(SubscriptionError::MaxLineLengthExceeded)))
237                }
238                LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
239            },
240            None => Poll::Ready(None),
241        }
242    }
243
244    fn handle_eoq(&mut self, change_id: Option<ChangeId>) {
245        self.observed_eoq = true;
246        self.last_change_id = change_id;
247    }
248
249    fn handle_change(&mut self, change_id: ChangeId) -> Result<(), SubscriptionError> {
250        match self.last_change_id {
251            Some(id) if id + 1 != change_id => {
252                return Err(SubscriptionError::MissedChange {
253                    expected: id + 1,
254                    got: change_id,
255                })
256            }
257            _ => (),
258        }
259
260        self.last_change_id = Some(change_id);
261
262        Ok(())
263    }
264
265    fn poll_request(
266        mut self: Pin<&mut Self>,
267        cx: &mut Context<'_>,
268    ) -> Poll<Result<FramedBody, SubscriptionError>> {
269        loop {
270            if let Some(res_fut) = self.response.as_mut() {
271                // return early w/ Poll::Pending if response is not ready
272                let res = ready!(Pin::new(res_fut).poll(cx));
273
274                // reset response
275                self.response = None;
276
277                return match res {
278                    Ok(res) => Poll::Ready(Ok(FramedRead::new(
279                        StreamReader::new(IoBodyStream { body: res.into() }),
280                        LinesBytesCodec::default(),
281                    ))),
282                    Err(e) => {
283                        let io_err = match e
284                            .source()
285                            .and_then(|source| source.downcast_ref::<io::Error>())
286                        {
287                            Some(io_err) => io::Error::from(io_err.kind()),
288                            None => io::Error::other(e),
289                        };
290                        Poll::Ready(Err(io_err.into()))
291                    }
292                };
293            } else if self.observed_eoq {
294                let response = self
295                    .client
296                    .get(format!(
297                        "http://{}/v1/subscriptions/{}?from={}",
298                        self.api_addr,
299                        self.id,
300                        self.last_change_id.unwrap_or_default()
301                    ))
302                    .header(http::header::ACCEPT, "application/json")
303                    .send();
304
305                self.response = Some(Box::new(response));
306                // loop around!
307            } else {
308                return Poll::Ready(Err(SubscriptionError::UnfinishedQuery));
309            }
310        }
311    }
312}
313
314impl<T> Stream for SubscriptionStream<T>
315where
316    T: DeserializeOwned + Unpin,
317{
318    type Item = Result<TypedQueryEvent<T>, SubscriptionError>;
319
320    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
321        // first, check if we need to wait for a backoff...
322        if let Some(backoff) = self.backoff.as_mut() {
323            ready!(backoff.as_mut().poll(cx));
324            self.backoff = None;
325        }
326
327        let io_err = match ready!(self.as_mut().poll_stream(cx)) {
328            Some(Err(SubscriptionError::Io(io_err))) => io_err,
329            other => {
330                self.backoff_count = 0;
331                return Poll::Ready(other);
332            }
333        };
334
335        // reset the stream
336        self.stream = None;
337
338        if self.backoff_count >= 10 {
339            return Poll::Ready(Some(Err(SubscriptionError::MaxRetryAttempts)));
340        }
341
342        error!("encountered a stream IO error: {io_err}, retrying in a bit");
343
344        let mut backoff = Box::pin(sleep(Duration::from_secs(1)));
345
346        // register w/ waker
347        _ = backoff.as_mut().poll(cx);
348
349        // this can't return Ready, right?
350        self.backoff = Some(backoff);
351
352        self.backoff_count += 1;
353
354        Poll::Pending
355    }
356}
357
358/// Stream returned by [`crate::CorrosionApiClient::updates_typed`].
359///
360/// Yields a [`TypedNotifyEvent`] for every row inserted, updated or deleted
361/// in the watched table after the update was registered.
362pub struct UpdatesStream<T> {
363    id: Uuid,
364    stream: FramedBody,
365    _deser: std::marker::PhantomData<T>,
366}
367
368/// Errors yielded by [`UpdatesStream`].
369#[derive(Debug, thiserror::Error)]
370pub enum UpdatesError {
371    /// Underlying I/O error on the HTTP body.
372    #[error(transparent)]
373    Io(#[from] io::Error),
374    /// A frame could not be decoded as a [`TypedNotifyEvent`].
375    #[error(transparent)]
376    Deserialize(#[from] serde_json::Error),
377    /// A single JSON line exceeded the codec's maximum length.
378    #[error("max line length exceeded")]
379    MaxLineLengthExceeded,
380}
381
382impl<T> UpdatesStream<T>
383where
384    T: DeserializeOwned + Unpin,
385{
386    /// Build an `UpdatesStream` from a freshly opened HTTP response.
387    pub fn new(id: Uuid, body: reqwest::Body) -> Self {
388        Self {
389            id,
390            stream: FramedRead::new(
391                StreamReader::new(IoBodyStream { body }),
392                LinesBytesCodec::default(),
393            ),
394            _deser: Default::default(),
395        }
396    }
397
398    /// Server-assigned subscription identifier for this stream.
399    pub fn id(&self) -> Uuid {
400        self.id
401    }
402}
403
404impl<T> Stream for UpdatesStream<T>
405where
406    T: DeserializeOwned + Unpin,
407{
408    type Item = Result<TypedNotifyEvent<T>, UpdatesError>;
409
410    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
411        let res = ready!(Pin::new(&mut self.stream).poll_next(cx));
412        match res {
413            Some(Ok(b)) => match serde_json::from_slice(&b) {
414                Ok(evt) => Poll::Ready(Some(Ok(evt))),
415                Err(e) => Poll::Ready(Some(Err(e.into()))),
416            },
417            Some(Err(e)) => match e {
418                LinesCodecError::MaxLineLengthExceeded => {
419                    Poll::Ready(Some(Err(UpdatesError::MaxLineLengthExceeded)))
420                }
421                LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
422            },
423            None => Poll::Ready(None),
424        }
425    }
426}
427
428/// Stream returned by [`crate::CorrosionApiClient::query_typed`] for a single query.
429///
430/// Yields a [`TypedQueryEvent`] with columns, each row of the result set, and a final [`TypedQueryEvent::EndOfQuery`].
431pub struct QueryStream<T> {
432    stream: FramedBody,
433    _deser: std::marker::PhantomData<T>,
434}
435
436/// Errors yielded by [`QueryStream`].
437#[derive(Debug, thiserror::Error)]
438pub enum QueryError {
439    /// Underlying I/O error on the HTTP body.
440    #[error(transparent)]
441    Io(#[from] io::Error),
442    /// A frame could not be decoded as a [`TypedQueryEvent`].
443    #[error(transparent)]
444    Deserialize(#[from] serde_json::Error),
445    /// A single JSON line exceeded the codec's maximum length.
446    #[error("max line length exceeded")]
447    MaxLineLengthExceeded,
448}
449
450impl<T> QueryStream<T>
451where
452    T: DeserializeOwned + Unpin,
453{
454    /// Build a `QueryStream` from a freshly opened HTTP response.
455    pub fn new(body: reqwest::Body) -> Self {
456        Self {
457            stream: FramedRead::new(
458                StreamReader::new(IoBodyStream { body }),
459                LinesBytesCodec::default(),
460            ),
461            _deser: Default::default(),
462        }
463    }
464}
465
466impl<T> Stream for QueryStream<T>
467where
468    T: DeserializeOwned + Unpin,
469{
470    type Item = Result<TypedQueryEvent<T>, QueryError>;
471
472    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
473        match ready!(Pin::new(&mut self.stream).poll_next(cx)) {
474            Some(Ok(b)) => match serde_json::from_slice(&b) {
475                Ok(evt) => Poll::Ready(Some(Ok(evt))),
476                Err(e) => Poll::Ready(Some(Err(e.into()))),
477            },
478            Some(Err(e)) => match e {
479                LinesCodecError::MaxLineLengthExceeded => {
480                    Poll::Ready(Some(Err(QueryError::MaxLineLengthExceeded)))
481                }
482                LinesCodecError::Io(io_err) => Poll::Ready(Some(Err(io_err.into()))),
483            },
484            None => Poll::Ready(None),
485        }
486    }
487}
488
489/// `LinesBytesCodec` used to to split up bytes into lines.
490/// It uses the `\n` character as the line delimiter.
491pub struct LinesBytesCodec {
492    // Stored index of the next index to examine for a `\n` character.
493    // This is used to optimize searching.
494    // For example, if `decode` was called with `abc`, it would hold `3`,
495    // because that is the next index to examine.
496    // The next time `decode` is called with `abcde\n`, the method will
497    // only look at `de\n` before returning.
498    next_index: usize,
499
500    /// The maximum length for a given line. If `usize::MAX`, lines will be
501    /// read until a `\n` character is reached.
502    max_length: usize,
503
504    /// Are we currently discarding the remainder of a line which was over
505    /// the length limit?
506    is_discarding: bool,
507}
508
509impl Default for LinesBytesCodec {
510    /// Returns a `LinesBytesCodec` for splitting up data into lines.
511    ///
512    /// # Note
513    ///
514    /// The returned `LinesBytesCodec` will not have an upper bound on the length
515    /// of a buffered line. See the documentation for [`new_with_max_length`]
516    /// for information on why this could be a potential security risk.
517    ///
518    fn default() -> Self {
519        LinesBytesCodec {
520            next_index: 0,
521            max_length: usize::MAX,
522            is_discarding: false,
523        }
524    }
525}
526
527impl Decoder for LinesBytesCodec {
528    type Item = BytesMut;
529    type Error = LinesCodecError;
530
531    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
532        loop {
533            // Determine how far into the buffer we'll search for a newline. If
534            // there's no max_length set, we'll read to the end of the buffer.
535            let read_to = std::cmp::min(self.max_length.saturating_add(1), buf.len());
536
537            let newline_offset = buf[self.next_index..read_to]
538                .iter()
539                .position(|b| *b == b'\n');
540
541            match (self.is_discarding, newline_offset) {
542                (true, Some(offset)) => {
543                    // If we found a newline, discard up to that offset and
544                    // then stop discarding. On the next iteration, we'll try
545                    // to read a line normally.
546                    buf.advance(offset + self.next_index + 1);
547                    self.is_discarding = false;
548                    self.next_index = 0;
549                }
550                (true, None) => {
551                    // Otherwise, we didn't find a newline, so we'll discard
552                    // everything we read. On the next iteration, we'll continue
553                    // discarding up to max_len bytes unless we find a newline.
554                    buf.advance(read_to);
555                    self.next_index = 0;
556                    if buf.is_empty() {
557                        return Ok(None);
558                    }
559                }
560                (false, Some(offset)) => {
561                    // Found a line!
562                    let newline_index = offset + self.next_index;
563                    self.next_index = 0;
564                    let mut line = buf.split_to(newline_index + 1);
565                    line.truncate(line.len() - 1);
566                    without_carriage_return(&mut line);
567                    return Ok(Some(line));
568                }
569                (false, None) if buf.len() > self.max_length => {
570                    // Reached the maximum length without finding a
571                    // newline, return an error and start discarding on the
572                    // next call.
573                    self.is_discarding = true;
574                    return Err(LinesCodecError::MaxLineLengthExceeded);
575                }
576                (false, None) => {
577                    // We didn't find a line or reach the length limit, so the next
578                    // call will resume searching at the current offset.
579                    self.next_index = read_to;
580                    return Ok(None);
581                }
582            }
583        }
584    }
585
586    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<BytesMut>, LinesCodecError> {
587        Ok(match self.decode(buf)? {
588            Some(frame) => Some(frame),
589            None => {
590                // No terminating newline - return remaining data, if any
591                if buf.is_empty() || buf == &b"\r"[..] {
592                    None
593                } else {
594                    let mut line = buf.split_to(buf.len());
595                    line.truncate(line.len() - 1);
596                    without_carriage_return(&mut line);
597                    self.next_index = 0;
598                    Some(line)
599                }
600            }
601        })
602    }
603}
604
605fn without_carriage_return(s: &mut BytesMut) {
606    if let Some(&b'\r') = s.last() {
607        s.truncate(s.len() - 1);
608    }
609}