amaters_sdk_rust/streaming.rs
1//! Streaming query support for AmateRS SDK
2//!
3//! Provides [`QueryStream`] — a bounded, cancellable stream of [`Row`]s produced
4//! by a background task. The stream implements [`futures::Stream`] so callers
5//! can use standard combinators (`map`, `filter`, `collect`, …).
6//!
7//! # Design
8//!
9//! * **Backpressure** — The producer writes into a bounded
10//! [`tokio::sync::mpsc`] channel whose capacity is set by
11//! [`StreamConfig::buffer_size`]. The producer is forced to `await` once the
12//! channel is full, which naturally throttles generation rate to consumption
13//! rate.
14//!
15//! * **Cancellation** — A [`tokio_util::sync::CancellationToken`] is shared
16//! between the consumer and the background producer task. Dropping the
17//! [`QueryStream`] cancels the token, and the producer checks it before every
18//! `send`.
19//!
20//! * **Lazy** — The background task is spawned by `QueryStream`; no
21//! data is generated until [`QueryStream`] is polled.
22//!
23//! The [`spawn_stub_producer`] helper is retained as a `#[doc(hidden)]`
24//! test utility; production code uses the real `execute_stream` RPC.
25
26use crate::error::SdkError;
27use futures::Stream;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::sync::atomic::{AtomicUsize, Ordering};
31use std::task::{Context, Poll};
32use std::time::Duration;
33use tokio::sync::mpsc;
34use tokio_util::sync::CancellationToken;
35
36// ---------------------------------------------------------------------------
37// Public types
38// ---------------------------------------------------------------------------
39
40/// Configuration controlling backpressure and optional timeout for a
41/// [`QueryStream`].
42#[derive(Debug, Clone)]
43pub struct StreamConfig {
44 /// Channel capacity — the maximum number of un-consumed [`Row`]s that
45 /// can be buffered in memory. When the channel is full the producer
46 /// blocks until the consumer drains at least one item. Defaults to
47 /// **64**.
48 pub buffer_size: usize,
49
50 /// Optional per-stream timeout in seconds. If set, the background task
51 /// is automatically cancelled after this many seconds even if rows
52 /// remain unread. `None` means no timeout.
53 pub timeout_secs: Option<u64>,
54}
55
56impl Default for StreamConfig {
57 fn default() -> Self {
58 Self {
59 buffer_size: 64,
60 timeout_secs: None,
61 }
62 }
63}
64
65impl StreamConfig {
66 /// Create a new config with the given buffer size.
67 pub fn new(buffer_size: usize) -> Self {
68 Self {
69 buffer_size,
70 timeout_secs: None,
71 }
72 }
73
74 /// Set an optional timeout (seconds) after which the stream is cancelled.
75 #[must_use]
76 pub fn with_timeout(mut self, secs: u64) -> Self {
77 self.timeout_secs = Some(secs);
78 self
79 }
80}
81
82// ---------------------------------------------------------------------------
83// Row
84// ---------------------------------------------------------------------------
85
86/// A single key-value result row returned from a streaming query.
87#[derive(Debug, Clone, PartialEq, Eq)]
88pub struct Row {
89 /// The row key (raw bytes).
90 pub key: Vec<u8>,
91 /// The row value (raw bytes; may be ciphertext).
92 pub value: Vec<u8>,
93}
94
95impl Row {
96 /// Create a new row from raw key and value bytes.
97 pub fn new(key: impl Into<Vec<u8>>, value: impl Into<Vec<u8>>) -> Self {
98 Self {
99 key: key.into(),
100 value: value.into(),
101 }
102 }
103}
104
105// ---------------------------------------------------------------------------
106// RowSender — handle given to the producer task
107// ---------------------------------------------------------------------------
108
109/// A handle used by the background producer task to send rows to the
110/// consumer and to detect cancellation.
111pub struct RowSender {
112 tx: mpsc::Sender<Result<Row, SdkError>>,
113 cancel: CancellationToken,
114 /// Tracks the total number of rows sent, exposed in tests.
115 pub sent: Arc<AtomicUsize>,
116}
117
118impl RowSender {
119 /// Send a row, blocking until there is capacity in the channel.
120 ///
121 /// Returns `false` if the stream was cancelled or the receiver was
122 /// dropped (either means the producer should stop).
123 pub async fn send_row(&self, row: Row) -> bool {
124 if self.cancel.is_cancelled() {
125 return false;
126 }
127 self.sent.fetch_add(1, Ordering::Relaxed);
128 self.tx.send(Ok(row)).await.is_ok()
129 }
130
131 /// Returns `true` if the stream has been cancelled.
132 pub fn is_cancelled(&self) -> bool {
133 self.cancel.is_cancelled()
134 }
135
136 /// Return a clone of the cancellation token so that producers can
137 /// `select!` against it while awaiting a slow I/O operation.
138 pub fn cancel_token(&self) -> CancellationToken {
139 self.cancel.clone()
140 }
141
142 /// Send an error to the consumer.
143 ///
144 /// Returns `false` if the stream was cancelled or the receiver was
145 /// dropped.
146 pub async fn send_error(&self, err: SdkError) -> bool {
147 if self.cancel.is_cancelled() {
148 return false;
149 }
150 self.tx.send(Err(err)).await.is_ok()
151 }
152}
153
154// ---------------------------------------------------------------------------
155// QueryStream
156// ---------------------------------------------------------------------------
157
158/// A cancellable, backpressure-aware stream of [`Row`]s from a query.
159///
160/// Implements [`futures::Stream`] with `Item = Result<Row, SdkError>`.
161/// Dropping the stream cancels the background producer task.
162pub struct QueryStream {
163 /// Receiving end of the bounded channel.
164 rx: mpsc::Receiver<Result<Row, SdkError>>,
165 /// Token used to signal cancellation to the producer task.
166 cancel: CancellationToken,
167}
168
169impl QueryStream {
170 /// Spawn a background producer task and return the paired consumer stream
171 /// together with a [`RowSender`] for the task to use.
172 ///
173 /// The caller is responsible for spawning a `tokio::task` that uses the
174 /// returned [`RowSender`].
175 pub fn new(config: &StreamConfig) -> (Self, RowSender) {
176 let (tx, rx) = mpsc::channel(config.buffer_size);
177 let cancel = CancellationToken::new();
178 let sent = Arc::new(AtomicUsize::new(0));
179
180 let sender = RowSender {
181 tx,
182 cancel: cancel.clone(),
183 sent,
184 };
185
186 let stream = Self { rx, cancel };
187
188 (stream, sender)
189 }
190
191 /// Cancel the background task immediately.
192 pub fn cancel(&self) {
193 self.cancel.cancel();
194 }
195}
196
197impl Drop for QueryStream {
198 fn drop(&mut self) {
199 // Signal the producer to stop when the stream is dropped.
200 self.cancel.cancel();
201 }
202}
203
204impl Stream for QueryStream {
205 type Item = Result<Row, SdkError>;
206
207 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
208 self.rx.poll_recv(cx)
209 }
210}
211
212// ---------------------------------------------------------------------------
213// Stub generator — used by client::stream_query
214// ---------------------------------------------------------------------------
215
216/// Spawn a simulated producer that generates `total_rows` mock rows derived
217/// from the query collection name.
218///
219/// This is a **test-only** helper used to exercise streaming infrastructure
220/// (backpressure, cancellation, etc.) without a live server. Production code
221/// uses the real `execute_stream` gRPC RPC instead.
222#[doc(hidden)]
223pub fn spawn_stub_producer(
224 query_collection: String,
225 total_rows: usize,
226 sender: RowSender,
227 timeout_secs: Option<u64>,
228) -> tokio::task::JoinHandle<()> {
229 tokio::spawn(async move {
230 let deadline = timeout_secs.map(|s| tokio::time::Instant::now() + Duration::from_secs(s));
231
232 for i in 0..total_rows {
233 // Check cancellation before each send.
234 if sender.is_cancelled() {
235 break;
236 }
237
238 // Honour optional deadline.
239 if let Some(dl) = deadline {
240 if tokio::time::Instant::now() >= dl {
241 break;
242 }
243 }
244
245 let key =
246 format!("{collection}:row:{i}", collection = query_collection, i = i).into_bytes();
247 let value = (i as u64).to_le_bytes().to_vec();
248 let row = Row::new(key, value);
249
250 if !sender.send_row(row).await {
251 break;
252 }
253 }
254 // Task finishes cleanly; channel closes when sender is dropped.
255 })
256}
257
258// ---------------------------------------------------------------------------
259// Tests
260// ---------------------------------------------------------------------------
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use futures::StreamExt;
266
267 #[tokio::test]
268 async fn test_stream_config_defaults() {
269 let cfg = StreamConfig::default();
270 assert_eq!(cfg.buffer_size, 64);
271 assert!(cfg.timeout_secs.is_none());
272 }
273
274 #[tokio::test]
275 async fn test_row_construction() {
276 let row = Row::new(b"key".to_vec(), b"value".to_vec());
277 assert_eq!(row.key, b"key");
278 assert_eq!(row.value, b"value");
279 }
280
281 #[tokio::test]
282 async fn test_stream_collects_rows() {
283 let config = StreamConfig::new(16);
284 let (stream, sender) = QueryStream::new(&config);
285
286 let _handle = spawn_stub_producer("test".to_string(), 5, sender, None);
287
288 let rows: Vec<_> = stream.collect().await;
289 assert_eq!(rows.len(), 5);
290 for r in &rows {
291 assert!(r.is_ok());
292 }
293 }
294
295 #[tokio::test]
296 async fn test_stream_cancellation_stops_producer() {
297 use std::sync::Arc;
298 use std::sync::atomic::{AtomicBool, Ordering};
299 use tokio::time::{Duration, sleep};
300
301 let config = StreamConfig::new(4);
302 let (stream, sender) = QueryStream::new(&config);
303 let finished = Arc::new(AtomicBool::new(false));
304 let finished_clone = Arc::clone(&finished);
305
306 let _handle = tokio::spawn(async move {
307 spawn_stub_producer("cancel_test".to_string(), 1_000, sender, None)
308 .await
309 .ok();
310 finished_clone.store(true, Ordering::Release);
311 });
312
313 // Drop after receiving just 2 rows.
314 let mut s = stream;
315 let _ = s.next().await;
316 let _ = s.next().await;
317 drop(s); // triggers CancellationToken::cancel()
318
319 // Producer should stop within 1 second.
320 let deadline = tokio::time::Instant::now() + Duration::from_secs(1);
321 while !finished.load(Ordering::Acquire) {
322 if tokio::time::Instant::now() >= deadline {
323 panic!("producer task did not stop within 1 second after stream was dropped");
324 }
325 sleep(Duration::from_millis(10)).await;
326 }
327 }
328}