google_cloud_pubsub/subscriber/
session.rs1use super::builder::StreamingPull;
16use super::handler::{AckResult, AtLeastOnce, Handler};
17use super::lease_loop::LeaseLoop;
18use super::lease_state::LeaseOptions;
19use super::leaser::DefaultLeaser;
20use super::stream::open_stream;
21use super::stub::{Stub, TonicStreaming};
22use super::transport::Transport;
23use crate::google::pubsub::v1::StreamingPullRequest;
24use crate::model::PubsubMessage;
25use crate::{Error, Result};
26use gaxi::grpc::from_status::to_gax_error;
27use gaxi::prost::FromProto as _;
28use std::collections::VecDeque;
29use std::sync::Arc;
30use tokio::sync::mpsc::UnboundedSender;
31use tokio_util::sync::{CancellationToken, DropGuard};
32
33#[derive(Debug)]
51pub struct Session {
52 inner: Arc<Transport>,
54
55 initial_req: StreamingPullRequest,
57
58 stream: Option<<Transport as Stub>::Stream>,
69
70 pool: VecDeque<(PubsubMessage, Handler)>,
76
77 message_tx: UnboundedSender<String>,
80
81 ack_tx: UnboundedSender<AckResult>,
84
85 shutdown: CancellationToken,
88
89 _keepalive_guard: DropGuard,
93
94 _lease_loop: tokio::task::JoinHandle<()>,
100}
101
102impl Session {
103 pub(super) fn new(builder: StreamingPull) -> Self {
104 let shutdown = CancellationToken::new();
105 let inner = builder.inner;
106 let subscription = builder.subscription;
107
108 let leaser = DefaultLeaser::new(
109 inner.clone(),
110 subscription.clone(),
111 builder.ack_deadline_seconds,
112 );
113 let LeaseLoop {
114 handle: _lease_loop,
115 message_tx,
116 ack_tx,
117 } = LeaseLoop::new(leaser, LeaseOptions::default());
118
119 let initial_req = StreamingPullRequest {
120 subscription,
121 stream_ack_deadline_seconds: builder.ack_deadline_seconds,
122 max_outstanding_messages: builder.max_outstanding_messages,
123 max_outstanding_bytes: builder.max_outstanding_bytes,
124 client_id: builder.client_id,
125 ..Default::default()
126 };
127
128 Self {
129 inner,
130 initial_req,
131 stream: None,
132 pool: VecDeque::new(),
133 message_tx,
134 ack_tx,
135 shutdown: shutdown.clone(),
136 _keepalive_guard: shutdown.drop_guard(),
137 _lease_loop,
138 }
139 }
140
141 pub async fn next(&mut self) -> Option<Result<(PubsubMessage, Handler)>> {
163 loop {
164 if let Some(item) = self.pool.pop_front() {
166 return Some(Ok(item));
167 }
168 if let Err(e) = self.stream_next().await? {
170 return Some(Err(e));
171 }
172 }
173 }
174
175 async fn mut_stream(&mut self) -> Result<&mut <Transport as Stub>::Stream> {
179 if self.stream.is_none() {
180 let stream = open_stream(
181 self.inner.clone(),
182 self.initial_req.clone(),
183 self.shutdown.clone(),
184 )
185 .await?;
186 self.stream = Some(stream);
187 }
188 Ok(self
189 .stream
190 .as_mut()
191 .expect("`self.stream.is_some()` must be true"))
192 }
193
194 async fn stream_next(&mut self) -> Option<Result<()>> {
195 let resp = {
196 let stream = match self.mut_stream().await {
197 Ok(s) => s,
198 Err(e) => return Some(Err(e)),
200 };
201
202 match stream.next_message().await.transpose()? {
203 Ok(resp) => resp,
204 Err(e) => return Some(Err(to_gax_error(e))),
205 }
206 };
207
208 for rm in resp.received_messages {
209 let Some(message) = rm.message else {
210 continue;
217 };
218 let _ = self.message_tx.send(rm.ack_id.clone());
219 let message = match message.cnv().map_err(Error::deser) {
220 Ok(message) => message,
221 Err(e) => return Some(Err(e)),
222 };
223 self.pool.push_back((
224 message,
225 Handler::AtLeastOnce(AtLeastOnce {
226 ack_id: rm.ack_id,
227 ack_tx: self.ack_tx.clone(),
228 }),
229 ));
230 }
231 Some(Ok(()))
232 }
233
234 #[cfg(test)]
235 async fn close(self) -> anyhow::Result<()> {
239 drop(self._keepalive_guard);
241
242 drop(self.message_tx);
244
245 self._lease_loop.await?;
247
248 Ok(())
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::super::client::Subscriber;
255 use super::super::keepalive::KEEPALIVE_PERIOD;
256 use super::super::lease_state::tests::{test_id, test_ids};
257 use super::*;
258 use auth::credentials::anonymous::Builder as Anonymous;
259 use pubsub_grpc_mock::google::pubsub::v1;
260 use pubsub_grpc_mock::{MockSubscriber, start};
261 use tokio::sync::mpsc::{channel, unbounded_channel};
262 use tokio::task::JoinHandle;
263 use tokio::time::Duration;
264
265 fn sorted(mut v: Vec<String>) -> Vec<String> {
266 v.sort();
267 v
268 }
269
270 fn test_data(v: i32) -> bytes::Bytes {
271 bytes::Bytes::from(format!("data-{}", test_id(v)))
272 }
273
274 fn test_response(range: std::ops::Range<i32>) -> v1::StreamingPullResponse {
275 v1::StreamingPullResponse {
276 received_messages: range
277 .into_iter()
278 .map(|i| v1::ReceivedMessage {
279 ack_id: test_id(i),
280 message: Some(v1::PubsubMessage {
281 data: test_data(i).to_vec(),
282 ..Default::default()
283 }),
284 ..Default::default()
285 })
286 .collect(),
287 ..Default::default()
288 }
289 }
290
291 async fn test_client(endpoint: String) -> anyhow::Result<Subscriber> {
292 Ok(Subscriber::builder()
293 .with_endpoint(endpoint)
294 .with_credentials(Anonymous::new().build())
295 .build()
296 .await?)
297 }
298
299 #[tokio::test]
300 async fn error_starting_stream() -> anyhow::Result<()> {
301 let mut mock = MockSubscriber::new();
302 mock.expect_streaming_pull()
303 .return_once(|_| Err(tonic::Status::internal("fail")));
304 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
305 let client = test_client(endpoint).await?;
306 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
307 let err = session
308 .next()
309 .await
310 .expect("stream should not be empty")
311 .expect_err("the first streamed item should be an error");
312 assert!(err.status().is_some(), "{err:?}");
313 let status = err.status().unwrap();
314 assert_eq!(status.code, gax::error::rpc::Code::Internal);
315 assert_eq!(status.message, "fail");
316
317 Ok(())
318 }
319
320 #[tokio::test]
321 async fn initial_request() -> anyhow::Result<()> {
322 const MIB: i64 = 1024 * 1024;
323
324 let (recover_writes_tx, mut recover_writes_rx) = channel(1);
327
328 let mut mock = MockSubscriber::new();
329 mock.expect_streaming_pull().return_once(move |request| {
330 tokio::spawn(async move {
331 let mut request_rx = request.into_inner();
334 while let Some(request) = request_rx.recv().await {
335 recover_writes_tx
336 .send(request)
337 .await
338 .expect("forwarding writes always succeeds");
339 }
340 });
341 Err(tonic::Status::internal("fail"))
342 });
343
344 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
345 let client = test_client(endpoint).await?;
346 let _ = client
347 .streaming_pull("projects/p/subscriptions/s")
348 .set_ack_deadline_seconds(20)
349 .set_max_outstanding_messages(2000)
350 .set_max_outstanding_bytes(200 * MIB)
351 .start()
352 .next()
353 .await;
354
355 let initial_req = recover_writes_rx
356 .recv()
357 .await
358 .expect("should receive a request")?;
359 assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
360 assert_eq!(initial_req.stream_ack_deadline_seconds, 20);
361 assert_eq!(initial_req.max_outstanding_messages, 2000);
362 assert_eq!(initial_req.max_outstanding_bytes, 200 * MIB);
363 assert!(!initial_req.client_id.is_empty());
364
365 Ok(())
366 }
367
368 #[tokio::test(start_paused = true)]
369 async fn basic_success() -> anyhow::Result<()> {
370 let (response_tx, response_rx) = channel(10);
371 let (ack_tx, mut ack_rx) = unbounded_channel();
372
373 let mut mock = MockSubscriber::new();
374 mock.expect_streaming_pull()
375 .return_once(|_| Ok(tonic::Response::from(response_rx)));
376 mock.expect_acknowledge().returning(move |r| {
377 ack_tx
378 .send(r.into_inner())
379 .expect("sending on channel always succeeds");
380 Ok(tonic::Response::from(()))
381 });
382 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
383 let client = test_client(endpoint).await?;
384 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
385
386 response_tx.send(Ok(test_response(1..2))).await?;
387 response_tx.send(Ok(test_response(2..4))).await?;
388 response_tx.send(Ok(test_response(4..7))).await?;
389 drop(response_tx);
390
391 for i in 1..7 {
392 let (m, Handler::AtLeastOnce(h)) =
393 session.next().await.transpose()?.expect("message {i}/6");
394 assert_eq!(m.data, test_data(i));
395 assert_eq!(h.ack_id, test_id(i));
396 h.ack();
397 }
398 let end = session.next().await.transpose()?;
399 assert!(end.is_none(), "Received extra message: {end:?}");
400
401 session.close().await?;
403
404 let ack_req = ack_rx.try_recv()?;
406 assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
407 assert_eq!(sorted(ack_req.ack_ids), test_ids(1..7));
408
409 Ok(())
410 }
411
412 #[tokio::test(start_paused = true)]
413 async fn basic_lease_management() -> anyhow::Result<()> {
414 let (response_tx, response_rx) = channel(10);
415 let (ack_tx, mut ack_rx) = unbounded_channel();
416 let (nack_tx, mut nack_rx) = unbounded_channel();
417 let (extend_tx, mut extend_rx) = unbounded_channel();
418
419 let mut mock = MockSubscriber::new();
420 mock.expect_streaming_pull()
421 .return_once(|_| Ok(tonic::Response::from(response_rx)));
422 mock.expect_acknowledge().returning(move |r| {
423 ack_tx
424 .send(r.into_inner())
425 .expect("sending on channel always succeeds");
426 Ok(tonic::Response::from(()))
427 });
428 mock.expect_modify_ack_deadline().returning(move |r| {
429 let r = r.into_inner();
430 if r.ack_deadline_seconds == 0 {
431 nack_tx.send(r).expect("sending on channel always succeeds");
432 } else {
433 extend_tx
434 .send(r)
435 .expect("sending on channel always succeeds");
436 }
437 Ok(tonic::Response::from(()))
438 });
439 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
440 let client = test_client(endpoint).await?;
441 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
442
443 response_tx.send(Ok(test_response(0..30))).await?;
444 drop(response_tx);
445
446 for i in 0..10 {
448 let Some((_, Handler::AtLeastOnce(h))) = session.next().await.transpose()? else {
449 anyhow::bail!("expected message {i}")
450 };
451 h.ack();
452 }
453 for i in 10..20 {
455 let Some((_, Handler::AtLeastOnce(h))) = session.next().await.transpose()? else {
456 anyhow::bail!("expected message {i}")
457 };
458 h.nack();
459 }
460 let mut hold = Vec::new();
462 for i in 20..30 {
463 let Some((_, Handler::AtLeastOnce(h))) = session.next().await.transpose()? else {
464 anyhow::bail!("expected message {i}")
465 };
466 hold.push(h);
467 }
468
469 tokio::time::advance(Duration::from_secs(10)).await;
472
473 session.close().await?;
475
476 let ack_req = ack_rx.try_recv()?;
478 assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
479 assert_eq!(sorted(ack_req.ack_ids), test_ids(0..10));
480 assert!(ack_rx.is_empty());
481
482 let nack_req = nack_rx.try_recv()?;
484 assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
485 assert_eq!(nack_req.ack_deadline_seconds, 0);
486 assert_eq!(sorted(nack_req.ack_ids), test_ids(10..20));
487
488 let nack_req = nack_rx.try_recv()?;
490 assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
491 assert_eq!(nack_req.ack_deadline_seconds, 0);
492 assert_eq!(sorted(nack_req.ack_ids), test_ids(20..30));
493 assert!(nack_rx.is_empty());
494
495 let extend_req = extend_rx.try_recv()?;
497 assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
498 assert_eq!(extend_req.ack_deadline_seconds, 10);
499 assert_eq!(sorted(extend_req.ack_ids), test_ids(20..30));
500
501 Ok(())
502 }
503
504 #[tokio::test(start_paused = true)]
505 async fn delayed_responses() -> anyhow::Result<()> {
506 let (response_tx, response_rx) = channel(10);
510 let handle: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
511 tokio::time::sleep(Duration::from_millis(20)).await;
512 response_tx.send(Ok(test_response(1..2))).await?;
513 Ok(())
514 });
515
516 let mut mock = MockSubscriber::new();
517 mock.expect_streaming_pull()
518 .return_once(|_| Ok(tonic::Response::from(response_rx)));
519 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
520 let client = test_client(endpoint).await?;
521 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
522 let (m, Handler::AtLeastOnce(h)) = session
523 .next()
524 .await
525 .transpose()?
526 .expect("stream should wait for a message");
527 assert_eq!(m.data, test_data(1));
528 assert_eq!(h.ack_id, test_id(1));
529
530 handle.await??;
531
532 Ok(())
533 }
534
535 #[tokio::test]
536 async fn serves_messages_immediately() -> anyhow::Result<()> {
537 let (response_tx, response_rx) = channel(10);
542
543 let mut mock = MockSubscriber::new();
544 mock.expect_streaming_pull()
545 .return_once(|_| Ok(tonic::Response::from(response_rx)));
546 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
547 let client = test_client(endpoint).await?;
548 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
549
550 for i in 1..7 {
551 response_tx.send(Ok(test_response(i..i + 1))).await?;
552
553 let (m, Handler::AtLeastOnce(h)) =
554 session.next().await.transpose()?.expect("message {i}/6");
555 assert_eq!(m.data, test_data(i));
556 assert_eq!(h.ack_id, test_id(i));
557 }
558 drop(response_tx);
559 let end = session.next().await.transpose()?;
560 assert!(end.is_none(), "Received extra message: {end:?}");
561
562 Ok(())
563 }
564
565 #[tokio::test]
566 async fn handles_empty_response() -> anyhow::Result<()> {
567 let (response_tx, response_rx) = channel(10);
568
569 let mut mock = MockSubscriber::new();
570 mock.expect_streaming_pull()
571 .return_once(|_| Ok(tonic::Response::from(response_rx)));
572 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
573 let client = test_client(endpoint).await?;
574 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
575
576 response_tx.send(Ok(test_response(1..2))).await?;
577 response_tx.send(Ok(test_response(2..2))).await?;
579 response_tx.send(Ok(test_response(2..3))).await?;
580 drop(response_tx);
581
582 for i in 1..3 {
583 let (m, Handler::AtLeastOnce(h)) =
584 session.next().await.transpose()?.expect("message {i}/2");
585 assert_eq!(m.data, test_data(i));
586 assert_eq!(h.ack_id, test_id(i));
587 }
588 let end = session.next().await.transpose()?;
589 assert!(end.is_none(), "Received extra message: {end:?}");
590
591 Ok(())
592 }
593
594 #[tokio::test(start_paused = true)]
595 async fn handles_missing_message_field() -> anyhow::Result<()> {
596 let (response_tx, response_rx) = channel(10);
597 let (extend_tx, mut extend_rx) = unbounded_channel();
598
599 let bad = v1::StreamingPullResponse {
600 received_messages: vec![v1::ReceivedMessage {
601 ack_id: "ignored-ack-id".to_string(),
602 message: None,
603 ..Default::default()
604 }],
605 ..Default::default()
606 };
607
608 let mut mock = MockSubscriber::new();
609 mock.expect_streaming_pull()
610 .return_once(|_| Ok(tonic::Response::from(response_rx)));
611 mock.expect_acknowledge()
612 .returning(|_| Ok(tonic::Response::from(())));
613 mock.expect_modify_ack_deadline().returning(move |r| {
614 extend_tx
615 .send(r.into_inner())
616 .expect("sending on channel always succeeds");
617 Ok(tonic::Response::from(()))
618 });
619 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
620 let client = test_client(endpoint).await?;
621 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
622
623 response_tx.send(Ok(test_response(1..4))).await?;
624 response_tx.send(Ok(bad)).await?;
626 response_tx.send(Ok(test_response(4..7))).await?;
627 drop(response_tx);
628
629 for i in 1..7 {
630 let (m, Handler::AtLeastOnce(h)) =
631 session.next().await.transpose()?.expect("message {i}/6");
632 assert_eq!(m.data, test_data(i));
633 assert_eq!(h.ack_id, test_id(i));
634 }
635 let end = session.next().await.transpose()?;
636 assert!(end.is_none(), "Received extra message: {end:?}");
637
638 tokio::time::advance(Duration::from_secs(10)).await;
641
642 session.close().await?;
644
645 let extend_req = extend_rx.try_recv()?;
647 assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
648 assert_eq!(extend_req.ack_deadline_seconds, 10);
649 assert_eq!(sorted(extend_req.ack_ids), test_ids(1..7));
651
652 Ok(())
653 }
654
655 #[tokio::test]
656 async fn permanent_error_midstream() -> anyhow::Result<()> {
657 let (response_tx, response_rx) = channel(10);
658
659 let mut mock = MockSubscriber::new();
660 mock.expect_streaming_pull()
661 .return_once(|_| Ok(tonic::Response::from(response_rx)));
662 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
663 let client = test_client(endpoint).await?;
664 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
665
666 response_tx.send(Ok(test_response(1..4))).await?;
667 response_tx
668 .send(Err(tonic::Status::internal("fail")))
669 .await?;
670 drop(response_tx);
671
672 for i in 1..4 {
673 let (m, Handler::AtLeastOnce(h)) =
674 session.next().await.transpose()?.expect("message {i}/3");
675 assert_eq!(m.data, test_data(i));
676 assert_eq!(h.ack_id, test_id(i));
677 }
678 let err = session
679 .next()
680 .await
681 .transpose()
682 .expect_err("expected an error from stream");
683 assert!(err.status().is_some(), "{err:?}");
684 let status = err.status().unwrap();
685 assert_eq!(status.code, gax::error::rpc::Code::Internal);
686 assert_eq!(status.message, "fail");
687
688 Ok(())
689 }
690
691 #[tokio::test(start_paused = true)]
692 async fn keepalives() -> anyhow::Result<()> {
693 let (recover_writes_tx, mut recover_writes_rx) = channel(1);
696 let (response_tx, response_rx) = channel(10);
697
698 let mut mock = MockSubscriber::new();
699 mock.expect_streaming_pull().return_once(move |request| {
700 tokio::spawn(async move {
701 let mut request_rx = request.into_inner();
704 while let Some(request) = request_rx.recv().await {
705 recover_writes_tx
706 .send(request)
707 .await
708 .expect("forwarding writes always succeeds");
709 }
710 });
711 Ok(tonic::Response::from(response_rx))
712 });
713
714 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
715 let client = test_client(endpoint).await?;
716 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
717 response_tx.send(Ok(test_response(1..4))).await?;
718 let _ = session.next().await;
719
720 let initial_req = recover_writes_rx
721 .recv()
722 .await
723 .expect("should receive an initial request")?;
724 assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
725
726 tokio::time::advance(KEEPALIVE_PERIOD).await;
728 let keepalive_req = recover_writes_rx
729 .recv()
730 .await
731 .expect("should receive a keepalive request")?;
732 assert_eq!(keepalive_req, v1::StreamingPullRequest::default());
733
734 drop(session);
737
738 tokio::time::advance(4 * KEEPALIVE_PERIOD).await;
741 assert!(recover_writes_rx.is_empty());
742
743 Ok(())
744 }
745
746 #[tokio::test]
747 async fn client_id() -> anyhow::Result<()> {
748 let (recover_writes_tx, mut recover_writes_rx) = channel(10);
751 let recover_writes_tx = std::sync::Arc::new(tokio::sync::Mutex::new(recover_writes_tx));
752
753 let mut mock = MockSubscriber::new();
754 mock.expect_streaming_pull()
755 .times(3)
756 .returning(move |request| {
757 let tx = recover_writes_tx.clone();
758 tokio::spawn(async move {
759 let mut request_rx = request.into_inner();
762 while let Some(request) = request_rx.recv().await {
763 tx.lock()
764 .await
765 .send(request)
766 .await
767 .expect("forwarding writes always succeeds");
768 }
769 });
770 Err(tonic::Status::internal("fail"))
771 });
772
773 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
774
775 let c1 = test_client(endpoint.clone()).await?;
778 let _ = c1
779 .streaming_pull("projects/p/subscriptions/s")
780 .start()
781 .next()
782 .await;
783 let req1 = recover_writes_rx
784 .recv()
785 .await
786 .expect("should receive a request")?;
787 let _ = c1
788 .streaming_pull("projects/p/subscriptions/s")
789 .start()
790 .next()
791 .await;
792 let req2 = recover_writes_rx
793 .recv()
794 .await
795 .expect("should receive a request")?;
796 assert_eq!(req1.client_id, req2.client_id);
797
798 let c2 = test_client(endpoint).await?;
801 let _ = c2
802 .streaming_pull("projects/p/subscriptions/s")
803 .start()
804 .next()
805 .await;
806 let req3 = recover_writes_rx
807 .recv()
808 .await
809 .expect("should receive a request")?;
810 assert_ne!(req1.client_id, req3.client_id);
811
812 Ok(())
813 }
814
815 #[tokio::test(start_paused = true)]
816 async fn no_immediate_message() -> anyhow::Result<()> {
817 const TEST_TIMEOUT: Duration = Duration::from_secs(42);
818
819 let (_response_tx, response_rx) = channel(10);
820
821 let mut mock = MockSubscriber::new();
822 mock.expect_streaming_pull()
823 .return_once(move |_| Ok(tonic::Response::from(response_rx)));
824
825 let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
826 let client = test_client(endpoint).await?;
827 let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
828
829 let _ = tokio::time::timeout(TEST_TIMEOUT, session.next())
830 .await
831 .expect_err("next() should never yield.");
832
833 Ok(())
834 }
835}