Skip to main content

triplox_client/
subscription.rs

1//! Live incremental query subscription over a streaming HTTP/2 response.
2//!
3//! [`ClientNode::subscribe`](crate::ClientNode::subscribe) returns a
4//! [`Subscription`], a `Stream` of [`Delta`]s decoded from the bare,
5//! self-delimiting msgpack frames the server streams. Dropping the subscription
6//! cancels the HTTP/2 stream, which the server observes as a teardown signal.
7
8use std::io;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use anyhow::{anyhow, bail, Error, Result};
13use bytes::{Buf, Bytes, BytesMut};
14use futures::{Stream, StreamExt};
15use tokio::sync::mpsc;
16use tokio::task::JoinHandle;
17use tokio_util::codec::{Decoder, FramedRead};
18use tokio_util::io::StreamReader;
19
20use crate::msgpack_codec::{subscription_frame_from_value, ErrorResponseBody, SubscriptionFrame};
21use crate::ops::DataType;
22use crate::protocol::DEFAULT_MAX_MESSAGE_SIZE;
23use crate::transaction::TxKey;
24
25type ByteStream = Pin<Box<dyn Stream<Item = io::Result<Bytes>> + Send>>;
26type FrameStream = FramedRead<StreamReader<ByteStream, Bytes>, MsgpackFrameDecoder>;
27
28const SUBSCRIPTION_QUEUE_CAPACITY: usize = 128;
29
30/// A single transaction's z-set changes for a subscribed query.
31#[derive(Debug, Clone, PartialEq)]
32pub struct Delta {
33    /// The transaction key that produced this delta.
34    pub tx_key: TxKey,
35    /// `(values, weight)` rows; `weight` is the raw signed multiplicity.
36    pub rows: Vec<(Vec<DataType>, i64)>,
37}
38
39/// A live incremental query subscription.
40///
41/// Implements `Stream<Item = Result<Delta>>`; use `StreamExt::next().await` or
42/// stream combinators. Dropping it cancels the HTTP/2 stream and unsubscribes.
43pub struct Subscription {
44    tx_key: TxKey,
45    deltas: mpsc::Receiver<Result<Delta>>,
46    reader: JoinHandle<()>,
47}
48
49fn error_frame_to_error(err: ErrorResponseBody) -> Error {
50    anyhow!("subscription error (code {}): {}", err.code, err.message)
51}
52
53async fn read_deltas(mut frames: FrameStream, sender: mpsc::Sender<Result<Delta>>) {
54    while let Some(frame) = frames.next().await {
55        let (item, terminal) = match frame {
56            Ok(SubscriptionFrame::Delta { tx_key, rows }) => (Ok(Delta { tx_key, rows }), false),
57            Ok(SubscriptionFrame::Error(err)) => (Err(error_frame_to_error(err)), true),
58            Ok(SubscriptionFrame::Open { .. }) => {
59                (Err(anyhow!("unexpected open frame mid-stream")), true)
60            }
61            Err(err) => (Err(err), true),
62        };
63
64        if sender.send(item).await.is_err() || terminal {
65            break;
66        }
67    }
68}
69
70impl Subscription {
71    /// The registration tx_key. Deltas describe transactions strictly after it.
72    pub fn tx_key(&self) -> TxKey {
73        self.tx_key
74    }
75
76    /// Wrap a streaming subscription response: frame its body and read the
77    /// leading `open` frame, returning the ready-to-poll subscription.
78    pub(crate) async fn connect(resp: reqwest::Response) -> Result<Self> {
79        let byte_stream = resp
80            .bytes_stream()
81            .map(|chunk| chunk.map_err(io::Error::other));
82        Self::from_byte_stream(byte_stream).await
83    }
84
85    async fn from_byte_stream<S>(stream: S) -> Result<Self>
86    where
87        S: Stream<Item = io::Result<Bytes>> + Send + 'static,
88    {
89        let reader = StreamReader::new(Box::pin(stream) as ByteStream);
90        let mut frames = FramedRead::new(reader, MsgpackFrameDecoder::default());
91        let tx_key = match frames.next().await {
92            Some(Ok(SubscriptionFrame::Open { tx_key, .. })) => tx_key,
93            Some(Ok(SubscriptionFrame::Error(err))) => return Err(error_frame_to_error(err)),
94            Some(Ok(other)) => bail!("expected open frame, got {other:?}"),
95            Some(Err(err)) => return Err(err),
96            None => bail!("subscription stream closed before the open frame"),
97        };
98        let (sender, deltas) = mpsc::channel(SUBSCRIPTION_QUEUE_CAPACITY);
99        let reader = tokio::spawn(read_deltas(frames, sender));
100        Ok(Subscription {
101            tx_key,
102            deltas,
103            reader,
104        })
105    }
106}
107
108impl Stream for Subscription {
109    type Item = Result<Delta>;
110
111    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112        let this = self.get_mut();
113        this.deltas.poll_recv(cx)
114    }
115}
116
117impl Drop for Subscription {
118    fn drop(&mut self) {
119        self.reader.abort();
120    }
121}
122
123/// Frames a byte stream of bare, self-delimiting msgpack values into
124/// [`SubscriptionFrame`]s. Returns `Ok(None)` for an incomplete frame (need more
125/// bytes); a corrupt or oversized frame is an error.
126pub(crate) struct MsgpackFrameDecoder {
127    max_frame_size: usize,
128}
129
130impl Default for MsgpackFrameDecoder {
131    fn default() -> Self {
132        Self {
133            max_frame_size: DEFAULT_MAX_MESSAGE_SIZE as usize,
134        }
135    }
136}
137
138/// `true` when the decode error is a truncated value (need more bytes) rather
139/// than a corrupt one.
140fn needs_more_data(err: &rmpv::decode::Error) -> bool {
141    match err {
142        rmpv::decode::Error::InvalidMarkerRead(e) | rmpv::decode::Error::InvalidDataRead(e) => {
143            e.kind() == io::ErrorKind::UnexpectedEof
144        }
145        rmpv::decode::Error::DepthLimitExceeded => false,
146    }
147}
148
149impl Decoder for MsgpackFrameDecoder {
150    type Item = SubscriptionFrame;
151    type Error = Error;
152
153    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
154        if src.is_empty() {
155            return Ok(None);
156        }
157        let mut cursor: &[u8] = &src[..];
158        let remaining_before = cursor.len();
159        match rmpv::decode::read_value(&mut cursor) {
160            Ok(value) => {
161                let consumed = remaining_before - cursor.len();
162                let frame = subscription_frame_from_value(value)?;
163                src.advance(consumed);
164                Ok(Some(frame))
165            }
166            Err(err) if needs_more_data(&err) => {
167                if src.len() > self.max_frame_size {
168                    bail!(
169                        "subscription frame exceeds maximum size of {} bytes",
170                        self.max_frame_size
171                    );
172                }
173                Ok(None)
174            }
175            Err(err) => Err(anyhow!("msgpack frame decode error: {err}")),
176        }
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use crate::msgpack_codec::encode_subscription_frame;
184    use crate::protocol::ColumnDescription;
185    use crate::transaction::TxKey;
186    use chrono::{TimeZone, Utc};
187    use std::collections::VecDeque;
188    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
189    use std::sync::{Arc, LazyLock};
190    use tokio::time::{sleep, timeout, Duration};
191
192    static SAMPLE_TX_KEY: LazyLock<TxKey> = LazyLock::new(|| TxKey {
193        tx_id: 3,
194        system_time: Utc.timestamp_opt(1_700_000_000, 0).unwrap(),
195    });
196
197    fn open_bytes() -> Vec<u8> {
198        encode_subscription_frame(&SubscriptionFrame::Open {
199            tx_key: *SAMPLE_TX_KEY,
200            columns: vec![ColumnDescription {
201                name: "n".to_string(),
202                data_type: 255,
203                members: None,
204            }],
205        })
206        .unwrap()
207    }
208
209    fn delta_bytes(name: &str) -> Vec<u8> {
210        encode_subscription_frame(&SubscriptionFrame::Delta {
211            tx_key: *SAMPLE_TX_KEY,
212            rows: vec![(vec![DataType::String(name.to_string())], 1)],
213        })
214        .unwrap()
215    }
216
217    fn error_bytes() -> Vec<u8> {
218        encode_subscription_frame(&SubscriptionFrame::Error(ErrorResponseBody {
219            severity: b'F',
220            code: 4000,
221            message: "boom".to_string(),
222            detail: None,
223            hint: None,
224        }))
225        .unwrap()
226    }
227
228    fn unknown_bytes() -> Vec<u8> {
229        // {"kind": "heartbeat"} — an unsupported frame kind the client must reject.
230        let mut buf = Vec::new();
231        rmp::encode::write_map_len(&mut buf, 1).unwrap();
232        rmp::encode::write_str(&mut buf, "kind").unwrap();
233        rmp::encode::write_str(&mut buf, "heartbeat").unwrap();
234        buf
235    }
236
237    #[test]
238    fn decoder_needs_more_then_completes() {
239        let bytes = open_bytes();
240        let mut decoder = MsgpackFrameDecoder::default();
241        let mut buf = BytesMut::from(&bytes[..bytes.len() - 1]);
242        assert!(decoder.decode(&mut buf).unwrap().is_none(), "truncated");
243        buf.extend_from_slice(&bytes[bytes.len() - 1..]);
244        let frame = decoder.decode(&mut buf).unwrap().expect("complete frame");
245        assert!(matches!(frame, SubscriptionFrame::Open { .. }));
246        assert!(buf.is_empty());
247    }
248
249    #[test]
250    fn decoder_rejects_non_map_frame() {
251        // A complete msgpack value that is not a frame map is a protocol error.
252        let mut v = Vec::new();
253        rmp::encode::write_uint(&mut v, 5).unwrap();
254        let mut buf = BytesMut::from(&v[..]);
255        assert!(MsgpackFrameDecoder::default().decode(&mut buf).is_err());
256    }
257
258    #[test]
259    fn decoder_rejects_oversize_frame() {
260        // str8 header declaring 100 bytes; an incomplete frame past the cap errors.
261        let mut buf = BytesMut::from(&[0xd9u8, 100, 0x00, 0x00][..]);
262        let mut decoder = MsgpackFrameDecoder { max_frame_size: 3 };
263        assert!(decoder.decode(&mut buf).is_err());
264    }
265
266    #[tokio::test]
267    async fn subscription_surfaces_unknown_frame_kind_error() {
268        let mut payload = Vec::new();
269        payload.extend(open_bytes());
270        payload.extend(unknown_bytes());
271        let stream =
272            futures::stream::once(async move { Ok::<Bytes, io::Error>(Bytes::from(payload)) });
273
274        let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
275        assert_eq!(sub.tx_key(), *SAMPLE_TX_KEY);
276
277        let err = sub.next().await.expect("an item").unwrap_err();
278        assert!(err
279            .to_string()
280            .contains("unknown subscription frame kind: heartbeat"));
281        assert!(sub.next().await.is_none(), "done after error");
282    }
283
284    #[tokio::test]
285    async fn subscription_surfaces_error_frame() {
286        let mut payload = Vec::new();
287        payload.extend(open_bytes());
288        payload.extend(error_bytes());
289        let stream =
290            futures::stream::once(async move { Ok::<Bytes, io::Error>(Bytes::from(payload)) });
291
292        let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
293        let err = sub.next().await.expect("an item").unwrap_err();
294        assert!(err.to_string().contains("4000"));
295        assert!(sub.next().await.is_none(), "done after error");
296    }
297
298    #[tokio::test]
299    async fn subscription_preserves_queued_delta_before_error_frame() {
300        let mut payload = Vec::new();
301        payload.extend(open_bytes());
302        payload.extend(delta_bytes("Alice"));
303        payload.extend(error_bytes());
304        let stream =
305            futures::stream::once(async move { Ok::<Bytes, io::Error>(Bytes::from(payload)) });
306
307        let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
308        let delta = sub.next().await.expect("first item").unwrap();
309        assert_eq!(
310            delta.rows,
311            vec![(vec![DataType::String("Alice".to_string())], 1)]
312        );
313
314        let err = sub.next().await.expect("second item").unwrap_err();
315        assert!(err.to_string().contains("4000"));
316        assert!(sub.next().await.is_none(), "done after error");
317    }
318
319    struct CountingByteStream {
320        chunks: VecDeque<Bytes>,
321        yielded: Arc<AtomicUsize>,
322    }
323
324    impl Stream for CountingByteStream {
325        type Item = io::Result<Bytes>;
326
327        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
328            match self.chunks.pop_front() {
329                Some(chunk) => {
330                    self.yielded.fetch_add(1, Ordering::SeqCst);
331                    Poll::Ready(Some(Ok(chunk)))
332                }
333                None => Poll::Ready(None),
334            }
335        }
336    }
337
338    #[tokio::test]
339    async fn subscription_reads_ahead_until_delta_queue_is_full() {
340        let yielded = Arc::new(AtomicUsize::new(0));
341        let mut chunks = VecDeque::new();
342        chunks.push_back(Bytes::from(open_bytes()));
343        for idx in 0..SUBSCRIPTION_QUEUE_CAPACITY + 3 {
344            chunks.push_back(Bytes::from(delta_bytes(&format!("Alice {idx}"))));
345        }
346        let stream = CountingByteStream {
347            chunks,
348            yielded: yielded.clone(),
349        };
350
351        let mut sub = Subscription::from_byte_stream(stream).await.unwrap();
352        let blocked_at = SUBSCRIPTION_QUEUE_CAPACITY + 2;
353
354        timeout(Duration::from_secs(1), async {
355            while yielded.load(Ordering::SeqCst) < blocked_at {
356                tokio::task::yield_now().await;
357            }
358        })
359        .await
360        .expect("reader should fill the bounded delta queue");
361
362        sleep(Duration::from_millis(25)).await;
363        assert_eq!(
364            yielded.load(Ordering::SeqCst),
365            blocked_at,
366            "reader should stop pulling once the delta queue is full"
367        );
368
369        let delta = sub.next().await.expect("buffered delta").unwrap();
370        assert_eq!(
371            delta.rows,
372            vec![(vec![DataType::String("Alice 0".to_string())], 1)]
373        );
374        timeout(Duration::from_secs(1), async {
375            while yielded.load(Ordering::SeqCst) < blocked_at + 1 {
376                tokio::task::yield_now().await;
377            }
378        })
379        .await
380        .expect("draining one delta should let the reader pull one more frame");
381    }
382
383    struct DropNotifyStream {
384        chunks: VecDeque<Bytes>,
385        polled_pending: Arc<AtomicBool>,
386        dropped: Arc<AtomicBool>,
387    }
388
389    impl Stream for DropNotifyStream {
390        type Item = io::Result<Bytes>;
391
392        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
393            match self.chunks.pop_front() {
394                Some(chunk) => Poll::Ready(Some(Ok(chunk))),
395                None => {
396                    self.polled_pending.store(true, Ordering::SeqCst);
397                    Poll::Pending
398                }
399            }
400        }
401    }
402
403    impl Drop for DropNotifyStream {
404        fn drop(&mut self) {
405            self.dropped.store(true, Ordering::SeqCst);
406        }
407    }
408
409    #[tokio::test]
410    async fn dropping_subscription_aborts_reader_and_drops_stream() {
411        let polled_pending = Arc::new(AtomicBool::new(false));
412        let dropped = Arc::new(AtomicBool::new(false));
413        let mut chunks = VecDeque::new();
414        chunks.push_back(Bytes::from(open_bytes()));
415        let stream = DropNotifyStream {
416            chunks,
417            polled_pending: polled_pending.clone(),
418            dropped: dropped.clone(),
419        };
420
421        let sub = Subscription::from_byte_stream(stream).await.unwrap();
422        timeout(Duration::from_secs(1), async {
423            while !polled_pending.load(Ordering::SeqCst) {
424                tokio::task::yield_now().await;
425            }
426        })
427        .await
428        .expect("reader should poll the upstream stream");
429
430        drop(sub);
431
432        timeout(Duration::from_secs(1), async {
433            while !dropped.load(Ordering::SeqCst) {
434                tokio::task::yield_now().await;
435            }
436        })
437        .await
438        .expect("dropping subscription should abort the reader task");
439    }
440}