1use std::env::var;
2
3use google_cloud_gax::conn::{ConnectionOptions, Environment};
4use google_cloud_gax::grpc::Status;
5use google_cloud_gax::retry::RetrySetting;
6use google_cloud_googleapis::pubsub::v1::{
7 DetachSubscriptionRequest, ListSnapshotsRequest, ListSubscriptionsRequest, ListTopicsRequest, Snapshot,
8};
9use google_cloud_token::NopeTokenSourceProvider;
10
11use crate::apiv1::conn_pool::{ConnectionManager, PUBSUB};
12use crate::apiv1::publisher_client::PublisherClient;
13use crate::apiv1::subscriber_client::SubscriberClient;
14use crate::subscription::{Subscription, SubscriptionConfig};
15use crate::topic::{Topic, TopicConfig};
16
17#[derive(Debug)]
18pub struct ClientConfig {
19 pub pool_size: Option<usize>,
21 pub project_id: Option<String>,
23 pub environment: Environment,
25 pub endpoint: String,
27 pub connection_option: ConnectionOptions,
29}
30
31impl Default for ClientConfig {
33 fn default() -> Self {
34 let emulator = var("PUBSUB_EMULATOR_HOST").ok();
35 let default_project_id = emulator.as_ref().map(|_| "local-project".to_string());
36 Self {
37 pool_size: Some(4),
38 environment: match emulator {
39 Some(v) => Environment::Emulator(v),
40 None => Environment::GoogleCloud(Box::new(NopeTokenSourceProvider {})),
41 },
42 project_id: default_project_id,
43 endpoint: PUBSUB.to_string(),
44 connection_option: ConnectionOptions::default(),
45 }
46 }
47}
48
49#[cfg(feature = "auth")]
50pub use google_cloud_auth;
51
52#[cfg(feature = "auth")]
53impl ClientConfig {
54 pub async fn with_auth(mut self) -> Result<Self, google_cloud_auth::error::Error> {
55 if let Environment::GoogleCloud(_) = self.environment {
56 let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new(Self::auth_config()).await?;
57 self.project_id = self.project_id.or(ts.project_id.clone());
58 self.environment = Environment::GoogleCloud(Box::new(ts))
59 }
60 Ok(self)
61 }
62
63 pub async fn with_credentials(
64 mut self,
65 credentials: google_cloud_auth::credentials::CredentialsFile,
66 ) -> Result<Self, google_cloud_auth::error::Error> {
67 if let Environment::GoogleCloud(_) = self.environment {
68 let ts = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials(
69 Self::auth_config(),
70 Box::new(credentials),
71 )
72 .await?;
73 self.project_id = self.project_id.or(ts.project_id.clone());
74 self.environment = Environment::GoogleCloud(Box::new(ts))
75 }
76 Ok(self)
77 }
78
79 fn auth_config() -> google_cloud_auth::project::Config<'static> {
80 google_cloud_auth::project::Config::default()
81 .with_audience(crate::apiv1::conn_pool::AUDIENCE)
82 .with_scopes(&crate::apiv1::conn_pool::SCOPES)
83 }
84}
85
86#[derive(thiserror::Error, Debug)]
87pub enum Error {
88 #[error(transparent)]
89 GAX(#[from] google_cloud_gax::conn::Error),
90 #[error("Project ID was not found")]
91 ProjectIdNotFound,
92}
93
94#[derive(Clone, Debug)]
99pub struct Client {
100 project_id: String,
101 pubc: PublisherClient,
102 subc: SubscriberClient,
103}
104
105impl Client {
106 pub async fn new(config: ClientConfig) -> Result<Self, Error> {
108 let pool_size = config.pool_size.unwrap_or_default();
109
110 let pubc = PublisherClient::new(
111 ConnectionManager::new(
112 pool_size,
113 config.endpoint.as_str(),
114 &config.environment,
115 &config.connection_option,
116 )
117 .await?,
118 );
119 let subc = SubscriberClient::new(
120 ConnectionManager::new(
121 pool_size,
122 config.endpoint.as_str(),
123 &config.environment,
124 &config.connection_option,
125 )
126 .await?,
127 ConnectionManager::new(
128 pool_size,
129 config.endpoint.as_str(),
130 &config.environment,
131 &config.connection_option,
132 )
133 .await?,
134 );
135 Ok(Self {
136 project_id: config.project_id.ok_or(Error::ProjectIdNotFound)?,
137 pubc,
138 subc,
139 })
140 }
141
142 pub async fn create_subscription(
162 &self,
163 id: &str,
164 topic_id: &str,
165 cfg: SubscriptionConfig,
166 retry: Option<RetrySetting>,
167 ) -> Result<Subscription, Status> {
168 let subscription = self.subscription(id);
169 subscription
170 .create(self.fully_qualified_topic_name(topic_id).as_str(), cfg, retry)
171 .await
172 .map(|_v| subscription)
173 }
174
175 pub async fn get_subscriptions(&self, retry: Option<RetrySetting>) -> Result<Vec<Subscription>, Status> {
177 let req = ListSubscriptionsRequest {
178 project: self.fully_qualified_project_name(),
179 page_size: 0,
180 page_token: "".to_string(),
181 };
182 self.subc.list_subscriptions(req, retry).await.map(|v| {
183 v.into_iter()
184 .map(|x| Subscription::new(x.name, self.subc.clone()))
185 .collect()
186 })
187 }
188
189 pub fn subscription(&self, id: &str) -> Subscription {
191 Subscription::new(self.fully_qualified_subscription_name(id), self.subc.clone())
192 }
193
194 pub async fn detach_subscription(&self, fqsn: &str, retry: Option<RetrySetting>) -> Result<(), Status> {
199 let req = DetachSubscriptionRequest {
200 subscription: fqsn.to_string(),
201 };
202 self.pubc.detach_subscription(req, retry).await.map(|_v| ())
203 }
204
205 pub async fn create_topic(
215 &self,
216 id: &str,
217 cfg: Option<TopicConfig>,
218 retry: Option<RetrySetting>,
219 ) -> Result<Topic, Status> {
220 let topic = self.topic(id);
221 topic.create(cfg, retry).await.map(|_v| topic)
222 }
223
224 pub async fn get_topics(&self, retry: Option<RetrySetting>) -> Result<Vec<String>, Status> {
226 let req = ListTopicsRequest {
227 project: self.fully_qualified_project_name(),
228 page_size: 0,
229 page_token: "".to_string(),
230 };
231 self.pubc
232 .list_topics(req, retry)
233 .await
234 .map(|v| v.into_iter().map(|x| x.name).collect())
235 }
236
237 pub fn topic(&self, id: &str) -> Topic {
244 Topic::new(self.fully_qualified_topic_name(id), self.pubc.clone(), self.subc.clone())
245 }
246
247 pub async fn get_snapshots(&self, retry: Option<RetrySetting>) -> Result<Vec<Snapshot>, Status> {
252 let req = ListSnapshotsRequest {
253 project: self.fully_qualified_project_name(),
254 page_size: 0,
255 page_token: "".to_string(),
256 };
257 self.subc.list_snapshots(req, retry).await
258 }
259
260 pub fn fully_qualified_topic_name(&self, id: &str) -> String {
261 if id.contains('/') {
262 id.to_string()
263 } else {
264 format!("projects/{}/topics/{}", self.project_id, id)
265 }
266 }
267
268 pub fn fully_qualified_subscription_name(&self, id: &str) -> String {
269 if id.contains('/') {
270 id.to_string()
271 } else {
272 format!("projects/{}/subscriptions/{}", self.project_id, id)
273 }
274 }
275
276 fn fully_qualified_project_name(&self) -> String {
277 format!("projects/{}", self.project_id)
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use std::collections::HashMap;
284 use std::thread;
285 use std::time::Duration;
286
287 use serial_test::serial;
288 use tokio_util::sync::CancellationToken;
289 use uuid::Uuid;
290
291 use google_cloud_googleapis::pubsub::v1::PubsubMessage;
292
293 use crate::client::Client;
294 use crate::subscriber::SubscriberConfig;
295 use crate::subscription::{ReceiveConfig, SubscriptionConfig};
296
297 #[ctor::ctor]
298 fn init() {
299 let _ = tracing_subscriber::fmt().try_init();
300 }
301
302 async fn create_client() -> Client {
303 std::env::set_var("PUBSUB_EMULATOR_HOST", "localhost:8681");
304
305 Client::new(Default::default()).await.unwrap()
306 }
307
308 async fn do_publish_and_subscribe(ordering_key: &str, bulk: bool) {
309 let client = create_client().await;
310
311 let order = !ordering_key.is_empty();
312 let uuid = Uuid::new_v4().hyphenated().to_string();
314 let topic_id = &format!("t{}", &uuid);
315 let subscription_id = &format!("s{}", &uuid);
316 let topic = client.create_topic(topic_id.as_str(), None, None).await.unwrap();
317 let publisher = topic.new_publisher(None);
318 let config = SubscriptionConfig {
319 enable_message_ordering: !ordering_key.is_empty(),
320 ..Default::default()
321 };
322 let subscription = client
323 .create_subscription(subscription_id.as_str(), topic_id.as_str(), config, None)
324 .await
325 .unwrap();
326
327 let cancellation_token = CancellationToken::new();
328 let config = ReceiveConfig {
330 worker_count: 2,
331 channel_capacity: None,
332 subscriber_config: Some(SubscriberConfig {
333 ping_interval: Duration::from_secs(1),
334 ..Default::default()
335 }),
336 };
337 let cancel_receiver = cancellation_token.clone();
338 let (s, mut r) = tokio::sync::mpsc::channel(100);
339 let handle = tokio::spawn(async move {
340 let _ = subscription
341 .receive(
342 move |v, _ctx| {
343 let s2 = s.clone();
344 async move {
345 let _ = v.ack().await;
346 let data = std::str::from_utf8(&v.message.data).unwrap().to_string();
347 tracing::info!(
348 "tid={:?} id={} data={}",
349 thread::current().id(),
350 v.message.message_id,
351 data
352 );
353 let _ = s2.send(data).await;
354 }
355 },
356 cancel_receiver,
357 Some(config),
358 )
359 .await;
360 });
361
362 let awaiters = if bulk {
364 let messages = (0..100)
365 .map(|key| PubsubMessage {
366 data: format!("abc_{key}").into(),
367 ordering_key: ordering_key.to_string(),
368 ..Default::default()
369 })
370 .collect();
371 publisher.publish_bulk(messages).await
372 } else {
373 let mut awaiters = Vec::with_capacity(100);
374 for key in 0..100 {
375 let message = PubsubMessage {
376 data: format!("abc_{key}").into(),
377 ordering_key: ordering_key.into(),
378 ..Default::default()
379 };
380 awaiters.push(publisher.publish(message).await);
381 }
382 awaiters
383 };
384 for v in awaiters {
385 tracing::info!("sent message_id = {}", v.get().await.unwrap());
386 }
387
388 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
389 cancellation_token.cancel();
390 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
391
392 let mut count = 0;
393 while let Some(data) = r.recv().await {
394 tracing::debug!("{}", data);
395 if order {
396 assert_eq!(format!("abc_{count}"), data);
397 }
398 count += 1;
399 }
400 assert_eq!(count, 100);
401 let _ = handle.await;
402
403 let mut publisher = publisher;
404 publisher.shutdown().await;
405 }
406
407 #[tokio::test(flavor = "multi_thread")]
408 #[serial]
409 async fn test_publish_subscribe_ordered() {
410 do_publish_and_subscribe("ordering", false).await;
411 }
412
413 #[tokio::test(flavor = "multi_thread")]
414 #[serial]
415 async fn test_publish_subscribe_ordered_bulk() {
416 do_publish_and_subscribe("ordering", true).await;
417 }
418
419 #[tokio::test(flavor = "multi_thread")]
420 #[serial]
421 async fn test_publish_subscribe_random() {
422 do_publish_and_subscribe("", false).await;
423 }
424
425 #[tokio::test(flavor = "multi_thread")]
426 #[serial]
427 async fn test_publish_subscribe_random_bulk() {
428 do_publish_and_subscribe("", true).await;
429 }
430
431 #[tokio::test(flavor = "multi_thread")]
432 #[serial]
433 async fn test_lifecycle() {
434 let client = create_client().await;
435
436 let uuid = Uuid::new_v4().hyphenated().to_string();
437 let topic_id = &format!("t{}", &uuid);
438 let subscription_id = &format!("s{}", &uuid);
439 let snapshot_id = &format!("snap{}", &uuid);
440 let topics = client.get_topics(None).await.unwrap();
441 let subs = client.get_subscriptions(None).await.unwrap();
442 let snapshots = client.get_snapshots(None).await.unwrap();
443 let _topic = client.create_topic(topic_id.as_str(), None, None).await.unwrap();
444 let subscription = client
445 .create_subscription(subscription_id.as_str(), topic_id.as_str(), SubscriptionConfig::default(), None)
446 .await
447 .unwrap();
448
449 let _ = subscription
450 .create_snapshot(snapshot_id, HashMap::default(), None)
451 .await
452 .unwrap();
453
454 let topics_after = client.get_topics(None).await.unwrap();
455 let subs_after = client.get_subscriptions(None).await.unwrap();
456 let snapshots_after = client.get_snapshots(None).await.unwrap();
457 assert_eq!(1, topics_after.len() - topics.len());
458 assert_eq!(1, subs_after.len() - subs.len());
459 assert_eq!(1, snapshots_after.len() - snapshots.len());
460 }
461}
462
463#[cfg(test)]
464mod tests_in_gcp {
465 use crate::client::{Client, ClientConfig};
466 use crate::publisher::PublisherConfig;
467 use google_cloud_gax::conn::Environment;
468 use google_cloud_gax::grpc::codegen::tokio_stream::StreamExt;
469 use google_cloud_googleapis::pubsub::v1::PubsubMessage;
470 use serial_test::serial;
471 use std::collections::HashMap;
472
473 use std::time::Duration;
474 use tokio::select;
475 use tokio_util::sync::CancellationToken;
476
477 fn make_msg(key: &str) -> PubsubMessage {
478 PubsubMessage {
479 data: if key.is_empty() {
480 "empty".into()
481 } else {
482 key.to_string().into()
483 },
484 ordering_key: key.into(),
485 ..Default::default()
486 }
487 }
488
489 #[tokio::test]
490 #[ignore]
491 async fn test_with_auth() {
492 let config = ClientConfig::default().with_auth().await.unwrap();
493 if let Environment::Emulator(_) = config.environment {
494 unreachable!()
495 }
496 }
497
498 #[tokio::test]
499 #[serial]
500 #[ignore]
501 async fn test_publish_ordering_in_gcp_flush_buffer() {
502 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
503 .await
504 .unwrap();
505 let topic = client.topic("test-topic2");
506 let publisher = topic.new_publisher(Some(PublisherConfig {
507 flush_interval: Duration::from_secs(3),
508 workers: 3,
509 ..Default::default()
510 }));
511
512 let mut awaiters = vec![];
513 for key in ["", "key1", "key2", "key3", "key3"] {
514 awaiters.push(publisher.publish(make_msg(key)).await);
515 }
516 for awaiter in awaiters.into_iter() {
517 tracing::info!("msg id {}", awaiter.get().await.unwrap());
518 }
519
520 let mut awaiters = vec![];
522 for key in ["", "key1", "key2", "key3", "key3"] {
523 awaiters.push(publisher.publish(make_msg(key)).await);
524 }
525 for awaiter in awaiters.into_iter() {
526 tracing::info!("msg id {}", awaiter.get().await.unwrap());
527 }
528 }
529
530 #[tokio::test]
531 #[serial]
532 #[ignore]
533 async fn test_publish_ordering_in_gcp_limit_exceed() {
534 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
535 .await
536 .unwrap();
537 let topic = client.topic("test-topic2");
538 let publisher = topic.new_publisher(Some(PublisherConfig {
539 flush_interval: Duration::from_secs(30),
540 workers: 1,
541 bundle_size: 8,
542 ..Default::default()
543 }));
544
545 let mut awaiters = vec![];
546 for key in ["", "key1", "key2", "key3", "key1", "key2", "key3", ""] {
547 awaiters.push(publisher.publish(make_msg(key)).await);
548 }
549 for awaiter in awaiters.into_iter() {
550 tracing::info!("msg id {}", awaiter.get().await.unwrap());
551 }
552
553 let mut awaiters = vec![];
555 for key in ["", "key1", "key2", "key3", "key1", "key2", "key3", ""] {
556 awaiters.push(publisher.publish(make_msg(key)).await);
557 }
558 for awaiter in awaiters.into_iter() {
559 tracing::info!("msg id {}", awaiter.get().await.unwrap());
560 }
561 }
562
563 #[tokio::test]
564 #[serial]
565 #[ignore]
566 async fn test_publish_ordering_in_gcp_bulk() {
567 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
568 .await
569 .unwrap();
570 let topic = client.topic("test-topic2");
571 let publisher = topic.new_publisher(Some(PublisherConfig {
572 flush_interval: Duration::from_secs(30),
573 workers: 2,
574 bundle_size: 8,
575 ..Default::default()
576 }));
577
578 let msgs = ["", "", "key1", "key1", "key2", "key2", "key3", "key3"]
579 .map(make_msg)
580 .to_vec();
581 for awaiter in publisher.publish_bulk(msgs).await.into_iter() {
582 tracing::info!("msg id {}", awaiter.get().await.unwrap());
583 }
584
585 let msgs = ["", "", "key1", "key1", "key2", "key2", "key3", "key3"]
587 .map(make_msg)
588 .to_vec();
589 for awaiter in publisher.publish_bulk(msgs).await.into_iter() {
590 tracing::info!("msg id {}", awaiter.get().await.unwrap());
591 }
592 }
593 #[tokio::test]
594 #[serial]
595 #[ignore]
596 async fn test_subscribe_exactly_once_delivery() {
597 let client = Client::new(ClientConfig::default().with_auth().await.unwrap())
598 .await
599 .unwrap();
600
601 let subscription = client.subscription("eod-test");
603 let config = subscription.config(None).await.unwrap().1;
604 assert!(config.enable_exactly_once_delivery);
605
606 let ctx = CancellationToken::new();
608 let ctx_pub = ctx.clone();
609 let publisher = client.topic("eod-test").new_publisher(None);
610 let pub_task = tokio::spawn(async move {
611 tracing::info!("start publisher");
612 loop {
613 if ctx_pub.is_cancelled() {
614 tracing::info!("finish publisher");
615 return;
616 }
617 publisher
618 .publish_blocking(PubsubMessage {
619 data: "msg".into(),
620 ..Default::default()
621 })
622 .get()
623 .await
624 .unwrap();
625 }
626 });
627
628 let ctx_sub = ctx.child_token();
630 let sub_task = tokio::spawn(async move {
631 tracing::info!("start subscriber");
632 let mut stream = subscription.subscribe(None).await.unwrap();
633 let mut msgs = HashMap::new();
634 while let Some(message) = select! {
635 msg = stream.next() => msg,
636 _ = ctx_sub.cancelled() => None
637 } {
638 let msg_id = &message.message.message_id;
639 tokio::time::sleep(Duration::from_secs(1)).await;
641 *msgs.entry(msg_id.clone()).or_insert(0) += 1;
642 message.ack().await.unwrap();
643 }
644 stream.dispose().await;
645 tracing::info!("finish subscriber");
646 msgs
647 });
648
649 tokio::time::sleep(Duration::from_secs(60)).await;
650
651 ctx.cancel();
653 pub_task.await.unwrap();
654 let received_msgs = sub_task.await.unwrap();
655 assert!(!received_msgs.is_empty());
656
657 tracing::info!("Number of received messages = {}", received_msgs.len());
658 for (msg_id, count) in received_msgs {
659 assert_eq!(count, 1, "msg_id = {msg_id}, count = {count}");
660 }
661 }
662}