Skip to main content

google_cloud_pubsub/subscriber/
message_stream.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::builder::Subscribe;
16use super::handler::{Action, AtLeastOnce, ExactlyOnce, Handler};
17use super::lease_loop::LeaseLoop;
18use super::lease_state::{ExactlyOnceInfo, LeaseInfo, LeaseOptions, NewMessage};
19use super::leaser::DefaultLeaser;
20use super::retry_policy::StreamRetryPolicy;
21use super::stream::Stream;
22use super::stub::TonicStreaming as _;
23use super::transport::Transport;
24use crate::google::pubsub::v1::{StreamingPullRequest, StreamingPullResponse};
25use crate::model::Message;
26use crate::{Error, Result};
27use futures::FutureExt;
28use futures::future::{BoxFuture, Shared};
29use gaxi::grpc::from_status::to_gax_error;
30use gaxi::prost::FromProto as _;
31use google_cloud_gax::retry_result::RetryResult;
32use std::collections::VecDeque;
33use std::sync::Arc;
34use tokio::sync::mpsc::{UnboundedSender, unbounded_channel};
35use tokio::time::Instant;
36use tokio_util::sync::CancellationToken;
37
38/// Represents an open subscribe stream.
39///
40/// This is a stream-like struct for serving messages to an application.
41///
42/// # Example
43/// ```
44/// # use google_cloud_pubsub::client::Subscriber;
45/// # async fn sample(client: Subscriber) -> anyhow::Result<()> {
46/// let mut stream = client
47///     .subscribe("projects/my-project/subscriptions/my-subscription")
48///     .build();
49/// while let Some((m, h)) = stream.next().await.transpose()? {
50///     println!("Received message m={m:?}");
51///     h.ack();
52/// }
53/// # Ok(()) }
54/// ```
55#[derive(Debug)]
56pub struct MessageStream {
57    /// Implementation of the `MessageStream`.
58    ///
59    /// To avoid atomic increments in the critical path, we separate the
60    /// shutdown token from the rest of the struct. This way we can hold a
61    /// mutable reference to `self.inner`, and a reference to `self.shutdown` at
62    /// the same time.
63    inner: MessageStreamImpl,
64
65    #[allow(dead_code)] // TODO(#5024) - implementation in progress...
66    /// This future is ready when the lease loop shutdown completes.
67    lease_loop: Shared<BoxFuture<'static, ()>>,
68
69    /// A token that can detect a shutdown from the application.
70    shutdown: CancellationToken,
71}
72
73#[derive(Debug)]
74pub struct MessageStreamImpl {
75    /// The stub implementing this struct.
76    stub: Arc<Transport>,
77
78    /// The initial request used to start a stream.
79    initial_req: StreamingPullRequest,
80
81    /// The bidirectional stream.
82    ///
83    /// We choose to lazy-initialize the stream when the application asks for a
84    /// message because tonic will not yield the stream to us until the first
85    /// response is available.[^1]
86    ///
87    /// The usability of the `MessageStream` API would suffer if creating an instance
88    /// of `MessageStream` is blocked on the first message being available.
89    ///
90    /// [^1]: <https://github.com/hyperium/tonic/issues/515>
91    stream: Option<StreamState>,
92
93    /// Applications ask for messages one at a time. Individual stream responses
94    /// can contain multiple messages. We use `pool` to hold the extra messages
95    /// while we wait to serve them to applications.
96    ///
97    /// A FIFO queue is necessary to preserve ordering.
98    pool: VecDeque<(Message, Handler)>,
99
100    /// A sender for sending new messages from the stream into the lease
101    /// management task.
102    message_tx: UnboundedSender<NewMessage>,
103
104    /// A sender for forwarding acks/nacks from the application to the lease
105    /// management task. Each `Handler` holds a clone of this.
106    ack_tx: UnboundedSender<Action>,
107}
108
109// We would rather always allocate enough space to hold the stream on the stack
110// than add a layer of indirection by `Box`ing it.
111#[allow(clippy::large_enum_variant)]
112#[derive(Debug)]
113enum StreamState {
114    /// The stream was cancelled or failed with a permanent error. It should not
115    /// be re-opened.
116    Closed,
117    /// The stream is active.
118    Active(Stream<Transport>),
119}
120
121impl MessageStream {
122    pub(super) fn new(builder: Subscribe) -> Self {
123        let stub = builder.inner;
124        let subscription = builder.subscription;
125
126        let (confirmed_tx, confirmed_rx) = unbounded_channel();
127        let leaser = DefaultLeaser::new(
128            stub.clone(),
129            confirmed_tx,
130            subscription.clone(),
131            builder.ack_deadline_seconds,
132            builder.grpc_subchannel_count,
133        );
134        let options = LeaseOptions {
135            max_lease: builder.max_lease,
136            shutdown_behavior: builder.shutdown_behavior,
137            ..Default::default()
138        };
139        let LeaseLoop {
140            handle,
141            message_tx,
142            ack_tx,
143        } = LeaseLoop::new(leaser, confirmed_rx, options);
144        let lease_loop = handle.map(|_| ()).boxed().shared();
145
146        let initial_req = StreamingPullRequest {
147            subscription,
148            stream_ack_deadline_seconds: builder.ack_deadline_seconds,
149            max_outstanding_messages: builder.max_outstanding_messages,
150            max_outstanding_bytes: builder.max_outstanding_bytes,
151            client_id: builder.client_id,
152            // `protocol_version == 1` means we support receiving heartbeats
153            // (empty `StreamingPullResponse`s) from the server.
154            protocol_version: 1,
155            ..Default::default()
156        };
157
158        let inner = MessageStreamImpl {
159            stub,
160            initial_req,
161            stream: None,
162            pool: VecDeque::new(),
163            message_tx,
164            ack_tx,
165        };
166        Self {
167            inner,
168            lease_loop,
169            shutdown: CancellationToken::new(),
170        }
171    }
172
173    /// Returns the next message received on this subscription.
174    ///
175    /// # Example
176    /// ```
177    /// # use google_cloud_pubsub::subscriber::MessageStream;
178    /// # async fn sample(mut stream: MessageStream) -> anyhow::Result<()> {
179    /// while let Some((m, h)) = stream.next().await.transpose()? {
180    ///     println!("Received message m={m:?}");
181    ///     h.ack();
182    /// }
183    /// # Ok(()) }
184    /// ```
185    ///
186    /// Returns the message data along with a [Handler] to acknowledge (ack) the
187    /// message.
188    ///
189    /// If the underlying stream encounters a permanent error, an `Error` is
190    /// returned instead.
191    ///
192    /// `None` represents the end of a stream, but in practice, the stream stays
193    /// open until it is cancelled or encounters a permanent error.
194    pub async fn next(&mut self) -> Option<Result<(Message, Handler)>> {
195        let next = tokio::select! {
196            biased;
197            _ = self.shutdown.cancelled() => None,
198            n = self.inner.next() => n,
199        };
200        if next.is_none() {
201            // Permanently close the stream, and drop messages that we haven't
202            // delivered to the application yet.
203            self.inner.stream = Some(StreamState::Closed);
204            self.inner.pool.clear();
205        }
206        next
207    }
208
209    #[cfg(feature = "unstable-stream")]
210    #[cfg_attr(docsrs, doc(cfg(feature = "unstable-stream")))]
211    /// Converts the `MessageStream` to a [`futures::Stream`].
212    ///
213    /// # Example
214    /// ```
215    /// # use google_cloud_pubsub::subscriber::MessageStream;
216    /// # async fn sample(stream: MessageStream) -> anyhow::Result<()> {
217    /// use futures::TryStreamExt;
218    /// let mut stream = stream.into_stream();
219    /// while let Some((m, h)) = stream.try_next().await? { /* ... */ }
220    /// # Ok(()) }
221    /// ```
222    pub fn into_stream(self) -> impl futures::Stream<Item = Result<(Message, Handler)>> + Unpin {
223        use futures::stream::unfold;
224        Box::pin(unfold(Some(self), move |state| async move {
225            if let Some(mut this) = state {
226                if let Some(chunk) = this.next().await {
227                    return Some((chunk, Some(this)));
228                }
229            };
230            None
231        }))
232    }
233
234    #[cfg(test)]
235    /// Close the stream, awaiting all pending acks and nacks.
236    ///
237    /// This is a useful method for setting clean test expectations.
238    async fn close(self) {
239        // Shutdown the stream and its keepalive task.
240        drop(self.inner.stream);
241
242        // Signal a shutdown to the lease management background task.
243        drop(self.inner.message_tx);
244        drop(self.inner.ack_tx);
245
246        // Wait for the lease management task to complete.
247        self.lease_loop.await;
248    }
249}
250
251impl MessageStreamImpl {
252    async fn next(&mut self) -> Option<Result<(Message, Handler)>> {
253        loop {
254            // Serve a message if we have one ready.
255            if let Some(item) = self.pool.pop_front() {
256                return Some(Ok(item));
257            }
258
259            // Otherwise, read the next response from the stream, which will
260            // likely populate the message pool.
261            //
262            // Note that a successful read does not necessarily mean there is a
263            // message in the pool. The server occasionally sends heartbeats
264            // (responses with an empty message list). Hence the loop.
265            if let Err(e) = self.populate_pool().await? {
266                // Handle errors opening or reading from the stream.
267                match StreamRetryPolicy::on_midstream_error(e) {
268                    RetryResult::Continue(_) => {
269                        // The stream failed with a transient error. Reset the stream.
270                        self.stream = None;
271                        continue;
272                    }
273                    RetryResult::Permanent(e) | RetryResult::Exhausted(e) => {
274                        // The stream failed with a permanent error. Return the error.
275                        self.stream = Some(StreamState::Closed);
276                        return Some(Err(e));
277                    }
278                }
279            }
280        }
281    }
282
283    /// Make a new attempt to open the underlying gRPC stream.
284    async fn open_stream(&mut self) -> Result<()> {
285        let stream = Stream::<Transport>::new(self.stub.clone(), self.initial_req.clone()).await?;
286        self.stream = Some(StreamState::Active(stream));
287        Ok(())
288    }
289
290    /// Reads the next response from the stream.
291    ///
292    /// If necessary, this method will open a new stream.
293    ///
294    /// If we receive an error either opening or reading from the stream, we
295    /// return it.
296    async fn next_response(&mut self) -> Option<Result<StreamingPullResponse>> {
297        if self.stream.is_none() {
298            // Open the stream, if necessary.
299            if let Err(e) = self.open_stream().await {
300                return Some(Err(e));
301            }
302        }
303
304        let stream = match self.stream.as_mut()? {
305            StreamState::Closed => return None,
306            StreamState::Active(s) => s,
307        };
308        stream
309            .next_message()
310            .await
311            .map_err(to_gax_error)
312            .transpose()
313    }
314
315    /// Populate the message pool by reading from the stream.
316    ///
317    /// Read the next response from the stream. If necessary, this method will
318    /// open a new stream.
319    ///
320    /// If we receive a response, we store the messages in `self.pool` and
321    /// forward the ack IDs to the lease management task.
322    ///
323    /// If we receive an error reading from the stream, we return it.
324    async fn populate_pool(&mut self) -> Option<Result<()>> {
325        // Read the next response from the stream.
326        let resp = match self.next_response().await? {
327            Ok(resp) => resp,
328            Err(e) => return Some(Err(e)),
329        };
330
331        let exactly_once = resp
332            .subscription_properties
333            .is_some_and(|m| m.exactly_once_delivery_enabled);
334
335        // Process the received messages in the response.
336        for rm in resp.received_messages {
337            let Some(message) = rm.message else {
338                // The message field should always be present. If not, the proto
339                // message was corrupted while in transit, or there is a bug in
340                // the service.
341                //
342                // The client can just ignore an ack ID without an associated
343                // message.
344                continue;
345            };
346
347            let (lease_info, handler) = if exactly_once {
348                let (result_tx, result_rx) = tokio::sync::oneshot::channel();
349                (
350                    LeaseInfo::ExactlyOnce(ExactlyOnceInfo::new(result_tx)),
351                    Handler::ExactlyOnce(ExactlyOnce::new(
352                        rm.ack_id.clone(),
353                        self.ack_tx.clone(),
354                        result_rx,
355                    )),
356                )
357            } else {
358                (
359                    LeaseInfo::AtLeastOnce(Instant::now()),
360                    Handler::AtLeastOnce(AtLeastOnce::new(rm.ack_id.clone(), self.ack_tx.clone())),
361                )
362            };
363
364            let _ = self.message_tx.send(NewMessage {
365                ack_id: rm.ack_id,
366                lease_info,
367            });
368            let message = match message.cnv().map_err(Error::deser) {
369                Ok(message) => message,
370                Err(e) => return Some(Err(e)),
371            };
372            self.pool.push_back((message, handler));
373        }
374        Some(Ok(()))
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::super::ShutdownBehavior;
381    use super::super::client::Subscriber;
382    use super::super::keepalive::KEEPALIVE_PERIOD;
383    use super::super::lease_state::tests::{test_id, test_ids};
384    use super::super::stream::{INITIAL_DELAY, MAXIMUM_DELAY};
385    use super::*;
386    use gaxi::grpc::tonic::{Response as TonicResponse, Status as TonicStatus};
387    use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
388    use google_cloud_test_macros::tokio_test_no_panics;
389    use pubsub_grpc_mock::google::pubsub::v1;
390    use pubsub_grpc_mock::{MockSubscriber, start};
391    use tokio::sync::mpsc::{channel, unbounded_channel};
392    use tokio::task::{JoinHandle, JoinSet};
393    use tokio::time::{Duration, Instant};
394
395    fn sorted(mut v: Vec<String>) -> Vec<String> {
396        v.sort();
397        v
398    }
399
400    fn test_data(v: i32) -> bytes::Bytes {
401        bytes::Bytes::from(format!("data-{}", test_id(v)))
402    }
403
404    fn test_response(range: std::ops::Range<i32>) -> v1::StreamingPullResponse {
405        v1::StreamingPullResponse {
406            received_messages: range
407                .into_iter()
408                .map(|i| v1::ReceivedMessage {
409                    ack_id: test_id(i),
410                    message: Some(v1::PubsubMessage {
411                        data: test_data(i).to_vec(),
412                        ..Default::default()
413                    }),
414                    ..Default::default()
415                })
416                .collect(),
417            ..Default::default()
418        }
419    }
420
421    fn test_exactly_once_response(range: std::ops::Range<i32>) -> v1::StreamingPullResponse {
422        v1::StreamingPullResponse {
423            subscription_properties: Some(v1::streaming_pull_response::SubscriptionProperties {
424                exactly_once_delivery_enabled: true,
425                ..Default::default()
426            }),
427            received_messages: range
428                .into_iter()
429                .map(|i| v1::ReceivedMessage {
430                    ack_id: test_id(i),
431                    message: Some(v1::PubsubMessage {
432                        data: test_data(i).to_vec(),
433                        ..Default::default()
434                    }),
435                    ..Default::default()
436                })
437                .collect(),
438            ..Default::default()
439        }
440    }
441
442    async fn test_client(endpoint: String) -> anyhow::Result<Subscriber> {
443        Ok(Subscriber::builder()
444            .with_endpoint(endpoint)
445            .with_credentials(Anonymous::new().build())
446            .build()
447            .await?)
448    }
449
450    #[tokio_test_no_panics]
451    async fn error_starting_stream() -> anyhow::Result<()> {
452        let mut mock = MockSubscriber::new();
453        mock.expect_streaming_pull()
454            .return_once(|_| Err(TonicStatus::failed_precondition("fail")));
455        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
456        let client = test_client(endpoint).await?;
457        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
458        let err = stream
459            .next()
460            .await
461            .expect("stream should not be empty")
462            .expect_err("the first streamed item should be an error");
463        assert!(err.status().is_some(), "{err:?}");
464        let status = err.status().unwrap();
465        assert_eq!(
466            status.code,
467            google_cloud_gax::error::rpc::Code::FailedPrecondition
468        );
469        assert_eq!(status.message, "fail");
470
471        Ok(())
472    }
473
474    #[tokio_test_no_panics]
475    async fn permanent_error_ends_stream() -> anyhow::Result<()> {
476        let mut mock = MockSubscriber::new();
477        mock.expect_streaming_pull()
478            .returning(|_| Err(TonicStatus::failed_precondition("fail")));
479        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
480        let client = test_client(endpoint).await?;
481        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
482        let next = stream.next().await;
483        assert!(
484            matches!(next, Some(Err(_))),
485            "expected permanent error, got {next:?}"
486        );
487
488        let next = stream.next().await;
489        assert!(next.is_none(), "expected end of stream, got {next:?}");
490
491        Ok(())
492    }
493
494    #[tokio_test_no_panics]
495    async fn initial_request() -> anyhow::Result<()> {
496        const MIB: i64 = 1024 * 1024;
497
498        // We use this channel to surface writes (requests) from outside our
499        // mock expectation.
500        let (recover_writes_tx, mut recover_writes_rx) = channel(1);
501
502        let mut mock = MockSubscriber::new();
503        mock.expect_streaming_pull().return_once(move |request| {
504            tokio::spawn(async move {
505                // Note that this task stays alive as long as we hold
506                // `recover_writes_rx`.
507                let mut request_rx = request.into_inner();
508                while let Some(request) = request_rx.recv().await {
509                    recover_writes_tx
510                        .send(request)
511                        .await
512                        .expect("forwarding writes always succeeds");
513                }
514            });
515            Err(TonicStatus::failed_precondition("fail"))
516        });
517
518        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
519        let client = test_client(endpoint).await?;
520        let _ = client
521            .subscribe("projects/p/subscriptions/s")
522            .set_max_lease_extension(Duration::from_secs(20))
523            .set_max_outstanding_messages(2000)
524            .set_max_outstanding_bytes(200 * MIB)
525            .build()
526            .next()
527            .await;
528
529        let initial_req = recover_writes_rx
530            .recv()
531            .await
532            .expect("should receive a request")?;
533        assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
534        assert_eq!(initial_req.stream_ack_deadline_seconds, 20);
535        assert_eq!(initial_req.max_outstanding_messages, 2000);
536        assert_eq!(initial_req.max_outstanding_bytes, 200 * MIB);
537        assert!(
538            !initial_req.client_id.is_empty(),
539            "initial request has empty client id: {initial_req:?}"
540        );
541        assert!(
542            initial_req.protocol_version >= 1,
543            "protocol_version={}",
544            initial_req.protocol_version
545        );
546
547        Ok(())
548    }
549
550    #[tokio_test_no_panics(start_paused = true)]
551    async fn basic_success() -> anyhow::Result<()> {
552        let (response_tx, response_rx) = channel(10);
553        let (ack_tx, mut ack_rx) = unbounded_channel();
554
555        let mut mock = MockSubscriber::new();
556        mock.expect_streaming_pull()
557            .return_once(|_| Ok(TonicResponse::from(response_rx)));
558        mock.expect_acknowledge().returning(move |r| {
559            ack_tx
560                .send(r.into_inner())
561                .expect("sending on channel always succeeds");
562            Ok(TonicResponse::from(()))
563        });
564        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
565        let client = test_client(endpoint).await?;
566        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
567
568        response_tx.send(Ok(test_response(1..2))).await?;
569        response_tx.send(Ok(test_response(2..4))).await?;
570        response_tx.send(Ok(test_response(4..7))).await?;
571        drop(response_tx);
572
573        for i in 1..7 {
574            let Some((m, h)) = stream.next().await.transpose()? else {
575                anyhow::bail!("expected message {i}/6")
576            };
577            assert_eq!(m.data, test_data(i));
578            assert_eq!(h.ack_id(), test_id(i));
579            h.ack();
580        }
581        let end = stream.next().await.transpose()?;
582        assert!(end.is_none(), "Received extra message: {end:?}");
583
584        // Wait for the stream to join its background tasks.
585        stream.close().await;
586
587        // Verify the acks went through.
588        let ack_req = ack_rx.try_recv()?;
589        assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
590        assert_eq!(sorted(ack_req.ack_ids), test_ids(1..7));
591
592        Ok(())
593    }
594
595    #[tokio_test_no_panics(start_paused = true)]
596    async fn basic_success_exactly_once() -> anyhow::Result<()> {
597        let (response_tx, response_rx) = channel(10);
598        let (ack_tx, mut ack_rx) = unbounded_channel();
599
600        let mut mock = MockSubscriber::new();
601        mock.expect_streaming_pull()
602            .return_once(|_| Ok(TonicResponse::from(response_rx)));
603        mock.expect_acknowledge().returning(move |r| {
604            ack_tx
605                .send(r.into_inner())
606                .expect("sending on channel always succeeds");
607            Ok(TonicResponse::from(()))
608        });
609        mock.expect_modify_ack_deadline()
610            .returning(|_| Ok(TonicResponse::from(())));
611        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
612        let client = test_client(endpoint).await?;
613        let mut stream = client
614            .subscribe("projects/p/subscriptions/s")
615            .set_shutdown_behavior(ShutdownBehavior::WaitForProcessing)
616            .build();
617
618        response_tx
619            .send(Ok(test_exactly_once_response(1..2)))
620            .await?;
621        response_tx
622            .send(Ok(test_exactly_once_response(2..4)))
623            .await?;
624        response_tx
625            .send(Ok(test_exactly_once_response(4..7)))
626            .await?;
627        drop(response_tx);
628
629        let mut acks = JoinSet::new();
630        for i in 1..7 {
631            let Some((m, Handler::ExactlyOnce(h))) = stream.next().await.transpose()? else {
632                anyhow::bail!("expected message {i}/6")
633            };
634            assert_eq!(m.data, test_data(i));
635            assert_eq!(h.ack_id(), test_id(i));
636            acks.spawn(h.confirmed_ack());
637        }
638        let end = stream.next().await.transpose()?;
639        assert!(end.is_none(), "Received extra message: {end:?}");
640
641        // Wait for the stream to join its background tasks.
642        stream.close().await;
643
644        // Verify the acks went through.
645        while let Some(r) = acks.join_next().await {
646            r??;
647        }
648        let ack_req = ack_rx.try_recv()?;
649        assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
650        assert_eq!(sorted(ack_req.ack_ids), test_ids(1..7));
651
652        Ok(())
653    }
654
655    #[tokio_test_no_panics(start_paused = true)]
656    async fn basic_lease_management() -> anyhow::Result<()> {
657        let (response_tx, response_rx) = channel(10);
658        let (ack_tx, mut ack_rx) = unbounded_channel();
659        let (nack_tx, mut nack_rx) = unbounded_channel();
660        let (extend_tx, mut extend_rx) = unbounded_channel();
661
662        let mut mock = MockSubscriber::new();
663        mock.expect_streaming_pull()
664            .return_once(|_| Ok(TonicResponse::from(response_rx)));
665        mock.expect_acknowledge().returning(move |r| {
666            ack_tx
667                .send(r.into_inner())
668                .expect("sending on channel always succeeds");
669            Ok(TonicResponse::from(()))
670        });
671        mock.expect_modify_ack_deadline().returning(move |r| {
672            let r = r.into_inner();
673            if r.ack_deadline_seconds == 0 {
674                nack_tx.send(r).expect("sending on channel always succeeds");
675            } else {
676                extend_tx
677                    .send(r)
678                    .expect("sending on channel always succeeds");
679            }
680            Ok(TonicResponse::from(()))
681        });
682        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
683        let client = test_client(endpoint).await?;
684        let mut stream = client
685            .subscribe("projects/p/subscriptions/s")
686            .set_max_lease_extension(Duration::from_secs(10))
687            .set_shutdown_behavior(ShutdownBehavior::NackImmediately)
688            .build();
689
690        response_tx.send(Ok(test_response(0..30))).await?;
691        drop(response_tx);
692
693        // Ack some messages
694        for i in 0..10 {
695            let Some((_, Handler::AtLeastOnce(h))) = stream.next().await.transpose()? else {
696                anyhow::bail!("expected message {i}")
697            };
698            h.ack();
699        }
700        // Nack some messages
701        for i in 10..20 {
702            let Some((_, Handler::AtLeastOnce(h))) = stream.next().await.transpose()? else {
703                anyhow::bail!("expected message {i}")
704            };
705            drop(h);
706        }
707        // Take a long time to process some messages
708        let mut hold = Vec::new();
709        for i in 20..30 {
710            let Some((_, Handler::AtLeastOnce(h))) = stream.next().await.transpose()? else {
711                anyhow::bail!("expected message {i}")
712            };
713            hold.push(h);
714        }
715
716        // Advance the clock 10s, which is the stream ack deadline. In this
717        // time, we should attempt at least one lease extension RPC.
718        tokio::time::advance(Duration::from_secs(10)).await;
719
720        // Close the stream, to make sure pending operations complete.
721        stream.close().await;
722
723        // Verify the acks went through.
724        let ack_req = ack_rx.try_recv()?;
725        assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
726        assert_eq!(sorted(ack_req.ack_ids), test_ids(0..10));
727        assert!(ack_rx.is_empty(), "{ack_rx:?}");
728
729        // Verify the initial nacks went through.
730        let nack_req = nack_rx.try_recv()?;
731        assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
732        assert_eq!(nack_req.ack_deadline_seconds, 0);
733        assert_eq!(sorted(nack_req.ack_ids), test_ids(10..20));
734
735        // Verify that we nack the leftover messages when the stream shuts down.
736        let nack_req = nack_rx.try_recv()?;
737        assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
738        assert_eq!(nack_req.ack_deadline_seconds, 0);
739        assert_eq!(sorted(nack_req.ack_ids), test_ids(20..30));
740        assert!(nack_rx.is_empty(), "{nack_rx:?}");
741
742        // Verify at least one lease extension attempt was made.
743        let extend_req = extend_rx.try_recv()?;
744        assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
745        assert_eq!(extend_req.ack_deadline_seconds, 10);
746        assert_eq!(sorted(extend_req.ack_ids), test_ids(20..30));
747
748        Ok(())
749    }
750
751    #[tokio_test_no_panics(start_paused = true)]
752    async fn delayed_responses() -> anyhow::Result<()> {
753        // In this test, we verify the case where an application asks for a
754        // message, but a response is not immediately available on the stream.
755
756        let (response_tx, response_rx) = channel(10);
757        let handle: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
758            tokio::time::sleep(Duration::from_millis(20)).await;
759            response_tx.send(Ok(test_response(1..2))).await?;
760            Ok(())
761        });
762
763        let mut mock = MockSubscriber::new();
764        mock.expect_streaming_pull()
765            .return_once(|_| Ok(TonicResponse::from(response_rx)));
766        mock.expect_modify_ack_deadline()
767            .returning(|_| Ok(TonicResponse::from(())));
768        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
769        let client = test_client(endpoint).await?;
770        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
771        let (m, h) = stream
772            .next()
773            .await
774            .transpose()?
775            .expect("stream should wait for a message");
776        assert_eq!(m.data, test_data(1));
777        assert_eq!(h.ack_id(), test_id(1));
778
779        handle.await??;
780
781        Ok(())
782    }
783
784    #[tokio_test_no_panics]
785    async fn serves_messages_immediately() -> anyhow::Result<()> {
786        // This test verifies we do not do something crazy like draining the
787        // stream (which would never end) before serving messages to the
788        // application.
789
790        let (response_tx, response_rx) = channel(10);
791
792        let mut mock = MockSubscriber::new();
793        mock.expect_streaming_pull()
794            .return_once(|_| Ok(TonicResponse::from(response_rx)));
795        mock.expect_modify_ack_deadline()
796            .returning(|_| Ok(TonicResponse::from(())));
797        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
798        let client = test_client(endpoint).await?;
799        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
800
801        for i in 1..7 {
802            response_tx.send(Ok(test_response(i..i + 1))).await?;
803
804            let Some((m, h)) = stream.next().await.transpose()? else {
805                anyhow::bail!("expected message {i}/6")
806            };
807            assert_eq!(m.data, test_data(i));
808            assert_eq!(h.ack_id(), test_id(i));
809        }
810        drop(response_tx);
811        let end = stream.next().await.transpose()?;
812        assert!(end.is_none(), "Received extra message: {end:?}");
813
814        Ok(())
815    }
816
817    #[tokio_test_no_panics]
818    async fn handles_empty_response() -> anyhow::Result<()> {
819        let (response_tx, response_rx) = channel(10);
820
821        let mut mock = MockSubscriber::new();
822        mock.expect_streaming_pull()
823            .return_once(|_| Ok(TonicResponse::from(response_rx)));
824        mock.expect_modify_ack_deadline()
825            .returning(|_| Ok(TonicResponse::from(())));
826        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
827        let client = test_client(endpoint).await?;
828        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
829
830        response_tx.send(Ok(test_response(1..2))).await?;
831        // See if we can handle an empty range
832        response_tx.send(Ok(test_response(2..2))).await?;
833        response_tx.send(Ok(test_response(2..3))).await?;
834        drop(response_tx);
835
836        for i in 1..3 {
837            let Some((m, h)) = stream.next().await.transpose()? else {
838                anyhow::bail!("expected message {i}/2")
839            };
840            assert_eq!(m.data, test_data(i));
841            assert_eq!(h.ack_id(), test_id(i));
842        }
843        let end = stream.next().await.transpose()?;
844        assert!(end.is_none(), "Received extra message: {end:?}");
845
846        Ok(())
847    }
848
849    #[tokio_test_no_panics(start_paused = true)]
850    async fn handles_missing_message_field() -> anyhow::Result<()> {
851        let (response_tx, response_rx) = channel(10);
852        let (extend_tx, mut extend_rx) = unbounded_channel();
853
854        let bad = v1::StreamingPullResponse {
855            received_messages: vec![v1::ReceivedMessage {
856                ack_id: "ignored-ack-id".to_string(),
857                message: None,
858                ..Default::default()
859            }],
860            ..Default::default()
861        };
862
863        let mut mock = MockSubscriber::new();
864        mock.expect_streaming_pull()
865            .return_once(|_| Ok(TonicResponse::from(response_rx)));
866        mock.expect_modify_ack_deadline().returning(move |r| {
867            let r = r.into_inner();
868            if r.ack_deadline_seconds != 0 {
869                extend_tx
870                    .send(r)
871                    .expect("sending on channel always succeeds");
872            }
873            Ok(TonicResponse::from(()))
874        });
875        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
876        let client = test_client(endpoint).await?;
877        let mut stream = client
878            .subscribe("projects/p/subscriptions/s")
879            .set_max_lease_extension(Duration::from_secs(10))
880            .set_shutdown_behavior(ShutdownBehavior::NackImmediately)
881            .build();
882
883        response_tx.send(Ok(test_response(1..4))).await?;
884        // See if we can handle an empty range
885        response_tx.send(Ok(bad)).await?;
886        response_tx.send(Ok(test_response(4..7))).await?;
887        drop(response_tx);
888
889        let mut handlers = Vec::new();
890        for i in 1..7 {
891            let Some((m, h)) = stream.next().await.transpose()? else {
892                anyhow::bail!("expected message {i}/6")
893            };
894            assert_eq!(m.data, test_data(i));
895            assert_eq!(h.ack_id(), test_id(i));
896            handlers.push(h);
897        }
898        let end = stream.next().await.transpose()?;
899        assert!(end.is_none(), "Received extra message: {end:?}");
900
901        // Advance the clock 10s, which is the stream ack deadline. In this
902        // time, we should attempt at least one lease extension RPC.
903        tokio::time::advance(Duration::from_secs(10)).await;
904
905        // Close the stream, to make sure pending operations complete.
906        stream.close().await;
907
908        // Verify at least one lease extension attempt was made.
909        let extend_req = extend_rx.try_recv()?;
910        assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
911        assert_eq!(extend_req.ack_deadline_seconds, 10);
912        // Note that we do not expect to see "ignored-ack-id".
913        assert_eq!(sorted(extend_req.ack_ids), test_ids(1..7));
914
915        Ok(())
916    }
917
918    #[tokio_test_no_panics]
919    async fn permanent_error_midstream() -> anyhow::Result<()> {
920        let (response_tx, response_rx) = channel(10);
921
922        let mut mock = MockSubscriber::new();
923        mock.expect_streaming_pull()
924            .return_once(|_| Ok(TonicResponse::from(response_rx)));
925        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
926        let client = test_client(endpoint).await?;
927        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
928
929        response_tx.send(Ok(test_response(1..4))).await?;
930        response_tx
931            .send(Err(TonicStatus::failed_precondition("fail")))
932            .await?;
933        drop(response_tx);
934
935        for i in 1..4 {
936            let Some((m, h)) = stream.next().await.transpose()? else {
937                anyhow::bail!("expected message {i}/3")
938            };
939            assert_eq!(m.data, test_data(i));
940            assert_eq!(h.ack_id(), test_id(i));
941        }
942        let err = stream
943            .next()
944            .await
945            .transpose()
946            .expect_err("expected an error from stream");
947        assert!(err.status().is_some(), "{err:?}");
948        let status = err.status().unwrap();
949        assert_eq!(
950            status.code,
951            google_cloud_gax::error::rpc::Code::FailedPrecondition
952        );
953        assert_eq!(status.message, "fail");
954
955        Ok(())
956    }
957
958    #[tokio_test_no_panics(start_paused = true)]
959    async fn keepalives() -> anyhow::Result<()> {
960        // We use this channel to surface writes (requests) from outside our
961        // mock expectation.
962        let (recover_writes_tx, mut recover_writes_rx) = channel(1);
963        let (response_tx, response_rx) = channel(10);
964
965        let mut mock = MockSubscriber::new();
966        mock.expect_streaming_pull().return_once(move |request| {
967            tokio::spawn(async move {
968                // Note that this task stays alive as long as we hold
969                // `recover_writes_rx`.
970                let mut request_rx = request.into_inner();
971                while let Some(request) = request_rx.recv().await {
972                    recover_writes_tx
973                        .send(request)
974                        .await
975                        .expect("forwarding writes always succeeds");
976                }
977            });
978            Ok(TonicResponse::from(response_rx))
979        });
980        mock.expect_modify_ack_deadline()
981            .returning(|_| Ok(TonicResponse::from(())));
982
983        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
984        let client = test_client(endpoint).await?;
985        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
986        response_tx.send(Ok(test_response(1..4))).await?;
987        let _ = stream.next().await;
988
989        let initial_req = recover_writes_rx
990            .recv()
991            .await
992            .expect("should receive an initial request")?;
993        assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
994
995        // Verify that we receive at least one keepalive request on the stream.
996        tokio::time::advance(KEEPALIVE_PERIOD).await;
997        let keepalive_req = recover_writes_rx
998            .recv()
999            .await
1000            .expect("should receive a keepalive request")?;
1001        assert_eq!(keepalive_req, v1::StreamingPullRequest::default());
1002
1003        // Drop the stream, which should signal a shutdown of the keepalive
1004        // task.
1005        drop(stream);
1006
1007        // Advance the time far enough to expect a keepalive ping, if the
1008        // keepalive task was still running.
1009        tokio::time::advance(4 * KEEPALIVE_PERIOD).await;
1010        assert!(recover_writes_rx.is_empty(), "{recover_writes_rx:?}");
1011
1012        Ok(())
1013    }
1014
1015    #[tokio_test_no_panics]
1016    async fn client_id() -> anyhow::Result<()> {
1017        // We use this channel to surface writes (requests) from outside our
1018        // mock expectation.
1019        let (recover_writes_tx, mut recover_writes_rx) = channel(10);
1020        let recover_writes_tx = std::sync::Arc::new(tokio::sync::Mutex::new(recover_writes_tx));
1021
1022        let mut mock = MockSubscriber::new();
1023        mock.expect_streaming_pull()
1024            .times(3)
1025            .returning(move |request| {
1026                let tx = recover_writes_tx.clone();
1027                tokio::spawn(async move {
1028                    // Note that this task stays alive as long as we hold
1029                    // `recover_writes_rx`.
1030                    let mut request_rx = request.into_inner();
1031                    while let Some(request) = request_rx.recv().await {
1032                        tx.lock()
1033                            .await
1034                            .send(request)
1035                            .await
1036                            .expect("forwarding writes always succeeds");
1037                    }
1038                });
1039                Err(TonicStatus::failed_precondition("fail"))
1040            });
1041
1042        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1043
1044        // Make two requests with the same client. The requests should have the
1045        // same client ID.
1046        let c1 = test_client(endpoint.clone()).await?;
1047        let _ = c1
1048            .subscribe("projects/p/subscriptions/s")
1049            .build()
1050            .next()
1051            .await;
1052        let req1 = recover_writes_rx
1053            .recv()
1054            .await
1055            .expect("should receive a request")?;
1056        let _ = c1
1057            .subscribe("projects/p/subscriptions/s")
1058            .build()
1059            .next()
1060            .await;
1061        let req2 = recover_writes_rx
1062            .recv()
1063            .await
1064            .expect("should receive a request")?;
1065        assert_eq!(req1.client_id, req2.client_id);
1066
1067        // Make a third request with a different client. This request should
1068        // have a different client ID.
1069        let c2 = test_client(endpoint).await?;
1070        let _ = c2
1071            .subscribe("projects/p/subscriptions/s")
1072            .build()
1073            .next()
1074            .await;
1075        let req3 = recover_writes_rx
1076            .recv()
1077            .await
1078            .expect("should receive a request")?;
1079        assert_ne!(req1.client_id, req3.client_id);
1080
1081        Ok(())
1082    }
1083
1084    #[tokio_test_no_panics(start_paused = true)]
1085    async fn no_immediate_message() -> anyhow::Result<()> {
1086        const TEST_TIMEOUT: Duration = Duration::from_secs(42);
1087
1088        let (_response_tx, response_rx) = channel(10);
1089
1090        let mut mock = MockSubscriber::new();
1091        mock.expect_streaming_pull()
1092            .return_once(move |_| Ok(TonicResponse::from(response_rx)));
1093
1094        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1095        let client = test_client(endpoint).await?;
1096        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
1097
1098        let _ = tokio::time::timeout(TEST_TIMEOUT, stream.next())
1099            .await
1100            .expect_err("next() should never yield.");
1101
1102        Ok(())
1103    }
1104
1105    #[tokio_test_no_panics(start_paused = true)]
1106    async fn retry_transient_when_starting_stream() -> anyhow::Result<()> {
1107        // The policy should retry forever. Our default retry policies have an
1108        // attempt limit of 10. So we arbitrarily pick a number greater than 10
1109        // for this test.
1110        const NUM_RETRIES: u32 = 20;
1111
1112        let start_time = Instant::now();
1113        let mut seq = mockall::Sequence::new();
1114        let mut mock = MockSubscriber::new();
1115
1116        // Simulate N transient errors
1117        mock.expect_streaming_pull()
1118            .times(NUM_RETRIES as usize)
1119            .in_sequence(&mut seq)
1120            .returning(|_| Err(TonicStatus::unavailable("try again")));
1121        // Simulate a permanent error. Otherwise, we would retry forever.
1122        mock.expect_streaming_pull()
1123            .times(1)
1124            .in_sequence(&mut seq)
1125            .return_once(|_| Err(TonicStatus::failed_precondition("fail")));
1126        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1127        let client = test_client(endpoint).await?;
1128        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
1129        let err = stream
1130            .next()
1131            .await
1132            .expect("stream should not be empty")
1133            .expect_err("the first streamed item should be an error");
1134        assert!(err.status().is_some(), "{err:?}");
1135        let status = err.status().unwrap();
1136        assert_eq!(
1137            status.code,
1138            google_cloud_gax::error::rpc::Code::FailedPrecondition
1139        );
1140        assert_eq!(status.message, "fail");
1141
1142        let elapsed = start_time.elapsed();
1143        assert!(
1144            elapsed <= MAXIMUM_DELAY * NUM_RETRIES,
1145            "elapsed={elapsed:?}"
1146        );
1147        assert!(
1148            elapsed >= INITIAL_DELAY * NUM_RETRIES,
1149            "elapsed={elapsed:?}"
1150        );
1151
1152        Ok(())
1153    }
1154
1155    #[tokio_test_no_panics(start_paused = true)]
1156    async fn resume_midstream_success() -> anyhow::Result<()> {
1157        let (response_tx_1, response_rx_1) = channel(10);
1158        let (response_tx_2, response_rx_2) = channel(10);
1159        let (response_tx_3, response_rx_3) = channel(10);
1160        let (ack_tx, mut ack_rx) = unbounded_channel();
1161
1162        let mut seq = mockall::Sequence::new();
1163        let mut mock = MockSubscriber::new();
1164        mock.expect_streaming_pull()
1165            .times(1)
1166            .in_sequence(&mut seq)
1167            .return_once(|_| Ok(TonicResponse::from(response_rx_1)));
1168        mock.expect_streaming_pull()
1169            .times(1)
1170            .in_sequence(&mut seq)
1171            .return_once(move |_| Ok(TonicResponse::from(response_rx_2)));
1172        mock.expect_streaming_pull()
1173            .times(1)
1174            .in_sequence(&mut seq)
1175            .return_once(|_| Ok(TonicResponse::from(response_rx_3)));
1176        mock.expect_acknowledge().times(1..).returning(move |r| {
1177            ack_tx
1178                .send(r.into_inner())
1179                .expect("sending on channel always succeeds");
1180            Ok(TonicResponse::from(()))
1181        });
1182        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1183        let client = test_client(endpoint).await?;
1184        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
1185
1186        response_tx_1.send(Ok(test_response(0..10))).await?;
1187        response_tx_1.send(Ok(test_response(10..20))).await?;
1188        response_tx_1
1189            .send(Err(TonicStatus::unavailable("GFE disconnect. try again")))
1190            .await?;
1191        drop(response_tx_1);
1192        response_tx_2.send(Ok(test_response(20..30))).await?;
1193        response_tx_2.send(Ok(test_response(30..40))).await?;
1194        response_tx_2
1195            .send(Err(TonicStatus::unavailable("GFE disconnect. try again")))
1196            .await?;
1197        drop(response_tx_2);
1198        response_tx_3.send(Ok(test_response(40..50))).await?;
1199        drop(response_tx_3);
1200
1201        for i in 0..50 {
1202            let (m, h) = stream
1203                .next()
1204                .await
1205                .unwrap_or_else(|| panic!("expected message {}/50", i + 1))?;
1206            assert_eq!(m.data, test_data(i));
1207            h.ack();
1208        }
1209        let end = stream.next().await.transpose()?;
1210        assert!(end.is_none(), "Received extra message: {end:?}");
1211
1212        // Wait for the stream to join its background tasks.
1213        stream.close().await;
1214
1215        // Verify the acks went through.
1216        let mut got = Vec::new();
1217        while let Ok(ack_req) = ack_rx.try_recv() {
1218            assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
1219            got.extend(ack_req.ack_ids);
1220        }
1221        assert_eq!(sorted(got), test_ids(0..50));
1222
1223        Ok(())
1224    }
1225
1226    #[tokio_test_no_panics(start_paused = true)]
1227    async fn resume_midstream_hits_permanent_error() -> anyhow::Result<()> {
1228        let (response_tx, response_rx) = channel(10);
1229        let (ack_tx, mut ack_rx) = unbounded_channel();
1230
1231        let mut seq = mockall::Sequence::new();
1232        let mut mock = MockSubscriber::new();
1233        // Start a successful stream, which will eventually disconnect.
1234        mock.expect_streaming_pull()
1235            .times(1)
1236            .in_sequence(&mut seq)
1237            .return_once(|_| Ok(TonicResponse::from(response_rx)));
1238        // Simulate transient errors attempting to resume the stream.
1239        mock.expect_streaming_pull()
1240            .times(3)
1241            .in_sequence(&mut seq)
1242            .returning(|_| Err(TonicStatus::unavailable("try again")));
1243        // Simulate a permanent error attempting to resume the stream.
1244        mock.expect_streaming_pull()
1245            .times(1)
1246            .in_sequence(&mut seq)
1247            .return_once(|_| Err(TonicStatus::failed_precondition("fail")));
1248        mock.expect_acknowledge().times(1..).returning(move |r| {
1249            ack_tx
1250                .send(r.into_inner())
1251                .expect("sending on channel always succeeds");
1252            Ok(TonicResponse::from(()))
1253        });
1254        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1255        let client = test_client(endpoint).await?;
1256        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
1257
1258        response_tx.send(Ok(test_response(0..10))).await?;
1259        response_tx.send(Ok(test_response(10..20))).await?;
1260        response_tx
1261            .send(Err(TonicStatus::unavailable("GFE disconnect. try again")))
1262            .await?;
1263        drop(response_tx);
1264
1265        for i in 0..20 {
1266            let (m, h) = stream
1267                .next()
1268                .await
1269                .unwrap_or_else(|| panic!("expected message {}/20", i + 1))?;
1270            assert_eq!(m.data, test_data(i));
1271            h.ack();
1272        }
1273        let err = stream
1274            .next()
1275            .await
1276            .transpose()
1277            .expect_err("expected an error from stream");
1278        assert!(err.status().is_some(), "{err:?}");
1279        let status = err.status().unwrap();
1280        assert_eq!(
1281            status.code,
1282            google_cloud_gax::error::rpc::Code::FailedPrecondition
1283        );
1284        assert_eq!(status.message, "fail");
1285
1286        // Wait for the stream to join its background tasks.
1287        stream.close().await;
1288
1289        // Verify the acks went through.
1290        let mut got = Vec::new();
1291        while let Ok(ack_req) = ack_rx.try_recv() {
1292            assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
1293            got.extend(ack_req.ack_ids);
1294        }
1295        assert_eq!(sorted(got), test_ids(0..20));
1296
1297        Ok(())
1298    }
1299
1300    #[tokio_test_no_panics]
1301    async fn routing_header() -> anyhow::Result<()> {
1302        let mut mock = MockSubscriber::new();
1303
1304        mock.expect_streaming_pull().return_once(move |request| {
1305            let metadata = request.metadata();
1306            assert_eq!(
1307                metadata
1308                    .get("x-goog-request-params")
1309                    .expect("routing header missing"),
1310                "subscription=projects/p/subscriptions/s"
1311            );
1312            Err(TonicStatus::failed_precondition("ignored"))
1313        });
1314
1315        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1316        let client = test_client(endpoint).await?;
1317
1318        let _ = client
1319            .subscribe("projects/p/subscriptions/s")
1320            .build()
1321            .next()
1322            .await;
1323
1324        Ok(())
1325    }
1326
1327    #[cfg(feature = "unstable-stream")]
1328    #[tokio_test_no_panics(start_paused = true)]
1329    async fn into_stream() -> anyhow::Result<()> {
1330        use futures::TryStreamExt;
1331        let (response_tx, response_rx) = channel(10);
1332        let (ack_tx, mut ack_rx) = unbounded_channel();
1333
1334        let mut mock = MockSubscriber::new();
1335        mock.expect_streaming_pull()
1336            .return_once(|_| Ok(TonicResponse::from(response_rx)));
1337        mock.expect_acknowledge().returning(move |r| {
1338            ack_tx
1339                .send(r.into_inner())
1340                .expect("sending on channel always succeeds");
1341            Ok(TonicResponse::from(()))
1342        });
1343
1344        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1345        let client = test_client(endpoint).await?;
1346
1347        let stream = client
1348            .subscribe("projects/p/subscriptions/s")
1349            .build()
1350            .into_stream();
1351
1352        response_tx.send(Ok(test_response(1..3))).await?;
1353        drop(response_tx);
1354
1355        let got: Vec<_> = stream
1356            .map_ok(|(m, h)| {
1357                h.ack();
1358                m.data
1359            })
1360            .try_collect()
1361            .await?;
1362        assert_eq!(got, vec![test_data(1), test_data(2)]);
1363
1364        let ack_req = ack_rx
1365            .recv()
1366            .await
1367            .expect("should receive acknowledgements");
1368        assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
1369        assert_eq!(sorted(ack_req.ack_ids), test_ids(1..3));
1370
1371        Ok(())
1372    }
1373
1374    #[tokio_test_no_panics(start_paused = true)]
1375    async fn basic_lease_expiration() -> anyhow::Result<()> {
1376        const MAX_LEASE_EXTENSION: Duration = Duration::from_secs(10);
1377        const MAX_LEASE: Duration = Duration::from_secs(30);
1378        // We configure a max lease for this test (30s) that differs from the
1379        // default (600s) to verify that an application's configuration
1380        // overrides the default.
1381
1382        let start_time = Instant::now();
1383        let (response_tx, response_rx) = channel(10);
1384        let (extend_tx, mut extend_rx) = unbounded_channel();
1385
1386        let mut mock = MockSubscriber::new();
1387        mock.expect_streaming_pull()
1388            .return_once(|_| Ok(TonicResponse::from(response_rx)));
1389        mock.expect_modify_ack_deadline().returning(move |r| {
1390            extend_tx
1391                .send(r.into_inner())
1392                .expect("sending on channel always succeeds");
1393            Ok(TonicResponse::from(()))
1394        });
1395        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1396        let client = test_client(endpoint).await?;
1397        let mut stream = client
1398            .subscribe("projects/p/subscriptions/s")
1399            .set_max_lease(MAX_LEASE)
1400            .set_max_lease_extension(MAX_LEASE_EXTENSION)
1401            .set_shutdown_behavior(ShutdownBehavior::NackImmediately)
1402            .build();
1403
1404        response_tx.send(Ok(test_response(0..1))).await?;
1405        drop(response_tx);
1406
1407        let (_m, _h) = stream
1408            .next()
1409            .await
1410            .expect("stream should yield a message")?;
1411
1412        // Advance the clock well past the expected message expiration,
1413        // recording the time at which we sent the last lease extension.
1414        let mut latest = None;
1415        for _ in 0..MAX_LEASE.as_secs() * 2 {
1416            while let Ok(r) = extend_rx.try_recv() {
1417                assert_ne!(r.ack_deadline_seconds, 0, "unexpectedly received a nack");
1418                latest = Some(start_time.elapsed());
1419            }
1420            tokio::time::advance(Duration::from_secs(1)).await;
1421            tokio::task::yield_now().await;
1422        }
1423
1424        // Verify when we stop sending lease extensions.
1425        let expected_range = (MAX_LEASE - MAX_LEASE_EXTENSION)..=MAX_LEASE;
1426        assert!(
1427            latest.is_some_and(|t| expected_range.contains(&t)),
1428            "{latest:?}"
1429        );
1430
1431        // Close the stream, to make sure pending operations complete.
1432        stream.close().await;
1433
1434        Ok(())
1435    }
1436
1437    #[tokio_test_no_panics(start_paused = true)]
1438    async fn shutdown_wait_for_processing() -> anyhow::Result<()> {
1439        let (response_tx, response_rx) = channel(10);
1440
1441        let mut mock = MockSubscriber::new();
1442        mock.expect_streaming_pull()
1443            .return_once(|_| Ok(TonicResponse::from(response_rx)));
1444        mock.expect_acknowledge()
1445            .times(1)
1446            .returning(|_| Ok(TonicResponse::from(())));
1447        mock.expect_modify_ack_deadline()
1448            .returning(|_| Ok(TonicResponse::from(())));
1449        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1450        let client = test_client(endpoint).await?;
1451        let mut stream = client
1452            .subscribe("projects/p/subscriptions/s")
1453            .set_shutdown_behavior(ShutdownBehavior::WaitForProcessing)
1454            .build();
1455
1456        response_tx.send(Ok(test_response(0..1))).await?;
1457        drop(response_tx);
1458
1459        let (_m, h) = stream
1460            .next()
1461            .await
1462            .expect("stream should yield a message")?;
1463
1464        tokio::spawn(async move {
1465            // Delay the ack until after the shutdown is signaled. It should
1466            // still go through.
1467            tokio::time::sleep(Duration::from_secs(5)).await;
1468            h.ack();
1469        });
1470
1471        // Close the stream, to make sure pending operations complete.
1472        stream.close().await;
1473
1474        Ok(())
1475    }
1476
1477    #[tokio_test_no_panics(start_paused = true)]
1478    async fn at_least_once_and_exactly_once() -> anyhow::Result<()> {
1479        let (response_tx, response_rx) = channel(10);
1480
1481        let mut mock = MockSubscriber::new();
1482        mock.expect_streaming_pull()
1483            .return_once(|_| Ok(TonicResponse::from(response_rx)));
1484        mock.expect_modify_ack_deadline()
1485            .returning(|_| Ok(TonicResponse::from(())));
1486        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1487        let client = test_client(endpoint).await?;
1488        let mut stream = client
1489            .subscribe("projects/p/subscriptions/s")
1490            .set_shutdown_behavior(ShutdownBehavior::NackImmediately)
1491            .build();
1492
1493        response_tx.send(Ok(test_response(0..1))).await?;
1494        response_tx
1495            .send(Ok(test_exactly_once_response(1..2)))
1496            .await?;
1497        response_tx.send(Ok(test_response(2..3))).await?;
1498        response_tx
1499            .send(Ok(test_exactly_once_response(3..4)))
1500            .await?;
1501        drop(response_tx);
1502
1503        let (m, h) = stream.next().await.expect("should yield a message")?;
1504        assert_eq!(m.data, test_data(0));
1505        assert_eq!(h.ack_id(), test_id(0));
1506        assert!(matches!(h, Handler::AtLeastOnce(_)), "{h:?}");
1507
1508        let (m, h) = stream.next().await.expect("should yield a message")?;
1509        assert_eq!(m.data, test_data(1));
1510        assert_eq!(h.ack_id(), test_id(1));
1511        assert!(matches!(h, Handler::ExactlyOnce(_)), "{h:?}");
1512
1513        let (m, h) = stream.next().await.expect("should yield a message")?;
1514        assert_eq!(m.data, test_data(2));
1515        assert_eq!(h.ack_id(), test_id(2));
1516        assert!(matches!(h, Handler::AtLeastOnce(_)), "{h:?}");
1517
1518        let (m, h) = stream.next().await.expect("should yield a message")?;
1519        assert_eq!(m.data, test_data(3));
1520        assert_eq!(h.ack_id(), test_id(3));
1521        assert!(matches!(h, Handler::ExactlyOnce(_)), "{h:?}");
1522
1523        let end = stream.next().await.transpose()?;
1524        assert!(end.is_none(), "Received extra message: {end:?}");
1525
1526        // Wait for the stream to join its background tasks.
1527        stream.close().await;
1528
1529        Ok(())
1530    }
1531
1532    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
1533    async fn cancel_before_open() -> anyhow::Result<()> {
1534        let mut mock = MockSubscriber::new();
1535        mock.expect_streaming_pull()
1536            .returning(|_| Err(TonicStatus::unavailable("try again")));
1537        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1538        let client = test_client(endpoint).await?;
1539        let mut stream = client.subscribe("projects/p/subscriptions/s").build();
1540        // TODO(#5024) - use public functions when available.
1541        let shutdown_token = stream.shutdown.clone();
1542
1543        let next = tokio::spawn(async move { stream.next().await });
1544        shutdown_token.cancel();
1545
1546        let end = next.await?;
1547        assert!(end.is_none(), "Shutdown should end the stream, got {end:?}");
1548
1549        Ok(())
1550    }
1551
1552    #[tokio_test_no_panics(start_paused = true)]
1553    async fn cancel_midstream() -> anyhow::Result<()> {
1554        let (response_tx, response_rx) = channel(10);
1555        let (ack_tx, mut ack_rx) = unbounded_channel();
1556        let (nack_tx, mut nack_rx) = unbounded_channel();
1557
1558        let mut mock = MockSubscriber::new();
1559        mock.expect_streaming_pull()
1560            .return_once(|_| Ok(TonicResponse::from(response_rx)));
1561        mock.expect_acknowledge().times(1).returning(move |r| {
1562            ack_tx
1563                .send(r.into_inner())
1564                .expect("sending on channel always succeeds");
1565            Ok(TonicResponse::from(()))
1566        });
1567        mock.expect_modify_ack_deadline()
1568            .times(1)
1569            .returning(move |r| {
1570                nack_tx
1571                    .send(r.into_inner())
1572                    .expect("sending on channel always succeeds");
1573                Ok(TonicResponse::from(()))
1574            });
1575        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
1576        let client = test_client(endpoint).await?;
1577        let mut stream = client
1578            .subscribe("projects/p/subscriptions/s")
1579            .set_shutdown_behavior(ShutdownBehavior::WaitForProcessing)
1580            .build();
1581        // TODO(#5024) - use public functions when available.
1582        let shutdown_token = stream.shutdown.clone();
1583
1584        response_tx.send(Ok(test_response(1..10))).await?;
1585        for i in 1..6 {
1586            let Some((m, h)) = stream.next().await.transpose()? else {
1587                anyhow::bail!("expected message {i}/5")
1588            };
1589            assert_eq!(m.data, test_data(i));
1590            h.ack();
1591        }
1592        shutdown_token.cancel();
1593        let end = stream.next().await.transpose()?;
1594        assert!(end.is_none(), "Shutdown should end the stream, got {end:?}");
1595
1596        // Verify that we drop the messages and handles in the pool that we have
1597        // not returned to the application yet.
1598        stream.close().await;
1599
1600        let ack_req = ack_rx.try_recv()?;
1601        assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
1602        assert_eq!(sorted(ack_req.ack_ids), test_ids(1..6));
1603
1604        let nack_req = nack_rx.try_recv()?;
1605        assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
1606        assert_eq!(nack_req.ack_deadline_seconds, 0);
1607        assert_eq!(sorted(nack_req.ack_ids), test_ids(6..10));
1608
1609        Ok(())
1610    }
1611}