1use std::env::var;
2
3use google_cloud_gax::conn::{ConnectionOptions, Environment};
4use google_cloud_gax::grpc::Status;
5use google_cloud_gax::retry::RetrySetting;
6use google_cloud_googleapis::pubsub::v1::{
7 DetachSubscriptionRequest, ListSnapshotsRequest, ListSubscriptionsRequest, ListTopicsRequest, Snapshot,
8};
9use token_source::NoopTokenSourceProvider;
10
11use crate::apiv1::conn_pool::{ConnectionManager, PUBSUB};
12use crate::apiv1::publisher_client::PublisherClient;
13use crate::apiv1::subscriber_client::SubscriberClient;
14use crate::subscription::{Subscription, SubscriptionConfig};
15use crate::topic::{Topic, TopicConfig};
16
17#[derive(Debug)]
18pub struct ClientConfig {
19 pub pool_size: Option<usize>,
21 pub project_id: Option<String>,
23 pub environment: Environment,
25 pub endpoint: String,
27 pub connection_option: ConnectionOptions,
29}
30
31impl Default for ClientConfig {
33 fn default() -> Self {
34 let emulator = var("PUBSUB_EMULATOR_HOST").ok();
35 let default_project_id = emulator.as_ref().map(|_| "local-project".to_string());
36 Self {
37 pool_size: Some(4),
38 environment: match emulator {
39 Some(v) => Environment::Emulator(v),
40 None => Environment::GoogleCloud(Box::new(NoopTokenSourceProvider {})),
41 },
42 project_id: default_project_id,
43 endpoint: PUBSUB.to_string(),
44 connection_option: ConnectionOptions::default(),
45 }
46 }
47}
48
49#[cfg(feature = "auth")]
50pub use google_cloud_auth;
51
52#[cfg(feature = "auth")]
53impl ClientConfig {
54 pub async fn with_auth(mut self) -> Result<Self, google_cloud_auth::error::Error> {
55 if let Environment::GoogleCloud(_) = self.environment {
56 let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new(Self::auth_config()).await?;
57 self.project_id = self.project_id.or(ts.project_id.clone());
58 self.environment = Environment::GoogleCloud(Box::new(ts))
59 }
60 Ok(self)
61 }
62
63 pub async fn with_credentials(
64 mut self,
65 credentials: google_cloud_auth::credentials::CredentialsFile,
66 ) -> Result<Self, google_cloud_auth::error::Error> {
67 if let Environment::GoogleCloud(_) = self.environment {
68 let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials(
69 Self::auth_config(),
70 Box::new(credentials),
71 )
72 .await?;
73 self.project_id = self.project_id.or(ts.project_id.clone());
74 self.environment = Environment::GoogleCloud(Box::new(ts))
75 }
76 Ok(self)
77 }
78
79 fn auth_config() -> google_cloud_auth::project::Config<'static> {
80 google_cloud_auth::project::Config::default()
81 .with_audience(crate::apiv1::conn_pool::AUDIENCE)
82 .with_scopes(&crate::apiv1::conn_pool::SCOPES)
83 }
84}
85
86#[derive(thiserror::Error, Debug)]
87pub enum Error {
88 #[error(transparent)]
89 GAX(#[from] google_cloud_gax::conn::Error),
90 #[error("Project ID was not found")]
91 ProjectIdNotFound,
92}
93
94#[derive(Clone, Debug)]
99pub struct Client {
100 project_id: String,
101 pubc: PublisherClient,
102 subc: SubscriberClient,
103}
104
105impl Client {
106 pub async fn new(config: ClientConfig) -> Result<Self, Error> {
108 let pool_size = config.pool_size.unwrap_or_default();
109
110 let pubc = PublisherClient::new(
111 ConnectionManager::new(
112 pool_size,
113 config.endpoint.as_str(),
114 &config.environment,
115 &config.connection_option,
116 )
117 .await?,
118 );
119 let subc = SubscriberClient::new(
120 ConnectionManager::new(
121 pool_size,
122 config.endpoint.as_str(),
123 &config.environment,
124 &config.connection_option,
125 )
126 .await?,
127 ConnectionManager::new(
128 pool_size,
129 config.endpoint.as_str(),
130 &config.environment,
131 &config.connection_option,
132 )
133 .await?,
134 );
135 Ok(Self {
136 project_id: config.project_id.ok_or(Error::ProjectIdNotFound)?,
137 pubc,
138 subc,
139 })
140 }
141
142 pub async fn create_subscription(
162 &self,
163 id: &str,
164 topic_id: &str,
165 cfg: SubscriptionConfig,
166 retry: Option<RetrySetting>,
167 ) -> Result<Subscription, Status> {
168 let subscription = self.subscription(id);
169 subscription
170 .create(self.fully_qualified_topic_name(topic_id).as_str(), cfg, retry)
171 .await
172 .map(|_v| subscription)
173 }
174
175 pub async fn get_subscriptions(&self, retry: Option<RetrySetting>) -> Result<Vec<Subscription>, Status> {
177 let req = ListSubscriptionsRequest {
178 project: self.fully_qualified_project_name(),
179 page_size: 0,
180 page_token: "".to_string(),
181 };
182 self.subc.list_subscriptions(req, retry).await.map(|v| {
183 v.into_iter()
184 .map(|x| Subscription::new(x.name, self.subc.clone()))
185 .collect()
186 })
187 }
188
189 pub fn subscription(&self, id: &str) -> Subscription {
191 Subscription::new(self.fully_qualified_subscription_name(id), self.subc.clone())
192 }
193
194 pub async fn detach_subscription(&self, fqsn: &str, retry: Option<RetrySetting>) -> Result<(), Status> {
199 let req = DetachSubscriptionRequest {
200 subscription: fqsn.to_string(),
201 };
202 self.pubc.detach_subscription(req, retry).await.map(|_v| ())
203 }
204
205 pub async fn create_topic(
215 &self,
216 id: &str,
217 cfg: Option<TopicConfig>,
218 retry: Option<RetrySetting>,
219 ) -> Result<Topic, Status> {
220 let topic = self.topic(id);
221 topic.create(cfg, retry).await.map(|_v| topic)
222 }
223
224 pub async fn get_topics(&self, retry: Option<RetrySetting>) -> Result<Vec<String>, Status> {
226 let req = ListTopicsRequest {
227 project: self.fully_qualified_project_name(),
228 page_size: 0,
229 page_token: "".to_string(),
230 };
231 self.pubc
232 .list_topics(req, retry)
233 .await
234 .map(|v| v.into_iter().map(|x| x.name).collect())
235 }
236
237 pub fn topic(&self, id: &str) -> Topic {
244 Topic::new(self.fully_qualified_topic_name(id), self.pubc.clone(), self.subc.clone())
245 }
246
247 pub async fn get_snapshots(&self, retry: Option<RetrySetting>) -> Result<Vec<Snapshot>, Status> {
252 let req = ListSnapshotsRequest {
253 project: self.fully_qualified_project_name(),
254 page_size: 0,
255 page_token: "".to_string(),
256 };
257 self.subc.list_snapshots(req, retry).await
258 }
259
260 pub fn fully_qualified_topic_name(&self, id: &str) -> String {
261 if id.contains('/') {
262 id.to_string()
263 } else {
264 format!("projects/{}/topics/{}", self.project_id, id)
265 }
266 }
267
268 pub fn fully_qualified_subscription_name(&self, id: &str) -> String {
269 if id.contains('/') {
270 id.to_string()
271 } else {
272 format!("projects/{}/subscriptions/{}", self.project_id, id)
273 }
274 }
275
276 fn fully_qualified_project_name(&self) -> String {
277 format!("projects/{}", self.project_id)
278 }
279}
280
281#[allow(deprecated)]
282#[cfg(test)]
283mod tests {
284 use std::collections::HashMap;
285 use std::thread;
286 use std::time::Duration;
287
288 use serial_test::serial;
289 use tokio_util::sync::CancellationToken;
290 use uuid::Uuid;
291
292 use google_cloud_googleapis::pubsub::v1::PubsubMessage;
293
294 use crate::client::Client;
295 use crate::subscriber::SubscriberConfig;
296 use crate::subscription::{ReceiveConfig, SubscriptionConfig};
297
298 #[ctor::ctor]
299 fn init() {
300 let _ = tracing_subscriber::fmt().try_init();
301 }
302
303 async fn create_client() -> Client {
304 std::env::set_var("PUBSUB_EMULATOR_HOST", "localhost:8681");
305
306 Client::new(Default::default()).await.unwrap()
307 }
308
309 async fn do_publish_and_subscribe(ordering_key: &str, bulk: bool) {
310 let client = create_client().await;
311
312 let order = !ordering_key.is_empty();
313 let uuid = Uuid::new_v4().hyphenated().to_string();
315 let topic_id = &format!("t{}", &uuid);
316 let subscription_id = &format!("s{}", &uuid);
317 let topic = client.create_topic(topic_id.as_str(), None, None).await.unwrap();
318 let publisher = topic.new_publisher(None);
319 let config = SubscriptionConfig {
320 enable_message_ordering: !ordering_key.is_empty(),
321 ..Default::default()
322 };
323 let subscription = client
324 .create_subscription(subscription_id.as_str(), topic_id.as_str(), config, None)
325 .await
326 .unwrap();
327
328 let cancellation_token = CancellationToken::new();
329 let config = ReceiveConfig {
331 worker_count: 2,
332 channel_capacity: None,
333 subscriber_config: Some(SubscriberConfig {
334 ping_interval: Duration::from_secs(1),
335 ..Default::default()
336 }),
337 };
338 let cancel_receiver = cancellation_token.clone();
339 let (s, mut r) = tokio::sync::mpsc::channel(100);
340 let handle = tokio::spawn(async move {
341 let _ = subscription
342 .receive(
343 move |v, _ctx| {
344 let s2 = s.clone();
345 async move {
346 let _ = v.ack().await;
347 let data = std::str::from_utf8(&v.message.data).unwrap().to_string();
348 tracing::info!(
349 "tid={:?} id={} data={}",
350 thread::current().id(),
351 v.message.message_id,
352 data
353 );
354 let _ = s2.send(data).await;
355 }
356 },
357 cancel_receiver,
358 Some(config),
359 )
360 .await;
361 });
362
363 let awaiters = if bulk {
365 let messages = (0..100)
366 .map(|key| PubsubMessage {
367 data: format!("abc_{key}").into(),
368 ordering_key: ordering_key.to_string(),
369 ..Default::default()
370 })
371 .collect();
372 publisher.publish_bulk(messages).await
373 } else {
374 let mut awaiters = Vec::with_capacity(100);
375 for key in 0..100 {
376 let message = PubsubMessage {
377 data: format!("abc_{key}").into(),
378 ordering_key: ordering_key.into(),
379 ..Default::default()
380 };
381 awaiters.push(publisher.publish(message).await);
382 }
383 awaiters
384 };
385 for v in awaiters {
386 tracing::info!("sent message_id = {}", v.get().await.unwrap());
387 }
388
389 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
390 cancellation_token.cancel();
391 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
392
393 let mut count = 0;
394 while let Some(data) = r.recv().await {
395 tracing::debug!("{}", data);
396 if order {
397 assert_eq!(format!("abc_{count}"), data);
398 }
399 count += 1;
400 }
401 assert_eq!(count, 100);
402 let _ = handle.await;
403
404 let mut publisher = publisher;
405 publisher.shutdown().await;
406 }
407
408 #[tokio::test(flavor = "multi_thread")]
409 #[serial]
410 async fn test_publish_subscribe_ordered() {
411 do_publish_and_subscribe("ordering", false).await;
412 }
413
414 #[tokio::test(flavor = "multi_thread")]
415 #[serial]
416 async fn test_publish_subscribe_ordered_bulk() {
417 do_publish_and_subscribe("ordering", true).await;
418 }
419
420 #[tokio::test(flavor = "multi_thread")]
421 #[serial]
422 async fn test_publish_subscribe_random() {
423 do_publish_and_subscribe("", false).await;
424 }
425
426 #[tokio::test(flavor = "multi_thread")]
427 #[serial]
428 async fn test_publish_subscribe_random_bulk() {
429 do_publish_and_subscribe("", true).await;
430 }
431
432 #[tokio::test(flavor = "multi_thread")]
433 #[serial]
434 async fn test_lifecycle() {
435 let client = create_client().await;
436
437 let uuid = Uuid::new_v4().hyphenated().to_string();
438 let topic_id = &format!("t{}", &uuid);
439 let subscription_id = &format!("s{}", &uuid);
440 let snapshot_id = &format!("snap{}", &uuid);
441 let topics = client.get_topics(None).await.unwrap();
442 let subs = client.get_subscriptions(None).await.unwrap();
443 let snapshots = client.get_snapshots(None).await.unwrap();
444 let _topic = client.create_topic(topic_id.as_str(), None, None).await.unwrap();
445 let subscription = client
446 .create_subscription(subscription_id.as_str(), topic_id.as_str(), SubscriptionConfig::default(), None)
447 .await
448 .unwrap();
449
450 let _ = subscription
451 .create_snapshot(snapshot_id, HashMap::default(), None)
452 .await
453 .unwrap();
454
455 let topics_after = client.get_topics(None).await.unwrap();
456 let subs_after = client.get_subscriptions(None).await.unwrap();
457 let snapshots_after = client.get_snapshots(None).await.unwrap();
458 assert_eq!(1, topics_after.len() - topics.len());
459 assert_eq!(1, subs_after.len() - subs.len());
460 assert_eq!(1, snapshots_after.len() - snapshots.len());
461 }
462}
463
464#[cfg(test)]
465mod tests_in_gcp {
466 use crate::client::{Client, ClientConfig};
467 use crate::publisher::PublisherConfig;
468 use google_cloud_gax::conn::Environment;
469 use google_cloud_gax::grpc::codegen::tokio_stream::StreamExt;
470 use google_cloud_googleapis::pubsub::v1::PubsubMessage;
471 use serial_test::serial;
472 use std::collections::HashMap;
473
474 use std::time::Duration;
475 use tokio::select;
476 use tokio_util::sync::CancellationToken;
477
478 fn make_msg(key: &str) -> PubsubMessage {
479 PubsubMessage {
480 data: if key.is_empty() {
481 "empty".into()
482 } else {
483 key.to_string().into()
484 },
485 ordering_key: key.into(),
486 ..Default::default()
487 }
488 }
489
490 #[tokio::test]
491 #[ignore]
492 async fn test_with_auth() {
493 let config = ClientConfig::default().with_auth().await.unwrap();
494 if let Environment::Emulator(_) = config.environment {
495 unreachable!()
496 }
497 }
498
499 #[tokio::test]
500 #[serial]
501 #[ignore]
502 async fn test_publish_ordering_in_gcp_flush_buffer() {
503 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
504 .await
505 .unwrap();
506 let topic = client.topic("test-topic2");
507 let publisher = topic.new_publisher(Some(PublisherConfig {
508 flush_interval: Duration::from_secs(3),
509 workers: 3,
510 ..Default::default()
511 }));
512
513 let mut awaiters = vec![];
514 for key in ["", "key1", "key2", "key3", "key3"] {
515 awaiters.push(publisher.publish(make_msg(key)).await);
516 }
517 for awaiter in awaiters.into_iter() {
518 tracing::info!("msg id {}", awaiter.get().await.unwrap());
519 }
520
521 let mut awaiters = vec![];
523 for key in ["", "key1", "key2", "key3", "key3"] {
524 awaiters.push(publisher.publish(make_msg(key)).await);
525 }
526 for awaiter in awaiters.into_iter() {
527 tracing::info!("msg id {}", awaiter.get().await.unwrap());
528 }
529 }
530
531 #[tokio::test]
532 #[serial]
533 #[ignore]
534 async fn test_publish_ordering_in_gcp_limit_exceed() {
535 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
536 .await
537 .unwrap();
538 let topic = client.topic("test-topic2");
539 let publisher = topic.new_publisher(Some(PublisherConfig {
540 flush_interval: Duration::from_secs(30),
541 workers: 1,
542 bundle_size: 8,
543 ..Default::default()
544 }));
545
546 let mut awaiters = vec![];
547 for key in ["", "key1", "key2", "key3", "key1", "key2", "key3", ""] {
548 awaiters.push(publisher.publish(make_msg(key)).await);
549 }
550 for awaiter in awaiters.into_iter() {
551 tracing::info!("msg id {}", awaiter.get().await.unwrap());
552 }
553
554 let mut awaiters = vec![];
556 for key in ["", "key1", "key2", "key3", "key1", "key2", "key3", ""] {
557 awaiters.push(publisher.publish(make_msg(key)).await);
558 }
559 for awaiter in awaiters.into_iter() {
560 tracing::info!("msg id {}", awaiter.get().await.unwrap());
561 }
562 }
563
564 #[tokio::test]
565 #[serial]
566 #[ignore]
567 async fn test_publish_ordering_in_gcp_bulk() {
568 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
569 .await
570 .unwrap();
571 let topic = client.topic("test-topic2");
572 let publisher = topic.new_publisher(Some(PublisherConfig {
573 flush_interval: Duration::from_secs(30),
574 workers: 2,
575 bundle_size: 8,
576 ..Default::default()
577 }));
578
579 let msgs = ["", "", "key1", "key1", "key2", "key2", "key3", "key3"]
580 .map(make_msg)
581 .to_vec();
582 for awaiter in publisher.publish_bulk(msgs).await.into_iter() {
583 tracing::info!("msg id {}", awaiter.get().await.unwrap());
584 }
585
586 let msgs = ["", "", "key1", "key1", "key2", "key2", "key3", "key3"]
588 .map(make_msg)
589 .to_vec();
590 for awaiter in publisher.publish_bulk(msgs).await.into_iter() {
591 tracing::info!("msg id {}", awaiter.get().await.unwrap());
592 }
593 }
594 #[tokio::test]
595 #[serial]
596 #[ignore]
597 async fn test_subscribe_exactly_once_delivery() {
598 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
599 .await
600 .unwrap();
601
602 let subscription = client.subscription("eod-test");
604 let config = subscription.config(None).await.unwrap().1;
605 assert!(config.enable_exactly_once_delivery);
606
607 let ctx = CancellationToken::new();
609 let ctx_pub = ctx.clone();
610 let publisher = client.topic("eod-test").new_publisher(None);
611 let pub_task = tokio::spawn(async move {
612 tracing::info!("start publisher");
613 loop {
614 if ctx_pub.is_cancelled() {
615 tracing::info!("finish publisher");
616 return;
617 }
618 publisher
619 .publish_blocking(PubsubMessage {
620 data: "msg".into(),
621 ..Default::default()
622 })
623 .get()
624 .await
625 .unwrap();
626 }
627 });
628
629 let ctx_sub = ctx.child_token();
631 let sub_task = tokio::spawn(async move {
632 tracing::info!("start subscriber");
633 let mut stream = subscription.subscribe(None).await.unwrap();
634 let mut msgs = HashMap::new();
635 while let Some(message) = select! {
636 msg = stream.next() => msg,
637 _ = ctx_sub.cancelled() => None
638 } {
639 let msg_id = &message.message.message_id;
640 tokio::time::sleep(Duration::from_secs(1)).await;
642 *msgs.entry(msg_id.clone()).or_insert(0) += 1;
643 message.ack().await.unwrap();
644 }
645 stream.dispose().await;
646 tracing::info!("finish subscriber");
647 msgs
648 });
649
650 tokio::time::sleep(Duration::from_secs(60)).await;
651
652 ctx.cancel();
654 pub_task.await.unwrap();
655 let received_msgs = sub_task.await.unwrap();
656 assert!(!received_msgs.is_empty());
657
658 tracing::info!("Number of received messages = {}", received_msgs.len());
659 for (msg_id, count) in received_msgs {
660 assert_eq!(count, 1, "msg_id = {msg_id}, count = {count}");
661 }
662 }
663
664 #[tokio::test]
665 #[serial]
666 #[ignore]
667 async fn test_pull_empty() {
668 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
669 .await
670 .unwrap();
671 let subscription = client.subscription("pull-test");
672 let messages = subscription.pull(10, None).await.unwrap();
673 assert!(messages.is_empty());
674 }
675}