jetstream_extra/
batch_fetch.rs

1// Copyright 2025 Synadia Communications Inc.
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! Batch fetch operations for JetStream streams.
15//!
16//! This module provides efficient batch fetching of messages from JetStream streams
17//! using the DIRECT.GET API as specified in ADR-31.
18//!
19//! # Examples
20//!
21//! ## Fetch a batch of messages
22//!
23//! ```no_run
24//! # use jetstream_extra::batch_fetch::BatchFetchExt;
25//! # use futures::StreamExt;
26//! # async fn example(context: async_nats::jetstream::Context) -> Result<(), Box<dyn std::error::Error>> {
27//! use jetstream_extra::batch_fetch::BatchFetchExt;
28//!
29//! // Fetch 100 messages starting from sequence 1
30//! let mut messages = context
31//!     .get_batch("my_stream", 100)
32//!     .send()
33//!     .await?;
34//!
35//! while let Some(msg) = messages.next().await {
36//!     let msg = msg?;
37//!     println!("Message at seq {}: {:?}", msg.sequence, msg.payload);
38//! }
39//! # Ok(())
40//! # }
41//! ```
42//!
43//! ## Get last messages for multiple subjects
44//!
45//! ```no_run
46//! # use jetstream_extra::batch_fetch::BatchFetchExt;
47//! # use futures::StreamExt;
48//! # async fn example(context: async_nats::jetstream::Context) -> Result<(), Box<dyn std::error::Error>> {
49//! use jetstream_extra::batch_fetch::BatchFetchExt;
50//!
51//! // Get the last message for each subject
52//! let subjects = vec!["events.user.1".to_string(), "events.user.2".to_string()];
53//! let mut messages = context
54//!     .get_last_messages_for("my_stream")
55//!     .subjects(subjects)
56//!     .send()
57//!     .await?;
58//!
59//! while let Some(msg) = messages.next().await {
60//!     let msg = msg?;
61//!     println!("Last message for {}: {:?}", msg.subject, msg.payload);
62//! }
63//! # Ok(())
64//! # }
65//! ```
66
67use std::marker::PhantomData;
68use std::pin::Pin;
69use std::task::{Context, Poll};
70
71use crate::StreamMessage;
72use async_nats::jetstream::context::traits::{ClientProvider, RequestSender, TimeoutProvider};
73use async_nats::{Message, Subject, Subscriber};
74use bytes::Bytes;
75use futures::{FutureExt, Stream, StreamExt};
76use serde::Serialize;
77use time::OffsetDateTime;
78use time::serde::rfc3339;
79use tracing::debug;
80
81// State types for compile-time mutual exclusivity
82pub struct NoSeq;
83pub struct WithSeq;
84pub struct NoTime;
85pub struct WithTime;
86
87/// Builder for batch fetching messages from a stream.
88pub struct GetBatchBuilder<T, SEQ = NoSeq, TIME = NoTime>
89where
90    T: ClientProvider + TimeoutProvider + RequestSender,
91{
92    context: T,
93    stream: String,
94    batch: usize,
95    seq: Option<u64>,
96    subject: Option<String>,
97    max_bytes: Option<usize>,
98    start_time: Option<OffsetDateTime>,
99    _phantom: PhantomData<(SEQ, TIME)>,
100}
101
102/// Builder for fetching last messages for multiple subjects.
103pub struct GetLastBuilder<T, SEQ = NoSeq, TIME = NoTime>
104where
105    T: ClientProvider + TimeoutProvider + RequestSender,
106{
107    context: T,
108    stream: String,
109    subjects: Option<Vec<String>>,
110    up_to_seq: Option<u64>,
111    up_to_time: Option<OffsetDateTime>,
112    batch: Option<usize>,
113    _phantom: PhantomData<(SEQ, TIME)>,
114}
115
116/// Extension trait for batch fetching messages from JetStream streams.
117pub trait BatchFetchExt: ClientProvider + TimeoutProvider + RequestSender + Clone {
118    /// Create a builder for fetching a batch of messages from a stream.
119    ///
120    /// Uses the DIRECT.GET API to efficiently retrieve multiple messages
121    /// in a single request. The server sends messages without flow control
122    /// up to the specified batch size or max_bytes limit.
123    fn get_batch(&self, stream: &str, batch: usize) -> GetBatchBuilder<Self, NoSeq, NoTime>;
124
125    /// Create a builder for fetching the last messages for multiple subjects.
126    ///
127    /// Retrieves the most recent message for each of the specified subjects
128    /// from the stream. Supports consistent point-in-time reads across
129    /// multiple subjects using `up_to_seq` or `up_to_time` options.
130    fn get_last_messages_for(&self, stream: &str) -> GetLastBuilder<Self, NoSeq, NoTime>;
131}
132
133/// Error type for batch fetch operations.
134pub type BatchFetchError = async_nats::error::Error<BatchFetchErrorKind>;
135
136/// Kinds of errors that can occur during batch fetch operations.
137#[derive(Debug, Clone, Copy, PartialEq)]
138pub enum BatchFetchErrorKind {
139    /// The server does not support batch get operations.
140    UnsupportedByServer,
141    NoMessages,
142    InvalidResponse,
143    Serialization,
144    Subscription,
145    Publish,
146    MissingHeader,
147    InvalidHeader,
148    InvalidRequest,
149    TooManySubjects,
150    BatchSizeTooLarge,
151    BatchSizeRequired,
152    SubjectsRequired,
153    InvalidStreamName,
154    InvalidOption,
155    TimedOut,
156    Other,
157}
158
159impl std::fmt::Display for BatchFetchErrorKind {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            Self::UnsupportedByServer => write!(f, "batch get not supported by server"),
163            Self::NoMessages => write!(f, "no messages found"),
164            Self::InvalidResponse => write!(f, "invalid response from server"),
165            Self::Serialization => write!(f, "serialization error"),
166            Self::Subscription => write!(f, "subscription error"),
167            Self::Publish => write!(f, "publish error"),
168            Self::MissingHeader => write!(f, "missing required header"),
169            Self::InvalidHeader => write!(f, "invalid header value"),
170            Self::InvalidRequest => write!(f, "invalid request parameters"),
171            Self::TooManySubjects => write!(f, "too many subjects (max 1024)"),
172            Self::BatchSizeTooLarge => write!(f, "batch size too large (max 1000)"),
173            Self::BatchSizeRequired => write!(f, "batch size is required"),
174            Self::SubjectsRequired => write!(f, "subjects are required for multi_last"),
175            Self::InvalidStreamName => write!(f, "invalid stream name"),
176            Self::InvalidOption => write!(f, "invalid option"),
177            Self::TimedOut => write!(f, "batch fetch operation timed out"),
178            Self::Other => write!(f, "batch fetch error"),
179        }
180    }
181}
182
183/// Request for batch get operations
184#[derive(Debug, Serialize)]
185struct GetBatchRequest {
186    #[serde(skip_serializing_if = "Option::is_none")]
187    seq: Option<u64>,
188    #[serde(skip_serializing_if = "Option::is_none")]
189    next_by_subj: Option<String>,
190    batch: usize,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    max_bytes: Option<usize>,
193    #[serde(skip_serializing_if = "Option::is_none", with = "rfc3339::option")]
194    start_time: Option<time::OffsetDateTime>,
195}
196
197/// Request for multi-last get operations
198#[derive(Debug, Serialize)]
199struct GetLastRequest {
200    multi_last: Vec<String>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    batch: Option<usize>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    up_to_seq: Option<u64>,
205    #[serde(skip_serializing_if = "Option::is_none", with = "rfc3339::option")]
206    up_to_time: Option<time::OffsetDateTime>,
207}
208
209/// Stream of messages from batch fetch operations.
210pub struct BatchStream {
211    subscriber: Subscriber,
212    timeout: std::time::Duration,
213    timeout_at: Option<Pin<Box<tokio::time::Sleep>>>,
214    terminated: bool,
215}
216
217impl Stream for BatchStream {
218    type Item = Result<StreamMessage, BatchFetchError>;
219
220    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
221        if self.terminated {
222            return Poll::Ready(None);
223        }
224
225        let timeout = self.timeout;
226        match self
227            .timeout_at
228            .get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout)))
229            .poll_unpin(cx)
230        {
231            Poll::Ready(_) => {
232                debug!("Batch fetch operation timed out after {:?}", timeout);
233                self.terminated = true;
234                return Poll::Ready(Some(Err(BatchFetchError::new(
235                    BatchFetchErrorKind::TimedOut,
236                ))));
237            }
238            Poll::Pending => {}
239        }
240
241        match self.subscriber.next().poll_unpin(cx) {
242            Poll::Ready(Some(msg)) => {
243                // Check for End-Of-Batch marker
244                // EOB can be detected in two ways:
245                // 1. ADR-31 spec: Empty payload with Status: 204, Description: EOB
246                // 2. Current server impl: Empty payload with missing essential headers
247                if msg.payload.is_empty()
248                    && let Some(headers) = &msg.headers
249                {
250                    let status = headers.get("Status").map(|v| v.as_str());
251                    let desc = headers.get("Description").map(|v| v.as_str());
252
253                    // Termination by EOB.
254                    if status == Some("204") && desc == Some("EOB") {
255                        self.terminated = true;
256                        return Poll::Ready(None);
257                    }
258
259                    // Termination by empty message with headers missing.
260                    if headers.is_empty() {
261                        self.terminated = true;
262                        return Poll::Ready(None);
263                    }
264
265                    // Termination by end of batch.
266                    // TODO(jrm): we should consider hinting those to the user.
267                    if headers.get(async_nats::header::NATS_SEQUENCE).is_some()
268                        || headers.get("Nats-Num-Pending").is_some()
269                        || headers.get("Nats-UpTo-Sequence").is_some()
270                    {
271                        self.terminated = true;
272                        return Poll::Ready(None);
273                    }
274                }
275
276                match convert_to_stream_message(msg) {
277                    Ok(raw_msg) => Poll::Ready(Some(Ok(raw_msg))),
278                    Err(e) => Poll::Ready(Some(Err(e))),
279                }
280            }
281            Poll::Ready(None) => {
282                self.terminated = true;
283                Poll::Ready(None)
284            }
285            Poll::Pending => Poll::Pending,
286        }
287    }
288}
289
290impl<T> BatchFetchExt for T
291where
292    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
293{
294    fn get_batch(&self, stream: &str, batch: usize) -> GetBatchBuilder<Self, NoSeq, NoTime> {
295        GetBatchBuilder {
296            context: self.clone(),
297            stream: stream.to_string(),
298            batch,
299            seq: None,
300            subject: None,
301            max_bytes: None,
302            start_time: None,
303            _phantom: PhantomData,
304        }
305    }
306
307    fn get_last_messages_for(&self, stream: &str) -> GetLastBuilder<Self, NoSeq, NoTime> {
308        GetLastBuilder {
309            context: self.clone(),
310            stream: stream.to_string(),
311            subjects: None,
312            up_to_seq: None,
313            up_to_time: None,
314            batch: None,
315            _phantom: PhantomData,
316        }
317    }
318}
319
320// Implementation for all states - common methods
321impl<T, SEQ, TIME> GetBatchBuilder<T, SEQ, TIME>
322where
323    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
324{
325    /// Set the subject filter (may include wildcards).
326    pub fn subject<S: Into<String>>(mut self, subject: S) -> Self {
327        self.subject = Some(subject.into());
328        self
329    }
330
331    /// Set the maximum bytes to return.
332    pub fn max_bytes(mut self, max_bytes: usize) -> Self {
333        self.max_bytes = Some(max_bytes);
334        self
335    }
336}
337
338// Methods only available when no time has been set
339impl<T> GetBatchBuilder<T, NoSeq, NoTime>
340where
341    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
342{
343    /// Set the starting sequence number.
344    /// This is mutually exclusive with `start_time`.
345    pub fn sequence(mut self, seq: u64) -> GetBatchBuilder<T, WithSeq, NoTime> {
346        self.seq = Some(seq);
347        GetBatchBuilder {
348            context: self.context,
349            stream: self.stream,
350            seq: self.seq,
351            batch: self.batch,
352            subject: self.subject,
353            max_bytes: self.max_bytes,
354            start_time: self.start_time,
355            _phantom: PhantomData,
356        }
357    }
358
359    /// Set the start time for time-based fetching.
360    /// This is mutually exclusive with `seq`.
361    pub fn start_time(mut self, start_time: OffsetDateTime) -> GetBatchBuilder<T, NoSeq, WithTime> {
362        self.start_time = Some(start_time);
363        GetBatchBuilder {
364            context: self.context,
365            stream: self.stream,
366            batch: self.batch,
367            seq: self.seq,
368            subject: self.subject,
369            max_bytes: self.max_bytes,
370            start_time: self.start_time,
371            _phantom: PhantomData,
372        }
373    }
374}
375
376// Additional methods for WithSeq state
377impl<T> GetBatchBuilder<T, WithSeq, NoTime>
378where
379    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
380{
381    // seq() and start_time() are not available in this state
382}
383
384// Additional methods for WithTime state
385impl<T> GetBatchBuilder<T, NoSeq, WithTime>
386where
387    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
388{
389    // seq() and start_time() are not available in this state
390}
391
392// Send method available for all states
393impl<T, SEQ, TIME> GetBatchBuilder<T, SEQ, TIME>
394where
395    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
396{
397    /// Send the batch fetch request and return a stream of messages.
398    pub async fn send(self) -> Result<BatchStream, BatchFetchError> {
399        // Validate stream name
400        if self.stream.is_empty() {
401            return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidStreamName));
402        }
403
404        // Validate batch size against server limit
405        if self.batch > 1000 {
406            return Err(BatchFetchError::new(BatchFetchErrorKind::BatchSizeTooLarge));
407        }
408
409        // Validate seq if specified (must be > 0)
410        if let Some(seq) = self.seq
411            && seq == 0
412        {
413            return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidOption));
414        }
415
416        // Validate max_bytes if specified (must be > 0)
417        if let Some(max_bytes) = self.max_bytes
418            && max_bytes == 0
419        {
420            return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidOption));
421        }
422
423        // Build the batch request per ADR-31
424        let request = GetBatchRequest {
425            seq: if self.seq.is_some() {
426                self.seq
427            } else if self.start_time.is_none() {
428                Some(1) // Default to sequence 1 if neither seq nor start_time is specified
429            } else {
430                None
431            },
432            next_by_subj: self.subject,
433            batch: self.batch,
434            max_bytes: self.max_bytes,
435            start_time: self.start_time,
436        };
437
438        let payload = serde_json::to_vec(&request)
439            .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Serialization, e))?
440            .into();
441        // RequestSender will add the proper prefix ($JS.API. or custom)
442        let subject = format!("DIRECT.GET.{}", self.stream);
443
444        send_batch_request(&self.context, subject, payload).await
445    }
446}
447
448// Implementation for all states - common methods
449impl<T, SEQ, TIME> GetLastBuilder<T, SEQ, TIME>
450where
451    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
452{
453    /// Set the subjects to fetch last messages for.
454    ///
455    /// # Limits
456    /// - Maximum: 1024 subjects per request (server limit)
457    /// - Returns `BatchFetchErrorKind::TooManySubjects` if exceeded
458    pub fn subjects(mut self, subjects: Vec<String>) -> Self {
459        self.subjects = Some(subjects);
460        self
461    }
462
463    /// Set the optional batch size.
464    ///
465    /// # Limits
466    /// - Maximum: 1000 messages per request (server limit)
467    /// - Returns `BatchFetchErrorKind::BatchSizeTooLarge` if exceeded
468    pub fn batch(mut self, batch: usize) -> Self {
469        self.batch = Some(batch);
470        self
471    }
472}
473
474// Methods only available when no up_to_seq has been set
475impl<T> GetLastBuilder<T, NoSeq, NoTime>
476where
477    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
478{
479    /// Set the sequence number to fetch up to (inclusive).
480    /// This is mutually exclusive with `up_to_time`.
481    pub fn up_to_seq(mut self, seq: u64) -> GetLastBuilder<T, WithSeq, NoTime> {
482        self.up_to_seq = Some(seq);
483        GetLastBuilder {
484            context: self.context,
485            stream: self.stream,
486            subjects: self.subjects,
487            up_to_seq: self.up_to_seq,
488            up_to_time: self.up_to_time,
489            batch: self.batch,
490            _phantom: PhantomData,
491        }
492    }
493
494    /// Set the time to fetch up to.
495    /// This is mutually exclusive with `up_to_seq`.
496    pub fn up_to_time(mut self, time: OffsetDateTime) -> GetLastBuilder<T, NoSeq, WithTime> {
497        self.up_to_time = Some(time);
498        GetLastBuilder {
499            context: self.context,
500            stream: self.stream,
501            subjects: self.subjects,
502            up_to_seq: self.up_to_seq,
503            up_to_time: self.up_to_time,
504            batch: self.batch,
505            _phantom: PhantomData,
506        }
507    }
508}
509
510// Additional methods for WithSeq state
511impl<T> GetLastBuilder<T, WithSeq, NoTime>
512where
513    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
514{
515    // up_to_seq() and up_to_time() are not available in this state
516}
517
518// Additional methods for WithTime state
519impl<T> GetLastBuilder<T, NoSeq, WithTime>
520where
521    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
522{
523    // up_to_seq() and up_to_time() are not available in this state
524}
525
526// Send method available for all states
527impl<T, SEQ, TIME> GetLastBuilder<T, SEQ, TIME>
528where
529    T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
530{
531    /// Send the request to get last messages and return a stream.
532    pub async fn send(self) -> Result<BatchStream, BatchFetchError> {
533        if self.stream.is_empty() {
534            return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidStreamName));
535        }
536
537        let subjects = self
538            .subjects
539            .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::SubjectsRequired))?;
540
541        // Validate subject count against server limit
542        if subjects.len() > 1024 {
543            return Err(BatchFetchError::new(BatchFetchErrorKind::TooManySubjects));
544        }
545
546        // Validate batch size if specified
547        if let Some(batch) = self.batch {
548            if batch == 0 {
549                return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidOption));
550            }
551            if batch > 1000 {
552                return Err(BatchFetchError::new(BatchFetchErrorKind::BatchSizeTooLarge));
553            }
554        }
555
556        if subjects.is_empty() {
557            return Err(BatchFetchError::new(BatchFetchErrorKind::SubjectsRequired));
558        }
559
560        let request = GetLastRequest {
561            multi_last: subjects,
562            batch: self.batch,
563            up_to_seq: self.up_to_seq,
564            up_to_time: self.up_to_time,
565        };
566
567        let payload = serde_json::to_vec(&request)
568            .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Serialization, e))?
569            .into();
570        let subject = format!("DIRECT.GET.{}", self.stream);
571
572        send_batch_request(&self.context, subject, payload).await
573    }
574}
575
576async fn send_batch_request<T>(
577    context: &T,
578    subject: String,
579    payload: Bytes,
580) -> Result<BatchStream, BatchFetchError>
581where
582    T: ClientProvider + TimeoutProvider + RequestSender,
583{
584    // Create inbox and subscribe to it for responses
585    let client = context.client();
586    let inbox = client.new_inbox();
587    let subscriber = client
588        .subscribe(inbox.clone())
589        .await
590        .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Subscription, e))?;
591
592    let request = async_nats::Request {
593        inbox: Some(inbox),
594        payload: Some(payload),
595        headers: None,
596        timeout: None,
597    };
598    context
599        .send_request(subject, request)
600        .await
601        .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Publish, e))?;
602
603    Ok(BatchStream {
604        subscriber,
605        terminated: false,
606        timeout: context.timeout(),
607        timeout_at: None,
608    })
609}
610
611fn convert_to_stream_message(msg: Message) -> Result<StreamMessage, BatchFetchError> {
612    if msg.payload.is_empty()
613        && let Some(headers) = &msg.headers
614    {
615        let status = headers.get("Status").map(|v| v.as_str());
616        match status {
617            Some("404") => return Err(BatchFetchError::new(BatchFetchErrorKind::NoMessages)),
618            Some("408") => return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidRequest)),
619            Some("413") => return Err(BatchFetchError::new(BatchFetchErrorKind::TooManySubjects)),
620            _ => {}
621        }
622    }
623
624    let headers = msg
625        .headers
626        .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::InvalidResponse))?;
627
628    // Check if server supports batch get by looking for Nats-Num-Pending header
629    // Servers that don't support batch get won't include this header
630    if headers.get("Nats-Num-Pending").is_none() {
631        return Err(BatchFetchError::new(
632            BatchFetchErrorKind::UnsupportedByServer,
633        ));
634    }
635
636    let subject = headers
637        .get(async_nats::header::NATS_SUBJECT)
638        .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::MissingHeader))?
639        .to_string();
640
641    let sequence = headers
642        .get(async_nats::header::NATS_SEQUENCE)
643        .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::MissingHeader))?
644        .as_str()
645        .parse::<u64>()
646        .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::InvalidHeader, e))?;
647
648    let time_str = headers
649        .get(async_nats::header::NATS_TIME_STAMP)
650        .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::MissingHeader))?
651        .as_str();
652
653    // Parse RFC3339 timestamp
654    let time =
655        time::OffsetDateTime::parse(time_str, &time::format_description::well_known::Rfc3339)
656            .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::InvalidHeader, e))?;
657
658    Ok(StreamMessage {
659        subject: Subject::from(subject),
660        sequence,
661        payload: msg.payload,
662        headers,
663        time,
664    })
665}