google_cloud_pubsub/
client.rs

1use std::env::var;
2
3use google_cloud_gax::conn::{ConnectionOptions, Environment};
4use google_cloud_gax::grpc::Status;
5use google_cloud_gax::retry::RetrySetting;
6use google_cloud_googleapis::pubsub::v1::{
7    DetachSubscriptionRequest, ListSnapshotsRequest, ListSubscriptionsRequest, ListTopicsRequest, Snapshot,
8};
9use google_cloud_token::NopeTokenSourceProvider;
10
11use crate::apiv1::conn_pool::{ConnectionManager, PUBSUB};
12use crate::apiv1::publisher_client::PublisherClient;
13use crate::apiv1::subscriber_client::SubscriberClient;
14use crate::subscription::{Subscription, SubscriptionConfig};
15use crate::topic::{Topic, TopicConfig};
16
17#[derive(Debug)]
18pub struct ClientConfig {
19    /// gRPC channel pool size
20    pub pool_size: Option<usize>,
21    /// Pub/Sub project_id
22    pub project_id: Option<String>,
23    /// Runtime project info
24    pub environment: Environment,
25    /// Overriding service endpoint
26    pub endpoint: String,
27    /// gRPC connection option
28    pub connection_option: ConnectionOptions,
29}
30
31/// ClientConfigs created by default will prefer to use `PUBSUB_EMULATOR_HOST`
32impl Default for ClientConfig {
33    fn default() -> Self {
34        let emulator = var("PUBSUB_EMULATOR_HOST").ok();
35        let default_project_id = emulator.as_ref().map(|_| "local-project".to_string());
36        Self {
37            pool_size: Some(4),
38            environment: match emulator {
39                Some(v) => Environment::Emulator(v),
40                None => Environment::GoogleCloud(Box::new(NopeTokenSourceProvider {})),
41            },
42            project_id: default_project_id,
43            endpoint: PUBSUB.to_string(),
44            connection_option: ConnectionOptions::default(),
45        }
46    }
47}
48
49#[cfg(feature = "auth")]
50pub use google_cloud_auth;
51
52#[cfg(feature = "auth")]
53impl ClientConfig {
54    pub async fn with_auth(mut self) -> Result<Self, google_cloud_auth::error::Error> {
55        if let Environment::GoogleCloud(_) = self.environment {
56            let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new(Self::auth_config()).await?;
57            self.project_id = self.project_id.or(ts.project_id.clone());
58            self.environment = Environment::GoogleCloud(Box::new(ts))
59        }
60        Ok(self)
61    }
62
63    pub async fn with_credentials(
64        mut self,
65        credentials: google_cloud_auth::credentials::CredentialsFile,
66    ) -> Result<Self, google_cloud_auth::error::Error> {
67        if let Environment::GoogleCloud(_) = self.environment {
68            let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials(
69                Self::auth_config(),
70                Box::new(credentials),
71            )
72            .await?;
73            self.project_id = self.project_id.or(ts.project_id.clone());
74            self.environment = Environment::GoogleCloud(Box::new(ts))
75        }
76        Ok(self)
77    }
78
79    fn auth_config() -> google_cloud_auth::project::Config<'static> {
80        google_cloud_auth::project::Config::default()
81            .with_audience(crate::apiv1::conn_pool::AUDIENCE)
82            .with_scopes(&crate::apiv1::conn_pool::SCOPES)
83    }
84}
85
86#[derive(thiserror::Error, Debug)]
87pub enum Error {
88    #[error(transparent)]
89    GAX(#[from] google_cloud_gax::conn::Error),
90    #[error("Project ID was not found")]
91    ProjectIdNotFound,
92}
93
94/// Client is a Google Pub/Sub client scoped to a single project.
95///
96/// Clients should be reused rather than being created as needed.
97/// A Client may be shared by multiple tasks.
98#[derive(Clone, Debug)]
99pub struct Client {
100    project_id: String,
101    pubc: PublisherClient,
102    subc: SubscriberClient,
103}
104
105impl Client {
106    /// new creates a Pub/Sub client. See [`ClientConfig`] for more information.
107    pub async fn new(config: ClientConfig) -> Result<Self, Error> {
108        let pool_size = config.pool_size.unwrap_or_default();
109
110        let pubc = PublisherClient::new(
111            ConnectionManager::new(
112                pool_size,
113                config.endpoint.as_str(),
114                &config.environment,
115                &config.connection_option,
116            )
117            .await?,
118        );
119        let subc = SubscriberClient::new(
120            ConnectionManager::new(
121                pool_size,
122                config.endpoint.as_str(),
123                &config.environment,
124                &config.connection_option,
125            )
126            .await?,
127            ConnectionManager::new(
128                pool_size,
129                config.endpoint.as_str(),
130                &config.environment,
131                &config.connection_option,
132            )
133            .await?,
134        );
135        Ok(Self {
136            project_id: config.project_id.ok_or(Error::ProjectIdNotFound)?,
137            pubc,
138            subc,
139        })
140    }
141
142    /// create_subscription creates a new subscription on a topic.
143    ///
144    /// id is the name of the subscription to create. It must start with a letter,
145    /// and contain only letters ([A-Za-z]), numbers ([0-9]), dashes (-),
146    /// underscores (_), periods (.), tildes (~), plus (+) or percent signs (%). It
147    /// must be between 3 and 255 characters in length, and must not start with
148    /// "goog".
149    ///
150    /// cfg.ack_deadline is the maximum time after a subscriber receives a message before
151    /// the subscriber should acknowledge the message. It must be between 10 and 600
152    /// seconds (inclusive), and is rounded down to the nearest second. If the
153    /// provided ackDeadline is 0, then the default value of 10 seconds is used.
154    /// Note: messages which are obtained via Subscription.Receive need not be
155    /// acknowledged within this deadline, as the deadline will be automatically
156    /// extended.
157    ///
158    /// cfg.push_config may be set to configure this subscription for push delivery.
159    ///
160    /// If the subscription already exists an error will be returned.
161    pub async fn create_subscription(
162        &self,
163        id: &str,
164        topic_id: &str,
165        cfg: SubscriptionConfig,
166        retry: Option<RetrySetting>,
167    ) -> Result<Subscription, Status> {
168        let subscription = self.subscription(id);
169        subscription
170            .create(self.fully_qualified_topic_name(topic_id).as_str(), cfg, retry)
171            .await
172            .map(|_v| subscription)
173    }
174
175    /// subscriptions returns an iterator which returns all of the subscriptions for the client's project.
176    pub async fn get_subscriptions(&self, retry: Option<RetrySetting>) -> Result<Vec<Subscription>, Status> {
177        let req = ListSubscriptionsRequest {
178            project: self.fully_qualified_project_name(),
179            page_size: 0,
180            page_token: "".to_string(),
181        };
182        self.subc.list_subscriptions(req, retry).await.map(|v| {
183            v.into_iter()
184                .map(|x| Subscription::new(x.name, self.subc.clone()))
185                .collect()
186        })
187    }
188
189    /// subscription creates a reference to a subscription.
190    pub fn subscription(&self, id: &str) -> Subscription {
191        Subscription::new(self.fully_qualified_subscription_name(id), self.subc.clone())
192    }
193
194    /// detach_subscription detaches a subscription from its topic. All messages
195    /// retained in the subscription are dropped. Subsequent `Pull` and `StreamingPull`
196    /// requests will return FAILED_PRECONDITION. If the subscription is a push
197    /// subscription, pushes to the endpoint will stop.
198    pub async fn detach_subscription(&self, fqsn: &str, retry: Option<RetrySetting>) -> Result<(), Status> {
199        let req = DetachSubscriptionRequest {
200            subscription: fqsn.to_string(),
201        };
202        self.pubc.detach_subscription(req, retry).await.map(|_v| ())
203    }
204
205    /// create_topic creates a new topic.
206    ///
207    /// The specified topic ID must start with a letter, and contain only letters
208    /// ([A-Za-z]), numbers ([0-9]), dashes (-), underscores (_), periods (.),
209    /// tildes (~), plus (+) or percent signs (%). It must be between 3 and 255
210    /// characters in length, and must not start with "goog". For more information,
211    /// see: https://cloud.google.com/pubsub/docs/admin#resource_names
212    ///
213    /// If the topic already exists an error will be returned.
214    pub async fn create_topic(
215        &self,
216        id: &str,
217        cfg: Option<TopicConfig>,
218        retry: Option<RetrySetting>,
219    ) -> Result<Topic, Status> {
220        let topic = self.topic(id);
221        topic.create(cfg, retry).await.map(|_v| topic)
222    }
223
224    /// topics returns an iterator which returns all of the topics for the client's project.
225    pub async fn get_topics(&self, retry: Option<RetrySetting>) -> Result<Vec<String>, Status> {
226        let req = ListTopicsRequest {
227            project: self.fully_qualified_project_name(),
228            page_size: 0,
229            page_token: "".to_string(),
230        };
231        self.pubc
232            .list_topics(req, retry)
233            .await
234            .map(|v| v.into_iter().map(|x| x.name).collect())
235    }
236
237    /// topic creates a reference to a topic in the client's project.
238    ///
239    /// If a Topic's Publish method is called, it has background tasks
240    /// associated with it. Clean them up by calling topic.stop.
241    ///
242    /// Avoid creating many Topic instances if you use them to publish.
243    pub fn topic(&self, id: &str) -> Topic {
244        Topic::new(self.fully_qualified_topic_name(id), self.pubc.clone(), self.subc.clone())
245    }
246
247    /// get_snapshots lists the existing snapshots. Snapshots are used in Seek (at https://cloud.google.com/pubsub/docs/replay-overview) operations, which
248    /// allow you to manage message acknowledgments in bulk. That is, you can set
249    /// the acknowledgment state of messages in an existing subscription to the
250    /// state captured by a snapshot.
251    pub async fn get_snapshots(&self, retry: Option<RetrySetting>) -> Result<Vec<Snapshot>, Status> {
252        let req = ListSnapshotsRequest {
253            project: self.fully_qualified_project_name(),
254            page_size: 0,
255            page_token: "".to_string(),
256        };
257        self.subc.list_snapshots(req, retry).await
258    }
259
260    pub fn fully_qualified_topic_name(&self, id: &str) -> String {
261        if id.contains('/') {
262            id.to_string()
263        } else {
264            format!("projects/{}/topics/{}", self.project_id, id)
265        }
266    }
267
268    pub fn fully_qualified_subscription_name(&self, id: &str) -> String {
269        if id.contains('/') {
270            id.to_string()
271        } else {
272            format!("projects/{}/subscriptions/{}", self.project_id, id)
273        }
274    }
275
276    fn fully_qualified_project_name(&self) -> String {
277        format!("projects/{}", self.project_id)
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use std::collections::HashMap;
284    use std::thread;
285    use std::time::Duration;
286
287    use serial_test::serial;
288    use tokio_util::sync::CancellationToken;
289    use uuid::Uuid;
290
291    use google_cloud_googleapis::pubsub::v1::PubsubMessage;
292
293    use crate::client::Client;
294    use crate::subscriber::SubscriberConfig;
295    use crate::subscription::{ReceiveConfig, SubscriptionConfig};
296
297    #[ctor::ctor]
298    fn init() {
299        let _ = tracing_subscriber::fmt().try_init();
300    }
301
302    async fn create_client() -> Client {
303        std::env::set_var("PUBSUB_EMULATOR_HOST", "localhost:8681");
304
305        Client::new(Default::default()).await.unwrap()
306    }
307
308    async fn do_publish_and_subscribe(ordering_key: &str, bulk: bool) {
309        let client = create_client().await;
310
311        let order = !ordering_key.is_empty();
312        // create
313        let uuid = Uuid::new_v4().hyphenated().to_string();
314        let topic_id = &format!("t{}", &uuid);
315        let subscription_id = &format!("s{}", &uuid);
316        let topic = client.create_topic(topic_id.as_str(), None, None).await.unwrap();
317        let publisher = topic.new_publisher(None);
318        let config = SubscriptionConfig {
319            enable_message_ordering: !ordering_key.is_empty(),
320            ..Default::default()
321        };
322        let subscription = client
323            .create_subscription(subscription_id.as_str(), topic_id.as_str(), config, None)
324            .await
325            .unwrap();
326
327        let cancellation_token = CancellationToken::new();
328        //subscribe
329        let config = ReceiveConfig {
330            worker_count: 2,
331            channel_capacity: None,
332            subscriber_config: Some(SubscriberConfig {
333                ping_interval: Duration::from_secs(1),
334                ..Default::default()
335            }),
336        };
337        let cancel_receiver = cancellation_token.clone();
338        let (s, mut r) = tokio::sync::mpsc::channel(100);
339        let handle = tokio::spawn(async move {
340            let _ = subscription
341                .receive(
342                    move |v, _ctx| {
343                        let s2 = s.clone();
344                        async move {
345                            let _ = v.ack().await;
346                            let data = std::str::from_utf8(&v.message.data).unwrap().to_string();
347                            tracing::info!(
348                                "tid={:?} id={} data={}",
349                                thread::current().id(),
350                                v.message.message_id,
351                                data
352                            );
353                            let _ = s2.send(data).await;
354                        }
355                    },
356                    cancel_receiver,
357                    Some(config),
358                )
359                .await;
360        });
361
362        //publish
363        let awaiters = if bulk {
364            let messages = (0..100)
365                .map(|key| PubsubMessage {
366                    data: format!("abc_{key}").into(),
367                    ordering_key: ordering_key.to_string(),
368                    ..Default::default()
369                })
370                .collect();
371            publisher.publish_bulk(messages).await
372        } else {
373            let mut awaiters = Vec::with_capacity(100);
374            for key in 0..100 {
375                let message = PubsubMessage {
376                    data: format!("abc_{key}").into(),
377                    ordering_key: ordering_key.into(),
378                    ..Default::default()
379                };
380                awaiters.push(publisher.publish(message).await);
381            }
382            awaiters
383        };
384        for v in awaiters {
385            tracing::info!("sent message_id = {}", v.get().await.unwrap());
386        }
387
388        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
389        cancellation_token.cancel();
390        tokio::time::sleep(std::time::Duration::from_secs(10)).await;
391
392        let mut count = 0;
393        while let Some(data) = r.recv().await {
394            tracing::debug!("{}", data);
395            if order {
396                assert_eq!(format!("abc_{count}"), data);
397            }
398            count += 1;
399        }
400        assert_eq!(count, 100);
401        let _ = handle.await;
402
403        let mut publisher = publisher;
404        publisher.shutdown().await;
405    }
406
407    #[tokio::test(flavor = "multi_thread")]
408    #[serial]
409    async fn test_publish_subscribe_ordered() {
410        do_publish_and_subscribe("ordering", false).await;
411    }
412
413    #[tokio::test(flavor = "multi_thread")]
414    #[serial]
415    async fn test_publish_subscribe_ordered_bulk() {
416        do_publish_and_subscribe("ordering", true).await;
417    }
418
419    #[tokio::test(flavor = "multi_thread")]
420    #[serial]
421    async fn test_publish_subscribe_random() {
422        do_publish_and_subscribe("", false).await;
423    }
424
425    #[tokio::test(flavor = "multi_thread")]
426    #[serial]
427    async fn test_publish_subscribe_random_bulk() {
428        do_publish_and_subscribe("", true).await;
429    }
430
431    #[tokio::test(flavor = "multi_thread")]
432    #[serial]
433    async fn test_lifecycle() {
434        let client = create_client().await;
435
436        let uuid = Uuid::new_v4().hyphenated().to_string();
437        let topic_id = &format!("t{}", &uuid);
438        let subscription_id = &format!("s{}", &uuid);
439        let snapshot_id = &format!("snap{}", &uuid);
440        let topics = client.get_topics(None).await.unwrap();
441        let subs = client.get_subscriptions(None).await.unwrap();
442        let snapshots = client.get_snapshots(None).await.unwrap();
443        let _topic = client.create_topic(topic_id.as_str(), None, None).await.unwrap();
444        let subscription = client
445            .create_subscription(subscription_id.as_str(), topic_id.as_str(), SubscriptionConfig::default(), None)
446            .await
447            .unwrap();
448
449        let _ = subscription
450            .create_snapshot(snapshot_id, HashMap::default(), None)
451            .await
452            .unwrap();
453
454        let topics_after = client.get_topics(None).await.unwrap();
455        let subs_after = client.get_subscriptions(None).await.unwrap();
456        let snapshots_after = client.get_snapshots(None).await.unwrap();
457        assert_eq!(1, topics_after.len() - topics.len());
458        assert_eq!(1, subs_after.len() - subs.len());
459        assert_eq!(1, snapshots_after.len() - snapshots.len());
460    }
461}
462
463#[cfg(test)]
464mod tests_in_gcp {
465    use crate::client::{Client, ClientConfig};
466    use crate::publisher::PublisherConfig;
467    use google_cloud_gax::conn::Environment;
468    use google_cloud_gax::grpc::codegen::tokio_stream::StreamExt;
469    use google_cloud_googleapis::pubsub::v1::PubsubMessage;
470    use serial_test::serial;
471    use std::collections::HashMap;
472
473    use std::time::Duration;
474    use tokio::select;
475    use tokio_util::sync::CancellationToken;
476
477    fn make_msg(key: &str) -> PubsubMessage {
478        PubsubMessage {
479            data: if key.is_empty() {
480                "empty".into()
481            } else {
482                key.to_string().into()
483            },
484            ordering_key: key.into(),
485            ..Default::default()
486        }
487    }
488
489    #[tokio::test]
490    #[ignore]
491    async fn test_with_auth() {
492        let config = ClientConfig::default().with_auth().await.unwrap();
493        if let Environment::Emulator(_) = config.environment {
494            unreachable!()
495        }
496    }
497
498    #[tokio::test]
499    #[serial]
500    #[ignore]
501    async fn test_publish_ordering_in_gcp_flush_buffer() {
502        let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
503            .await
504            .unwrap();
505        let topic = client.topic("test-topic2");
506        let publisher = topic.new_publisher(Some(PublisherConfig {
507            flush_interval: Duration::from_secs(3),
508            workers: 3,
509            ..Default::default()
510        }));
511
512        let mut awaiters = vec![];
513        for key in ["", "key1", "key2", "key3", "key3"] {
514            awaiters.push(publisher.publish(make_msg(key)).await);
515        }
516        for awaiter in awaiters.into_iter() {
517            tracing::info!("msg id {}", awaiter.get().await.unwrap());
518        }
519
520        // check same key
521        let mut awaiters = vec![];
522        for key in ["", "key1", "key2", "key3", "key3"] {
523            awaiters.push(publisher.publish(make_msg(key)).await);
524        }
525        for awaiter in awaiters.into_iter() {
526            tracing::info!("msg id {}", awaiter.get().await.unwrap());
527        }
528    }
529
530    #[tokio::test]
531    #[serial]
532    #[ignore]
533    async fn test_publish_ordering_in_gcp_limit_exceed() {
534        let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
535            .await
536            .unwrap();
537        let topic = client.topic("test-topic2");
538        let publisher = topic.new_publisher(Some(PublisherConfig {
539            flush_interval: Duration::from_secs(30),
540            workers: 1,
541            bundle_size: 8,
542            ..Default::default()
543        }));
544
545        let mut awaiters = vec![];
546        for key in ["", "key1", "key2", "key3", "key1", "key2", "key3", ""] {
547            awaiters.push(publisher.publish(make_msg(key)).await);
548        }
549        for awaiter in awaiters.into_iter() {
550            tracing::info!("msg id {}", awaiter.get().await.unwrap());
551        }
552
553        // check same key twice
554        let mut awaiters = vec![];
555        for key in ["", "key1", "key2", "key3", "key1", "key2", "key3", ""] {
556            awaiters.push(publisher.publish(make_msg(key)).await);
557        }
558        for awaiter in awaiters.into_iter() {
559            tracing::info!("msg id {}", awaiter.get().await.unwrap());
560        }
561    }
562
563    #[tokio::test]
564    #[serial]
565    #[ignore]
566    async fn test_publish_ordering_in_gcp_bulk() {
567        let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
568            .await
569            .unwrap();
570        let topic = client.topic("test-topic2");
571        let publisher = topic.new_publisher(Some(PublisherConfig {
572            flush_interval: Duration::from_secs(30),
573            workers: 2,
574            bundle_size: 8,
575            ..Default::default()
576        }));
577
578        let msgs = ["", "", "key1", "key1", "key2", "key2", "key3", "key3"]
579            .map(make_msg)
580            .to_vec();
581        for awaiter in publisher.publish_bulk(msgs).await.into_iter() {
582            tracing::info!("msg id {}", awaiter.get().await.unwrap());
583        }
584
585        // check same key twice
586        let msgs = ["", "", "key1", "key1", "key2", "key2", "key3", "key3"]
587            .map(make_msg)
588            .to_vec();
589        for awaiter in publisher.publish_bulk(msgs).await.into_iter() {
590            tracing::info!("msg id {}", awaiter.get().await.unwrap());
591        }
592    }
593    #[tokio::test]
594    #[serial]
595    #[ignore]
596    async fn test_subscribe_exactly_once_delivery() {
597        let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
598            .await
599            .unwrap();
600
601        // Check if the subscription is exactly_once_delivery
602        let subscription = client.subscription("eod-test");
603        let config = subscription.config(None).await.unwrap().1;
604        assert!(config.enable_exactly_once_delivery);
605
606        // publish message
607        let ctx = CancellationToken::new();
608        let ctx_pub = ctx.clone();
609        let publisher = client.topic("eod-test").new_publisher(None);
610        let pub_task = tokio::spawn(async move {
611            tracing::info!("start publisher");
612            loop {
613                if ctx_pub.is_cancelled() {
614                    tracing::info!("finish publisher");
615                    return;
616                }
617                publisher
618                    .publish_blocking(PubsubMessage {
619                        data: "msg".into(),
620                        ..Default::default()
621                    })
622                    .get()
623                    .await
624                    .unwrap();
625            }
626        });
627
628        // subscribe message
629        let ctx_sub = ctx.child_token();
630        let sub_task = tokio::spawn(async move {
631            tracing::info!("start subscriber");
632            let mut stream = subscription.subscribe(None).await.unwrap();
633            let mut msgs = HashMap::new();
634            while let Some(message) = select! {
635                msg = stream.next() => msg,
636                _ = ctx_sub.cancelled() => None
637            } {
638                let msg_id = &message.message.message_id;
639                // heavy task
640                tokio::time::sleep(Duration::from_secs(1)).await;
641                *msgs.entry(msg_id.clone()).or_insert(0) += 1;
642                message.ack().await.unwrap();
643            }
644            stream.dispose().await;
645            tracing::info!("finish subscriber");
646            msgs
647        });
648
649        tokio::time::sleep(Duration::from_secs(60)).await;
650
651        // check redelivered messages
652        ctx.cancel();
653        pub_task.await.unwrap();
654        let received_msgs = sub_task.await.unwrap();
655        assert!(!received_msgs.is_empty());
656
657        tracing::info!("Number of received messages = {}", received_msgs.len());
658        for (msg_id, count) in received_msgs {
659            assert_eq!(count, 1, "msg_id = {msg_id}, count = {count}");
660        }
661    }
662}