google_cloud_pubsub/
subscription.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::{Duration, SystemTime};
6
7use prost_types::{DurationError, FieldMask};
8use tokio_util::sync::CancellationToken;
9
10use google_cloud_gax::grpc::codegen::tokio_stream::Stream;
11use google_cloud_gax::grpc::{Code, Status};
12use google_cloud_gax::retry::RetrySetting;
13use google_cloud_googleapis::pubsub::v1::seek_request::Target;
14use google_cloud_googleapis::pubsub::v1::subscription::AnalyticsHubSubscriptionInfo;
15use google_cloud_googleapis::pubsub::v1::{
16    BigQueryConfig, CloudStorageConfig, CreateSnapshotRequest, DeadLetterPolicy, DeleteSnapshotRequest,
17    DeleteSubscriptionRequest, ExpirationPolicy, GetSnapshotRequest, GetSubscriptionRequest, PullRequest, PushConfig,
18    RetryPolicy, SeekRequest, Snapshot, Subscription as InternalSubscription, UpdateSubscriptionRequest,
19};
20
21use crate::apiv1::subscriber_client::SubscriberClient;
22
23use crate::subscriber::{ack, ReceivedMessage, Subscriber, SubscriberConfig};
24
25#[derive(Debug, Clone, Default)]
26pub struct SubscriptionConfig {
27    pub push_config: Option<PushConfig>,
28    pub ack_deadline_seconds: i32,
29    pub retain_acked_messages: bool,
30    pub message_retention_duration: Option<Duration>,
31    pub labels: HashMap<String, String>,
32    pub enable_message_ordering: bool,
33    pub expiration_policy: Option<ExpirationPolicy>,
34    pub filter: String,
35    pub dead_letter_policy: Option<DeadLetterPolicy>,
36    pub retry_policy: Option<RetryPolicy>,
37    pub detached: bool,
38    pub topic_message_retention_duration: Option<Duration>,
39    pub enable_exactly_once_delivery: bool,
40    pub bigquery_config: Option<BigQueryConfig>,
41    pub state: i32,
42    pub cloud_storage_config: Option<CloudStorageConfig>,
43    pub analytics_hub_subscription_info: Option<AnalyticsHubSubscriptionInfo>,
44}
45impl From<InternalSubscription> for SubscriptionConfig {
46    fn from(f: InternalSubscription) -> Self {
47        Self {
48            push_config: f.push_config,
49            bigquery_config: f.bigquery_config,
50            ack_deadline_seconds: f.ack_deadline_seconds,
51            retain_acked_messages: f.retain_acked_messages,
52            message_retention_duration: f
53                .message_retention_duration
54                .map(|v| std::time::Duration::new(v.seconds as u64, v.nanos as u32)),
55            labels: f.labels,
56            enable_message_ordering: f.enable_message_ordering,
57            expiration_policy: f.expiration_policy,
58            filter: f.filter,
59            dead_letter_policy: f.dead_letter_policy,
60            retry_policy: f.retry_policy,
61            detached: f.detached,
62            topic_message_retention_duration: f
63                .topic_message_retention_duration
64                .map(|v| std::time::Duration::new(v.seconds as u64, v.nanos as u32)),
65            enable_exactly_once_delivery: f.enable_exactly_once_delivery,
66            state: f.state,
67            cloud_storage_config: f.cloud_storage_config,
68            analytics_hub_subscription_info: f.analytics_hub_subscription_info,
69        }
70    }
71}
72
73#[derive(Debug, Clone, Default)]
74pub struct SubscriptionConfigToUpdate {
75    pub push_config: Option<PushConfig>,
76    pub bigquery_config: Option<BigQueryConfig>,
77    pub ack_deadline_seconds: Option<i32>,
78    pub retain_acked_messages: Option<bool>,
79    pub message_retention_duration: Option<Duration>,
80    pub labels: Option<HashMap<String, String>>,
81    pub expiration_policy: Option<ExpirationPolicy>,
82    pub dead_letter_policy: Option<DeadLetterPolicy>,
83    pub retry_policy: Option<RetryPolicy>,
84}
85
86#[derive(Debug, Clone, Default)]
87pub struct SubscribeConfig {
88    enable_multiple_subscriber: bool,
89    channel_capacity: Option<usize>,
90    subscriber_config: Option<SubscriberConfig>,
91}
92
93impl SubscribeConfig {
94    pub fn with_enable_multiple_subscriber(mut self, v: bool) -> Self {
95        self.enable_multiple_subscriber = v;
96        self
97    }
98    pub fn with_subscriber_config(mut self, v: SubscriberConfig) -> Self {
99        self.subscriber_config = Some(v);
100        self
101    }
102    pub fn with_channel_capacity(mut self, v: usize) -> Self {
103        self.channel_capacity = Some(v);
104        self
105    }
106}
107
108#[derive(Debug, Clone)]
109pub struct ReceiveConfig {
110    pub worker_count: usize,
111    pub channel_capacity: Option<usize>,
112    pub subscriber_config: Option<SubscriberConfig>,
113}
114
115impl Default for ReceiveConfig {
116    fn default() -> Self {
117        Self {
118            worker_count: 10,
119            subscriber_config: None,
120            channel_capacity: None,
121        }
122    }
123}
124
125#[derive(Debug, Clone)]
126pub enum SeekTo {
127    Timestamp(SystemTime),
128    Snapshot(String),
129}
130
131impl From<SeekTo> for Target {
132    fn from(to: SeekTo) -> Target {
133        use SeekTo::*;
134        match to {
135            Timestamp(t) => Target::Time(prost_types::Timestamp::from(t)),
136            Snapshot(s) => Target::Snapshot(s),
137        }
138    }
139}
140
141pub struct MessageStream {
142    queue: async_channel::Receiver<ReceivedMessage>,
143    cancel: CancellationToken,
144    tasks: Vec<Subscriber>,
145}
146
147impl MessageStream {
148    pub fn cancellable(&self) -> CancellationToken {
149        self.cancel.clone()
150    }
151
152    pub async fn dispose(&mut self) {
153        // Close streaming pull task
154        if !self.cancel.is_cancelled() {
155            self.cancel.cancel();
156        }
157
158        // Wait for all the streaming pull close.
159        for task in &mut self.tasks {
160            task.done().await;
161        }
162
163        // Nack for remaining messages.
164        while let Ok(message) = self.queue.recv().await {
165            if let Err(err) = message.nack().await {
166                tracing::warn!("failed to nack message messageId={} {:?}", message.message.message_id, err);
167            }
168        }
169    }
170
171    /// Immediately Nack on cancel
172    pub async fn read(&mut self) -> Option<ReceivedMessage> {
173        let message = tokio::select! {
174            msg = self.queue.recv() => msg.ok(),
175            _ = self.cancel.cancelled() => None
176        };
177        if message.is_none() {
178            self.dispose().await;
179        }
180        message
181    }
182}
183
184impl Drop for MessageStream {
185    fn drop(&mut self) {
186        if !self.queue.is_empty() {
187            tracing::warn!("Call 'dispose' before drop in order to call nack for remaining messages");
188        }
189        if !self.cancel.is_cancelled() {
190            self.cancel.cancel();
191        }
192    }
193}
194
195impl Stream for MessageStream {
196    type Item = ReceivedMessage;
197
198    /// Return None unless the queue is open.
199    /// Use CancellationToken for SubscribeConfig to get None
200    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201        Pin::new(&mut self.get_mut().queue).poll_next(cx)
202    }
203}
204
205/// Subscription is a reference to a PubSub subscription.
206#[derive(Clone, Debug)]
207pub struct Subscription {
208    fqsn: String,
209    subc: SubscriberClient,
210}
211
212impl Subscription {
213    pub(crate) fn new(fqsn: String, subc: SubscriberClient) -> Self {
214        Self { fqsn, subc }
215    }
216
217    pub(crate) fn streaming_pool_size(&self) -> usize {
218        self.subc.streaming_pool_size()
219    }
220
221    /// id returns the unique identifier of the subscription within its project.
222    pub fn id(&self) -> String {
223        self.fqsn
224            .rfind('/')
225            .map_or("".to_string(), |i| self.fqsn[(i + 1)..].to_string())
226    }
227
228    /// fully_qualified_name returns the globally unique printable name of the subscription.
229    pub fn fully_qualified_name(&self) -> &str {
230        self.fqsn.as_str()
231    }
232
233    /// fully_qualified_snapshot_name returns the globally unique printable name of the snapshot.
234    pub fn fully_qualified_snapshot_name(&self, id: &str) -> String {
235        if id.contains('/') {
236            id.to_string()
237        } else {
238            format!("{}/snapshots/{}", self.fully_qualified_project_name(), id)
239        }
240    }
241
242    fn fully_qualified_project_name(&self) -> String {
243        let parts: Vec<_> = self
244            .fqsn
245            .split('/')
246            .enumerate()
247            .filter(|&(i, _)| i < 2)
248            .map(|e| e.1)
249            .collect();
250        parts.join("/")
251    }
252
253    pub fn get_client(&self) -> SubscriberClient {
254        self.subc.clone()
255    }
256
257    /// create creates the subscription.
258    pub async fn create(&self, fqtn: &str, cfg: SubscriptionConfig, retry: Option<RetrySetting>) -> Result<(), Status> {
259        self.subc
260            .create_subscription(
261                InternalSubscription {
262                    name: self.fully_qualified_name().to_string(),
263                    topic: fqtn.to_string(),
264                    push_config: cfg.push_config,
265                    bigquery_config: cfg.bigquery_config,
266                    cloud_storage_config: cfg.cloud_storage_config,
267                    ack_deadline_seconds: cfg.ack_deadline_seconds,
268                    labels: cfg.labels,
269                    enable_message_ordering: cfg.enable_message_ordering,
270                    expiration_policy: cfg.expiration_policy,
271                    filter: cfg.filter,
272                    dead_letter_policy: cfg.dead_letter_policy,
273                    retry_policy: cfg.retry_policy,
274                    detached: cfg.detached,
275                    message_retention_duration: cfg
276                        .message_retention_duration
277                        .map(Duration::try_into)
278                        .transpose()
279                        .map_err(|err: DurationError| Status::internal(err.to_string()))?,
280                    retain_acked_messages: cfg.retain_acked_messages,
281                    topic_message_retention_duration: cfg
282                        .topic_message_retention_duration
283                        .map(Duration::try_into)
284                        .transpose()
285                        .map_err(|err: DurationError| Status::internal(err.to_string()))?,
286                    enable_exactly_once_delivery: cfg.enable_exactly_once_delivery,
287                    state: cfg.state,
288                    analytics_hub_subscription_info: cfg.analytics_hub_subscription_info,
289                },
290                retry,
291            )
292            .await
293            .map(|_v| ())
294    }
295
296    /// delete deletes the subscription.
297    pub async fn delete(&self, retry: Option<RetrySetting>) -> Result<(), Status> {
298        let req = DeleteSubscriptionRequest {
299            subscription: self.fqsn.to_string(),
300        };
301        self.subc.delete_subscription(req, retry).await.map(|v| v.into_inner())
302    }
303
304    /// exists reports whether the subscription exists on the server.
305    pub async fn exists(&self, retry: Option<RetrySetting>) -> Result<bool, Status> {
306        let req = GetSubscriptionRequest {
307            subscription: self.fqsn.to_string(),
308        };
309        match self.subc.get_subscription(req, retry).await {
310            Ok(_) => Ok(true),
311            Err(e) => {
312                if e.code() == Code::NotFound {
313                    Ok(false)
314                } else {
315                    Err(e)
316                }
317            }
318        }
319    }
320
321    /// config fetches the current configuration for the subscription.
322    pub async fn config(&self, retry: Option<RetrySetting>) -> Result<(String, SubscriptionConfig), Status> {
323        let req = GetSubscriptionRequest {
324            subscription: self.fqsn.to_string(),
325        };
326        self.subc.get_subscription(req, retry).await.map(|v| {
327            let inner = v.into_inner();
328            (inner.topic.to_string(), inner.into())
329        })
330    }
331
332    /// update changes an existing subscription according to the fields set in updating.
333    /// It returns the new SubscriptionConfig.
334    pub async fn update(
335        &self,
336        updating: SubscriptionConfigToUpdate,
337        retry: Option<RetrySetting>,
338    ) -> Result<(String, SubscriptionConfig), Status> {
339        let req = GetSubscriptionRequest {
340            subscription: self.fqsn.to_string(),
341        };
342        let mut config = self.subc.get_subscription(req, retry.clone()).await?.into_inner();
343
344        let mut paths = vec![];
345        if updating.push_config.is_some() {
346            config.push_config = updating.push_config;
347            paths.push("push_config".to_string());
348        }
349        if updating.bigquery_config.is_some() {
350            config.bigquery_config = updating.bigquery_config;
351            paths.push("bigquery_config".to_string());
352        }
353        if let Some(v) = updating.ack_deadline_seconds {
354            config.ack_deadline_seconds = v;
355            paths.push("ack_deadline_seconds".to_string());
356        }
357        if let Some(v) = updating.retain_acked_messages {
358            config.retain_acked_messages = v;
359            paths.push("retain_acked_messages".to_string());
360        }
361        if updating.message_retention_duration.is_some() {
362            config.message_retention_duration = updating
363                .message_retention_duration
364                .map(prost_types::Duration::try_from)
365                .transpose()
366                .map_err(|err| Status::internal(err.to_string()))?;
367            paths.push("message_retention_duration".to_string());
368        }
369        if updating.expiration_policy.is_some() {
370            config.expiration_policy = updating.expiration_policy;
371            paths.push("expiration_policy".to_string());
372        }
373        if let Some(v) = updating.labels {
374            config.labels = v;
375            paths.push("labels".to_string());
376        }
377        if updating.retry_policy.is_some() {
378            config.retry_policy = updating.retry_policy;
379            paths.push("retry_policy".to_string());
380        }
381
382        let update_req = UpdateSubscriptionRequest {
383            subscription: Some(config),
384            update_mask: Some(FieldMask { paths }),
385        };
386        self.subc.update_subscription(update_req, retry).await.map(|v| {
387            let inner = v.into_inner();
388            (inner.topic.to_string(), inner.into())
389        })
390    }
391
392    /// pull get message synchronously.
393    /// It blocks until at least one message is available.
394    pub async fn pull(&self, max_messages: i32, retry: Option<RetrySetting>) -> Result<Vec<ReceivedMessage>, Status> {
395        #[allow(deprecated)]
396        let req = PullRequest {
397            subscription: self.fqsn.clone(),
398            return_immediately: false,
399            max_messages,
400        };
401        let messages = self.subc.pull(req, retry).await?.into_inner().received_messages;
402        Ok(messages
403            .into_iter()
404            .filter(|m| m.message.is_some())
405            .map(|m| {
406                ReceivedMessage::new(
407                    self.fqsn.clone(),
408                    self.subc.clone(),
409                    m.message.unwrap(),
410                    m.ack_id,
411                    (m.delivery_attempt > 0).then_some(m.delivery_attempt as usize),
412                )
413            })
414            .collect())
415    }
416
417    /// subscribe creates a `Stream` of `ReceivedMessage`
418    /// ```
419    /// use google_cloud_pubsub::subscription::{SubscribeConfig, Subscription};
420    /// use tokio::select;
421    /// use google_cloud_gax::grpc::Status;
422    ///
423    /// async fn run(subscription: Subscription) -> Result<(), Status> {
424    ///     let mut iter = subscription.subscribe(None).await?;
425    ///     let ctx = iter.cancellable();
426    ///     let handler = tokio::spawn(async move {
427    ///         while let Some(message) = iter.read().await {
428    ///             let _ = message.ack().await;
429    ///         }
430    ///     });
431    ///     // Cancel and wait for nack all the pulled messages.
432    ///     ctx.cancel();
433    ///     let _ = handler.await;
434    ///     Ok(())
435    ///  }
436    /// ```
437    ///
438    /// ```
439    /// use google_cloud_pubsub::subscription::{SubscribeConfig, Subscription};
440    /// use futures_util::StreamExt;
441    /// use tokio::select;
442    /// use google_cloud_gax::grpc::Status;
443    ///
444    /// async fn run(subscription: Subscription) -> Result<(), Status> {
445    ///     let mut iter = subscription.subscribe(None).await?;
446    ///     let ctx = iter.cancellable();
447    ///     let handler = tokio::spawn(async move {
448    ///         while let Some(message) = iter.next().await {
449    ///             let _ = message.ack().await;
450    ///         }
451    ///     });
452    ///     // Cancel and wait for receive all the pulled messages.
453    ///     ctx.cancel();
454    ///     let _ = handler.await;
455    ///     Ok(())
456    ///  }
457    /// ```
458    pub async fn subscribe(&self, opt: Option<SubscribeConfig>) -> Result<MessageStream, Status> {
459        let opt = opt.unwrap_or_default();
460        let (tx, rx) = create_channel(opt.channel_capacity);
461        let cancel = CancellationToken::new();
462        let sub_opt = self.unwrap_subscribe_config(opt.subscriber_config).await?;
463
464        // spawn a separate subscriber task for each connection in the pool
465        let subscribers = if opt.enable_multiple_subscriber {
466            self.streaming_pool_size()
467        } else {
468            1
469        };
470        let mut tasks = Vec::with_capacity(subscribers);
471        for _ in 0..subscribers {
472            tasks.push(Subscriber::start(
473                cancel.clone(),
474                self.fqsn.clone(),
475                self.subc.clone(),
476                tx.clone(),
477                sub_opt.clone(),
478            ));
479        }
480
481        Ok(MessageStream {
482            queue: rx,
483            cancel,
484            tasks,
485        })
486    }
487
488    /// receive calls f with the outstanding messages from the subscription.
489    /// It blocks until cancellation token is cancelled, or the service returns a non-retryable error.
490    /// The standard way to terminate a receive is to use CancellationToken.
491    pub async fn receive<F>(
492        &self,
493        f: impl Fn(ReceivedMessage, CancellationToken) -> F + Send + 'static + Sync + Clone,
494        cancel: CancellationToken,
495        config: Option<ReceiveConfig>,
496    ) -> Result<(), Status>
497    where
498        F: Future<Output = ()> + Send + 'static,
499    {
500        let op = config.unwrap_or_default();
501        let mut receivers = Vec::with_capacity(op.worker_count);
502        let mut senders = Vec::with_capacity(receivers.len());
503        let sub_opt = self.unwrap_subscribe_config(op.subscriber_config).await?;
504
505        if self
506            .config(sub_opt.retry_setting.clone())
507            .await?
508            .1
509            .enable_message_ordering
510        {
511            (0..op.worker_count).for_each(|_v| {
512                let (sender, receiver) = create_channel(op.channel_capacity);
513                receivers.push(receiver);
514                senders.push(sender);
515            });
516        } else {
517            let (sender, receiver) = create_channel(op.channel_capacity);
518            (0..op.worker_count).for_each(|_v| {
519                receivers.push(receiver.clone());
520                senders.push(sender.clone());
521            });
522        }
523
524        //same ordering key is in same stream.
525        let subscribers: Vec<Subscriber> = senders
526            .into_iter()
527            .map(|queue| {
528                Subscriber::start(cancel.clone(), self.fqsn.clone(), self.subc.clone(), queue, sub_opt.clone())
529            })
530            .collect();
531
532        let mut message_receivers = Vec::with_capacity(receivers.len());
533        for receiver in receivers {
534            let f_clone = f.clone();
535            let cancel_clone = cancel.clone();
536            let name = self.fqsn.clone();
537            message_receivers.push(tokio::spawn(async move {
538                while let Ok(message) = receiver.recv().await {
539                    f_clone(message, cancel_clone.clone()).await;
540                }
541                // queue is closed by subscriber when the cancellation token is cancelled
542                tracing::trace!("stop message receiver : {}", name);
543            }));
544        }
545        cancel.cancelled().await;
546
547        // wait for all the threads finish.
548        for mut subscriber in subscribers {
549            subscriber.done().await;
550        }
551
552        // wait for all the receivers process received messages
553        for mr in message_receivers {
554            let _ = mr.await;
555        }
556        Ok(())
557    }
558
559    /// Ack acknowledges the messages associated with the ack_ids in the AcknowledgeRequest.
560    /// The Pub/Sub system can remove the relevant messages from the subscription.
561    /// This method is for batch acking.
562    ///
563    /// ```
564    /// use google_cloud_pubsub::client::Client;
565    /// use google_cloud_pubsub::subscription::Subscription;
566    /// use google_cloud_gax::grpc::Status;
567    /// use std::time::Duration;
568    /// use tokio_util::sync::CancellationToken;;
569    ///
570    /// #[tokio::main]
571    /// async fn run(client: Client) -> Result<(), Status> {
572    ///     let subscription = client.subscription("test-subscription");
573    ///     let ctx = CancellationToken::new();
574    ///     let (sender, mut receiver)  = tokio::sync::mpsc::unbounded_channel();
575    ///     let subscription_for_receive = subscription.clone();
576    ///     let ctx_for_receive = ctx.clone();
577    ///     let ctx_for_ack_manager = ctx.clone();
578    ///
579    ///     // receive
580    ///     let handle = tokio::spawn(async move {
581    ///         let _ = subscription_for_receive.receive(move |message, _ctx| {
582    ///             let sender = sender.clone();
583    ///             async move {
584    ///                 let _ = sender.send(message.ack_id().to_string());
585    ///             }
586    ///         }, ctx_for_receive.clone(), None).await;
587    ///     });
588    ///
589    ///     // batch ack manager
590    ///     let ack_manager = tokio::spawn( async move {
591    ///         let mut ack_ids = Vec::new();
592    ///         loop {
593    ///             tokio::select! {
594    ///                 _ = ctx_for_ack_manager.cancelled() => {
595    ///                     return subscription.ack(ack_ids).await;
596    ///                 },
597    ///                 r = tokio::time::timeout(Duration::from_secs(10), receiver.recv()) => match r {
598    ///                     Ok(ack_id) => {
599    ///                         if let Some(ack_id) = ack_id {
600    ///                             ack_ids.push(ack_id);
601    ///                             if ack_ids.len() > 10 {
602    ///                                 let _ = subscription.ack(ack_ids).await;
603    ///                                 ack_ids = Vec::new();
604    ///                             }
605    ///                         }
606    ///                     },
607    ///                     Err(_e) => {
608    ///                         // timeout
609    ///                         let _ = subscription.ack(ack_ids).await;
610    ///                         ack_ids = Vec::new();
611    ///                     }
612    ///                 }
613    ///             }
614    ///         }
615    ///     });
616    ///
617    ///     ctx.cancel();
618    ///     Ok(())
619    ///  }
620    /// ```
621    pub async fn ack(&self, ack_ids: Vec<String>) -> Result<(), Status> {
622        ack(&self.subc, self.fqsn.to_string(), ack_ids).await
623    }
624
625    /// seek seeks the subscription a past timestamp or a saved snapshot.
626    pub async fn seek(&self, to: SeekTo, retry: Option<RetrySetting>) -> Result<(), Status> {
627        let to = match to {
628            SeekTo::Timestamp(t) => SeekTo::Timestamp(t),
629            SeekTo::Snapshot(name) => SeekTo::Snapshot(self.fully_qualified_snapshot_name(name.as_str())),
630        };
631
632        let req = SeekRequest {
633            subscription: self.fqsn.to_owned(),
634            target: Some(to.into()),
635        };
636
637        let _ = self.subc.seek(req, retry).await?;
638        Ok(())
639    }
640
641    /// get_snapshot fetches an existing pubsub snapshot.
642    pub async fn get_snapshot(&self, name: &str, retry: Option<RetrySetting>) -> Result<Snapshot, Status> {
643        let req = GetSnapshotRequest {
644            snapshot: self.fully_qualified_snapshot_name(name),
645        };
646        Ok(self.subc.get_snapshot(req, retry).await?.into_inner())
647    }
648
649    /// create_snapshot creates a new pubsub snapshot from the subscription's state at the time of calling.
650    /// The snapshot retains the messages for the topic the subscription is subscribed to, with the acknowledgment
651    /// states consistent with the subscriptions.
652    /// The created snapshot is guaranteed to retain:
653    /// - The message backlog on the subscription -- or to be specific, messages that are unacknowledged
654    ///   at the time of the subscription's creation.
655    /// - All messages published to the subscription's topic after the snapshot's creation.
656    ///   Snapshots have a finite lifetime -- a maximum of 7 days from the time of creation, beyond which
657    ///   they are discarded and any messages being retained solely due to the snapshot dropped.
658    pub async fn create_snapshot(
659        &self,
660        name: &str,
661        labels: HashMap<String, String>,
662        retry: Option<RetrySetting>,
663    ) -> Result<Snapshot, Status> {
664        let req = CreateSnapshotRequest {
665            name: self.fully_qualified_snapshot_name(name),
666            labels,
667            subscription: self.fqsn.to_owned(),
668        };
669        Ok(self.subc.create_snapshot(req, retry).await?.into_inner())
670    }
671
672    /// delete_snapshot deletes an existing pubsub snapshot.
673    pub async fn delete_snapshot(&self, name: &str, retry: Option<RetrySetting>) -> Result<(), Status> {
674        let req = DeleteSnapshotRequest {
675            snapshot: self.fully_qualified_snapshot_name(name),
676        };
677        let _ = self.subc.delete_snapshot(req, retry).await?;
678        Ok(())
679    }
680
681    async fn unwrap_subscribe_config(&self, cfg: Option<SubscriberConfig>) -> Result<SubscriberConfig, Status> {
682        if let Some(cfg) = cfg {
683            return Ok(cfg);
684        }
685        let cfg = self.config(None).await?;
686        let mut default_cfg = SubscriberConfig {
687            stream_ack_deadline_seconds: cfg.1.ack_deadline_seconds.clamp(10, 600),
688            ..Default::default()
689        };
690        if cfg.1.enable_exactly_once_delivery {
691            default_cfg.max_outstanding_messages = 5;
692        }
693        Ok(default_cfg)
694    }
695}
696
697fn create_channel(
698    channel_capacity: Option<usize>,
699) -> (async_channel::Sender<ReceivedMessage>, async_channel::Receiver<ReceivedMessage>) {
700    match channel_capacity {
701        None => async_channel::unbounded(),
702        Some(cap) => async_channel::bounded(cap),
703    }
704}
705
706#[cfg(test)]
707mod tests {
708    use std::collections::HashMap;
709    use std::sync::atomic::AtomicU32;
710    use std::sync::atomic::Ordering::SeqCst;
711    use std::sync::{Arc, Mutex};
712    use std::time::Duration;
713
714    use futures_util::StreamExt;
715    use serial_test::serial;
716    use tokio_util::sync::CancellationToken;
717    use uuid::Uuid;
718
719    use google_cloud_gax::conn::{ConnectionOptions, Environment};
720    use google_cloud_googleapis::pubsub::v1::{PublishRequest, PubsubMessage};
721
722    use crate::apiv1::conn_pool::ConnectionManager;
723    use crate::apiv1::publisher_client::PublisherClient;
724    use crate::apiv1::subscriber_client::SubscriberClient;
725    use crate::subscriber::ReceivedMessage;
726    use crate::subscription::{
727        ReceiveConfig, SeekTo, SubscribeConfig, Subscription, SubscriptionConfig, SubscriptionConfigToUpdate,
728    };
729
730    const PROJECT_NAME: &str = "local-project";
731    const EMULATOR: &str = "localhost:8681";
732
733    #[ctor::ctor]
734    fn init() {
735        let _ = tracing_subscriber::fmt().try_init();
736    }
737
738    async fn create_subscription(enable_exactly_once_delivery: bool) -> Subscription {
739        let cm = ConnectionManager::new(
740            4,
741            "",
742            &Environment::Emulator(EMULATOR.to_string()),
743            &ConnectionOptions::default(),
744        )
745        .await
746        .unwrap();
747        let cm2 = ConnectionManager::new(
748            4,
749            "",
750            &Environment::Emulator(EMULATOR.to_string()),
751            &ConnectionOptions::default(),
752        )
753        .await
754        .unwrap();
755        let client = SubscriberClient::new(cm, cm2);
756
757        let uuid = Uuid::new_v4().hyphenated().to_string();
758        let subscription_name = format!("projects/{}/subscriptions/s{}", PROJECT_NAME, &uuid);
759        let topic_name = format!("projects/{PROJECT_NAME}/topics/test-topic1");
760        let subscription = Subscription::new(subscription_name, client);
761        let config = SubscriptionConfig {
762            enable_exactly_once_delivery,
763            ..Default::default()
764        };
765        if !subscription.exists(None).await.unwrap() {
766            subscription.create(topic_name.as_str(), config, None).await.unwrap();
767        }
768        subscription
769    }
770
771    async fn publish(messages: Option<Vec<PubsubMessage>>) {
772        let pubc = PublisherClient::new(
773            ConnectionManager::new(
774                4,
775                "",
776                &Environment::Emulator(EMULATOR.to_string()),
777                &ConnectionOptions::default(),
778            )
779            .await
780            .unwrap(),
781        );
782        let messages = messages.unwrap_or(vec![PubsubMessage {
783            data: "test_message".into(),
784            ..Default::default()
785        }]);
786        let req = PublishRequest {
787            topic: format!("projects/{PROJECT_NAME}/topics/test-topic1"),
788            messages,
789        };
790        let _ = pubc.publish(req, None).await;
791    }
792
793    async fn test_subscription(enable_exactly_once_delivery: bool) {
794        let subscription = create_subscription(enable_exactly_once_delivery).await;
795
796        let topic_name = format!("projects/{PROJECT_NAME}/topics/test-topic1");
797        let config = subscription.config(None).await.unwrap();
798        assert_eq!(config.0, topic_name);
799
800        let updating = SubscriptionConfigToUpdate {
801            ack_deadline_seconds: Some(100),
802            ..Default::default()
803        };
804        let new_config = subscription.update(updating, None).await.unwrap();
805        assert_eq!(new_config.0, topic_name);
806        assert_eq!(new_config.1.ack_deadline_seconds, 100);
807
808        let receiver_ctx = CancellationToken::new();
809        let cancel_receiver = receiver_ctx.clone();
810        let handle = tokio::spawn(async move {
811            let _ = subscription
812                .receive(
813                    |message, _ctx| async move {
814                        println!("{}", message.message.message_id);
815                        let _ = message.ack().await;
816                    },
817                    cancel_receiver,
818                    None,
819                )
820                .await;
821            subscription.delete(None).await.unwrap();
822            assert!(!subscription.exists(None).await.unwrap())
823        });
824        tokio::time::sleep(Duration::from_secs(3)).await;
825        receiver_ctx.cancel();
826        let _ = handle.await;
827    }
828
829    #[tokio::test(flavor = "multi_thread")]
830    #[serial]
831    async fn test_pull() {
832        let subscription = create_subscription(false).await;
833        let base = PubsubMessage {
834            data: "test_message".into(),
835            ..Default::default()
836        };
837        publish(Some(vec![base.clone(), base.clone(), base])).await;
838        let messages = subscription.pull(2, None).await.unwrap();
839        assert_eq!(messages.len(), 2);
840        for m in messages {
841            m.ack().await.unwrap();
842        }
843        subscription.delete(None).await.unwrap();
844    }
845
846    #[tokio::test]
847    #[serial]
848    async fn test_subscription_exactly_once() {
849        test_subscription(true).await;
850    }
851
852    #[tokio::test]
853    #[serial]
854    async fn test_subscription_at_least_once() {
855        test_subscription(false).await;
856    }
857
858    #[tokio::test(flavor = "multi_thread")]
859    #[serial]
860    async fn test_multi_subscriber_single_subscription_unbound() {
861        test_multi_subscriber_single_subscription(None).await;
862    }
863
864    #[tokio::test(flavor = "multi_thread")]
865    #[serial]
866    async fn test_multi_subscriber_single_subscription_bound() {
867        let opt = Some(ReceiveConfig {
868            channel_capacity: Some(1),
869            ..Default::default()
870        });
871        test_multi_subscriber_single_subscription(opt).await;
872    }
873
874    async fn test_multi_subscriber_single_subscription(opt: Option<ReceiveConfig>) {
875        let msg = PubsubMessage {
876            data: "test".into(),
877            ..Default::default()
878        };
879        let msg_size = 10;
880        let msgs: Vec<PubsubMessage> = (0..msg_size).map(|_v| msg.clone()).collect();
881        let subscription = create_subscription(false).await;
882        let cancellation_token = CancellationToken::new();
883        let cancel_receiver = cancellation_token.clone();
884        let v = Arc::new(AtomicU32::new(0));
885        let v2 = v.clone();
886        let handle = tokio::spawn(async move {
887            let _ = subscription
888                .receive(
889                    move |message, _ctx| {
890                        let v2 = v2.clone();
891                        async move {
892                            tracing::info!("received {}", message.message.message_id);
893                            v2.fetch_add(1, SeqCst);
894                            let _ = message.ack().await;
895                        }
896                    },
897                    cancel_receiver,
898                    opt,
899                )
900                .await;
901        });
902        publish(Some(msgs)).await;
903        tokio::time::sleep(Duration::from_secs(5)).await;
904        cancellation_token.cancel();
905        let _ = handle.await;
906        assert_eq!(v.load(SeqCst), msg_size);
907    }
908
909    #[tokio::test(flavor = "multi_thread")]
910    #[serial]
911    async fn test_multi_subscriber_multi_subscription() {
912        let mut subscriptions = vec![];
913
914        let ctx = CancellationToken::new();
915        for _ in 0..3 {
916            let subscription = create_subscription(false).await;
917            let v = Arc::new(AtomicU32::new(0));
918            let ctx = ctx.clone();
919            let v2 = v.clone();
920            let handle = tokio::spawn(async move {
921                let _ = subscription
922                    .receive(
923                        move |message, _ctx| {
924                            let v2 = v2.clone();
925                            async move {
926                                v2.fetch_add(1, SeqCst);
927                                let _ = message.ack().await;
928                            }
929                        },
930                        ctx,
931                        None,
932                    )
933                    .await;
934            });
935            subscriptions.push((handle, v))
936        }
937
938        publish(None).await;
939        tokio::time::sleep(Duration::from_secs(5)).await;
940
941        ctx.cancel();
942        for (task, v) in subscriptions {
943            let _ = task.await;
944            assert_eq!(v.load(SeqCst), 1);
945        }
946    }
947
948    #[tokio::test(flavor = "multi_thread")]
949    #[serial]
950    async fn test_batch_acking() {
951        let ctx = CancellationToken::new();
952        let subscription = create_subscription(false).await;
953        let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
954        let subscription_for_receive = subscription.clone();
955        let ctx_for_receive = ctx.clone();
956        let handle = tokio::spawn(async move {
957            let _ = subscription_for_receive
958                .receive(
959                    move |message, _ctx| {
960                        let sender = sender.clone();
961                        async move {
962                            let _ = sender.send(message.ack_id().to_string());
963                        }
964                    },
965                    ctx_for_receive.clone(),
966                    None,
967                )
968                .await;
969        });
970
971        let ctx_for_ack_manager = ctx.clone();
972        let ack_manager = tokio::spawn(async move {
973            let mut ack_ids = Vec::new();
974            while !ctx_for_ack_manager.is_cancelled() {
975                match tokio::time::timeout(Duration::from_secs(10), receiver.recv()).await {
976                    Ok(ack_id) => {
977                        if let Some(ack_id) = ack_id {
978                            ack_ids.push(ack_id);
979                            if ack_ids.len() > 10 {
980                                subscription.ack(ack_ids).await.unwrap();
981                                ack_ids = Vec::new();
982                            }
983                        }
984                    }
985                    Err(_e) => {
986                        // timeout
987                        subscription.ack(ack_ids).await.unwrap();
988                        ack_ids = Vec::new();
989                    }
990                }
991            }
992            // flush
993            subscription.ack(ack_ids).await
994        });
995
996        publish(None).await;
997        tokio::time::sleep(Duration::from_secs(5)).await;
998
999        ctx.cancel();
1000        let _ = handle.await;
1001        assert!(ack_manager.await.is_ok());
1002    }
1003
1004    #[tokio::test]
1005    #[serial]
1006    async fn test_snapshots() {
1007        let subscription = create_subscription(false).await;
1008
1009        let snapshot_name = format!("snapshot-{}", rand::random::<u64>());
1010        let labels: HashMap<String, String> =
1011            HashMap::from_iter([("label-1".into(), "v1".into()), ("label-2".into(), "v2".into())]);
1012        let expected_fq_snap_name = format!("projects/{PROJECT_NAME}/snapshots/{snapshot_name}");
1013
1014        // cleanup; TODO: remove?
1015        let _response = subscription.delete_snapshot(snapshot_name.as_str(), None).await;
1016
1017        // create
1018        let created_snapshot = subscription
1019            .create_snapshot(snapshot_name.as_str(), labels.clone(), None)
1020            .await
1021            .unwrap();
1022
1023        assert_eq!(created_snapshot.name, expected_fq_snap_name);
1024        // NOTE: we don't assert the labels due to lack of label support in the pubsub emulator.
1025
1026        // get
1027        let retrieved_snapshot = subscription.get_snapshot(snapshot_name.as_str(), None).await.unwrap();
1028        assert_eq!(created_snapshot, retrieved_snapshot);
1029
1030        // delete
1031        subscription
1032            .delete_snapshot(snapshot_name.as_str(), None)
1033            .await
1034            .unwrap();
1035
1036        let _deleted_snapshot_status = subscription
1037            .get_snapshot(snapshot_name.as_str(), None)
1038            .await
1039            .expect_err("snapshot should have been deleted");
1040
1041        let _delete_again = subscription
1042            .delete_snapshot(snapshot_name.as_str(), None)
1043            .await
1044            .expect_err("snapshot should already be deleted");
1045    }
1046
1047    async fn ack_all(messages: &[ReceivedMessage]) {
1048        for message in messages.iter() {
1049            message.ack().await.unwrap();
1050        }
1051    }
1052
1053    #[tokio::test]
1054    #[serial]
1055    async fn test_seek_snapshot() {
1056        let subscription = create_subscription(false).await;
1057        let snapshot_name = format!("snapshot-{}", rand::random::<u64>());
1058
1059        // publish and receive a message
1060        publish(None).await;
1061        let messages = subscription.pull(100, None).await.unwrap();
1062        ack_all(&messages).await;
1063        assert_eq!(messages.len(), 1);
1064
1065        // snapshot at received = 1
1066        let _snapshot = subscription
1067            .create_snapshot(snapshot_name.as_str(), HashMap::new(), None)
1068            .await
1069            .unwrap();
1070
1071        // publish and receive another message
1072        publish(None).await;
1073        let messages = subscription.pull(100, None).await.unwrap();
1074        assert_eq!(messages.len(), 1);
1075        ack_all(&messages).await;
1076
1077        // rewind to snapshot at received = 1
1078        subscription
1079            .seek(SeekTo::Snapshot(snapshot_name.clone()), None)
1080            .await
1081            .unwrap();
1082
1083        // assert we receive the 1 message we should receive again
1084        let messages = subscription.pull(100, None).await.unwrap();
1085        assert_eq!(messages.len(), 1);
1086        ack_all(&messages).await;
1087
1088        // cleanup
1089        subscription
1090            .delete_snapshot(snapshot_name.as_str(), None)
1091            .await
1092            .unwrap();
1093        subscription.delete(None).await.unwrap();
1094    }
1095
1096    #[tokio::test]
1097    #[serial]
1098    async fn test_seek_timestamp() {
1099        let subscription = create_subscription(false).await;
1100
1101        // enable acked message retention on subscription -- required for timestamp-based seeks
1102        subscription
1103            .update(
1104                SubscriptionConfigToUpdate {
1105                    retain_acked_messages: Some(true),
1106                    message_retention_duration: Some(Duration::new(60 * 60 * 2, 0)),
1107                    ..Default::default()
1108                },
1109                None,
1110            )
1111            .await
1112            .unwrap();
1113
1114        // publish and receive a message
1115        publish(None).await;
1116        let messages = subscription.pull(100, None).await.unwrap();
1117        ack_all(&messages).await;
1118        assert_eq!(messages.len(), 1);
1119
1120        let message_publish_time = messages.first().unwrap().message.publish_time.to_owned().unwrap();
1121
1122        // rewind to a timestamp where message was just published
1123        subscription
1124            .seek(SeekTo::Timestamp(message_publish_time.to_owned().try_into().unwrap()), None)
1125            .await
1126            .unwrap();
1127
1128        // consume -- should receive the first message again
1129        let messages = subscription.pull(100, None).await.unwrap();
1130        ack_all(&messages).await;
1131        assert_eq!(messages.len(), 1);
1132        let seek_message_publish_time = messages.first().unwrap().message.publish_time.to_owned().unwrap();
1133        assert_eq!(seek_message_publish_time, message_publish_time);
1134
1135        // cleanup
1136        subscription.delete(None).await.unwrap();
1137    }
1138
1139    #[tokio::test(flavor = "multi_thread")]
1140    #[serial]
1141    async fn test_subscribe_single_subscriber() {
1142        test_subscribe(None).await;
1143    }
1144
1145    #[tokio::test(flavor = "multi_thread")]
1146    #[serial]
1147    async fn test_subscribe_multiple_subscriber() {
1148        test_subscribe(Some(SubscribeConfig::default().with_enable_multiple_subscriber(true))).await;
1149    }
1150
1151    #[tokio::test(flavor = "multi_thread")]
1152    #[serial]
1153    async fn test_subscribe_multiple_subscriber_bound() {
1154        test_subscribe(Some(
1155            SubscribeConfig::default()
1156                .with_enable_multiple_subscriber(true)
1157                .with_channel_capacity(1),
1158        ))
1159        .await;
1160    }
1161
1162    async fn test_subscribe(opt: Option<SubscribeConfig>) {
1163        let msg = PubsubMessage {
1164            data: "test".into(),
1165            ..Default::default()
1166        };
1167        let msg_count = 10;
1168        let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1169        let subscription = create_subscription(false).await;
1170        let received = Arc::new(Mutex::new(0));
1171        let checking = received.clone();
1172        let mut iter = subscription.subscribe(opt).await.unwrap();
1173        let cancellable = iter.cancellable();
1174        let handler = tokio::spawn(async move {
1175            while let Some(message) = iter.next().await {
1176                tracing::info!("received {}", message.message.message_id);
1177                *received.lock().unwrap() += 1;
1178                tokio::time::sleep(Duration::from_millis(500)).await;
1179                let _ = message.ack().await;
1180            }
1181        });
1182        publish(Some(msg)).await;
1183        tokio::time::sleep(Duration::from_secs(8)).await;
1184        cancellable.cancel();
1185        let _ = handler.await;
1186        assert_eq!(*checking.lock().unwrap(), msg_count);
1187    }
1188
1189    #[tokio::test(flavor = "multi_thread")]
1190    #[serial]
1191    async fn test_subscribe_nack_on_cancel_read() {
1192        subscribe_nack_on_cancel_read(10, true).await;
1193        subscribe_nack_on_cancel_read(0, true).await;
1194        subscribe_nack_on_cancel_read(10, false).await;
1195        subscribe_nack_on_cancel_read(0, false).await;
1196    }
1197
1198    #[tokio::test(flavor = "multi_thread")]
1199    #[serial]
1200    async fn test_subscribe_nack_on_cancel_next() {
1201        // cancel after subscribe all message
1202        subscribe_nack_on_cancel_next(10, Duration::from_secs(3)).await;
1203        // cancel after process all message
1204        subscribe_nack_on_cancel_next(10, Duration::from_millis(0)).await;
1205        // no message
1206        subscribe_nack_on_cancel_next(0, Duration::from_secs(3)).await;
1207    }
1208
1209    async fn subscribe_nack_on_cancel_read(msg_count: usize, should_cancel: bool) {
1210        let opt = Some(SubscribeConfig::default().with_enable_multiple_subscriber(true));
1211
1212        let msg = PubsubMessage {
1213            data: "test".into(),
1214            ..Default::default()
1215        };
1216        let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1217        let subscription = create_subscription(false).await;
1218        let received = Arc::new(Mutex::new(0));
1219        let checking = received.clone();
1220
1221        let mut iter = subscription.subscribe(opt).await.unwrap();
1222        let ctx = iter.cancellable();
1223        let handler = tokio::spawn(async move {
1224            while let Some(message) = iter.read().await {
1225                tracing::info!("received {}", message.message.message_id);
1226                *received.lock().unwrap() += 1;
1227                if should_cancel {
1228                    // expect cancel
1229                    tokio::time::sleep(Duration::from_secs(10)).await;
1230                } else {
1231                    tokio::time::sleep(Duration::from_millis(1)).await;
1232                }
1233                let _ = message.ack().await;
1234            }
1235        });
1236        publish(Some(msg)).await;
1237        tokio::time::sleep(Duration::from_secs(10)).await;
1238        ctx.cancel();
1239        handler.await.unwrap();
1240        if should_cancel && msg_count > 0 {
1241            // expect nack
1242            assert!(*checking.lock().unwrap() < msg_count);
1243        } else {
1244            // all delivered
1245            assert_eq!(*checking.lock().unwrap(), msg_count);
1246        }
1247    }
1248
1249    async fn subscribe_nack_on_cancel_next(msg_count: usize, recv_time: Duration) {
1250        let opt = Some(SubscribeConfig::default().with_enable_multiple_subscriber(true));
1251
1252        let msg = PubsubMessage {
1253            data: "test".into(),
1254            ..Default::default()
1255        };
1256        let msg: Vec<PubsubMessage> = (0..msg_count).map(|_v| msg.clone()).collect();
1257        let subscription = create_subscription(false).await;
1258        let received = Arc::new(Mutex::new(0));
1259        let checking = received.clone();
1260
1261        let mut iter = subscription.subscribe(opt).await.unwrap();
1262        let ctx = iter.cancellable();
1263        let handler = tokio::spawn(async move {
1264            while let Some(message) = iter.next().await {
1265                tracing::info!("received {}", message.message.message_id);
1266                *received.lock().unwrap() += 1;
1267                tokio::time::sleep(recv_time).await;
1268                let _ = message.ack().await;
1269            }
1270        });
1271        publish(Some(msg)).await;
1272        tokio::time::sleep(Duration::from_secs(10)).await;
1273        ctx.cancel();
1274        handler.await.unwrap();
1275        assert_eq!(*checking.lock().unwrap(), msg_count);
1276    }
1277
1278    #[tokio::test(flavor = "multi_thread")]
1279    #[serial]
1280    async fn test_message_stream_dispose() {
1281        let subscription = create_subscription(false).await;
1282        let mut iter = subscription.subscribe(None).await.unwrap();
1283        iter.dispose().await;
1284        // no effect
1285        iter.dispose().await;
1286        assert!(iter.next().await.is_none());
1287    }
1288}