1use crate::client::{QueueProvider, SessionProvider};
61use crate::error::QueueError;
62use crate::message::{
63 Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
64};
65use crate::provider::{ProviderType, SessionSupport};
66use async_nats::jetstream::{
67 self, consumer::pull::Config as ConsumerConfig, stream::Config as StreamConfig, AckKind,
68 Context as JetStreamContext,
69};
70use async_trait::async_trait;
71use bytes::Bytes;
72use chrono::Duration;
73use futures::StreamExt;
74use serde::{Deserialize, Serialize};
75use std::collections::HashMap;
76use std::sync::Arc;
77use tokio::sync::Mutex;
78use tracing::{debug, instrument, warn};
79
80#[cfg(test)]
81#[path = "nats_tests.rs"]
82mod tests;
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct NatsConfig {
109 pub url: String,
111 pub stream_prefix: String,
113 pub max_deliver: Option<i64>,
115 pub ack_wait: Duration,
117 pub session_lock_duration: Duration,
119 pub enable_dead_letter: bool,
121 pub dead_letter_subject_prefix: Option<String>,
123 pub credentials_path: Option<String>,
125}
126
127impl Default for NatsConfig {
128 fn default() -> Self {
129 Self {
130 url: "nats://localhost:4222".to_string(),
131 stream_prefix: "queue-runtime".to_string(),
132 max_deliver: Some(3),
133 ack_wait: Duration::seconds(30),
134 session_lock_duration: Duration::minutes(5),
135 enable_dead_letter: true,
136 dead_letter_subject_prefix: Some("dlq".to_string()),
137 credentials_path: None,
138 }
139 }
140}
141
142#[derive(Debug)]
148pub struct NatsError {
149 message: String,
150}
151
152impl NatsError {
153 fn new(message: impl Into<String>) -> Self {
154 Self {
155 message: message.into(),
156 }
157 }
158
159 pub fn to_queue_error(&self) -> QueueError {
161 QueueError::ProviderError {
162 provider: "nats".to_string(),
163 code: "NATS_ERROR".to_string(),
164 message: self.message.clone(),
165 }
166 }
167}
168
169impl std::fmt::Display for NatsError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 write!(f, "NATS error: {}", self.message)
172 }
173}
174
175impl std::error::Error for NatsError {}
176
177struct InFlightEntry {
183 js_message: async_nats::jetstream::Message,
185 lock_expires_at: Timestamp,
187 dead_letter_subject: Option<String>,
189}
190
191fn redact_url(url: &str) -> String {
201 match url::Url::parse(url) {
202 Ok(mut parsed) => {
203 let has_credentials = !parsed.username().is_empty() || parsed.password().is_some();
204 if has_credentials {
205 let _ = parsed.set_username("***");
206 let _ = parsed.set_password(Some("***"));
207 }
208 parsed.to_string()
209 }
210 Err(_) => "<invalid-url>".to_string(),
211 }
212}
213
214fn nats_safe(s: &str) -> String {
218 s.replace(['-', ' '], "_")
219}
220
221fn queue_subject(config: &NatsConfig, queue: &QueueName) -> String {
223 format!(
224 "{}.{}",
225 nats_safe(&config.stream_prefix),
226 nats_safe(queue.as_str())
227 )
228}
229
230fn nats_safe_session_id(id: &str) -> String {
237 id.chars()
238 .map(|c| {
239 if c.is_ascii_alphanumeric() || c == '_' {
240 c
241 } else {
242 '_'
243 }
244 })
245 .collect()
246}
247
248fn session_subject(config: &NatsConfig, queue: &QueueName, session_id: &SessionId) -> String {
250 let safe_session = nats_safe_session_id(session_id.as_str());
251 format!(
252 "{}.{}.session.{}",
253 nats_safe(&config.stream_prefix),
254 nats_safe(queue.as_str()),
255 safe_session
256 )
257}
258
259fn stream_name(config: &NatsConfig, queue: &QueueName) -> String {
261 format!(
263 "{}-{}",
264 nats_safe(&config.stream_prefix),
265 nats_safe(queue.as_str())
266 )
267}
268
269fn consumer_name(config: &NatsConfig, queue: &QueueName) -> String {
277 format!(
278 "{}-{}-consumer",
279 nats_safe(&config.stream_prefix),
280 nats_safe(queue.as_str())
281 )
282}
283
284fn session_consumer_name(config: &NatsConfig, queue: &QueueName, session_id: &SessionId) -> String {
293 let safe_sid = nats_safe_session_id(session_id.as_str());
294 format!(
295 "{}-{}-session-{}-consumer",
296 nats_safe(&config.stream_prefix),
297 nats_safe(queue.as_str()),
298 safe_sid
299 )
300}
301
302fn dead_letter_subject(config: &NatsConfig, queue: &QueueName) -> Option<String> {
304 if !config.enable_dead_letter {
305 return None;
306 }
307 config
308 .dead_letter_subject_prefix
309 .as_ref()
310 .map(|prefix| format!("{}.{}", nats_safe(prefix), nats_safe(queue.as_str())))
311}
312
313pub struct NatsProvider {
322 client: async_nats::Client,
323 jetstream: JetStreamContext,
324 config: NatsConfig,
325 in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
327}
328
329impl NatsProvider {
330 pub async fn new(config: NatsConfig) -> Result<Self, NatsError> {
349 let connect_options = if let Some(ref creds_path) = config.credentials_path {
350 async_nats::ConnectOptions::with_credentials_file(creds_path.as_str())
351 .await
352 .map_err(|e| NatsError::new(format!("failed to load NATS credentials: {}", e)))?
353 } else {
354 async_nats::ConnectOptions::new()
355 };
356
357 let client = connect_options.connect(&config.url).await.map_err(|e| {
358 NatsError::new(format!(
359 "failed to connect to NATS at '{}': {}",
360 redact_url(&config.url),
361 e
362 ))
363 })?;
364
365 let jetstream = jetstream::new(client.clone());
366
367 debug!(url = %redact_url(&config.url), "Connected to NATS");
368
369 Ok(Self {
370 client,
371 jetstream,
372 config,
373 in_flight: Arc::new(Mutex::new(HashMap::new())),
374 })
375 }
376
377 async fn ensure_stream(&self, queue: &QueueName) -> Result<(), QueueError> {
382 let name = stream_name(&self.config, queue);
383 let subject = queue_subject(&self.config, queue);
384
385 let subjects = vec![subject.clone(), format!("{}.session.>", subject)];
387
388 let stream_config = StreamConfig {
389 name: name.clone(),
390 subjects,
391 retention: async_nats::jetstream::stream::RetentionPolicy::WorkQueue,
392 storage: async_nats::jetstream::stream::StorageType::File,
393 ..Default::default()
394 };
395
396 self.jetstream
397 .get_or_create_stream(stream_config)
398 .await
399 .map_err(|e| QueueError::ProviderError {
400 provider: "nats".to_string(),
401 code: "STREAM_CREATE_FAILED".to_string(),
402 message: format!("failed to ensure JetStream stream '{}': {}", name, e),
403 })?;
404
405 self.ensure_dlq_stream(queue).await?;
407
408 Ok(())
409 }
410
411 async fn ensure_dlq_stream(&self, queue: &QueueName) -> Result<(), QueueError> {
413 let dlq_subject = match dead_letter_subject(&self.config, queue) {
414 Some(s) => s,
415 None => return Ok(()),
416 };
417
418 let dlq_stream_name = format!(
419 "dlq-{}-{}",
420 nats_safe(&self.config.stream_prefix),
421 nats_safe(queue.as_str())
422 );
423
424 let stream_config = StreamConfig {
425 name: dlq_stream_name.clone(),
426 subjects: vec![dlq_subject],
427 storage: async_nats::jetstream::stream::StorageType::File,
428 ..Default::default()
429 };
430
431 self.jetstream
432 .get_or_create_stream(stream_config)
433 .await
434 .map_err(|e| QueueError::ProviderError {
435 provider: "nats".to_string(),
436 code: "DLQ_STREAM_CREATE_FAILED".to_string(),
437 message: format!("failed to ensure DLQ stream '{}': {}", dlq_stream_name, e),
438 })?;
439
440 Ok(())
441 }
442
443 async fn create_consumer(
454 &self,
455 queue: &QueueName,
456 name: &str,
457 filter_subject: &str,
458 ) -> Result<async_nats::jetstream::consumer::Consumer<ConsumerConfig>, QueueError> {
459 let stream_name = stream_name(&self.config, queue);
460 let ack_wait_std = self
461 .config
462 .ack_wait
463 .to_std()
464 .unwrap_or(std::time::Duration::from_secs(30));
465
466 let consumer_config = ConsumerConfig {
467 name: Some(name.to_string()),
468 durable_name: Some(name.to_string()),
469 filter_subject: filter_subject.to_string(),
470 ack_policy: async_nats::jetstream::consumer::AckPolicy::Explicit,
471 ack_wait: ack_wait_std,
472 max_deliver: self.config.max_deliver.unwrap_or(-1),
473 inactive_threshold: self
480 .config
481 .session_lock_duration
482 .to_std()
483 .unwrap_or(std::time::Duration::from_secs(300))
484 .saturating_mul(2),
485 ..Default::default()
486 };
487
488 let stream = self.jetstream.get_stream(&stream_name).await.map_err(|e| {
489 QueueError::ProviderError {
490 provider: "nats".to_string(),
491 code: "STREAM_GET_FAILED".to_string(),
492 message: format!("failed to get stream '{}': {}", stream_name, e),
493 }
494 })?;
495
496 let consumer = stream
497 .get_or_create_consumer(name, consumer_config)
498 .await
499 .map_err(|e| QueueError::ProviderError {
500 provider: "nats".to_string(),
501 code: "CONSUMER_CREATE_FAILED".to_string(),
502 message: format!(
503 "failed to get or create pull consumer on '{}': {}",
504 stream_name, e
505 ),
506 })?;
507
508 Ok(consumer)
509 }
510
511 fn build_headers(message: &Message) -> async_nats::header::HeaderMap {
513 let mut headers = async_nats::header::HeaderMap::new();
514
515 if let Some(ref sid) = message.session_id {
516 headers.insert("x-session-id", sid.as_str());
517 }
518 if let Some(ref corr_id) = message.correlation_id {
519 headers.insert("x-correlation-id", corr_id.as_str());
520 }
521 for (k, v) in &message.attributes {
522 headers.insert(format!("x-attr-{}", k).as_str(), v.as_str());
524 }
525
526 headers
527 }
528
529 fn extract_attributes(
531 headers: &Option<async_nats::header::HeaderMap>,
532 ) -> HashMap<String, String> {
533 let mut attrs = HashMap::new();
534 if let Some(hm) = headers {
535 for (name, values) in hm.iter() {
536 let key: &str = name.as_ref();
538 if let Some(attr_key) = key.strip_prefix("x-attr-") {
539 if let Some(val) = values.first() {
540 attrs.insert(attr_key.to_string(), val.as_str().to_string());
541 }
542 }
543 }
544 }
545 attrs
546 }
547
548 fn extract_session_id(headers: &Option<async_nats::header::HeaderMap>) -> Option<SessionId> {
550 if let Some(hm) = headers {
551 if let Some(val) = hm.get("x-session-id") {
552 let id = val.as_str().to_string();
553 return SessionId::new(id).ok();
554 }
555 }
556 None
557 }
558
559 fn extract_correlation_id(headers: &Option<async_nats::header::HeaderMap>) -> Option<String> {
561 if let Some(hm) = headers {
562 if let Some(val) = hm.get("x-correlation-id") {
563 return Some(val.as_str().to_string());
564 }
565 }
566 None
567 }
568
569 async fn register_js_message(
571 &self,
572 js_message: async_nats::jetstream::Message,
573 queue: &QueueName,
574 ) -> ReceivedMessage {
575 let headers = js_message.message.headers.clone();
576 let session_id = Self::extract_session_id(&headers);
577 let attributes = Self::extract_attributes(&headers);
578 let correlation_id = Self::extract_correlation_id(&headers);
579 let delivery_count = js_message.info().map(|i| i.delivered as u32).unwrap_or(1);
580 let body = Bytes::copy_from_slice(&js_message.message.payload);
581
582 let now = Timestamp::now();
583 let lock_expires_at =
584 Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
585
586 let receipt_id = uuid::Uuid::new_v4().to_string();
587 let message_id = MessageId::new();
588
589 let dlq_subject = dead_letter_subject(&self.config, queue);
590
591 self.in_flight.lock().await.insert(
592 receipt_id.clone(),
593 InFlightEntry {
594 js_message,
595 lock_expires_at,
596 dead_letter_subject: dlq_subject,
597 },
598 );
599
600 ReceivedMessage {
601 message_id,
602 body,
603 attributes,
604 session_id,
605 correlation_id,
606 receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::Nats),
607 delivery_count,
608 first_delivered_at: now,
609 delivered_at: now,
610 }
611 }
612}
613
614#[async_trait]
619impl QueueProvider for NatsProvider {
620 #[instrument(skip(self, message), fields(queue = %queue))]
621 async fn send_message(
622 &self,
623 queue: &QueueName,
624 message: &Message,
625 ) -> Result<MessageId, QueueError> {
626 let size = message.body.len();
627 let max_size = self.provider_type().max_message_size();
628 if size > max_size {
629 return Err(QueueError::MessageTooLarge { size, max_size });
630 }
631
632 self.ensure_stream(queue).await?;
633
634 let subject = if let Some(ref sid) = message.session_id {
636 session_subject(&self.config, queue, sid)
637 } else {
638 queue_subject(&self.config, queue)
639 };
640
641 let headers = Self::build_headers(message);
642 let payload = Bytes::copy_from_slice(&message.body);
643
644 self.jetstream
645 .publish_with_headers(subject.clone(), headers, payload)
646 .await
647 .map_err(|e| QueueError::ProviderError {
648 provider: "nats".to_string(),
649 code: "PUBLISH_FAILED".to_string(),
650 message: format!("failed to publish to subject '{}': {}", subject, e),
651 })?
652 .await
653 .map_err(|e| QueueError::ProviderError {
654 provider: "nats".to_string(),
655 code: "PUBLISH_ACK_FAILED".to_string(),
656 message: format!("JetStream publish ack failed: {}", e),
657 })?;
658
659 let message_id = MessageId::new();
660 debug!(%message_id, %queue, "Published message to NATS JetStream");
661 Ok(message_id)
662 }
663
664 #[instrument(skip(self, messages), fields(queue = %queue, count = messages.len()))]
665 async fn send_messages(
666 &self,
667 queue: &QueueName,
668 messages: &[Message],
669 ) -> Result<Vec<MessageId>, QueueError> {
670 if messages.len() > self.max_batch_size() as usize {
671 return Err(QueueError::BatchTooLarge {
672 size: messages.len(),
673 max_size: self.max_batch_size() as usize,
674 });
675 }
676
677 let mut ids = Vec::with_capacity(messages.len());
684 for message in messages {
685 ids.push(self.send_message(queue, message).await?);
686 }
687 Ok(ids)
688 }
689
690 #[instrument(skip(self), fields(queue = %queue))]
691 async fn receive_message(
692 &self,
693 queue: &QueueName,
694 timeout: Duration,
695 ) -> Result<Option<ReceivedMessage>, QueueError> {
696 self.ensure_stream(queue).await?;
697
698 let subject = queue_subject(&self.config, queue);
699 let name = consumer_name(&self.config, queue);
700 let consumer = self.create_consumer(queue, &name, &subject).await?;
701
702 let timeout_std = timeout
703 .to_std()
704 .unwrap_or(std::time::Duration::from_secs(30));
705
706 let mut messages = consumer
707 .fetch()
708 .max_messages(1)
709 .expires(timeout_std)
710 .messages()
711 .await
712 .map_err(|e| QueueError::ProviderError {
713 provider: "nats".to_string(),
714 code: "FETCH_FAILED".to_string(),
715 message: format!("failed to fetch from JetStream: {}", e),
716 })?;
717
718 match tokio::time::timeout(timeout_std, messages.next()).await {
719 Ok(Some(Ok(js_msg))) => {
720 let msg = self.register_js_message(js_msg, queue).await;
721 Ok(Some(msg))
722 }
723 Ok(Some(Err(e))) => Err(QueueError::ProviderError {
724 provider: "nats".to_string(),
725 code: "MESSAGE_ERROR".to_string(),
726 message: format!("error reading JetStream message: {}", e),
727 }),
728 Ok(None) | Err(_) => Ok(None),
729 }
730 }
731
732 #[instrument(skip(self), fields(queue = %queue, max = max_messages))]
733 async fn receive_messages(
734 &self,
735 queue: &QueueName,
736 max_messages: u32,
737 timeout: Duration,
738 ) -> Result<Vec<ReceivedMessage>, QueueError> {
739 self.ensure_stream(queue).await?;
740
741 let subject = queue_subject(&self.config, queue);
742 let name = consumer_name(&self.config, queue);
743 let consumer = self.create_consumer(queue, &name, &subject).await?;
744
745 let timeout_std = timeout
746 .to_std()
747 .unwrap_or(std::time::Duration::from_secs(30));
748
749 let mut js_messages = consumer
750 .fetch()
751 .max_messages(max_messages as usize)
752 .expires(timeout_std)
753 .messages()
754 .await
755 .map_err(|e| QueueError::ProviderError {
756 provider: "nats".to_string(),
757 code: "FETCH_FAILED".to_string(),
758 message: format!("failed to fetch from JetStream: {}", e),
759 })?;
760
761 let mut result = Vec::new();
762 let deadline = tokio::time::Instant::now() + timeout_std;
763
764 loop {
765 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
766 if remaining.is_zero() || result.len() >= max_messages as usize {
767 break;
768 }
769
770 match tokio::time::timeout(remaining, js_messages.next()).await {
771 Ok(Some(Ok(js_msg))) => {
772 let msg = self.register_js_message(js_msg, queue).await;
773 result.push(msg);
774 }
775 Ok(Some(Err(e))) => {
776 return Err(QueueError::ProviderError {
777 provider: "nats".to_string(),
778 code: "MESSAGE_ERROR".to_string(),
779 message: format!("error reading JetStream message: {}", e),
780 });
781 }
782 Ok(None) | Err(_) => break,
783 }
784 }
785
786 Ok(result)
787 }
788
789 #[instrument(skip(self, receipt))]
790 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
791 let mut in_flight = self.in_flight.lock().await;
792
793 match in_flight.get(receipt.handle()) {
796 None => {
797 return Err(QueueError::MessageNotFound {
798 receipt: receipt.handle().to_string(),
799 });
800 }
801 Some(entry) if Timestamp::now() > entry.lock_expires_at => {
802 in_flight.remove(receipt.handle());
803 return Err(QueueError::MessageNotFound {
804 receipt: format!("{}(expired)", receipt.handle()),
805 });
806 }
807 Some(_) => {}
808 }
809
810 let entry = in_flight
811 .remove(receipt.handle())
812 .expect("entry present after pre-check");
813
814 entry
815 .js_message
816 .ack()
817 .await
818 .map_err(|e| QueueError::ProviderError {
819 provider: "nats".to_string(),
820 code: "ACK_FAILED".to_string(),
821 message: format!("JetStream ack failed: {}", e),
822 })?;
823
824 Ok(())
825 }
826
827 #[instrument(skip(self, receipt))]
828 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
829 let mut in_flight = self.in_flight.lock().await;
830
831 match in_flight.get(receipt.handle()) {
832 None => {
833 return Err(QueueError::MessageNotFound {
834 receipt: receipt.handle().to_string(),
835 });
836 }
837 Some(entry) if Timestamp::now() > entry.lock_expires_at => {
838 in_flight.remove(receipt.handle());
839 return Err(QueueError::MessageNotFound {
840 receipt: format!("{}(expired)", receipt.handle()),
841 });
842 }
843 Some(_) => {}
844 }
845
846 let entry = in_flight
847 .remove(receipt.handle())
848 .expect("entry present after pre-check");
849
850 entry
851 .js_message
852 .ack_with(AckKind::Nak(None))
853 .await
854 .map_err(|e| QueueError::ProviderError {
855 provider: "nats".to_string(),
856 code: "NAK_FAILED".to_string(),
857 message: format!("JetStream nak failed: {}", e),
858 })?;
859
860 Ok(())
861 }
862
863 #[instrument(skip(self, receipt), fields(reason = %reason))]
864 async fn dead_letter_message(
865 &self,
866 receipt: &ReceiptHandle,
867 reason: &str,
868 ) -> Result<(), QueueError> {
869 let mut in_flight = self.in_flight.lock().await;
870
871 match in_flight.get(receipt.handle()) {
872 None => {
873 return Err(QueueError::MessageNotFound {
874 receipt: receipt.handle().to_string(),
875 });
876 }
877 Some(entry) if Timestamp::now() > entry.lock_expires_at => {
878 in_flight.remove(receipt.handle());
879 return Err(QueueError::MessageNotFound {
880 receipt: format!("{}(expired)", receipt.handle()),
881 });
882 }
883 Some(_) => {}
884 }
885
886 let entry = in_flight
887 .remove(receipt.handle())
888 .expect("entry present after pre-check");
889
890 entry
894 .js_message
895 .ack_with(async_nats::jetstream::AckKind::Term)
896 .await
897 .map_err(|e| QueueError::ProviderError {
898 provider: "nats".to_string(),
899 code: "TERM_FAILED".to_string(),
900 message: format!("JetStream term ack failed: {}", e),
901 })?;
902
903 if let Some(ref dlq_subject) = entry.dead_letter_subject {
909 let mut headers = async_nats::header::HeaderMap::new();
910 headers.insert("x-dead-letter-reason", reason);
911 let payload = entry.js_message.message.payload.clone();
912 if let Some(msg_headers) = &entry.js_message.message.headers {
913 for (name, values) in msg_headers.iter() {
914 let key: &str = name.as_ref();
916 for val in values.iter() {
917 headers.insert(key, val.as_str());
918 }
919 }
920 }
921
922 if let Err(e) = self
923 .client
924 .publish_with_headers(dlq_subject.clone(), headers, payload)
925 .await
926 {
927 warn!(
930 reason,
931 dlq_subject,
932 error = %e,
933 "Failed to publish dead-lettered message to DLQ (message already terminated)"
934 );
935 } else {
936 debug!(
937 reason,
938 dlq_subject, "Message dead-lettered and published to DLQ"
939 );
940 }
941 } else {
942 debug!(reason, "Message terminated (no DLQ configured)");
943 }
944
945 Ok(())
946 }
947
948 #[instrument(skip(self), fields(queue = %queue))]
949 async fn create_session_client(
950 &self,
951 queue: &QueueName,
952 session_id: Option<SessionId>,
953 ) -> Result<Box<dyn SessionProvider>, QueueError> {
954 let sid = match session_id {
955 Some(id) => id,
956 None => {
957 return Err(QueueError::SessionNotFound {
959 session_id: "<any>".to_string(),
960 });
961 }
962 };
963
964 self.ensure_stream(queue).await?;
965
966 let subject = session_subject(&self.config, queue, &sid);
967 let name = session_consumer_name(&self.config, queue, &sid);
968 let consumer = self.create_consumer(queue, &name, &subject).await?;
969
970 let now = Timestamp::now();
971 let lock_expires_at =
972 Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
973
974 Ok(Box::new(NatsSessionProvider {
975 consumer: Arc::new(Mutex::new(consumer)),
976 client: self.client.clone(),
977 session_id: sid,
978 queue_name: queue.clone(),
979 in_flight: self.in_flight.clone(),
980 lock_expires_at: Arc::new(std::sync::Mutex::new(lock_expires_at)),
981 config: self.config.clone(),
982 }))
983 }
984
985 fn provider_type(&self) -> ProviderType {
986 ProviderType::Nats
987 }
988
989 fn supports_sessions(&self) -> SessionSupport {
990 SessionSupport::Emulated
991 }
992
993 fn supports_batching(&self) -> bool {
994 true
995 }
996
997 fn max_batch_size(&self) -> u32 {
998 100
999 }
1000}
1001
1002pub struct NatsSessionProvider {
1012 consumer: Arc<Mutex<async_nats::jetstream::consumer::Consumer<ConsumerConfig>>>,
1013 client: async_nats::Client,
1014 session_id: SessionId,
1015 queue_name: QueueName,
1016 in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
1017 lock_expires_at: Arc<std::sync::Mutex<Timestamp>>,
1019 config: NatsConfig,
1020}
1021
1022#[async_trait]
1023impl SessionProvider for NatsSessionProvider {
1024 #[instrument(skip(self), fields(session_id = %self.session_id))]
1025 async fn receive_message(
1026 &self,
1027 timeout: Duration,
1028 ) -> Result<Option<ReceivedMessage>, QueueError> {
1029 self.check_lock()?;
1030
1031 let timeout_std = timeout
1032 .to_std()
1033 .unwrap_or(std::time::Duration::from_secs(30));
1034
1035 let consumer = self.consumer.lock().await;
1036
1037 let mut messages = consumer
1038 .fetch()
1039 .max_messages(1)
1040 .expires(timeout_std)
1041 .messages()
1042 .await
1043 .map_err(|e| QueueError::ProviderError {
1044 provider: "nats".to_string(),
1045 code: "FETCH_FAILED".to_string(),
1046 message: format!("session fetch failed: {}", e),
1047 })?;
1048
1049 match tokio::time::timeout(timeout_std, messages.next()).await {
1050 Ok(Some(Ok(js_msg))) => {
1051 let msg = self.register_session_message(js_msg).await;
1052 Ok(Some(msg))
1053 }
1054 Ok(Some(Err(e))) => Err(QueueError::ProviderError {
1055 provider: "nats".to_string(),
1056 code: "MESSAGE_ERROR".to_string(),
1057 message: format!("session message error: {}", e),
1058 }),
1059 Ok(None) | Err(_) => Ok(None),
1060 }
1061 }
1062
1063 #[instrument(skip(self, receipt))]
1064 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1065 self.check_lock()?;
1066 self.ack_message(receipt, SettlementKind::Ack).await
1067 }
1068
1069 #[instrument(skip(self, receipt))]
1070 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1071 self.check_lock()?;
1072 self.ack_message(receipt, SettlementKind::Nak).await
1073 }
1074
1075 #[instrument(skip(self, receipt), fields(reason = %reason))]
1076 async fn dead_letter_message(
1077 &self,
1078 receipt: &ReceiptHandle,
1079 reason: &str,
1080 ) -> Result<(), QueueError> {
1081 self.check_lock()?;
1082
1083 let mut in_flight = self.in_flight.lock().await;
1084
1085 match in_flight.get(receipt.handle()) {
1087 None => {
1088 return Err(QueueError::MessageNotFound {
1089 receipt: receipt.handle().to_string(),
1090 });
1091 }
1092 Some(entry) if Timestamp::now() > entry.lock_expires_at => {
1093 in_flight.remove(receipt.handle());
1094 return Err(QueueError::MessageNotFound {
1095 receipt: format!("{}(expired)", receipt.handle()),
1096 });
1097 }
1098 Some(_) => {}
1099 }
1100
1101 let entry = in_flight
1102 .remove(receipt.handle())
1103 .expect("entry present after pre-check");
1104
1105 entry
1106 .js_message
1107 .ack_with(async_nats::jetstream::AckKind::Term)
1108 .await
1109 .map_err(|e| QueueError::ProviderError {
1110 provider: "nats".to_string(),
1111 code: "TERM_FAILED".to_string(),
1112 message: format!("session term ack failed: {}", e),
1113 })?;
1114
1115 if let Some(ref dlq_subject) = entry.dead_letter_subject {
1118 let mut headers = async_nats::header::HeaderMap::new();
1119 headers.insert("x-dead-letter-reason", reason);
1120 let payload = entry.js_message.message.payload.clone();
1121
1122 if let Err(e) = self
1123 .client
1124 .publish_with_headers(dlq_subject.clone(), headers, payload)
1125 .await
1126 {
1127 warn!(
1128 reason,
1129 dlq_subject,
1130 error = %e,
1131 "Session: failed to publish dead-lettered message to DLQ (message already terminated)"
1132 );
1133 } else {
1134 debug!(reason, dlq_subject, "Session message dead-lettered");
1135 }
1136 }
1137
1138 Ok(())
1139 }
1140
1141 async fn renew_session_lock(&self) -> Result<(), QueueError> {
1142 advance_session_lock(&self.lock_expires_at, self.config.session_lock_duration)?;
1143 debug!(session_id = %self.session_id, "NATS session lock renewed");
1144 Ok(())
1145 }
1146
1147 async fn close_session(&self) -> Result<(), QueueError> {
1148 Ok(())
1150 }
1151
1152 fn session_id(&self) -> &SessionId {
1153 &self.session_id
1154 }
1155
1156 fn session_expires_at(&self) -> Timestamp {
1157 *self
1159 .lock_expires_at
1160 .lock()
1161 .unwrap_or_else(|e| e.into_inner())
1162 }
1163}
1164
1165fn check_session_lock(
1174 lock_expires_at: &std::sync::Mutex<Timestamp>,
1175 session_id: &SessionId,
1176) -> Result<(), QueueError> {
1177 let expires = *lock_expires_at
1178 .lock()
1179 .map_err(|_| QueueError::ProviderError {
1180 provider: "nats".to_string(),
1181 code: "INTERNAL_ERROR".to_string(),
1182 message: "session lock mutex poisoned".to_string(),
1183 })?;
1184 if Timestamp::now() > expires {
1185 return Err(QueueError::SessionLocked {
1186 session_id: session_id.as_str().to_string(),
1187 locked_until: expires,
1188 });
1189 }
1190 Ok(())
1191}
1192
1193fn advance_session_lock(
1197 lock_expires_at: &std::sync::Mutex<Timestamp>,
1198 duration: Duration,
1199) -> Result<Timestamp, QueueError> {
1200 let new_expiry = Timestamp::from_datetime(Timestamp::now().as_datetime() + duration);
1201 *lock_expires_at
1202 .lock()
1203 .map_err(|_| QueueError::ProviderError {
1204 provider: "nats".to_string(),
1205 code: "INTERNAL_ERROR".to_string(),
1206 message: "session lock mutex poisoned".to_string(),
1207 })? = new_expiry;
1208 Ok(new_expiry)
1209}
1210
1211enum SettlementKind {
1213 Ack,
1214 Nak,
1215}
1216
1217impl NatsSessionProvider {
1218 fn check_lock(&self) -> Result<(), QueueError> {
1220 check_session_lock(&self.lock_expires_at, &self.session_id)
1221 }
1222
1223 async fn ack_message(
1225 &self,
1226 receipt: &ReceiptHandle,
1227 kind: SettlementKind,
1228 ) -> Result<(), QueueError> {
1229 let mut in_flight = self.in_flight.lock().await;
1230
1231 match in_flight.get(receipt.handle()) {
1233 None => {
1234 return Err(QueueError::MessageNotFound {
1235 receipt: receipt.handle().to_string(),
1236 });
1237 }
1238 Some(entry) if Timestamp::now() > entry.lock_expires_at => {
1239 in_flight.remove(receipt.handle());
1240 return Err(QueueError::MessageNotFound {
1241 receipt: format!("{}(expired)", receipt.handle()),
1242 });
1243 }
1244 Some(_) => {}
1245 }
1246
1247 let entry = in_flight
1248 .remove(receipt.handle())
1249 .expect("entry present after pre-check");
1250
1251 match kind {
1252 SettlementKind::Ack => {
1253 entry
1254 .js_message
1255 .ack()
1256 .await
1257 .map_err(|e| QueueError::ProviderError {
1258 provider: "nats".to_string(),
1259 code: "ACK_FAILED".to_string(),
1260 message: format!("session ack failed: {}", e),
1261 })
1262 }
1263 SettlementKind::Nak => {
1264 entry
1265 .js_message
1266 .ack_with(AckKind::Nak(None))
1267 .await
1268 .map_err(|e| QueueError::ProviderError {
1269 provider: "nats".to_string(),
1270 code: "NAK_FAILED".to_string(),
1271 message: format!("session nak failed: {}", e),
1272 })
1273 }
1274 }
1275 }
1276
1277 async fn register_session_message(
1279 &self,
1280 js_message: async_nats::jetstream::Message,
1281 ) -> ReceivedMessage {
1282 let headers = js_message.message.headers.clone();
1283 let attributes = NatsProvider::extract_attributes(&headers);
1284 let correlation_id = NatsProvider::extract_correlation_id(&headers);
1285 let delivery_count = js_message.info().map(|i| i.delivered as u32).unwrap_or(1);
1286 let body = Bytes::copy_from_slice(&js_message.message.payload);
1287
1288 let now = Timestamp::now();
1289 let lock_expires_at =
1290 Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
1291
1292 let receipt_id = uuid::Uuid::new_v4().to_string();
1293 let message_id = MessageId::new();
1294
1295 let dlq_subject = dead_letter_subject(&self.config, &self.queue_name);
1296
1297 self.in_flight.lock().await.insert(
1298 receipt_id.clone(),
1299 InFlightEntry {
1300 js_message,
1301 lock_expires_at,
1302 dead_letter_subject: dlq_subject,
1303 },
1304 );
1305
1306 ReceivedMessage {
1307 message_id,
1308 body,
1309 attributes,
1310 session_id: Some(self.session_id.clone()),
1311 correlation_id,
1312 receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::Nats),
1313 delivery_count,
1314 first_delivered_at: now,
1315 delivered_at: now,
1316 }
1317 }
1318}