1use std::marker::PhantomData;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use futures::TryStreamExt;
6use futures::stream::BoxStream;
7use serde::de::DeserializeOwned;
8
9use crate::r#enum::Format;
10use crate::error::{SubError, UnSubError};
11use crate::traits::{
12 AckTrait, SubCtxTrait, SubOptTrait, SubTrait, UnSubTrait,
13};
14
15pub struct Sub<T> {
68 ctx: Arc<dyn SubCtxTrait + Send + Sync>,
69 unsub: Option<Arc<dyn UnSubTrait + Send + Sync>>,
70 options: Arc<dyn SubOptTrait + Send + Sync>,
71 _marker: PhantomData<T>,
72}
73
74impl<T> Sub<T>
75where
76 T: DeserializeOwned + Send + Sync,
77{
78 pub fn new(
86 ctx: Arc<dyn SubCtxTrait + Send + Sync>,
87 unsub: Option<Arc<dyn UnSubTrait + Send + Sync>>,
88 options: Arc<dyn SubOptTrait + Send + Sync>,
89 ) -> Self {
90 Self {
91 ctx,
92 unsub,
93 options,
94 _marker: PhantomData,
95 }
96 }
97}
98
99#[async_trait]
100impl<T> SubTrait for Sub<T>
101where
102 T: DeserializeOwned + Send + Sync,
103{
104 type Item = T;
105 async fn subscribe(
109 &self,
110 ) -> Result<
111 BoxStream<Result<(Self::Item, Arc<dyn AckTrait + Send + Sync>), SubError>>,
112 SubError,
113 > {
114 let messages = self.ctx.subscribe().await?;
115 let stream = messages.and_then(async move |(msg, acker)| {
116 if self.options.get_auto_ack() {
117 acker.ack().await?;
118 }
119 let data = match self.options.get_format() {
120 Format::MessagePack => {
121 rmp_serde::from_slice::<T>(&msg).map_err(SubError::MessagePackDecode)
122 }
123 Format::JSON => {
124 serde_json::from_slice::<T>(&msg).map_err(SubError::Json)
125 }
126 }?;
127 Ok((data, acker))
128 });
129 return Ok(Box::pin(stream));
130 }
131}
132
133#[async_trait]
134impl<T> UnSubTrait for Sub<T>
135where
136 T: DeserializeOwned + Send + Sync,
137{
138 async fn unsubscribe(&self) -> Result<(), UnSubError> {
140 if let Some(unsub) = &self.unsub {
141 unsub.unsubscribe().await?;
142 }
143 return Ok(());
144 }
145}
146
147#[cfg(test)]
148mod test {
149 use ::bytes::Bytes;
150 use ::futures::stream::StreamExt;
151 use ::rmp_serde::to_vec as to_msgpack;
152 use ::serde_json::to_vec as jsonify;
153
154 use crate::error::AckError;
155 use crate::tests::{entity::TestEntity, subscribe::SubscribeMock};
156 use crate::traits::{MockAckTrait, MockSubOptTrait};
157
158 use super::*;
159
160 async fn test_subscribe(format: Format, auto_ack: bool) {
161 let entities = vec![
162 TestEntity::new(1, "Test1"),
163 TestEntity::new(2, "Test2"),
164 TestEntity::new(3, "Test3"),
165 ];
166 let data: Vec<(Bytes, Arc<dyn AckTrait + Send + Sync>)> = entities
167 .iter()
168 .map(|e| {
169 let mut ack_mock = MockAckTrait::new();
170 if auto_ack {
171 ack_mock.expect_ack().returning(|| Ok(())).once();
172 } else {
173 ack_mock.expect_ack().never();
174 }
175 return (
176 Bytes::from(match format {
177 Format::MessagePack => to_msgpack(e).unwrap(),
178 Format::JSON => jsonify(e).unwrap(),
179 }),
180 Arc::new(ack_mock) as Arc<dyn AckTrait + Send + Sync>,
181 );
182 })
183 .collect();
184 let ctx: Arc<dyn SubCtxTrait + Send + Sync> =
185 Arc::new(SubscribeMock::new(data));
186 let mut options = MockSubOptTrait::new();
187 options
188 .expect_get_auto_ack()
189 .return_const(auto_ack)
190 .times(entities.len());
191 options
192 .expect_get_format()
193 .return_const(format)
194 .times(entities.len());
195 let subscribe: Sub<TestEntity> = Sub::new(
196 ctx,
197 None,
198 Arc::new(options) as Arc<dyn SubOptTrait + Send + Sync>,
199 );
200 let stream = subscribe.subscribe().await.unwrap();
201 let obtained: Vec<TestEntity> = stream
202 .try_collect::<Vec<_>>()
203 .await
204 .unwrap()
205 .into_iter()
206 .map(|(entity, _ack)| entity)
207 .collect();
208 assert_eq!(obtained, entities);
209 }
210
211 #[tokio::test]
212 async fn test_subscribe_json() {
213 test_subscribe(Format::JSON, true).await;
214 }
215
216 #[tokio::test]
217 async fn test_subscribe_messagepack() {
218 test_subscribe(Format::MessagePack, true).await;
219 }
220
221 #[tokio::test]
222 async fn test_subscribe_json_no_auto_ack() {
223 test_subscribe(Format::JSON, false).await;
224 }
225
226 #[tokio::test]
227 async fn test_subscribe_messagepack_no_auto_ack() {
228 test_subscribe(Format::MessagePack, false).await;
229 }
230
231 async fn test_ack_err(format: Format) {
232 let mut data: Vec<(Bytes, Arc<dyn AckTrait + Send + Sync>)> = Vec::new();
233 data.push((Bytes::new(), {
234 let mut ack_mock = MockAckTrait::new();
235 ack_mock
236 .expect_ack()
237 .returning(|| Err(AckError::ErrorTest))
238 .once();
239 Arc::new(ack_mock)
240 }));
241 let ctx: Arc<dyn SubCtxTrait + Send + Sync> =
242 Arc::new(SubscribeMock::new(data));
243 let mut options = MockSubOptTrait::new();
244 options.expect_get_auto_ack().return_const(true).once();
245 options.expect_get_format().return_const(format).never();
246 let subscribe: Sub<TestEntity> = Sub::new(
247 ctx,
248 None,
249 Arc::new(options) as Arc<dyn SubOptTrait + Send + Sync>,
250 );
251 let stream = subscribe.subscribe().await.unwrap();
252 let obtained: Vec<String> = stream
253 .collect::<Vec<_>>()
254 .await
255 .iter()
256 .filter_map(|res| res.as_ref().map_err(|err| err.to_string()).err())
257 .collect();
258 assert_eq!(
259 obtained,
260 vec![SubError::AckError(AckError::ErrorTest).to_string()]
261 );
262 }
263
264 #[tokio::test]
265 async fn test_ack_json_err() {
266 test_ack_err(Format::JSON).await;
267 }
268
269 #[tokio::test]
270 async fn test_ack_messagepack_err() {
271 test_ack_err(Format::MessagePack).await;
272 }
273}