1use std::marker::PhantomData;
68use std::pin::Pin;
69use std::task::{Context, Poll};
70
71use crate::StreamMessage;
72use async_nats::jetstream::context::traits::{ClientProvider, RequestSender, TimeoutProvider};
73use async_nats::{Message, Subject, Subscriber};
74use bytes::Bytes;
75use futures::{FutureExt, Stream, StreamExt};
76use serde::Serialize;
77use time::OffsetDateTime;
78use time::serde::rfc3339;
79use tracing::debug;
80
81pub struct NoSeq;
83pub struct WithSeq;
84pub struct NoTime;
85pub struct WithTime;
86
87pub struct GetBatchBuilder<T, SEQ = NoSeq, TIME = NoTime>
89where
90 T: ClientProvider + TimeoutProvider + RequestSender,
91{
92 context: T,
93 stream: String,
94 batch: usize,
95 seq: Option<u64>,
96 subject: Option<String>,
97 max_bytes: Option<usize>,
98 start_time: Option<OffsetDateTime>,
99 _phantom: PhantomData<(SEQ, TIME)>,
100}
101
102pub struct GetLastBuilder<T, SEQ = NoSeq, TIME = NoTime>
104where
105 T: ClientProvider + TimeoutProvider + RequestSender,
106{
107 context: T,
108 stream: String,
109 subjects: Option<Vec<String>>,
110 up_to_seq: Option<u64>,
111 up_to_time: Option<OffsetDateTime>,
112 batch: Option<usize>,
113 _phantom: PhantomData<(SEQ, TIME)>,
114}
115
116pub trait BatchFetchExt: ClientProvider + TimeoutProvider + RequestSender + Clone {
118 fn get_batch(&self, stream: &str, batch: usize) -> GetBatchBuilder<Self, NoSeq, NoTime>;
124
125 fn get_last_messages_for(&self, stream: &str) -> GetLastBuilder<Self, NoSeq, NoTime>;
131}
132
133pub type BatchFetchError = async_nats::error::Error<BatchFetchErrorKind>;
135
136#[derive(Debug, Clone, Copy, PartialEq)]
138pub enum BatchFetchErrorKind {
139 UnsupportedByServer,
141 NoMessages,
142 InvalidResponse,
143 Serialization,
144 Subscription,
145 Publish,
146 MissingHeader,
147 InvalidHeader,
148 InvalidRequest,
149 TooManySubjects,
150 BatchSizeTooLarge,
151 BatchSizeRequired,
152 SubjectsRequired,
153 InvalidStreamName,
154 InvalidOption,
155 TimedOut,
156 Other,
157}
158
159impl std::fmt::Display for BatchFetchErrorKind {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 Self::UnsupportedByServer => write!(f, "batch get not supported by server"),
163 Self::NoMessages => write!(f, "no messages found"),
164 Self::InvalidResponse => write!(f, "invalid response from server"),
165 Self::Serialization => write!(f, "serialization error"),
166 Self::Subscription => write!(f, "subscription error"),
167 Self::Publish => write!(f, "publish error"),
168 Self::MissingHeader => write!(f, "missing required header"),
169 Self::InvalidHeader => write!(f, "invalid header value"),
170 Self::InvalidRequest => write!(f, "invalid request parameters"),
171 Self::TooManySubjects => write!(f, "too many subjects (max 1024)"),
172 Self::BatchSizeTooLarge => write!(f, "batch size too large (max 1000)"),
173 Self::BatchSizeRequired => write!(f, "batch size is required"),
174 Self::SubjectsRequired => write!(f, "subjects are required for multi_last"),
175 Self::InvalidStreamName => write!(f, "invalid stream name"),
176 Self::InvalidOption => write!(f, "invalid option"),
177 Self::TimedOut => write!(f, "batch fetch operation timed out"),
178 Self::Other => write!(f, "batch fetch error"),
179 }
180 }
181}
182
183#[derive(Debug, Serialize)]
185struct GetBatchRequest {
186 #[serde(skip_serializing_if = "Option::is_none")]
187 seq: Option<u64>,
188 #[serde(skip_serializing_if = "Option::is_none")]
189 next_by_subj: Option<String>,
190 batch: usize,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 max_bytes: Option<usize>,
193 #[serde(skip_serializing_if = "Option::is_none", with = "rfc3339::option")]
194 start_time: Option<time::OffsetDateTime>,
195}
196
197#[derive(Debug, Serialize)]
199struct GetLastRequest {
200 multi_last: Vec<String>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 batch: Option<usize>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 up_to_seq: Option<u64>,
205 #[serde(skip_serializing_if = "Option::is_none", with = "rfc3339::option")]
206 up_to_time: Option<time::OffsetDateTime>,
207}
208
209pub struct BatchStream {
211 subscriber: Subscriber,
212 timeout: std::time::Duration,
213 timeout_at: Option<Pin<Box<tokio::time::Sleep>>>,
214 terminated: bool,
215}
216
217impl Stream for BatchStream {
218 type Item = Result<StreamMessage, BatchFetchError>;
219
220 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
221 if self.terminated {
222 return Poll::Ready(None);
223 }
224
225 let timeout = self.timeout;
226 match self
227 .timeout_at
228 .get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout)))
229 .poll_unpin(cx)
230 {
231 Poll::Ready(_) => {
232 debug!("Batch fetch operation timed out after {:?}", timeout);
233 self.terminated = true;
234 return Poll::Ready(Some(Err(BatchFetchError::new(
235 BatchFetchErrorKind::TimedOut,
236 ))));
237 }
238 Poll::Pending => {}
239 }
240
241 match self.subscriber.next().poll_unpin(cx) {
242 Poll::Ready(Some(msg)) => {
243 if msg.payload.is_empty()
248 && let Some(headers) = &msg.headers
249 {
250 let status = headers.get("Status").map(|v| v.as_str());
251 let desc = headers.get("Description").map(|v| v.as_str());
252
253 if status == Some("204") && desc == Some("EOB") {
255 self.terminated = true;
256 return Poll::Ready(None);
257 }
258
259 if headers.is_empty() {
261 self.terminated = true;
262 return Poll::Ready(None);
263 }
264
265 if headers.get(async_nats::header::NATS_SEQUENCE).is_some()
268 || headers.get("Nats-Num-Pending").is_some()
269 || headers.get("Nats-UpTo-Sequence").is_some()
270 {
271 self.terminated = true;
272 return Poll::Ready(None);
273 }
274 }
275
276 match convert_to_stream_message(msg) {
277 Ok(raw_msg) => Poll::Ready(Some(Ok(raw_msg))),
278 Err(e) => Poll::Ready(Some(Err(e))),
279 }
280 }
281 Poll::Ready(None) => {
282 self.terminated = true;
283 Poll::Ready(None)
284 }
285 Poll::Pending => Poll::Pending,
286 }
287 }
288}
289
290impl<T> BatchFetchExt for T
291where
292 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
293{
294 fn get_batch(&self, stream: &str, batch: usize) -> GetBatchBuilder<Self, NoSeq, NoTime> {
295 GetBatchBuilder {
296 context: self.clone(),
297 stream: stream.to_string(),
298 batch,
299 seq: None,
300 subject: None,
301 max_bytes: None,
302 start_time: None,
303 _phantom: PhantomData,
304 }
305 }
306
307 fn get_last_messages_for(&self, stream: &str) -> GetLastBuilder<Self, NoSeq, NoTime> {
308 GetLastBuilder {
309 context: self.clone(),
310 stream: stream.to_string(),
311 subjects: None,
312 up_to_seq: None,
313 up_to_time: None,
314 batch: None,
315 _phantom: PhantomData,
316 }
317 }
318}
319
320impl<T, SEQ, TIME> GetBatchBuilder<T, SEQ, TIME>
322where
323 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
324{
325 pub fn subject<S: Into<String>>(mut self, subject: S) -> Self {
327 self.subject = Some(subject.into());
328 self
329 }
330
331 pub fn max_bytes(mut self, max_bytes: usize) -> Self {
333 self.max_bytes = Some(max_bytes);
334 self
335 }
336}
337
338impl<T> GetBatchBuilder<T, NoSeq, NoTime>
340where
341 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
342{
343 pub fn sequence(mut self, seq: u64) -> GetBatchBuilder<T, WithSeq, NoTime> {
346 self.seq = Some(seq);
347 GetBatchBuilder {
348 context: self.context,
349 stream: self.stream,
350 seq: self.seq,
351 batch: self.batch,
352 subject: self.subject,
353 max_bytes: self.max_bytes,
354 start_time: self.start_time,
355 _phantom: PhantomData,
356 }
357 }
358
359 pub fn start_time(mut self, start_time: OffsetDateTime) -> GetBatchBuilder<T, NoSeq, WithTime> {
362 self.start_time = Some(start_time);
363 GetBatchBuilder {
364 context: self.context,
365 stream: self.stream,
366 batch: self.batch,
367 seq: self.seq,
368 subject: self.subject,
369 max_bytes: self.max_bytes,
370 start_time: self.start_time,
371 _phantom: PhantomData,
372 }
373 }
374}
375
376impl<T> GetBatchBuilder<T, WithSeq, NoTime>
378where
379 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
380{
381 }
383
384impl<T> GetBatchBuilder<T, NoSeq, WithTime>
386where
387 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
388{
389 }
391
392impl<T, SEQ, TIME> GetBatchBuilder<T, SEQ, TIME>
394where
395 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
396{
397 pub async fn send(self) -> Result<BatchStream, BatchFetchError> {
399 if self.stream.is_empty() {
401 return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidStreamName));
402 }
403
404 if self.batch > 1000 {
406 return Err(BatchFetchError::new(BatchFetchErrorKind::BatchSizeTooLarge));
407 }
408
409 if let Some(seq) = self.seq
411 && seq == 0
412 {
413 return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidOption));
414 }
415
416 if let Some(max_bytes) = self.max_bytes
418 && max_bytes == 0
419 {
420 return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidOption));
421 }
422
423 let request = GetBatchRequest {
425 seq: if self.seq.is_some() {
426 self.seq
427 } else if self.start_time.is_none() {
428 Some(1) } else {
430 None
431 },
432 next_by_subj: self.subject,
433 batch: self.batch,
434 max_bytes: self.max_bytes,
435 start_time: self.start_time,
436 };
437
438 let payload = serde_json::to_vec(&request)
439 .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Serialization, e))?
440 .into();
441 let subject = format!("DIRECT.GET.{}", self.stream);
443
444 send_batch_request(&self.context, subject, payload).await
445 }
446}
447
448impl<T, SEQ, TIME> GetLastBuilder<T, SEQ, TIME>
450where
451 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
452{
453 pub fn subjects(mut self, subjects: Vec<String>) -> Self {
459 self.subjects = Some(subjects);
460 self
461 }
462
463 pub fn batch(mut self, batch: usize) -> Self {
469 self.batch = Some(batch);
470 self
471 }
472}
473
474impl<T> GetLastBuilder<T, NoSeq, NoTime>
476where
477 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
478{
479 pub fn up_to_seq(mut self, seq: u64) -> GetLastBuilder<T, WithSeq, NoTime> {
482 self.up_to_seq = Some(seq);
483 GetLastBuilder {
484 context: self.context,
485 stream: self.stream,
486 subjects: self.subjects,
487 up_to_seq: self.up_to_seq,
488 up_to_time: self.up_to_time,
489 batch: self.batch,
490 _phantom: PhantomData,
491 }
492 }
493
494 pub fn up_to_time(mut self, time: OffsetDateTime) -> GetLastBuilder<T, NoSeq, WithTime> {
497 self.up_to_time = Some(time);
498 GetLastBuilder {
499 context: self.context,
500 stream: self.stream,
501 subjects: self.subjects,
502 up_to_seq: self.up_to_seq,
503 up_to_time: self.up_to_time,
504 batch: self.batch,
505 _phantom: PhantomData,
506 }
507 }
508}
509
510impl<T> GetLastBuilder<T, WithSeq, NoTime>
512where
513 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
514{
515 }
517
518impl<T> GetLastBuilder<T, NoSeq, WithTime>
520where
521 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
522{
523 }
525
526impl<T, SEQ, TIME> GetLastBuilder<T, SEQ, TIME>
528where
529 T: ClientProvider + TimeoutProvider + RequestSender + Clone + Send + Sync,
530{
531 pub async fn send(self) -> Result<BatchStream, BatchFetchError> {
533 if self.stream.is_empty() {
534 return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidStreamName));
535 }
536
537 let subjects = self
538 .subjects
539 .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::SubjectsRequired))?;
540
541 if subjects.len() > 1024 {
543 return Err(BatchFetchError::new(BatchFetchErrorKind::TooManySubjects));
544 }
545
546 if let Some(batch) = self.batch {
548 if batch == 0 {
549 return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidOption));
550 }
551 if batch > 1000 {
552 return Err(BatchFetchError::new(BatchFetchErrorKind::BatchSizeTooLarge));
553 }
554 }
555
556 if subjects.is_empty() {
557 return Err(BatchFetchError::new(BatchFetchErrorKind::SubjectsRequired));
558 }
559
560 let request = GetLastRequest {
561 multi_last: subjects,
562 batch: self.batch,
563 up_to_seq: self.up_to_seq,
564 up_to_time: self.up_to_time,
565 };
566
567 let payload = serde_json::to_vec(&request)
568 .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Serialization, e))?
569 .into();
570 let subject = format!("DIRECT.GET.{}", self.stream);
571
572 send_batch_request(&self.context, subject, payload).await
573 }
574}
575
576async fn send_batch_request<T>(
577 context: &T,
578 subject: String,
579 payload: Bytes,
580) -> Result<BatchStream, BatchFetchError>
581where
582 T: ClientProvider + TimeoutProvider + RequestSender,
583{
584 let client = context.client();
586 let inbox = client.new_inbox();
587 let subscriber = client
588 .subscribe(inbox.clone())
589 .await
590 .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Subscription, e))?;
591
592 let request = async_nats::Request {
593 inbox: Some(inbox),
594 payload: Some(payload),
595 headers: None,
596 timeout: None,
597 };
598 context
599 .send_request(subject, request)
600 .await
601 .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::Publish, e))?;
602
603 Ok(BatchStream {
604 subscriber,
605 terminated: false,
606 timeout: context.timeout(),
607 timeout_at: None,
608 })
609}
610
611fn convert_to_stream_message(msg: Message) -> Result<StreamMessage, BatchFetchError> {
612 if msg.payload.is_empty()
613 && let Some(headers) = &msg.headers
614 {
615 let status = headers.get("Status").map(|v| v.as_str());
616 match status {
617 Some("404") => return Err(BatchFetchError::new(BatchFetchErrorKind::NoMessages)),
618 Some("408") => return Err(BatchFetchError::new(BatchFetchErrorKind::InvalidRequest)),
619 Some("413") => return Err(BatchFetchError::new(BatchFetchErrorKind::TooManySubjects)),
620 _ => {}
621 }
622 }
623
624 let headers = msg
625 .headers
626 .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::InvalidResponse))?;
627
628 if headers.get("Nats-Num-Pending").is_none() {
631 return Err(BatchFetchError::new(
632 BatchFetchErrorKind::UnsupportedByServer,
633 ));
634 }
635
636 let subject = headers
637 .get(async_nats::header::NATS_SUBJECT)
638 .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::MissingHeader))?
639 .to_string();
640
641 let sequence = headers
642 .get(async_nats::header::NATS_SEQUENCE)
643 .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::MissingHeader))?
644 .as_str()
645 .parse::<u64>()
646 .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::InvalidHeader, e))?;
647
648 let time_str = headers
649 .get(async_nats::header::NATS_TIME_STAMP)
650 .ok_or_else(|| BatchFetchError::new(BatchFetchErrorKind::MissingHeader))?
651 .as_str();
652
653 let time =
655 time::OffsetDateTime::parse(time_str, &time::format_description::well_known::Rfc3339)
656 .map_err(|e| BatchFetchError::with_source(BatchFetchErrorKind::InvalidHeader, e))?;
657
658 Ok(StreamMessage {
659 subject: Subject::from(subject),
660 sequence,
661 payload: msg.payload,
662 headers,
663 time,
664 })
665}