object_transfer/
sub.rs

1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use futures::TryStreamExt;
6use futures::stream::BoxStream;
7use serde::de::DeserializeOwned;
8
9use crate::r#enum::Format;
10use crate::error::{SubError, UnSubError};
11use crate::traits::{
12  AckTrait, SubCtxTrait, SubOptTrait, SubTrait, UnSubTrait,
13};
14
15/// Subscriber wrapper that deserializes messages and optionally acknowledges
16/// them.
17///
18/// The subscriber relies on a [`SubCtxTrait`] implementation for message
19/// retrieval and a [`SubOptTrait`] provider for decoding and acknowledgment
20/// behavior.
21///
22/// # Example
23///
24/// ```rust,no_run
25/// use std::sync::Arc;
26/// use futures::StreamExt;
27/// use serde::Deserialize;
28/// use object_transfer::{Format, Sub};
29/// use object_transfer::nats::{AckSubOptions, SubFetcher};
30/// use object_transfer::traits::{SubTrait, UnSubTrait};
31///
32/// #[derive(Deserialize, Debug)]
33/// struct Event {
34///   id: u64,
35///   name: String,
36/// }
37///
38/// #[tokio::main]
39/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
40///   // Build a JetStream context and configure a durable pull consumer.
41///   let client = async_nats::connect("demo.nats.io").await?;
42///   let js = Arc::new(async_nats::jetstream::new(client));
43///
44///   let options = Arc::new(
45///     AckSubOptions::new(Format::JSON, Arc::from("events"))
46///       .subjects(vec!["events.user_created"])
47///       .durable_name("user-created")
48///       .auto_ack(false),
49///   );
50///
51///   // SubFetcher implements both SubCtxTrait and UnSubTrait.
52///   let fetcher = Arc::new(SubFetcher::new(js, options.clone()).await?);
53///   let unsub = Some(fetcher.clone() as Arc<dyn UnSubTrait + Send + Sync>);
54///
55///   let subscriber: Sub<Event> = Sub::new(fetcher, unsub, options);
56///   let mut stream = subscriber.subscribe().await?;
57///
58///   while let Some(Ok((event, ack))) = stream.next().await {
59///     println!("received {:?}", event);
60///     // Manually ack since auto_ack(false).
61///     ack.ack().await?;
62///   }
63///
64///   Ok(())
65/// }
66/// ```
67pub struct Sub<T> {
68  ctx: Arc<dyn SubCtxTrait + Send + Sync>,
69  unsub: Option<Arc<dyn UnSubTrait + Send + Sync>>,
70  options: Arc<dyn SubOptTrait + Send + Sync>,
71  _marker: PhantomData<T>,
72}
73
74impl<T> Sub<T>
75where
76  T: DeserializeOwned + Send + Sync,
77{
78  /// Creates a new subscriber using the provided context, optional
79  /// unsubscribe handler, and subscription options.
80  ///
81  /// # Parameters
82  /// - `ctx`: Message retrieval context responsible for producing raw items.
83  /// - `unsub`: Optional handler to cancel the subscription when requested.
84  /// - `options`: Subscription behavior such as auto-ack and payload format.
85  pub fn new(
86    ctx: Arc<dyn SubCtxTrait + Send + Sync>,
87    unsub: Option<Arc<dyn UnSubTrait + Send + Sync>>,
88    options: Arc<dyn SubOptTrait + Send + Sync>,
89  ) -> Self {
90    Self {
91      ctx,
92      unsub,
93      options,
94      _marker: PhantomData,
95    }
96  }
97}
98
99#[async_trait]
100impl<T> SubTrait for Sub<T>
101where
102  T: DeserializeOwned + Send + Sync,
103{
104  type Item = T;
105  /// Returns a stream of decoded messages alongside their acknowledgment
106  /// handles. When auto-acknowledgment is enabled, messages are acknowledged
107  /// before being yielded to the consumer.
108  async fn subscribe(
109    &self,
110  ) -> Result<
111    BoxStream<Result<(Self::Item, Arc<dyn AckTrait + Send + Sync>), SubError>>,
112    SubError,
113  > {
114    let messages = self.ctx.subscribe().await?;
115    let stream = messages.and_then(async move |(msg, acker)| {
116      if self.options.get_auto_ack() {
117        acker.ack().await?;
118      }
119      let data = match self.options.get_format() {
120        Format::MessagePack => {
121          rmp_serde::from_slice::<T>(&msg).map_err(SubError::MessagePackDecode)
122        }
123        Format::JSON => {
124          serde_json::from_slice::<T>(&msg).map_err(SubError::Json)
125        }
126      }?;
127      Ok((data, acker))
128    });
129    return Ok(Box::pin(stream));
130  }
131}
132
133#[async_trait]
134impl<T> UnSubTrait for Sub<T>
135where
136  T: DeserializeOwned + Send + Sync,
137{
138  /// Invokes the optional unsubscribe handler, if present.
139  async fn unsubscribe(&self) -> Result<(), UnSubError> {
140    if let Some(unsub) = &self.unsub {
141      unsub.unsubscribe().await?;
142    }
143    return Ok(());
144  }
145}
146
147#[cfg(test)]
148mod test {
149  use ::bytes::Bytes;
150  use ::futures::stream::StreamExt;
151  use ::rmp_serde::to_vec as to_msgpack;
152  use ::serde_json::to_vec as jsonify;
153
154  use crate::error::AckError;
155  use crate::tests::{entity::TestEntity, subscribe::SubscribeMock};
156  use crate::traits::{MockAckTrait, MockSubOptTrait};
157
158  use super::*;
159
160  async fn test_subscribe(format: Format, auto_ack: bool) {
161    let entities = vec![
162      TestEntity::new(1, "Test1"),
163      TestEntity::new(2, "Test2"),
164      TestEntity::new(3, "Test3"),
165    ];
166    let data: Vec<(Bytes, Arc<dyn AckTrait + Send + Sync>)> = entities
167      .iter()
168      .map(|e| {
169        let mut ack_mock = MockAckTrait::new();
170        if auto_ack {
171          ack_mock.expect_ack().returning(|| Ok(())).once();
172        } else {
173          ack_mock.expect_ack().never();
174        }
175        return (
176          Bytes::from(match format {
177            Format::MessagePack => to_msgpack(e).unwrap(),
178            Format::JSON => jsonify(e).unwrap(),
179          }),
180          Arc::new(ack_mock) as Arc<dyn AckTrait + Send + Sync>,
181        );
182      })
183      .collect();
184    let ctx: Arc<dyn SubCtxTrait + Send + Sync> =
185      Arc::new(SubscribeMock::new(data));
186    let mut options = MockSubOptTrait::new();
187    options
188      .expect_get_auto_ack()
189      .return_const(auto_ack)
190      .times(entities.len());
191    options
192      .expect_get_format()
193      .return_const(format)
194      .times(entities.len());
195    let subscribe: Sub<TestEntity> = Sub::new(
196      ctx,
197      None,
198      Arc::new(options) as Arc<dyn SubOptTrait + Send + Sync>,
199    );
200    let stream = subscribe.subscribe().await.unwrap();
201    let obtained: Vec<TestEntity> = stream
202      .try_collect::<Vec<_>>()
203      .await
204      .unwrap()
205      .into_iter()
206      .map(|(entity, _ack)| entity)
207      .collect();
208    assert_eq!(obtained, entities);
209  }
210
211  #[tokio::test]
212  async fn test_subscribe_json() {
213    test_subscribe(Format::JSON, true).await;
214  }
215
216  #[tokio::test]
217  async fn test_subscribe_messagepack() {
218    test_subscribe(Format::MessagePack, true).await;
219  }
220
221  #[tokio::test]
222  async fn test_subscribe_json_no_auto_ack() {
223    test_subscribe(Format::JSON, false).await;
224  }
225
226  #[tokio::test]
227  async fn test_subscribe_messagepack_no_auto_ack() {
228    test_subscribe(Format::MessagePack, false).await;
229  }
230
231  async fn test_ack_err(format: Format) {
232    let mut data: Vec<(Bytes, Arc<dyn AckTrait + Send + Sync>)> = Vec::new();
233    data.push((Bytes::new(), {
234      let mut ack_mock = MockAckTrait::new();
235      ack_mock
236        .expect_ack()
237        .returning(|| Err(AckError::ErrorTest))
238        .once();
239      Arc::new(ack_mock)
240    }));
241    let ctx: Arc<dyn SubCtxTrait + Send + Sync> =
242      Arc::new(SubscribeMock::new(data));
243    let mut options = MockSubOptTrait::new();
244    options.expect_get_auto_ack().return_const(true).once();
245    options.expect_get_format().return_const(format).never();
246    let subscribe: Sub<TestEntity> = Sub::new(
247      ctx,
248      None,
249      Arc::new(options) as Arc<dyn SubOptTrait + Send + Sync>,
250    );
251    let stream = subscribe.subscribe().await.unwrap();
252    let obtained: Vec<String> = stream
253      .collect::<Vec<_>>()
254      .await
255      .iter()
256      .filter_map(|res| res.as_ref().map_err(|err| err.to_string()).err())
257      .collect();
258    assert_eq!(
259      obtained,
260      vec![SubError::AckError(AckError::ErrorTest).to_string()]
261    );
262  }
263
264  #[tokio::test]
265  async fn test_ack_json_err() {
266    test_ack_err(Format::JSON).await;
267  }
268
269  #[tokio::test]
270  async fn test_ack_messagepack_err() {
271    test_ack_err(Format::MessagePack).await;
272  }
273}