1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::{Duration, SystemTime};
6
7use prost_types::{DurationError, FieldMask};
8use tokio_util::sync::CancellationToken;
9
10use google_cloud_gax::grpc::codegen::tokio_stream::Stream;
11use google_cloud_gax::grpc::{Code, Status};
12use google_cloud_gax::retry::RetrySetting;
13use google_cloud_googleapis::pubsub::v1::seek_request::Target;
14use google_cloud_googleapis::pubsub::v1::subscription::AnalyticsHubSubscriptionInfo;
15use google_cloud_googleapis::pubsub::v1::{
16 BigQueryConfig, CloudStorageConfig, CreateSnapshotRequest, DeadLetterPolicy, DeleteSnapshotRequest,
17 DeleteSubscriptionRequest, ExpirationPolicy, GetSnapshotRequest, GetSubscriptionRequest, PullRequest, PushConfig,
18 RetryPolicy, SeekRequest, Snapshot, Subscription as InternalSubscription, UpdateSubscriptionRequest,
19};
20
21use crate::apiv1::subscriber_client::SubscriberClient;
22
23use crate::subscriber::{ack, ReceivedMessage, Subscriber, SubscriberConfig};
24
25#[derive(Debug, Clone, Default)]
26pub struct SubscriptionConfig {
27 pub push_config: Option<PushConfig>,
28 pub ack_deadline_seconds: i32,
29 pub retain_acked_messages: bool,
30 pub message_retention_duration: Option<Duration>,
31 pub labels: HashMap<String, String>,
32 pub enable_message_ordering: bool,
33 pub expiration_policy: Option<ExpirationPolicy>,
34 pub filter: String,
35 pub dead_letter_policy: Option<DeadLetterPolicy>,
36 pub retry_policy: Option<RetryPolicy>,
37 pub detached: bool,
38 pub topic_message_retention_duration: Option<Duration>,
39 pub enable_exactly_once_delivery: bool,
40 pub bigquery_config: Option<BigQueryConfig>,
41 pub state: i32,
42 pub cloud_storage_config: Option<CloudStorageConfig>,
43 pub analytics_hub_subscription_info: Option<AnalyticsHubSubscriptionInfo>,
44}
45impl From<InternalSubscription> for SubscriptionConfig {
46 fn from(f: InternalSubscription) -> Self {
47 Self {
48 push_config: f.push_config,
49 bigquery_config: f.bigquery_config,
50 ack_deadline_seconds: f.ack_deadline_seconds,
51 retain_acked_messages: f.retain_acked_messages,
52 message_retention_duration: f
53 .message_retention_duration
54 .map(|v| std::time::Duration::new(v.seconds as u64, v.nanos as u32)),
55 labels: f.labels,
56 enable_message_ordering: f.enable_message_ordering,
57 expiration_policy: f.expiration_policy,
58 filter: f.filter,
59 dead_letter_policy: f.dead_letter_policy,
60 retry_policy: f.retry_policy,
61 detached: f.detached,
62 topic_message_retention_duration: f
63 .topic_message_retention_duration
64 .map(|v| std::time::Duration::new(v.seconds as u64, v.nanos as u32)),
65 enable_exactly_once_delivery: f.enable_exactly_once_delivery,
66 state: f.state,
67 cloud_storage_config: f.cloud_storage_config,
68 analytics_hub_subscription_info: f.analytics_hub_subscription_info,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Default)]
74pub struct SubscriptionConfigToUpdate {
75 pub push_config: Option<PushConfig>,
76 pub bigquery_config: Option<BigQueryConfig>,
77 pub ack_deadline_seconds: Option<i32>,
78 pub retain_acked_messages: Option<bool>,
79 pub message_retention_duration: Option<Duration>,
80 pub labels: Option<HashMap<String, String>>,
81 pub expiration_policy: Option<ExpirationPolicy>,
82 pub dead_letter_policy: Option<DeadLetterPolicy>,
83 pub retry_policy: Option<RetryPolicy>,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct SubscribeConfig {
88 enable_multiple_subscriber: bool,
89 channel_capacity: Option<usize>,
90 subscriber_config: Option<SubscriberConfig>,
91}
92
93impl SubscribeConfig {
94 pub fn with_enable_multiple_subscriber(mut self, v: bool) -> Self {
95 self.enable_multiple_subscriber = v;
96 self
97 }
98 pub fn with_subscriber_config(mut self, v: SubscriberConfig) -> Self {
99 self.subscriber_config = Some(v);
100 self
101 }
102 pub fn with_channel_capacity(mut self, v: usize) -> Self {
103 self.channel_capacity = Some(v);
104 self
105 }
106}
107
108#[derive(Debug, Clone)]
109pub struct ReceiveConfig {
110 pub worker_count: usize,
111 pub channel_capacity: Option<usize>,
112 pub subscriber_config: Option<SubscriberConfig>,
113}
114
115impl Default for ReceiveConfig {
116 fn default() -> Self {
117 Self {
118 worker_count: 10,
119 subscriber_config: None,
120 channel_capacity: None,
121 }
122 }
123}
124
125#[derive(Debug, Clone)]
126pub enum SeekTo {
127 Timestamp(SystemTime),
128 Snapshot(String),
129}
130
131impl From<SeekTo> for Target {
132 fn from(to: SeekTo) -> Target {
133 use SeekTo::*;
134 match to {
135 Timestamp(t) => Target::Time(prost_types::Timestamp::from(t)),
136 Snapshot(s) => Target::Snapshot(s),
137 }
138 }
139}
140
141pub struct MessageStream {
142 queue: async_channel::Receiver<ReceivedMessage>,
143 cancel: CancellationToken,
144 tasks: Vec<Subscriber>,
145}
146
147impl MessageStream {
148 pub fn cancellable(&self) -> CancellationToken {
149 self.cancel.clone()
150 }
151
152 pub async fn dispose(&mut self) {
153 if !self.cancel.is_cancelled() {
155 self.cancel.cancel();
156 }
157
158 for task in &mut self.tasks {
160 task.done().await;
161 }
162
163 while let Ok(message) = self.queue.recv().await {
165 if let Err(err) = message.nack().await {
166 tracing::warn!("failed to nack message messageId={} {:?}", message.message.message_id, err);
167 }
168 }
169 }
170
171 pub async fn read(&mut self) -> Option<ReceivedMessage> {
173 let message = tokio::select! {
174 msg = self.queue.recv() => msg.ok(),
175 _ = self.cancel.cancelled() => None
176 };
177 if message.is_none() {
178 self.dispose().await;
179 }
180 message
181 }
182}
183
184impl Drop for MessageStream {
185 fn drop(&mut self) {
186 if !self.queue.is_empty() {
187 tracing::warn!("Call 'dispose' before drop in order to call nack for remaining messages");
188 }
189 if !self.cancel.is_cancelled() {
190 self.cancel.cancel();
191 }
192 }
193}
194
195impl Stream for MessageStream {
196 type Item = ReceivedMessage;
197
198 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201 Pin::new(&mut self.get_mut().queue).poll_next(cx)
202 }
203}
204
205#[derive(Clone, Debug)]
207pub struct Subscription {
208 fqsn: String,
209 subc: SubscriberClient,
210}
211
212impl Subscription {
213 pub(crate) fn new(fqsn: String, subc: SubscriberClient) -> Self {
214 Self { fqsn, subc }
215 }
216
217 pub(crate) fn streaming_pool_size(&self) -> usize {
218 self.subc.streaming_pool_size()
219 }
220
221 pub fn id(&self) -> String {
223 self.fqsn
224 .rfind('/')
225 .map_or("".to_string(), |i| self.fqsn[(i + 1)..].to_string())
226 }
227
228 pub fn fully_qualified_name(&self) -> &str {
230 self.fqsn.as_str()
231 }
232
233 pub fn fully_qualified_snapshot_name(&self, id: &str) -> String {
235 if id.contains('/') {
236 id.to_string()
237 } else {
238 format!("{}/snapshots/{}", self.fully_qualified_project_name(), id)
239 }
240 }
241
242 fn fully_qualified_project_name(&self) -> String {
243 let parts: Vec<_> = self
244 .fqsn
245 .split('/')
246 .enumerate()
247 .filter(|&(i, _)| i < 2)
248 .map(|e| e.1)
249 .collect();
250 parts.join("/")
251 }
252
253 pub fn get_client(&self) -> SubscriberClient {
254 self.subc.clone()
255 }
256
257 pub async fn create(&self, fqtn: &str, cfg: SubscriptionConfig, retry: Option<RetrySetting>) -> Result<(), Status> {
259 self.subc
260 .create_subscription(
261 InternalSubscription {
262 name: self.fully_qualified_name().to_string(),
263 topic: fqtn.to_string(),
264 push_config: cfg.push_config,
265 bigquery_config: cfg.bigquery_config,
266 cloud_storage_config: cfg.cloud_storage_config,
267 ack_deadline_seconds: cfg.ack_deadline_seconds,
268 labels: cfg.labels,
269 enable_message_ordering: cfg.enable_message_ordering,
270 expiration_policy: cfg.expiration_policy,
271 filter: cfg.filter,
272 dead_letter_policy: cfg.dead_letter_policy,
273 retry_policy: cfg.retry_policy,
274 detached: cfg.detached,
275 message_retention_duration: cfg
276 .message_retention_duration
277 .map(Duration::try_into)
278 .transpose()
279 .map_err(|err: DurationError| Status::internal(err.to_string()))?,
280 retain_acked_messages: cfg.retain_acked_messages,
281 topic_message_retention_duration: cfg
282 .topic_message_retention_duration
283 .map(Duration::try_into)
284 .transpose()
285 .map_err(|err: DurationError| Status::internal(err.to_string()))?,
286 enable_exactly_once_delivery: cfg.enable_exactly_once_delivery,
287 state: cfg.state,
288 analytics_hub_subscription_info: cfg.analytics_hub_subscription_info,
289 },
290 retry,
291 )
292 .await
293 .map(|_v| ())
294 }
295
296 pub async fn delete(&self, retry: Option<RetrySetting>) -> Result<(), Status> {
298 let req = DeleteSubscriptionRequest {
299 subscription: self.fqsn.to_string(),
300 };
301 self.subc.delete_subscription(req, retry).await.map(|v| v.into_inner())
302 }
303
304 pub async fn exists(&self, retry: Option<RetrySetting>) -> Result<bool, Status> {
306 let req = GetSubscriptionRequest {
307 subscription: self.fqsn.to_string(),
308 };
309 match self.subc.get_subscription(req, retry).await {
310 Ok(_) => Ok(true),
311 Err(e) => {
312 if e.code() == Code::NotFound {
313 Ok(false)
314 } else {
315 Err(e)
316 }
317 }
318 }
319 }
320
321 pub async fn config(&self, retry: Option<RetrySetting>) -> Result<(String, SubscriptionConfig), Status> {
323 let req = GetSubscriptionRequest {
324 subscription: self.fqsn.to_string(),
325 };
326 self.subc.get_subscription(req, retry).await.map(|v| {
327 let inner = v.into_inner();
328 (inner.topic.to_string(), inner.into())
329 })
330 }
331
332 pub async fn update(
335 &self,
336 updating: SubscriptionConfigToUpdate,
337 retry: Option<RetrySetting>,
338 ) -> Result<(String, SubscriptionConfig), Status> {
339 let req = GetSubscriptionRequest {
340 subscription: self.fqsn.to_string(),
341 };
342 let mut config = self.subc.get_subscription(req, retry.clone()).await?.into_inner();
343
344 let mut paths = vec![];
345 if updating.push_config.is_some() {
346 config.push_config = updating.push_config;
347 paths.push("push_config".to_string());
348 }
349 if updating.bigquery_config.is_some() {
350 config.bigquery_config = updating.bigquery_config;
351 paths.push("bigquery_config".to_string());
352 }
353 if let Some(v) = updating.ack_deadline_seconds {
354 config.ack_deadline_seconds = v;
355 paths.push("ack_deadline_seconds".to_string());
356 }
357 if let Some(v) = updating.retain_acked_messages {
358 config.retain_acked_messages = v;
359 paths.push("retain_acked_messages".to_string());
360 }
361 if updating.message_retention_duration.is_some() {
362 config.message_retention_duration = updating
363 .message_retention_duration
364 .map(prost_types::Duration::try_from)
365 .transpose()
366 .map_err(|err| Status::internal(err.to_string()))?;
367 paths.push("message_retention_duration".to_string());
368 }
369 if updating.expiration_policy.is_some() {
370 config.expiration_policy = updating.expiration_policy;
371 paths.push("expiration_policy".to_string());
372 }
373 if let Some(v) = updating.labels {
374 config.labels = v;
375 paths.push("labels".to_string());
376 }
377 if updating.retry_policy.is_some() {
378 config.retry_policy = updating.retry_policy;
379 paths.push("retry_policy".to_string());
380 }
381
382 let update_req = UpdateSubscriptionRequest {
383 subscription: Some(config),
384 update_mask: Some(FieldMask { paths }),
385 };
386 self.subc.update_subscription(update_req, retry).await.map(|v| {
387 let inner = v.into_inner();
388 (inner.topic.to_string(), inner.into())
389 })
390 }
391
392 pub async fn pull(&self, max_messages: i32, retry: Option<RetrySetting>) -> Result<Vec<ReceivedMessage>, Status> {
395 #[allow(deprecated)]
396 let req = PullRequest {
397 subscription: self.fqsn.clone(),
398 return_immediately: false,
399 max_messages,
400 };
401 let messages = self.subc.pull(req, retry).await?.into_inner().received_messages;
402 Ok(messages
403 .into_iter()
404 .filter(|m| m.message.is_some())
405 .map(|m| {
406 ReceivedMessage::new(
407 self.fqsn.clone(),
408 self.subc.clone(),
409 m.message.unwrap(),
410 m.ack_id,
411 (m.delivery_attempt > 0).then_some(m.delivery_attempt as usize),
412 )
413 })
414 .collect())
415 }
416
417 pub async fn subscribe(&self, opt: Option<SubscribeConfig>) -> Result<MessageStream, Status> {
459 let opt = opt.unwrap_or_default();
460 let (tx, rx) = create_channel(opt.channel_capacity);
461 let cancel = CancellationToken::new();
462 let sub_opt = self.unwrap_subscribe_config(opt.subscriber_config).await?;
463
464 let subscribers = if opt.enable_multiple_subscriber {
466 self.streaming_pool_size()
467 } else {
468 1
469 };
470 let mut tasks = Vec::with_capacity(subscribers);
471 for _ in 0..subscribers {
472 tasks.push(Subscriber::start(
473 cancel.clone(),
474 self.fqsn.clone(),
475 self.subc.clone(),
476 tx.clone(),
477 sub_opt.clone(),
478 ));
479 }
480
481 Ok(MessageStream {
482 queue: rx,
483 cancel,
484 tasks,
485 })
486 }
487
488 pub async fn receive<F>(
492 &self,
493 f: impl Fn(ReceivedMessage, CancellationToken) -> F + Send + 'static + Sync + Clone,
494 cancel: CancellationToken,
495 config: Option<ReceiveConfig>,
496 ) -> Result<(), Status>
497 where
498 F: Future<Output = ()> + Send + 'static,
499 {
500 let op = config.unwrap_or_default();
501 let mut receivers = Vec::with_capacity(op.worker_count);
502 let mut senders = Vec::with_capacity(receivers.len());
503 let sub_opt = self.unwrap_subscribe_config(op.subscriber_config).await?;
504
505 if self
506 .config(sub_opt.retry_setting.clone())
507 .await?
508 .1
509 .enable_message_ordering
510 {
511 (0..op.worker_count).for_each(|_v| {
512 let (sender, receiver) = create_channel(op.channel_capacity);
513 receivers.push(receiver);
514 senders.push(sender);
515 });
516 } else {
517 let (sender, receiver) = create_channel(op.channel_capacity);
518 (0..op.worker_count).for_each(|_v| {
519 receivers.push(receiver.clone());
520 senders.push(sender.clone());
521 });
522 }
523
524 let subscribers: Vec<Subscriber> = senders
526 .into_iter()
527 .map(|queue| {
528 Subscriber::start(cancel.clone(), self.fqsn.clone(), self.subc.clone(), queue, sub_opt.clone())
529 })
530 .collect();
531
532 let mut message_receivers = Vec::with_capacity(receivers.len());
533 for receiver in receivers {
534 let f_clone = f.clone();
535 let cancel_clone = cancel.clone();
536 let name = self.fqsn.clone();
537 message_receivers.push(tokio::spawn(async move {
538 while let Ok(message) = receiver.recv().await {
539 f_clone(message, cancel_clone.clone()).await;
540 }
541 tracing::trace!("stop message receiver : {}", name);
543 }));
544 }
545 cancel.cancelled().await;
546
547 for mut subscriber in subscribers {
549 subscriber.done().await;
550 }
551
552 for mr in message_receivers {
554 let _ = mr.await;
555 }
556 Ok(())
557 }
558
559 pub async fn ack(&self, ack_ids: Vec<String>) -> Result<(), Status> {
622 ack(&self.subc, self.fqsn.to_string(), ack_ids).await
623 }
624
625 pub async fn seek(&self, to: SeekTo, retry: Option<RetrySetting>) -> Result<(), Status> {
627 let to = match to {
628 SeekTo::Timestamp(t) => SeekTo::Timestamp(t),
629 SeekTo::Snapshot(name) => SeekTo::Snapshot(self.fully_qualified_snapshot_name(name.as_str())),
630 };
631
632 let req = SeekRequest {
633 subscription: self.fqsn.to_owned(),
634 target: Some(to.into()),
635 };
636
637 let _ = self.subc.seek(req, retry).await?;
638 Ok(())
639 }
640
641 pub async fn get_snapshot(&self, name: &str, retry: Option<RetrySetting>) -> Result<Snapshot, Status> {
643 let req = GetSnapshotRequest {
644 snapshot: self.fully_qualified_snapshot_name(name),
645 };
646 Ok(self.subc.get_snapshot(req, retry).await?.into_inner())
647 }
648
649 pub async fn create_snapshot(
659 &self,
660 name: &str,
661 labels: HashMap<String, String>,
662 retry: Option<RetrySetting>,
663 ) -> Result<Snapshot, Status> {
664 let req = CreateSnapshotRequest {
665 name: self.fully_qualified_snapshot_name(name),
666 labels,
667 subscription: self.fqsn.to_owned(),
668 };
669 Ok(self.subc.create_snapshot(req, retry).await?.into_inner())
670 }
671
672 pub async fn delete_snapshot(&self, name: &str, retry: Option<RetrySetting>) -> Result<(), Status> {
674 let req = DeleteSnapshotRequest {
675 snapshot: self.fully_qualified_snapshot_name(name),
676 };
677 let _ = self.subc.delete_snapshot(req, retry).await?;
678 Ok(())
679 }
680
681 async fn unwrap_subscribe_config(&self, cfg: Option<SubscriberConfig>) -> Result<SubscriberConfig, Status> {
682 if let Some(cfg) = cfg {
683 return Ok(cfg);
684 }
685 let cfg = self.config(None).await?;
686 let mut default_cfg = SubscriberConfig {
687 stream_ack_deadline_seconds: cfg.1.ack_deadline_seconds.clamp(10, 600),
688 ..Default::default()
689 };
690 if cfg.1.enable_exactly_once_delivery {
691 default_cfg.max_outstanding_messages = 5;
692 }
693 Ok(default_cfg)
694 }
695}
696
697fn create_channel(
698 channel_capacity: Option<usize>,
699) -> (async_channel::Sender<ReceivedMessage>, async_channel::Receiver<ReceivedMessage>) {
700 match channel_capacity {
701 None => async_channel::unbounded(),
702 Some(cap) => async_channel::bounded(cap),
703 }
704}
705
706#[cfg(test)]
707mod tests {
708 use std::collections::HashMap;
709 use std::sync::atomic::AtomicU32;
710 use std::sync::atomic::Ordering::SeqCst;
711 use std::sync::{Arc, Mutex};
712 use std::time::Duration;
713
714 use futures_util::StreamExt;
715 use serial_test::serial;
716 use tokio_util::sync::CancellationToken;
717 use uuid::Uuid;
718
719 use google_cloud_gax::conn::{ConnectionOptions, Environment};
720 use google_cloud_googleapis::pubsub::v1::{PublishRequest, PubsubMessage};
721
722 use crate::apiv1::conn_pool::ConnectionManager;
723 use crate::apiv1::publisher_client::PublisherClient;
724 use crate::apiv1::subscriber_client::SubscriberClient;
725 use crate::subscriber::ReceivedMessage;
726 use crate::subscription::{
727 ReceiveConfig, SeekTo, SubscribeConfig, Subscription, SubscriptionConfig, SubscriptionConfigToUpdate,
728 };
729
730 const PROJECT_NAME: &str = "local-project";
731 const EMULATOR: &str = "localhost:8681";
732
733 #[ctor::ctor]
734 fn init() {
735 let _ = tracing_subscriber::fmt().try_init();
736 }
737
738 async fn create_subscription(enable_exactly_once_delivery: bool) -> Subscription {
739 let cm = ConnectionManager::new(
740 4,
741 "",
742 &Environment::Emulator(EMULATOR.to_string()),
743 &ConnectionOptions::default(),
744 )
745 .await
746 .unwrap();
747 let cm2 = ConnectionManager::new(
748 4,
749 "",
750 &Environment::Emulator(EMULATOR.to_string()),
751 &ConnectionOptions::default(),
752 )
753 .await
754 .unwrap();
755 let client = SubscriberClient::new(cm, cm2);
756
757 let uuid = Uuid::new_v4().hyphenated().to_string();
758 let subscription_name = format!("projects/{}/subscriptions/s{}", PROJECT_NAME, &uuid);
759 let topic_name = format!("projects/{PROJECT_NAME}/topics/test-topic1");
760 let subscription = Subscription::new(subscription_name, client);
761 let config = SubscriptionConfig {
762 enable_exactly_once_delivery,
763 ..Default::default()
764 };
765 if !subscription.exists(None).await.unwrap() {
766 subscription.create(topic_name.as_str(), config, None).await.unwrap();
767 }
768 subscription
769 }
770
771 async fn publish(messages: Option<Vec<PubsubMessage>>) {
772 let pubc = PublisherClient::new(
773 ConnectionManager::new(
774 4,
775 "",
776 &Environment::Emulator(EMULATOR.to_string()),
777 &ConnectionOptions::default(),
778 )
779 .await
780 .unwrap(),
781 );
782 let messages = messages.unwrap_or(vec![PubsubMessage {
783 data: "test_message".into(),
784 ..Default::default()
785 }]);
786 let req = PublishRequest {
787 topic: format!("projects/{PROJECT_NAME}/topics/test-topic1"),
788 messages,
789 };
790 let _ = pubc.publish(req, None).await;
791 }
792
793 async fn test_subscription(enable_exactly_once_delivery: bool) {
794 let subscription = create_subscription(enable_exactly_once_delivery).await;
795
796 let topic_name = format!("projects/{PROJECT_NAME}/topics/test-topic1");
797 let config = subscription.config(None).await.unwrap();
798 assert_eq!(config.0, topic_name);
799
800 let updating = SubscriptionConfigToUpdate {
801 ack_deadline_seconds: Some(100),
802 ..Default::default()
803 };
804 let new_config = subscription.update(updating, None).await.unwrap();
805 assert_eq!(new_config.0, topic_name);
806 assert_eq!(new_config.1.ack_deadline_seconds, 100);
807
808 let receiver_ctx = CancellationToken::new();
809 let cancel_receiver = receiver_ctx.clone();
810 let handle = tokio::spawn(async move {
811 let _ = subscription
812 .receive(
813 |message, _ctx| async move {
814 println!("{}", message.message.message_id);
815 let _ = message.ack().await;
816 },
817 cancel_receiver,
818 None,
819 )
820 .await;
821 subscription.delete(None).await.unwrap();
822 assert!(!subscription.exists(None).await.unwrap())
823 });
824 tokio::time::sleep(Duration::from_secs(3)).await;
825 receiver_ctx.cancel();
826 let _ = handle.await;
827 }
828
829 #[tokio::test(flavor = "multi_thread")]
830 #[serial]
831 async fn test_pull() {
832 let subscription = create_subscription(false).await;
833 let base = PubsubMessage {
834 data: "test_message".into(),
835 ..Default::default()
836 };
837 publish(Some(vec![base.clone(), base.clone(), base])).await;
838 let messages = subscription.pull(2, None).await.unwrap();
839 assert_eq!(messages.len(), 2);
840 for m in messages {
841 m.ack().await.unwrap();
842 }
843 subscription.delete(None).await.unwrap();
844 }
845
846 #[tokio::test]
847 #[serial]
848 async fn test_subscription_exactly_once() {
849 test_subscription(true).await;
850 }
851
852 #[tokio::test]
853 #[serial]
854 async fn test_subscription_at_least_once() {
855 test_subscription(false).await;
856 }
857
858 #[tokio::test(flavor = "multi_thread")]
859 #[serial]
860 async fn test_multi_subscriber_single_subscription_unbound() {
861 test_multi_subscriber_single_subscription(None).await;
862 }
863
864 #[tokio::test(flavor = "multi_thread")]
865 #[serial]
866 async fn test_multi_subscriber_single_subscription_bound() {
867 let opt = Some(ReceiveConfig {
868 channel_capacity: Some(1),
869 ..Default::default()
870 });
871 test_multi_subscriber_single_subscription(opt).await;
872 }
873
874 async fn test_multi_subscriber_single_subscription(opt: Option<ReceiveConfig>) {
875 let msg = PubsubMessage {
876 data: "test".into(),
877 ..Default::default()
878 };
879 let msg_size = 10;
880 let msgs: Vec<PubsubMessage> = (0..msg_size).map(|_v| msg.clone()).collect();
881 let subscription = create_subscription(false).await;
882 let cancellation_token = CancellationToken::new();
883 let cancel_receiver = cancellation_token.clone();
884 let v = Arc::new(AtomicU32::new(0));
885 let v2 = v.clone();
886 let handle = tokio::spawn(async move {
887 let _ = subscription
888 .receive(
889 move |message, _ctx| {
890 let v2 = v2.clone();
891 async move {
892 tracing::info!("received {}", message.message.message_id);
893 v2.fetch_add(1, SeqCst);
894 let _ = message.ack().await;
895 }
896 },
897 cancel_receiver,
898 opt,
899 )
900 .await;
901 });
902 publish(Some(msgs)).await;
903 tokio::time::sleep(Duration::from_secs(5)).await;
904 cancellation_token.cancel();
905 let _ = handle.await;
906 assert_eq!(v.load(SeqCst), msg_size);
907 }
908
909 #[tokio::test(flavor = "multi_thread")]
910 #[serial]
911 async fn test_multi_subscriber_multi_subscription() {
912 let mut subscriptions = vec![];
913
914 let ctx = CancellationToken::new();
915 for _ in 0..3 {
916 let subscription = create_subscription(false).await;
917 let v = Arc::new(AtomicU32::new(0));
918 let ctx = ctx.clone();
919 let v2 = v.clone();
920 let handle = tokio::spawn(async move {
921 let _ = subscription
922 .receive(
923 move |message, _ctx| {
924 let v2 = v2.clone();
925 async move {
926 v2.fetch_add(1, SeqCst);
927 let _ = message.ack().await;
928 }
929 },
930 ctx,
931 None,
932 )
933 .await;
934 });
935 subscriptions.push((handle, v))
936 }
937
938 publish(None).await;
939 tokio::time::sleep(Duration::from_secs(5)).await;
940
941 ctx.cancel();
942 for (task, v) in subscriptions {
943 let _ = task.await;
944 assert_eq!(v.load(SeqCst), 1);
945 }
946 }
947
948 #[tokio::test(flavor = "multi_thread")]
949 #[serial]
950 async fn test_batch_acking() {
951 let ctx = CancellationToken::new();
952 let subscription = create_subscription(false).await;
953 let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
954 let subscription_for_receive = subscription.clone();
955 let ctx_for_receive = ctx.clone();
956 let handle = tokio::spawn(async move {
957 let _ = subscription_for_receive
958 .receive(
959 move |message, _ctx| {
960 let sender = sender.clone();
961 async move {
962 let _ = sender.send(message.ack_id().to_string());
963 }
964 },
965 ctx_for_receive.clone(),
966 None,
967 )
968 .await;
969 });
970
971 let ctx_for_ack_manager = ctx.clone();
972 let ack_manager = tokio::spawn(async move {
973 let mut ack_ids = Vec::new();
974 while !ctx_for_ack_manager.is_cancelled() {
975 match tokio::time::timeout(Duration::from_secs(10), receiver.recv()).await {
976 Ok(ack_id) => {
977 if let Some(ack_id) = ack_id {
978 ack_ids.push(ack_id);
979 if ack_ids.len() > 10 {
980 subscription.ack(ack_ids).await.unwrap();
981 ack_ids = Vec::new();
982 }
983 }
984 }
985 Err(_e) => {
986 subscription.ack(ack_ids).await.unwrap();
988 ack_ids = Vec::new();
989 }
990 }
991 }
992 subscription.ack(ack_ids).await
994 });
995
996 publish(None).await;
997 tokio::time::sleep(Duration::from_secs(5)).await;
998
999 ctx.cancel();
1000 let _ = handle.await;
1001 assert!(ack_manager.await.is_ok());
1002 }
1003
1004 #[tokio::test]
1005 #[serial]
1006 async fn test_snapshots() {
1007 let subscription = create_subscription(false).await;
1008
1009 let snapshot_name = format!("snapshot-{}", rand::random::<u64>());
1010 let labels: HashMap<String, String> =
1011 HashMap::from_iter([("label-1".into(), "v1".into()), ("label-2".into(), "v2".into())]);
1012 let expected_fq_snap_name = format!("projects/{PROJECT_NAME}/snapshots/{snapshot_name}");
1013
1014 let _response = subscription.delete_snapshot(snapshot_name.as_str(), None).await;
1016
1017 let created_snapshot = subscription
1019 .create_snapshot(snapshot_name.as_str(), labels.clone(), None)
1020 .await
1021 .unwrap();
1022
1023 assert_eq!(created_snapshot.name, expected_fq_snap_name);
1024 let retrieved_snapshot = subscription.get_snapshot(snapshot_name.as_str(), None).await.unwrap();
1028 assert_eq!(created_snapshot, retrieved_snapshot);
1029
1030 subscription
1032 .delete_snapshot(snapshot_name.as_str(), None)
1033 .await
1034 .unwrap();
1035
1036 let _deleted_snapshot_status = subscription
1037 .get_snapshot(snapshot_name.as_str(), None)
1038 .await
1039 .expect_err("snapshot should have been deleted");
1040
1041 let _delete_again = subscription
1042 .delete_snapshot(snapshot_name.as_str(), None)
1043 .await
1044 .expect_err("snapshot should already be deleted");
1045 }
1046
1047 async fn ack_all(messages: &[ReceivedMessage]) {
1048 for message in messages.iter() {
1049 message.ack().await.unwrap();
1050 }
1051 }
1052
1053 #[tokio::test]
1054 #[serial]
1055 async fn test_seek_snapshot() {
1056 let subscription = create_subscription(false).await;
1057 let snapshot_name = format!("snapshot-{}", rand::random::<u64>());
1058
1059 publish(None).await;
1061 let messages = subscription.pull(100, None).await.unwrap();
1062 ack_all(&messages).await;
1063 assert_eq!(messages.len(), 1);
1064
1065 let _snapshot = subscription
1067 .create_snapshot(snapshot_name.as_str(), HashMap::new(), None)
1068 .await
1069 .unwrap();
1070
1071 publish(None).await;
1073 let messages = subscription.pull(100, None).await.unwrap();
1074 assert_eq!(messages.len(), 1);
1075 ack_all(&messages).await;
1076
1077 subscription
1079 .seek(SeekTo::Snapshot(snapshot_name.clone()), None)
1080 .await
1081 .unwrap();
1082
1083 let messages = subscription.pull(100, None).await.unwrap();
1085 assert_eq!(messages.len(), 1);
1086 ack_all(&messages).await;
1087
1088 subscription
1090 .delete_snapshot(snapshot_name.as_str(), None)
1091 .await
1092 .unwrap();
1093 subscription.delete(None).await.unwrap();
1094 }
1095
1096 #[tokio::test]
1097 #[serial]
1098 async fn test_seek_timestamp() {
1099 let subscription = create_subscription(false).await;
1100
1101 subscription
1103 .update(
1104 SubscriptionConfigToUpdate {
1105 retain_acked_messages: Some(true),
1106 message_retention_duration: Some(Duration::new(60 * 60 * 2, 0)),
1107 ..Default::default()
1108 },
1109 None,
1110 )
1111 .await
1112 .unwrap();
1113
1114 publish(None).await;
1116 let messages = subscription.pull(100, None).await.unwrap();
1117 ack_all(&messages).await;
1118 assert_eq!(messages.len(), 1);
1119
1120 let message_publish_time = messages.first().unwrap().message.publish_time.to_owned().unwrap();
1121
1122 subscription
1124 .seek(SeekTo::Timestamp(message_publish_time.to_owned().try_into().unwrap()), None)
1125 .await
1126 .unwrap();
1127
1128 let messages = subscription.pull(100, None).await.unwrap();
1130 ack_all(&messages).await;
1131 assert_eq!(messages.len(), 1);
1132 let seek_message_publish_time = messages.first().unwrap().message.publish_time.to_owned().unwrap();
1133 assert_eq!(seek_message_publish_time, message_publish_time);
1134
1135 subscription.delete(None).await.unwrap();
1137 }
1138
1139 #[tokio::test(flavor = "multi_thread")]
1140 #[serial]
1141 async fn test_subscribe_single_subscriber() {
1142 test_subscribe(None).await;
1143 }
1144
1145 #[tokio::test(flavor = "multi_thread")]
1146 #[serial]
1147 async fn test_subscribe_multiple_subscriber() {
1148 test_subscribe(Some(SubscribeConfig::default().with_enable_multiple_subscriber(true))).await;
1149 }
1150
1151 #[tokio::test(flavor = "multi_thread")]
1152 #[serial]
1153 async fn test_subscribe_multiple_subscriber_bound() {
1154 test_subscribe(Some(
1155 SubscribeConfig::default()
1156 .with_enable_multiple_subscriber(true)
1157 .with_channel_capacity(1),
1158 ))
1159 .await;
1160 }
1161
1162 async fn test_subscribe(opt: Option<SubscribeConfig>) {
1163 let msg = PubsubMessage {
1164 data: "test".into(),
1165 ..Default::default()
1166 };
1167 let msg_count = 10;
1168 let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1169 let subscription = create_subscription(false).await;
1170 let received = Arc::new(Mutex::new(0));
1171 let checking = received.clone();
1172 let mut iter = subscription.subscribe(opt).await.unwrap();
1173 let cancellable = iter.cancellable();
1174 let handler = tokio::spawn(async move {
1175 while let Some(message) = iter.next().await {
1176 tracing::info!("received {}", message.message.message_id);
1177 *received.lock().unwrap() += 1;
1178 tokio::time::sleep(Duration::from_millis(500)).await;
1179 let _ = message.ack().await;
1180 }
1181 });
1182 publish(Some(msg)).await;
1183 tokio::time::sleep(Duration::from_secs(8)).await;
1184 cancellable.cancel();
1185 let _ = handler.await;
1186 assert_eq!(*checking.lock().unwrap(), msg_count);
1187 }
1188
1189 #[tokio::test(flavor = "multi_thread")]
1190 #[serial]
1191 async fn test_subscribe_nack_on_cancel_read() {
1192 subscribe_nack_on_cancel_read(10, true).await;
1193 subscribe_nack_on_cancel_read(0, true).await;
1194 subscribe_nack_on_cancel_read(10, false).await;
1195 subscribe_nack_on_cancel_read(0, false).await;
1196 }
1197
1198 #[tokio::test(flavor = "multi_thread")]
1199 #[serial]
1200 async fn test_subscribe_nack_on_cancel_next() {
1201 subscribe_nack_on_cancel_next(10, Duration::from_secs(3)).await;
1203 subscribe_nack_on_cancel_next(10, Duration::from_millis(0)).await;
1205 subscribe_nack_on_cancel_next(0, Duration::from_secs(3)).await;
1207 }
1208
1209 async fn subscribe_nack_on_cancel_read(msg_count: usize, should_cancel: bool) {
1210 let opt = Some(SubscribeConfig::default().with_enable_multiple_subscriber(true));
1211
1212 let msg = PubsubMessage {
1213 data: "test".into(),
1214 ..Default::default()
1215 };
1216 let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1217 let subscription = create_subscription(false).await;
1218 let received = Arc::new(Mutex::new(0));
1219 let checking = received.clone();
1220
1221 let mut iter = subscription.subscribe(opt).await.unwrap();
1222 let ctx = iter.cancellable();
1223 let handler = tokio::spawn(async move {
1224 while let Some(message) = iter.read().await {
1225 tracing::info!("received {}", message.message.message_id);
1226 *received.lock().unwrap() += 1;
1227 if should_cancel {
1228 tokio::time::sleep(Duration::from_secs(10)).await;
1230 } else {
1231 tokio::time::sleep(Duration::from_millis(1)).await;
1232 }
1233 let _ = message.ack().await;
1234 }
1235 });
1236 publish(Some(msg)).await;
1237 tokio::time::sleep(Duration::from_secs(10)).await;
1238 ctx.cancel();
1239 handler.await.unwrap();
1240 if should_cancel && msg_count > 0 {
1241 assert!(*checking.lock().unwrap() < msg_count);
1243 } else {
1244 assert_eq!(*checking.lock().unwrap(), msg_count);
1246 }
1247 }
1248
1249 async fn subscribe_nack_on_cancel_next(msg_count: usize, recv_time: Duration) {
1250 let opt = Some(SubscribeConfig::default().with_enable_multiple_subscriber(true));
1251
1252 let msg = PubsubMessage {
1253 data: "test".into(),
1254 ..Default::default()
1255 };
1256 let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1257 let subscription = create_subscription(false).await;
1258 let received = Arc::new(Mutex::new(0));
1259 let checking = received.clone();
1260
1261 let mut iter = subscription.subscribe(opt).await.unwrap();
1262 let ctx = iter.cancellable();
1263 let handler = tokio::spawn(async move {
1264 while let Some(message) = iter.next().await {
1265 tracing::info!("received {}", message.message.message_id);
1266 *received.lock().unwrap() += 1;
1267 tokio::time::sleep(recv_time).await;
1268 let _ = message.ack().await;
1269 }
1270 });
1271 publish(Some(msg)).await;
1272 tokio::time::sleep(Duration::from_secs(10)).await;
1273 ctx.cancel();
1274 handler.await.unwrap();
1275 assert_eq!(*checking.lock().unwrap(), msg_count);
1276 }
1277
1278 #[tokio::test(flavor = "multi_thread")]
1279 #[serial]
1280 async fn test_message_stream_dispose() {
1281 let subscription = create_subscription(false).await;
1282 let mut iter = subscription.subscribe(None).await.unwrap();
1283 iter.dispose().await;
1284 iter.dispose().await;
1286 assert!(iter.next().await.is_none());
1287 }
1288}