gcloud_pubsub/
subscription.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::{Duration, SystemTime};
6
7use prost_types::{DurationError, FieldMask};
8use tokio_util::sync::CancellationToken;
9
10use google_cloud_gax::grpc::codegen::tokio_stream::Stream;
11use google_cloud_gax::grpc::{Code, Status};
12use google_cloud_gax::retry::RetrySetting;
13use google_cloud_googleapis::pubsub::v1::seek_request::Target;
14use google_cloud_googleapis::pubsub::v1::subscription::AnalyticsHubSubscriptionInfo;
15use google_cloud_googleapis::pubsub::v1::{
16    BigQueryConfig, CloudStorageConfig, CreateSnapshotRequest, DeadLetterPolicy, DeleteSnapshotRequest,
17    DeleteSubscriptionRequest, ExpirationPolicy, GetSnapshotRequest, GetSubscriptionRequest, MessageTransform,
18    PullRequest, PushConfig, RetryPolicy, SeekRequest, Snapshot, Subscription as InternalSubscription,
19    UpdateSubscriptionRequest,
20};
21
22use crate::apiv1::subscriber_client::SubscriberClient;
23
24use crate::subscriber::{ack, ReceivedMessage, Subscriber, SubscriberConfig};
25
26#[derive(Debug, Clone, Default)]
27pub struct SubscriptionConfig {
28    pub push_config: Option<PushConfig>,
29    pub ack_deadline_seconds: i32,
30    pub retain_acked_messages: bool,
31    pub message_retention_duration: Option<Duration>,
32    pub labels: HashMap<String, String>,
33    pub enable_message_ordering: bool,
34    pub expiration_policy: Option<ExpirationPolicy>,
35    pub filter: String,
36    pub dead_letter_policy: Option<DeadLetterPolicy>,
37    pub retry_policy: Option<RetryPolicy>,
38    pub detached: bool,
39    pub topic_message_retention_duration: Option<Duration>,
40    pub enable_exactly_once_delivery: bool,
41    pub bigquery_config: Option<BigQueryConfig>,
42    pub state: i32,
43    pub cloud_storage_config: Option<CloudStorageConfig>,
44    pub analytics_hub_subscription_info: Option<AnalyticsHubSubscriptionInfo>,
45    pub message_transforms: Vec<MessageTransform>,
46}
47impl From<InternalSubscription> for SubscriptionConfig {
48    fn from(f: InternalSubscription) -> Self {
49        Self {
50            push_config: f.push_config,
51            bigquery_config: f.bigquery_config,
52            ack_deadline_seconds: f.ack_deadline_seconds,
53            retain_acked_messages: f.retain_acked_messages,
54            message_retention_duration: f
55                .message_retention_duration
56                .map(|v| std::time::Duration::new(v.seconds as u64, v.nanos as u32)),
57            labels: f.labels,
58            enable_message_ordering: f.enable_message_ordering,
59            expiration_policy: f.expiration_policy,
60            filter: f.filter,
61            dead_letter_policy: f.dead_letter_policy,
62            retry_policy: f.retry_policy,
63            detached: f.detached,
64            topic_message_retention_duration: f
65                .topic_message_retention_duration
66                .map(|v| std::time::Duration::new(v.seconds as u64, v.nanos as u32)),
67            enable_exactly_once_delivery: f.enable_exactly_once_delivery,
68            state: f.state,
69            cloud_storage_config: f.cloud_storage_config,
70            analytics_hub_subscription_info: f.analytics_hub_subscription_info,
71            message_transforms: f.message_transforms,
72        }
73    }
74}
75
76#[derive(Debug, Clone, Default)]
77pub struct SubscriptionConfigToUpdate {
78    pub push_config: Option<PushConfig>,
79    pub bigquery_config: Option<BigQueryConfig>,
80    pub ack_deadline_seconds: Option<i32>,
81    pub retain_acked_messages: Option<bool>,
82    pub message_retention_duration: Option<Duration>,
83    pub labels: Option<HashMap<String, String>>,
84    pub expiration_policy: Option<ExpirationPolicy>,
85    pub dead_letter_policy: Option<DeadLetterPolicy>,
86    pub retry_policy: Option<RetryPolicy>,
87}
88
89#[derive(Debug, Clone, Default)]
90pub struct SubscribeConfig {
91    enable_multiple_subscriber: bool,
92    channel_capacity: Option<usize>,
93    subscriber_config: Option<SubscriberConfig>,
94}
95
96impl SubscribeConfig {
97    pub fn with_enable_multiple_subscriber(mut self, v: bool) -> Self {
98        self.enable_multiple_subscriber = v;
99        self
100    }
101    pub fn with_subscriber_config(mut self, v: SubscriberConfig) -> Self {
102        self.subscriber_config = Some(v);
103        self
104    }
105    pub fn with_channel_capacity(mut self, v: usize) -> Self {
106        self.channel_capacity = Some(v);
107        self
108    }
109}
110
111#[derive(Debug, Clone)]
112pub struct ReceiveConfig {
113    pub worker_count: usize,
114    pub channel_capacity: Option<usize>,
115    pub subscriber_config: Option<SubscriberConfig>,
116}
117
118impl Default for ReceiveConfig {
119    fn default() -> Self {
120        Self {
121            worker_count: 10,
122            subscriber_config: None,
123            channel_capacity: None,
124        }
125    }
126}
127
128#[derive(Debug, Clone)]
129pub enum SeekTo {
130    Timestamp(SystemTime),
131    Snapshot(String),
132}
133
134impl From<SeekTo> for Target {
135    fn from(to: SeekTo) -> Target {
136        use SeekTo::*;
137        match to {
138            Timestamp(t) => Target::Time(prost_types::Timestamp::from(t)),
139            Snapshot(s) => Target::Snapshot(s),
140        }
141    }
142}
143
144pub struct MessageStream {
145    queue: async_channel::Receiver<ReceivedMessage>,
146    cancel: CancellationToken,
147    tasks: Vec<Subscriber>,
148}
149
150impl MessageStream {
151    pub fn cancellable(&self) -> CancellationToken {
152        self.cancel.clone()
153    }
154
155    pub async fn dispose(&mut self) {
156        // Close streaming pull task
157        if !self.cancel.is_cancelled() {
158            self.cancel.cancel();
159        }
160
161        // Wait for all the streaming pull close.
162        for task in &mut self.tasks {
163            task.done().await;
164        }
165
166        // Nack for remaining messages.
167        while let Ok(message) = self.queue.recv().await {
168            if let Err(err) = message.nack().await {
169                tracing::warn!("failed to nack message messageId={} {:?}", message.message.message_id, err);
170            }
171        }
172    }
173
174    /// Immediately Nack on cancel
175    pub async fn read(&mut self) -> Option<ReceivedMessage> {
176        let message = tokio::select! {
177            msg = self.queue.recv() => msg.ok(),
178            _ = self.cancel.cancelled() => None
179        };
180        if message.is_none() {
181            self.dispose().await;
182        }
183        message
184    }
185}
186
187impl Drop for MessageStream {
188    fn drop(&mut self) {
189        if !self.queue.is_empty() {
190            tracing::warn!("Call 'dispose' before drop in order to call nack for remaining messages");
191        }
192        if !self.cancel.is_cancelled() {
193            self.cancel.cancel();
194        }
195    }
196}
197
198impl Stream for MessageStream {
199    type Item = ReceivedMessage;
200
201    /// Return None unless the queue is open.
202    /// Use CancellationToken for SubscribeConfig to get None
203    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        Pin::new(&mut self.get_mut().queue).poll_next(cx)
205    }
206}
207
208/// Subscription is a reference to a PubSub subscription.
209#[derive(Clone, Debug)]
210pub struct Subscription {
211    fqsn: String,
212    subc: SubscriberClient,
213}
214
215impl Subscription {
216    pub(crate) fn new(fqsn: String, subc: SubscriberClient) -> Self {
217        Self { fqsn, subc }
218    }
219
220    pub(crate) fn streaming_pool_size(&self) -> usize {
221        self.subc.streaming_pool_size()
222    }
223
224    /// id returns the unique identifier of the subscription within its project.
225    pub fn id(&self) -> String {
226        self.fqsn
227            .rfind('/')
228            .map_or("".to_string(), |i| self.fqsn[(i + 1)..].to_string())
229    }
230
231    /// fully_qualified_name returns the globally unique printable name of the subscription.
232    pub fn fully_qualified_name(&self) -> &str {
233        self.fqsn.as_str()
234    }
235
236    /// fully_qualified_snapshot_name returns the globally unique printable name of the snapshot.
237    pub fn fully_qualified_snapshot_name(&self, id: &str) -> String {
238        if id.contains('/') {
239            id.to_string()
240        } else {
241            format!("{}/snapshots/{}", self.fully_qualified_project_name(), id)
242        }
243    }
244
245    fn fully_qualified_project_name(&self) -> String {
246        let parts: Vec<_> = self
247            .fqsn
248            .split('/')
249            .enumerate()
250            .filter(|&(i, _)| i < 2)
251            .map(|e| e.1)
252            .collect();
253        parts.join("/")
254    }
255
256    pub fn get_client(&self) -> SubscriberClient {
257        self.subc.clone()
258    }
259
260    /// create creates the subscription.
261    pub async fn create(&self, fqtn: &str, cfg: SubscriptionConfig, retry: Option<RetrySetting>) -> Result<(), Status> {
262        self.subc
263            .create_subscription(
264                InternalSubscription {
265                    name: self.fully_qualified_name().to_string(),
266                    topic: fqtn.to_string(),
267                    push_config: cfg.push_config,
268                    bigquery_config: cfg.bigquery_config,
269                    cloud_storage_config: cfg.cloud_storage_config,
270                    ack_deadline_seconds: cfg.ack_deadline_seconds,
271                    labels: cfg.labels,
272                    enable_message_ordering: cfg.enable_message_ordering,
273                    expiration_policy: cfg.expiration_policy,
274                    filter: cfg.filter,
275                    dead_letter_policy: cfg.dead_letter_policy,
276                    retry_policy: cfg.retry_policy,
277                    detached: cfg.detached,
278                    message_retention_duration: cfg
279                        .message_retention_duration
280                        .map(Duration::try_into)
281                        .transpose()
282                        .map_err(|err: DurationError| Status::internal(err.to_string()))?,
283                    retain_acked_messages: cfg.retain_acked_messages,
284                    topic_message_retention_duration: cfg
285                        .topic_message_retention_duration
286                        .map(Duration::try_into)
287                        .transpose()
288                        .map_err(|err: DurationError| Status::internal(err.to_string()))?,
289                    enable_exactly_once_delivery: cfg.enable_exactly_once_delivery,
290                    state: cfg.state,
291                    analytics_hub_subscription_info: cfg.analytics_hub_subscription_info,
292                    message_transforms: cfg.message_transforms,
293                },
294                retry,
295            )
296            .await
297            .map(|_v| ())
298    }
299
300    /// delete deletes the subscription.
301    pub async fn delete(&self, retry: Option<RetrySetting>) -> Result<(), Status> {
302        let req = DeleteSubscriptionRequest {
303            subscription: self.fqsn.to_string(),
304        };
305        self.subc.delete_subscription(req, retry).await.map(|v| v.into_inner())
306    }
307
308    /// exists reports whether the subscription exists on the server.
309    pub async fn exists(&self, retry: Option<RetrySetting>) -> Result<bool, Status> {
310        let req = GetSubscriptionRequest {
311            subscription: self.fqsn.to_string(),
312        };
313        match self.subc.get_subscription(req, retry).await {
314            Ok(_) => Ok(true),
315            Err(e) => {
316                if e.code() == Code::NotFound {
317                    Ok(false)
318                } else {
319                    Err(e)
320                }
321            }
322        }
323    }
324
325    /// config fetches the current configuration for the subscription.
326    pub async fn config(&self, retry: Option<RetrySetting>) -> Result<(String, SubscriptionConfig), Status> {
327        let req = GetSubscriptionRequest {
328            subscription: self.fqsn.to_string(),
329        };
330        self.subc.get_subscription(req, retry).await.map(|v| {
331            let inner = v.into_inner();
332            (inner.topic.to_string(), inner.into())
333        })
334    }
335
336    /// update changes an existing subscription according to the fields set in updating.
337    /// It returns the new SubscriptionConfig.
338    pub async fn update(
339        &self,
340        updating: SubscriptionConfigToUpdate,
341        retry: Option<RetrySetting>,
342    ) -> Result<(String, SubscriptionConfig), Status> {
343        let req = GetSubscriptionRequest {
344            subscription: self.fqsn.to_string(),
345        };
346        let mut config = self.subc.get_subscription(req, retry.clone()).await?.into_inner();
347
348        let mut paths = vec![];
349        if updating.push_config.is_some() {
350            config.push_config = updating.push_config;
351            paths.push("push_config".to_string());
352        }
353        if updating.bigquery_config.is_some() {
354            config.bigquery_config = updating.bigquery_config;
355            paths.push("bigquery_config".to_string());
356        }
357        if let Some(v) = updating.ack_deadline_seconds {
358            config.ack_deadline_seconds = v;
359            paths.push("ack_deadline_seconds".to_string());
360        }
361        if let Some(v) = updating.retain_acked_messages {
362            config.retain_acked_messages = v;
363            paths.push("retain_acked_messages".to_string());
364        }
365        if updating.message_retention_duration.is_some() {
366            config.message_retention_duration = updating
367                .message_retention_duration
368                .map(prost_types::Duration::try_from)
369                .transpose()
370                .map_err(|err| Status::internal(err.to_string()))?;
371            paths.push("message_retention_duration".to_string());
372        }
373        if updating.expiration_policy.is_some() {
374            config.expiration_policy = updating.expiration_policy;
375            paths.push("expiration_policy".to_string());
376        }
377        if let Some(v) = updating.labels {
378            config.labels = v;
379            paths.push("labels".to_string());
380        }
381        if updating.retry_policy.is_some() {
382            config.retry_policy = updating.retry_policy;
383            paths.push("retry_policy".to_string());
384        }
385
386        let update_req = UpdateSubscriptionRequest {
387            subscription: Some(config),
388            update_mask: Some(FieldMask { paths }),
389        };
390        self.subc.update_subscription(update_req, retry).await.map(|v| {
391            let inner = v.into_inner();
392            (inner.topic.to_string(), inner.into())
393        })
394    }
395
396    /// pull pulls messages from the server.
397    pub async fn pull(&self, max_messages: i32, retry: Option<RetrySetting>) -> Result<Vec<ReceivedMessage>, Status> {
398        #[allow(deprecated)]
399        let req = PullRequest {
400            subscription: self.fqsn.clone(),
401            return_immediately: false,
402            max_messages,
403        };
404        let messages = self.subc.pull(req, retry).await?.into_inner().received_messages;
405        Ok(messages
406            .into_iter()
407            .filter(|m| m.message.is_some())
408            .map(|m| {
409                ReceivedMessage::new(
410                    self.fqsn.clone(),
411                    self.subc.clone(),
412                    m.message.unwrap(),
413                    m.ack_id,
414                    (m.delivery_attempt > 0).then_some(m.delivery_attempt as usize),
415                )
416            })
417            .collect())
418    }
419
420    /// subscribe creates a `Stream` of `ReceivedMessage`
421    /// ```
422    /// use google_cloud_pubsub::subscription::{SubscribeConfig, Subscription};
423    /// use tokio::select;
424    /// use google_cloud_gax::grpc::Status;
425    ///
426    /// async fn run(subscription: Subscription) -> Result<(), Status> {
427    ///     let mut iter = subscription.subscribe(None).await?;
428    ///     let ctx = iter.cancellable();
429    ///     let handler = tokio::spawn(async move {
430    ///         while let Some(message) = iter.read().await {
431    ///             let _ = message.ack().await;
432    ///         }
433    ///     });
434    ///     // Cancel and wait for nack all the pulled messages.
435    ///     ctx.cancel();
436    ///     let _ = handler.await;
437    ///     Ok(())
438    ///  }
439    /// ```
440    ///
441    /// ```
442    /// use google_cloud_pubsub::subscription::{SubscribeConfig, Subscription};
443    /// use futures_util::StreamExt;
444    /// use tokio::select;
445    /// use google_cloud_gax::grpc::Status;
446    ///
447    /// async fn run(subscription: Subscription) -> Result<(), Status> {
448    ///     let mut iter = subscription.subscribe(None).await?;
449    ///     let ctx = iter.cancellable();
450    ///     let handler = tokio::spawn(async move {
451    ///         while let Some(message) = iter.next().await {
452    ///             let _ = message.ack().await;
453    ///         }
454    ///     });
455    ///     // Cancel and wait for receive all the pulled messages.
456    ///     ctx.cancel();
457    ///     let _ = handler.await;
458    ///     Ok(())
459    ///  }
460    /// ```
461    pub async fn subscribe(&self, opt: Option<SubscribeConfig>) -> Result<MessageStream, Status> {
462        let opt = opt.unwrap_or_default();
463        let (tx, rx) = create_channel(opt.channel_capacity);
464        let cancel = CancellationToken::new();
465        let sub_opt = self.unwrap_subscribe_config(opt.subscriber_config).await?;
466
467        // spawn a separate subscriber task for each connection in the pool
468        let subscribers = if opt.enable_multiple_subscriber {
469            self.streaming_pool_size()
470        } else {
471            1
472        };
473        let mut tasks = Vec::with_capacity(subscribers);
474        for _ in 0..subscribers {
475            tasks.push(Subscriber::start(
476                cancel.clone(),
477                self.fqsn.clone(),
478                self.subc.clone(),
479                tx.clone(),
480                sub_opt.clone(),
481            ));
482        }
483
484        Ok(MessageStream {
485            queue: rx,
486            cancel,
487            tasks,
488        })
489    }
490
491    /// receive calls f with the outstanding messages from the subscription.
492    /// It blocks until cancellation token is cancelled, or the service returns a non-retryable error.
493    /// The standard way to terminate a receive is to use CancellationToken.
494    pub async fn receive<F>(
495        &self,
496        f: impl Fn(ReceivedMessage, CancellationToken) -> F + Send + 'static + Sync + Clone,
497        cancel: CancellationToken,
498        config: Option<ReceiveConfig>,
499    ) -> Result<(), Status>
500    where
501        F: Future<Output = ()> + Send + 'static,
502    {
503        let op = config.unwrap_or_default();
504        let mut receivers = Vec::with_capacity(op.worker_count);
505        let mut senders = Vec::with_capacity(receivers.len());
506        let sub_opt = self.unwrap_subscribe_config(op.subscriber_config).await?;
507
508        if self
509            .config(sub_opt.retry_setting.clone())
510            .await?
511            .1
512            .enable_message_ordering
513        {
514            (0..op.worker_count).for_each(|_v| {
515                let (sender, receiver) = create_channel(op.channel_capacity);
516                receivers.push(receiver);
517                senders.push(sender);
518            });
519        } else {
520            let (sender, receiver) = create_channel(op.channel_capacity);
521            (0..op.worker_count).for_each(|_v| {
522                receivers.push(receiver.clone());
523                senders.push(sender.clone());
524            });
525        }
526
527        //same ordering key is in same stream.
528        let subscribers: Vec<Subscriber> = senders
529            .into_iter()
530            .map(|queue| {
531                Subscriber::start(cancel.clone(), self.fqsn.clone(), self.subc.clone(), queue, sub_opt.clone())
532            })
533            .collect();
534
535        let mut message_receivers = Vec::with_capacity(receivers.len());
536        for receiver in receivers {
537            let f_clone = f.clone();
538            let cancel_clone = cancel.clone();
539            let name = self.fqsn.clone();
540            message_receivers.push(tokio::spawn(async move {
541                while let Ok(message) = receiver.recv().await {
542                    f_clone(message, cancel_clone.clone()).await;
543                }
544                // queue is closed by subscriber when the cancellation token is cancelled
545                tracing::trace!("stop message receiver : {}", name);
546            }));
547        }
548        cancel.cancelled().await;
549
550        // wait for all the threads finish.
551        for mut subscriber in subscribers {
552            subscriber.done().await;
553        }
554
555        // wait for all the receivers process received messages
556        for mr in message_receivers {
557            let _ = mr.await;
558        }
559        Ok(())
560    }
561
562    /// Ack acknowledges the messages associated with the ack_ids in the AcknowledgeRequest.
563    /// The Pub/Sub system can remove the relevant messages from the subscription.
564    /// This method is for batch acking.
565    ///
566    /// ```
567    /// use google_cloud_pubsub::client::Client;
568    /// use google_cloud_pubsub::subscription::Subscription;
569    /// use google_cloud_gax::grpc::Status;
570    /// use std::time::Duration;
571    /// use tokio_util::sync::CancellationToken;;
572    ///
573    /// #[tokio::main]
574    /// async fn run(client: Client) -> Result<(), Status> {
575    ///     let subscription = client.subscription("test-subscription");
576    ///     let ctx = CancellationToken::new();
577    ///     let (sender, mut receiver)  = tokio::sync::mpsc::unbounded_channel();
578    ///     let subscription_for_receive = subscription.clone();
579    ///     let ctx_for_receive = ctx.clone();
580    ///     let ctx_for_ack_manager = ctx.clone();
581    ///
582    ///     // receive
583    ///     let handle = tokio::spawn(async move {
584    ///         let _ = subscription_for_receive.receive(move |message, _ctx| {
585    ///             let sender = sender.clone();
586    ///             async move {
587    ///                 let _ = sender.send(message.ack_id().to_string());
588    ///             }
589    ///         }, ctx_for_receive.clone(), None).await;
590    ///     });
591    ///
592    ///     // batch ack manager
593    ///     let ack_manager = tokio::spawn( async move {
594    ///         let mut ack_ids = Vec::new();
595    ///         loop {
596    ///             tokio::select! {
597    ///                 _ = ctx_for_ack_manager.cancelled() => {
598    ///                     return subscription.ack(ack_ids).await;
599    ///                 },
600    ///                 r = tokio::time::timeout(Duration::from_secs(10), receiver.recv()) => match r {
601    ///                     Ok(ack_id) => {
602    ///                         if let Some(ack_id) = ack_id {
603    ///                             ack_ids.push(ack_id);
604    ///                             if ack_ids.len() > 10 {
605    ///                                 let _ = subscription.ack(ack_ids).await;
606    ///                                 ack_ids = Vec::new();
607    ///                             }
608    ///                         }
609    ///                     },
610    ///                     Err(_e) => {
611    ///                         // timeout
612    ///                         let _ = subscription.ack(ack_ids).await;
613    ///                         ack_ids = Vec::new();
614    ///                     }
615    ///                 }
616    ///             }
617    ///         }
618    ///     });
619    ///
620    ///     ctx.cancel();
621    ///     Ok(())
622    ///  }
623    /// ```
624    pub async fn ack(&self, ack_ids: Vec<String>) -> Result<(), Status> {
625        ack(&self.subc, self.fqsn.to_string(), ack_ids).await
626    }
627
628    /// seek seeks the subscription a past timestamp or a saved snapshot.
629    pub async fn seek(&self, to: SeekTo, retry: Option<RetrySetting>) -> Result<(), Status> {
630        let to = match to {
631            SeekTo::Timestamp(t) => SeekTo::Timestamp(t),
632            SeekTo::Snapshot(name) => SeekTo::Snapshot(self.fully_qualified_snapshot_name(name.as_str())),
633        };
634
635        let req = SeekRequest {
636            subscription: self.fqsn.to_owned(),
637            target: Some(to.into()),
638        };
639
640        let _ = self.subc.seek(req, retry).await?;
641        Ok(())
642    }
643
644    /// get_snapshot fetches an existing pubsub snapshot.
645    pub async fn get_snapshot(&self, name: &str, retry: Option<RetrySetting>) -> Result<Snapshot, Status> {
646        let req = GetSnapshotRequest {
647            snapshot: self.fully_qualified_snapshot_name(name),
648        };
649        Ok(self.subc.get_snapshot(req, retry).await?.into_inner())
650    }
651
652    /// create_snapshot creates a new pubsub snapshot from the subscription's state at the time of calling.
653    /// The snapshot retains the messages for the topic the subscription is subscribed to, with the acknowledgment
654    /// states consistent with the subscriptions.
655    /// The created snapshot is guaranteed to retain:
656    /// - The message backlog on the subscription -- or to be specific, messages that are unacknowledged
657    ///   at the time of the subscription's creation.
658    /// - All messages published to the subscription's topic after the snapshot's creation.
659    ///   Snapshots have a finite lifetime -- a maximum of 7 days from the time of creation, beyond which
660    ///   they are discarded and any messages being retained solely due to the snapshot dropped.
661    pub async fn create_snapshot(
662        &self,
663        name: &str,
664        labels: HashMap<String, String>,
665        retry: Option<RetrySetting>,
666    ) -> Result<Snapshot, Status> {
667        let req = CreateSnapshotRequest {
668            name: self.fully_qualified_snapshot_name(name),
669            labels,
670            subscription: self.fqsn.to_owned(),
671        };
672        Ok(self.subc.create_snapshot(req, retry).await?.into_inner())
673    }
674
675    /// delete_snapshot deletes an existing pubsub snapshot.
676    pub async fn delete_snapshot(&self, name: &str, retry: Option<RetrySetting>) -> Result<(), Status> {
677        let req = DeleteSnapshotRequest {
678            snapshot: self.fully_qualified_snapshot_name(name),
679        };
680        let _ = self.subc.delete_snapshot(req, retry).await?;
681        Ok(())
682    }
683
684    async fn unwrap_subscribe_config(&self, cfg: Option<SubscriberConfig>) -> Result<SubscriberConfig, Status> {
685        if let Some(cfg) = cfg {
686            return Ok(cfg);
687        }
688        let cfg = self.config(None).await?;
689        let mut default_cfg = SubscriberConfig {
690            stream_ack_deadline_seconds: cfg.1.ack_deadline_seconds.clamp(10, 600),
691            ..Default::default()
692        };
693        if cfg.1.enable_exactly_once_delivery {
694            default_cfg.max_outstanding_messages = 5;
695        }
696        Ok(default_cfg)
697    }
698}
699
700fn create_channel(
701    channel_capacity: Option<usize>,
702) -> (async_channel::Sender<ReceivedMessage>, async_channel::Receiver<ReceivedMessage>) {
703    match channel_capacity {
704        None => async_channel::unbounded(),
705        Some(cap) => async_channel::bounded(cap),
706    }
707}
708
709#[cfg(test)]
710mod tests {
711    use std::collections::HashMap;
712    use std::sync::atomic::AtomicU32;
713    use std::sync::atomic::Ordering::SeqCst;
714    use std::sync::{Arc, Mutex};
715    use std::time::Duration;
716
717    use futures_util::StreamExt;
718    use serial_test::serial;
719    use tokio_util::sync::CancellationToken;
720    use uuid::Uuid;
721
722    use google_cloud_gax::conn::{ConnectionOptions, Environment};
723    use google_cloud_googleapis::pubsub::v1::{PublishRequest, PubsubMessage};
724
725    use crate::apiv1::conn_pool::ConnectionManager;
726    use crate::apiv1::publisher_client::PublisherClient;
727    use crate::apiv1::subscriber_client::SubscriberClient;
728    use crate::subscriber::ReceivedMessage;
729    use crate::subscription::{
730        ReceiveConfig, SeekTo, SubscribeConfig, Subscription, SubscriptionConfig, SubscriptionConfigToUpdate,
731    };
732
733    const PROJECT_NAME: &str = "local-project";
734    const EMULATOR: &str = "localhost:8681";
735
736    #[ctor::ctor]
737    fn init() {
738        let _ = tracing_subscriber::fmt().try_init();
739    }
740
741    async fn create_subscription(enable_exactly_once_delivery: bool) -> Subscription {
742        let cm = ConnectionManager::new(
743            4,
744            "",
745            &Environment::Emulator(EMULATOR.to_string()),
746            &ConnectionOptions::default(),
747        )
748        .await
749        .unwrap();
750        let cm2 = ConnectionManager::new(
751            4,
752            "",
753            &Environment::Emulator(EMULATOR.to_string()),
754            &ConnectionOptions::default(),
755        )
756        .await
757        .unwrap();
758        let client = SubscriberClient::new(cm, cm2);
759
760        let uuid = Uuid::new_v4().hyphenated().to_string();
761        let subscription_name = format!("projects/{}/subscriptions/s{}", PROJECT_NAME, &uuid);
762        let topic_name = format!("projects/{PROJECT_NAME}/topics/test-topic1");
763        let subscription = Subscription::new(subscription_name, client);
764        let config = SubscriptionConfig {
765            enable_exactly_once_delivery,
766            ..Default::default()
767        };
768        if !subscription.exists(None).await.unwrap() {
769            subscription.create(topic_name.as_str(), config, None).await.unwrap();
770        }
771        subscription
772    }
773
774    async fn publish(messages: Option<Vec<PubsubMessage>>) {
775        let pubc = PublisherClient::new(
776            ConnectionManager::new(
777                4,
778                "",
779                &Environment::Emulator(EMULATOR.to_string()),
780                &ConnectionOptions::default(),
781            )
782            .await
783            .unwrap(),
784        );
785        let messages = messages.unwrap_or(vec![PubsubMessage {
786            data: "test_message".into(),
787            ..Default::default()
788        }]);
789        let req = PublishRequest {
790            topic: format!("projects/{PROJECT_NAME}/topics/test-topic1"),
791            messages,
792        };
793        let _ = pubc.publish(req, None).await;
794    }
795
796    async fn test_subscription(enable_exactly_once_delivery: bool) {
797        let subscription = create_subscription(enable_exactly_once_delivery).await;
798
799        let topic_name = format!("projects/{PROJECT_NAME}/topics/test-topic1");
800        let config = subscription.config(None).await.unwrap();
801        assert_eq!(config.0, topic_name);
802
803        let updating = SubscriptionConfigToUpdate {
804            ack_deadline_seconds: Some(100),
805            ..Default::default()
806        };
807        let new_config = subscription.update(updating, None).await.unwrap();
808        assert_eq!(new_config.0, topic_name);
809        assert_eq!(new_config.1.ack_deadline_seconds, 100);
810
811        let receiver_ctx = CancellationToken::new();
812        let cancel_receiver = receiver_ctx.clone();
813        let handle = tokio::spawn(async move {
814            let _ = subscription
815                .receive(
816                    |message, _ctx| async move {
817                        println!("{}", message.message.message_id);
818                        let _ = message.ack().await;
819                    },
820                    cancel_receiver,
821                    None,
822                )
823                .await;
824            subscription.delete(None).await.unwrap();
825            assert!(!subscription.exists(None).await.unwrap())
826        });
827        tokio::time::sleep(Duration::from_secs(3)).await;
828        receiver_ctx.cancel();
829        let _ = handle.await;
830    }
831
832    #[tokio::test(flavor = "multi_thread")]
833    #[serial]
834    async fn test_pull() {
835        let subscription = create_subscription(false).await;
836        let base = PubsubMessage {
837            data: "test_message".into(),
838            ..Default::default()
839        };
840        publish(Some(vec![base.clone(), base.clone(), base])).await;
841        let messages = subscription.pull(2, None).await.unwrap();
842        assert_eq!(messages.len(), 2);
843        for m in messages {
844            m.ack().await.unwrap();
845        }
846        subscription.delete(None).await.unwrap();
847    }
848
849    #[tokio::test]
850    #[serial]
851    async fn test_subscription_exactly_once() {
852        test_subscription(true).await;
853    }
854
855    #[tokio::test]
856    #[serial]
857    async fn test_subscription_at_least_once() {
858        test_subscription(false).await;
859    }
860
861    #[tokio::test(flavor = "multi_thread")]
862    #[serial]
863    async fn test_multi_subscriber_single_subscription_unbound() {
864        test_multi_subscriber_single_subscription(None).await;
865    }
866
867    #[tokio::test(flavor = "multi_thread")]
868    #[serial]
869    async fn test_multi_subscriber_single_subscription_bound() {
870        let opt = Some(ReceiveConfig {
871            channel_capacity: Some(1),
872            ..Default::default()
873        });
874        test_multi_subscriber_single_subscription(opt).await;
875    }
876
877    async fn test_multi_subscriber_single_subscription(opt: Option<ReceiveConfig>) {
878        let msg = PubsubMessage {
879            data: "test".into(),
880            ..Default::default()
881        };
882        let msg_size = 10;
883        let msgs: Vec<PubsubMessage> = (0..msg_size).map(|_v| msg.clone()).collect();
884        let subscription = create_subscription(false).await;
885        let cancellation_token = CancellationToken::new();
886        let cancel_receiver = cancellation_token.clone();
887        let v = Arc::new(AtomicU32::new(0));
888        let v2 = v.clone();
889        let handle = tokio::spawn(async move {
890            let _ = subscription
891                .receive(
892                    move |message, _ctx| {
893                        let v2 = v2.clone();
894                        async move {
895                            tracing::info!("received {}", message.message.message_id);
896                            v2.fetch_add(1, SeqCst);
897                            let _ = message.ack().await;
898                        }
899                    },
900                    cancel_receiver,
901                    opt,
902                )
903                .await;
904        });
905        publish(Some(msgs)).await;
906        tokio::time::sleep(Duration::from_secs(5)).await;
907        cancellation_token.cancel();
908        let _ = handle.await;
909        assert_eq!(v.load(SeqCst), msg_size);
910    }
911
912    #[tokio::test(flavor = "multi_thread")]
913    #[serial]
914    async fn test_multi_subscriber_multi_subscription() {
915        let mut subscriptions = vec![];
916
917        let ctx = CancellationToken::new();
918        for _ in 0..3 {
919            let subscription = create_subscription(false).await;
920            let v = Arc::new(AtomicU32::new(0));
921            let ctx = ctx.clone();
922            let v2 = v.clone();
923            let handle = tokio::spawn(async move {
924                let _ = subscription
925                    .receive(
926                        move |message, _ctx| {
927                            let v2 = v2.clone();
928                            async move {
929                                v2.fetch_add(1, SeqCst);
930                                let _ = message.ack().await;
931                            }
932                        },
933                        ctx,
934                        None,
935                    )
936                    .await;
937            });
938            subscriptions.push((handle, v))
939        }
940
941        publish(None).await;
942        tokio::time::sleep(Duration::from_secs(5)).await;
943
944        ctx.cancel();
945        for (task, v) in subscriptions {
946            let _ = task.await;
947            assert_eq!(v.load(SeqCst), 1);
948        }
949    }
950
951    #[tokio::test(flavor = "multi_thread")]
952    #[serial]
953    async fn test_batch_acking() {
954        let ctx = CancellationToken::new();
955        let subscription = create_subscription(false).await;
956        let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
957        let subscription_for_receive = subscription.clone();
958        let ctx_for_receive = ctx.clone();
959        let handle = tokio::spawn(async move {
960            let _ = subscription_for_receive
961                .receive(
962                    move |message, _ctx| {
963                        let sender = sender.clone();
964                        async move {
965                            let _ = sender.send(message.ack_id().to_string());
966                        }
967                    },
968                    ctx_for_receive.clone(),
969                    None,
970                )
971                .await;
972        });
973
974        let ctx_for_ack_manager = ctx.clone();
975        let ack_manager = tokio::spawn(async move {
976            let mut ack_ids = Vec::new();
977            while !ctx_for_ack_manager.is_cancelled() {
978                match tokio::time::timeout(Duration::from_secs(10), receiver.recv()).await {
979                    Ok(ack_id) => {
980                        if let Some(ack_id) = ack_id {
981                            ack_ids.push(ack_id);
982                            if ack_ids.len() > 10 {
983                                subscription.ack(ack_ids).await.unwrap();
984                                ack_ids = Vec::new();
985                            }
986                        }
987                    }
988                    Err(_e) => {
989                        // timeout
990                        subscription.ack(ack_ids).await.unwrap();
991                        ack_ids = Vec::new();
992                    }
993                }
994            }
995            // flush
996            subscription.ack(ack_ids).await
997        });
998
999        publish(None).await;
1000        tokio::time::sleep(Duration::from_secs(5)).await;
1001
1002        ctx.cancel();
1003        let _ = handle.await;
1004        assert!(ack_manager.await.is_ok());
1005    }
1006
1007    #[tokio::test]
1008    #[serial]
1009    async fn test_snapshots() {
1010        let subscription = create_subscription(false).await;
1011
1012        let snapshot_name = format!("snapshot-{}", rand::random::<u64>());
1013        let labels: HashMap<String, String> =
1014            HashMap::from_iter([("label-1".into(), "v1".into()), ("label-2".into(), "v2".into())]);
1015        let expected_fq_snap_name = format!("projects/{PROJECT_NAME}/snapshots/{snapshot_name}");
1016
1017        // cleanup; TODO: remove?
1018        let _response = subscription.delete_snapshot(snapshot_name.as_str(), None).await;
1019
1020        // create
1021        let created_snapshot = subscription
1022            .create_snapshot(snapshot_name.as_str(), labels.clone(), None)
1023            .await
1024            .unwrap();
1025
1026        assert_eq!(created_snapshot.name, expected_fq_snap_name);
1027        // NOTE: we don't assert the labels due to lack of label support in the pubsub emulator.
1028
1029        // get
1030        let retrieved_snapshot = subscription.get_snapshot(snapshot_name.as_str(), None).await.unwrap();
1031        assert_eq!(created_snapshot, retrieved_snapshot);
1032
1033        // delete
1034        subscription
1035            .delete_snapshot(snapshot_name.as_str(), None)
1036            .await
1037            .unwrap();
1038
1039        let _deleted_snapshot_status = subscription
1040            .get_snapshot(snapshot_name.as_str(), None)
1041            .await
1042            .expect_err("snapshot should have been deleted");
1043
1044        let _delete_again = subscription
1045            .delete_snapshot(snapshot_name.as_str(), None)
1046            .await
1047            .expect_err("snapshot should already be deleted");
1048    }
1049
1050    async fn ack_all(messages: &[ReceivedMessage]) {
1051        for message in messages.iter() {
1052            message.ack().await.unwrap();
1053        }
1054    }
1055
1056    #[tokio::test]
1057    #[serial]
1058    async fn test_seek_snapshot() {
1059        let subscription = create_subscription(false).await;
1060        let snapshot_name = format!("snapshot-{}", rand::random::<u64>());
1061
1062        // publish and receive a message
1063        publish(None).await;
1064        let messages = subscription.pull(100, None).await.unwrap();
1065        ack_all(&messages).await;
1066        assert_eq!(messages.len(), 1);
1067
1068        // snapshot at received = 1
1069        let _snapshot = subscription
1070            .create_snapshot(snapshot_name.as_str(), HashMap::new(), None)
1071            .await
1072            .unwrap();
1073
1074        // publish and receive another message
1075        publish(None).await;
1076        let messages = subscription.pull(100, None).await.unwrap();
1077        assert_eq!(messages.len(), 1);
1078        ack_all(&messages).await;
1079
1080        // rewind to snapshot at received = 1
1081        subscription
1082            .seek(SeekTo::Snapshot(snapshot_name.clone()), None)
1083            .await
1084            .unwrap();
1085
1086        // assert we receive the 1 message we should receive again
1087        let messages = subscription.pull(100, None).await.unwrap();
1088        assert_eq!(messages.len(), 1);
1089        ack_all(&messages).await;
1090
1091        // cleanup
1092        subscription
1093            .delete_snapshot(snapshot_name.as_str(), None)
1094            .await
1095            .unwrap();
1096        subscription.delete(None).await.unwrap();
1097    }
1098
1099    #[tokio::test]
1100    #[serial]
1101    async fn test_seek_timestamp() {
1102        let subscription = create_subscription(false).await;
1103
1104        // enable acked message retention on subscription -- required for timestamp-based seeks
1105        subscription
1106            .update(
1107                SubscriptionConfigToUpdate {
1108                    retain_acked_messages: Some(true),
1109                    message_retention_duration: Some(Duration::new(60 * 60 * 2, 0)),
1110                    ..Default::default()
1111                },
1112                None,
1113            )
1114            .await
1115            .unwrap();
1116
1117        // publish and receive a message
1118        publish(None).await;
1119        let messages = subscription.pull(100, None).await.unwrap();
1120        ack_all(&messages).await;
1121        assert_eq!(messages.len(), 1);
1122
1123        let message_publish_time = messages.first().unwrap().message.publish_time.to_owned().unwrap();
1124
1125        // rewind to a timestamp where message was just published
1126        subscription
1127            .seek(SeekTo::Timestamp(message_publish_time.to_owned().try_into().unwrap()), None)
1128            .await
1129            .unwrap();
1130
1131        // consume -- should receive the first message again
1132        let messages = subscription.pull(100, None).await.unwrap();
1133        ack_all(&messages).await;
1134        assert_eq!(messages.len(), 1);
1135        let seek_message_publish_time = messages.first().unwrap().message.publish_time.to_owned().unwrap();
1136        assert_eq!(seek_message_publish_time, message_publish_time);
1137
1138        // cleanup
1139        subscription.delete(None).await.unwrap();
1140    }
1141
1142    #[tokio::test(flavor = "multi_thread")]
1143    #[serial]
1144    async fn test_subscribe_single_subscriber() {
1145        test_subscribe(None).await;
1146    }
1147
1148    #[tokio::test(flavor = "multi_thread")]
1149    #[serial]
1150    async fn test_subscribe_multiple_subscriber() {
1151        test_subscribe(Some(SubscribeConfig::default().with_enable_multiple_subscriber(true))).await;
1152    }
1153
1154    #[tokio::test(flavor = "multi_thread")]
1155    #[serial]
1156    async fn test_subscribe_multiple_subscriber_bound() {
1157        test_subscribe(Some(
1158            SubscribeConfig::default()
1159                .with_enable_multiple_subscriber(true)
1160                .with_channel_capacity(1),
1161        ))
1162        .await;
1163    }
1164
1165    async fn test_subscribe(opt: Option<SubscribeConfig>) {
1166        let msg = PubsubMessage {
1167            data: "test".into(),
1168            ..Default::default()
1169        };
1170        let msg_count = 10;
1171        let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1172        let subscription = create_subscription(false).await;
1173        let received = Arc::new(Mutex::new(0));
1174        let checking = received.clone();
1175        let mut iter = subscription.subscribe(opt).await.unwrap();
1176        let cancellable = iter.cancellable();
1177        let handler = tokio::spawn(async move {
1178            while let Some(message) = iter.next().await {
1179                tracing::info!("received {}", message.message.message_id);
1180                *received.lock().unwrap() += 1;
1181                tokio::time::sleep(Duration::from_millis(500)).await;
1182                let _ = message.ack().await;
1183            }
1184        });
1185        publish(Some(msg)).await;
1186        tokio::time::sleep(Duration::from_secs(8)).await;
1187        cancellable.cancel();
1188        let _ = handler.await;
1189        assert_eq!(*checking.lock().unwrap(), msg_count);
1190    }
1191
1192    #[tokio::test(flavor = "multi_thread")]
1193    #[serial]
1194    async fn test_subscribe_nack_on_cancel_read() {
1195        subscribe_nack_on_cancel_read(10, true).await;
1196        subscribe_nack_on_cancel_read(0, true).await;
1197        subscribe_nack_on_cancel_read(10, false).await;
1198        subscribe_nack_on_cancel_read(0, false).await;
1199    }
1200
1201    #[tokio::test(flavor = "multi_thread")]
1202    #[serial]
1203    async fn test_subscribe_nack_on_cancel_next() {
1204        // cancel after subscribe all message
1205        subscribe_nack_on_cancel_next(10, Duration::from_secs(3)).await;
1206        // cancel after process all message
1207        subscribe_nack_on_cancel_next(10, Duration::from_millis(0)).await;
1208        // no message
1209        subscribe_nack_on_cancel_next(0, Duration::from_secs(3)).await;
1210    }
1211
1212    async fn subscribe_nack_on_cancel_read(msg_count: usize, should_cancel: bool) {
1213        let opt = Some(SubscribeConfig::default().with_enable_multiple_subscriber(true));
1214
1215        let msg = PubsubMessage {
1216            data: "test".into(),
1217            ..Default::default()
1218        };
1219        let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1220        let subscription = create_subscription(false).await;
1221        let received = Arc::new(Mutex::new(0));
1222        let checking = received.clone();
1223
1224        let mut iter = subscription.subscribe(opt).await.unwrap();
1225        let ctx = iter.cancellable();
1226        let handler = tokio::spawn(async move {
1227            while let Some(message) = iter.read().await {
1228                tracing::info!("received {}", message.message.message_id);
1229                *received.lock().unwrap() += 1;
1230                if should_cancel {
1231                    // expect cancel
1232                    tokio::time::sleep(Duration::from_secs(10)).await;
1233                } else {
1234                    tokio::time::sleep(Duration::from_millis(1)).await;
1235                }
1236                let _ = message.ack().await;
1237            }
1238        });
1239        publish(Some(msg)).await;
1240        tokio::time::sleep(Duration::from_secs(10)).await;
1241        ctx.cancel();
1242        handler.await.unwrap();
1243        if should_cancel && msg_count > 0 {
1244            // expect nack
1245            assert!(*checking.lock().unwrap() < msg_count);
1246        } else {
1247            // all delivered
1248            assert_eq!(*checking.lock().unwrap(), msg_count);
1249        }
1250    }
1251
1252    async fn subscribe_nack_on_cancel_next(msg_count: usize, recv_time: Duration) {
1253        let opt = Some(SubscribeConfig::default().with_enable_multiple_subscriber(true));
1254
1255        let msg = PubsubMessage {
1256            data: "test".into(),
1257            ..Default::default()
1258        };
1259        let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1260        let subscription = create_subscription(false).await;
1261        let received = Arc::new(Mutex::new(0));
1262        let checking = received.clone();
1263
1264        let mut iter = subscription.subscribe(opt).await.unwrap();
1265        let ctx = iter.cancellable();
1266        let handler = tokio::spawn(async move {
1267            while let Some(message) = iter.next().await {
1268                tracing::info!("received {}", message.message.message_id);
1269                *received.lock().unwrap() += 1;
1270                tokio::time::sleep(recv_time).await;
1271                let _ = message.ack().await;
1272            }
1273        });
1274        publish(Some(msg)).await;
1275        tokio::time::sleep(Duration::from_secs(10)).await;
1276        ctx.cancel();
1277        handler.await.unwrap();
1278        assert_eq!(*checking.lock().unwrap(), msg_count);
1279    }
1280
1281    #[tokio::test(flavor = "multi_thread")]
1282    #[serial]
1283    async fn test_message_stream_dispose() {
1284        let subscription = create_subscription(false).await;
1285        let mut iter = subscription.subscribe(None).await.unwrap();
1286        iter.dispose().await;
1287        // no effect
1288        iter.dispose().await;
1289        assert!(iter.next().await.is_none());
1290    }
1291}