Skip to main content

amaters_sdk_rust/
streaming.rs

1//! Streaming query support for AmateRS SDK
2//!
3//! Provides [`QueryStream`] — a bounded, cancellable stream of [`Row`]s produced
4//! by a background task.  The stream implements [`futures::Stream`] so callers
5//! can use standard combinators (`map`, `filter`, `collect`, …).
6//!
7//! # Design
8//!
9//! * **Backpressure** — The producer writes into a bounded
10//!   [`tokio::sync::mpsc`] channel whose capacity is set by
11//!   [`StreamConfig::buffer_size`].  The producer is forced to `await` once the
12//!   channel is full, which naturally throttles generation rate to consumption
13//!   rate.
14//!
15//! * **Cancellation** — A [`tokio_util::sync::CancellationToken`] is shared
16//!   between the consumer and the background producer task.  Dropping the
17//!   [`QueryStream`] cancels the token, and the producer checks it before every
18//!   `send`.
19//!
20//! * **Lazy** — The background task is spawned by `QueryStream`; no
21//!   data is generated until [`QueryStream`] is polled.
22//!
23//! The [`spawn_stub_producer`] helper is retained as a `#[doc(hidden)]`
24//! test utility; production code uses the real `execute_stream` RPC.
25
26use crate::error::SdkError;
27use futures::Stream;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::sync::atomic::{AtomicUsize, Ordering};
31use std::task::{Context, Poll};
32use std::time::Duration;
33use tokio::sync::mpsc;
34use tokio_util::sync::CancellationToken;
35
36// ---------------------------------------------------------------------------
37// Public types
38// ---------------------------------------------------------------------------
39
40/// Configuration controlling backpressure and optional timeout for a
41/// [`QueryStream`].
42#[derive(Debug, Clone)]
43pub struct StreamConfig {
44    /// Channel capacity — the maximum number of un-consumed [`Row`]s that
45    /// can be buffered in memory.  When the channel is full the producer
46    /// blocks until the consumer drains at least one item.  Defaults to
47    /// **64**.
48    pub buffer_size: usize,
49
50    /// Optional per-stream timeout in seconds.  If set, the background task
51    /// is automatically cancelled after this many seconds even if rows
52    /// remain unread.  `None` means no timeout.
53    pub timeout_secs: Option<u64>,
54}
55
56impl Default for StreamConfig {
57    fn default() -> Self {
58        Self {
59            buffer_size: 64,
60            timeout_secs: None,
61        }
62    }
63}
64
65impl StreamConfig {
66    /// Create a new config with the given buffer size.
67    pub fn new(buffer_size: usize) -> Self {
68        Self {
69            buffer_size,
70            timeout_secs: None,
71        }
72    }
73
74    /// Set an optional timeout (seconds) after which the stream is cancelled.
75    #[must_use]
76    pub fn with_timeout(mut self, secs: u64) -> Self {
77        self.timeout_secs = Some(secs);
78        self
79    }
80}
81
82// ---------------------------------------------------------------------------
83// Row
84// ---------------------------------------------------------------------------
85
86/// A single key-value result row returned from a streaming query.
87#[derive(Debug, Clone, PartialEq, Eq)]
88pub struct Row {
89    /// The row key (raw bytes).
90    pub key: Vec<u8>,
91    /// The row value (raw bytes; may be ciphertext).
92    pub value: Vec<u8>,
93}
94
95impl Row {
96    /// Create a new row from raw key and value bytes.
97    pub fn new(key: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) -> Self {
98        Self {
99            key: key.into(),
100            value: value.into(),
101        }
102    }
103}
104
105// ---------------------------------------------------------------------------
106// RowSender — handle given to the producer task
107// ---------------------------------------------------------------------------
108
109/// A handle used by the background producer task to send rows to the
110/// consumer and to detect cancellation.
111pub struct RowSender {
112    tx: mpsc::Sender<Result<Row, SdkError>>,
113    cancel: CancellationToken,
114    /// Tracks the total number of rows sent, exposed in tests.
115    pub sent: Arc<AtomicUsize>,
116}
117
118impl RowSender {
119    /// Send a row, blocking until there is capacity in the channel.
120    ///
121    /// Returns `false` if the stream was cancelled or the receiver was
122    /// dropped (either means the producer should stop).
123    pub async fn send_row(&self, row: Row) -> bool {
124        if self.cancel.is_cancelled() {
125            return false;
126        }
127        self.sent.fetch_add(1, Ordering::Relaxed);
128        self.tx.send(Ok(row)).await.is_ok()
129    }
130
131    /// Returns `true` if the stream has been cancelled.
132    pub fn is_cancelled(&self) -> bool {
133        self.cancel.is_cancelled()
134    }
135
136    /// Return a clone of the cancellation token so that producers can
137    /// `select!` against it while awaiting a slow I/O operation.
138    pub fn cancel_token(&self) -> CancellationToken {
139        self.cancel.clone()
140    }
141
142    /// Send an error to the consumer.
143    ///
144    /// Returns `false` if the stream was cancelled or the receiver was
145    /// dropped.
146    pub async fn send_error(&self, err: SdkError) -> bool {
147        if self.cancel.is_cancelled() {
148            return false;
149        }
150        self.tx.send(Err(err)).await.is_ok()
151    }
152}
153
154// ---------------------------------------------------------------------------
155// QueryStream
156// ---------------------------------------------------------------------------
157
158/// A cancellable, backpressure-aware stream of [`Row`]s from a query.
159///
160/// Implements [`futures::Stream`] with `Item = Result<Row, SdkError>`.
161/// Dropping the stream cancels the background producer task.
162pub struct QueryStream {
163    /// Receiving end of the bounded channel.
164    rx: mpsc::Receiver<Result<Row, SdkError>>,
165    /// Token used to signal cancellation to the producer task.
166    cancel: CancellationToken,
167}
168
169impl QueryStream {
170    /// Spawn a background producer task and return the paired consumer stream
171    /// together with a [`RowSender`] for the task to use.
172    ///
173    /// The caller is responsible for spawning a `tokio::task` that uses the
174    /// returned [`RowSender`].
175    pub fn new(config: &StreamConfig) -> (Self, RowSender) {
176        let (tx, rx) = mpsc::channel(config.buffer_size);
177        let cancel = CancellationToken::new();
178        let sent = Arc::new(AtomicUsize::new(0));
179
180        let sender = RowSender {
181            tx,
182            cancel: cancel.clone(),
183            sent,
184        };
185
186        let stream = Self { rx, cancel };
187
188        (stream, sender)
189    }
190
191    /// Cancel the background task immediately.
192    pub fn cancel(&self) {
193        self.cancel.cancel();
194    }
195}
196
197impl Drop for QueryStream {
198    fn drop(&mut self) {
199        // Signal the producer to stop when the stream is dropped.
200        self.cancel.cancel();
201    }
202}
203
204impl Stream for QueryStream {
205    type Item = Result<Row, SdkError>;
206
207    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
208        self.rx.poll_recv(cx)
209    }
210}
211
212// ---------------------------------------------------------------------------
213// Stub generator — used by client::stream_query
214// ---------------------------------------------------------------------------
215
216/// Spawn a simulated producer that generates `total_rows` mock rows derived
217/// from the query collection name.
218///
219/// This is a **test-only** helper used to exercise streaming infrastructure
220/// (backpressure, cancellation, etc.) without a live server.  Production code
221/// uses the real `execute_stream` gRPC RPC instead.
222#[doc(hidden)]
223pub fn spawn_stub_producer(
224    query_collection: String,
225    total_rows: usize,
226    sender: RowSender,
227    timeout_secs: Option<u64>,
228) -> tokio::task::JoinHandle<()> {
229    tokio::spawn(async move {
230        let deadline = timeout_secs.map(|s| tokio::time::Instant::now() + Duration::from_secs(s));
231
232        for i in 0..total_rows {
233            // Check cancellation before each send.
234            if sender.is_cancelled() {
235                break;
236            }
237
238            // Honour optional deadline.
239            if let Some(dl) = deadline {
240                if tokio::time::Instant::now() >= dl {
241                    break;
242                }
243            }
244
245            let key =
246                format!("{collection}:row:{i}", collection = query_collection, i = i).into_bytes();
247            let value = (i as u64).to_le_bytes().to_vec();
248            let row = Row::new(key, value);
249
250            if !sender.send_row(row).await {
251                break;
252            }
253        }
254        // Task finishes cleanly; channel closes when sender is dropped.
255    })
256}
257
258// ---------------------------------------------------------------------------
259// Tests
260// ---------------------------------------------------------------------------
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use futures::StreamExt;
266
267    #[tokio::test]
268    async fn test_stream_config_defaults() {
269        let cfg = StreamConfig::default();
270        assert_eq!(cfg.buffer_size, 64);
271        assert!(cfg.timeout_secs.is_none());
272    }
273
274    #[tokio::test]
275    async fn test_row_construction() {
276        let row = Row::new(b"key".to_vec(), b"value".to_vec());
277        assert_eq!(row.key, b"key");
278        assert_eq!(row.value, b"value");
279    }
280
281    #[tokio::test]
282    async fn test_stream_collects_rows() {
283        let config = StreamConfig::new(16);
284        let (stream, sender) = QueryStream::new(&config);
285
286        let _handle = spawn_stub_producer("test".to_string(), 5, sender, None);
287
288        let rows: Vec<_> = stream.collect().await;
289        assert_eq!(rows.len(), 5);
290        for r in &rows {
291            assert!(r.is_ok());
292        }
293    }
294
295    #[tokio::test]
296    async fn test_stream_cancellation_stops_producer() {
297        use std::sync::Arc;
298        use std::sync::atomic::{AtomicBool, Ordering};
299        use tokio::time::{Duration, sleep};
300
301        let config = StreamConfig::new(4);
302        let (stream, sender) = QueryStream::new(&config);
303        let finished = Arc::new(AtomicBool::new(false));
304        let finished_clone = Arc::clone(&finished);
305
306        let _handle = tokio::spawn(async move {
307            spawn_stub_producer("cancel_test".to_string(), 1_000, sender, None)
308                .await
309                .ok();
310            finished_clone.store(true, Ordering::Release);
311        });
312
313        // Drop after receiving just 2 rows.
314        let mut s = stream;
315        let _ = s.next().await;
316        let _ = s.next().await;
317        drop(s); // triggers CancellationToken::cancel()
318
319        // Producer should stop within 1 second.
320        let deadline = tokio::time::Instant::now() + Duration::from_secs(1);
321        while !finished.load(Ordering::Acquire) {
322            if tokio::time::Instant::now() >= deadline {
323                panic!("producer task did not stop within 1 second after stream was dropped");
324            }
325            sleep(Duration::from_millis(10)).await;
326        }
327    }
328}