google_cloud_pubsub/subscriber/
session.rs

1// Copyright 2026 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::builder::StreamingPull;
16use super::handler::{AckResult, AtLeastOnce, Handler};
17use super::lease_loop::LeaseLoop;
18use super::lease_state::LeaseOptions;
19use super::leaser::DefaultLeaser;
20use super::stream::open_stream;
21use super::stub::{Stub, TonicStreaming};
22use super::transport::Transport;
23use crate::google::pubsub::v1::StreamingPullRequest;
24use crate::model::PubsubMessage;
25use crate::{Error, Result};
26use gaxi::grpc::from_status::to_gax_error;
27use gaxi::prost::FromProto as _;
28use std::collections::VecDeque;
29use std::sync::Arc;
30use tokio::sync::mpsc::UnboundedSender;
31use tokio_util::sync::{CancellationToken, DropGuard};
32
33/// Represents an open subscribe session.
34///
35/// This is a stream-like struct for serving messages to an application.
36///
37/// # Example
38/// ```
39/// # use google_cloud_pubsub::client::Subscriber;
40/// # async fn sample(client: Subscriber) -> anyhow::Result<()> {
41/// let mut session = client
42///     .streaming_pull("projects/my-project/subscriptions/my-subscription")
43///     .start();
44/// while let Some((m, h)) = session.next().await.transpose()? {
45///     println!("Received message m={m:?}");
46///     h.ack();
47/// }
48/// # Ok(()) }
49/// ```
50#[derive(Debug)]
51pub struct Session {
52    /// The stub implementing this struct.
53    inner: Arc<Transport>,
54
55    /// The initial request used to start a stream.
56    initial_req: StreamingPullRequest,
57
58    /// The bidirectional stream.
59    ///
60    /// We choose to lazy-initialize the stream when the application asks for a
61    /// message because tonic will not yield the stream to us until the first
62    /// response is available.[^1]
63    ///
64    /// The usability of the `Session` API would suffer if creating an instance
65    /// of `Session` is blocked on the first message being available.
66    ///
67    /// [^1]: <https://github.com/hyperium/tonic/issues/515>
68    stream: Option<<Transport as Stub>::Stream>,
69
70    /// Applications ask for messages one at a time. Individual stream responses
71    /// can contain multiple messages. We use `pool` to hold the extra messages
72    /// while we wait to serve them to applications.
73    ///
74    /// A FIFO queue is necessary to preserve ordering.
75    pool: VecDeque<(PubsubMessage, Handler)>,
76
77    /// A sender for sending new messages from the stream into the lease
78    /// management task.
79    message_tx: UnboundedSender<String>,
80
81    /// A sender for forwarding acks/nacks from the application to the lease
82    /// management task. Each `Handler` holds a clone of this.
83    ack_tx: UnboundedSender<AckResult>,
84
85    /// A cancellation token for signalling a shutdown to the task sending
86    /// keepalive pings.
87    shutdown: CancellationToken,
88
89    /// A guard which signals a shutdown to the task sending keepalive pings
90    /// when it is dropped. It is more convenient to hold a `DropGuard` than to
91    /// have a custom `impl Drop for Session`.
92    _keepalive_guard: DropGuard,
93
94    /// A handle on the lease loop task.
95    ///
96    /// We hold onto this handle so we can await pending lease operations. While awaiting pending
97    /// lease operations is useful for setting expectations in our unit tests, it is not that
98    /// helpful to applications in practice.
99    _lease_loop: tokio::task::JoinHandle<()>,
100}
101
102impl Session {
103    pub(super) fn new(builder: StreamingPull) -> Self {
104        let shutdown = CancellationToken::new();
105        let inner = builder.inner;
106        let subscription = builder.subscription;
107
108        let leaser = DefaultLeaser::new(
109            inner.clone(),
110            subscription.clone(),
111            builder.ack_deadline_seconds,
112        );
113        let LeaseLoop {
114            handle: _lease_loop,
115            message_tx,
116            ack_tx,
117        } = LeaseLoop::new(leaser, LeaseOptions::default());
118
119        let initial_req = StreamingPullRequest {
120            subscription,
121            stream_ack_deadline_seconds: builder.ack_deadline_seconds,
122            max_outstanding_messages: builder.max_outstanding_messages,
123            max_outstanding_bytes: builder.max_outstanding_bytes,
124            client_id: builder.client_id,
125            ..Default::default()
126        };
127
128        Self {
129            inner,
130            initial_req,
131            stream: None,
132            pool: VecDeque::new(),
133            message_tx,
134            ack_tx,
135            shutdown: shutdown.clone(),
136            _keepalive_guard: shutdown.drop_guard(),
137            _lease_loop,
138        }
139    }
140
141    /// Returns the next message received on this subscription.
142    ///
143    /// The message data is returned along with a [Handler] for acknowledging
144    /// (ack) or rejecting (nack) the message.
145    ///
146    /// If the underlying stream encounters a permanent error, an `Error` is
147    /// returned instead.
148    ///
149    /// `None` represents the end of a stream, but in practice, the stream stays
150    /// open until it is cancelled or encounters a permanent error.
151    ///
152    /// # Example
153    /// ```
154    /// # use google_cloud_pubsub::subscriber::session::Session;
155    /// # async fn sample(mut session: Session) -> anyhow::Result<()> {
156    /// while let Some((m, h)) = session.next().await.transpose()? {
157    ///     println!("Received message m={m:?}");
158    ///     h.ack();
159    /// }
160    /// # Ok(()) }
161    /// ```
162    pub async fn next(&mut self) -> Option<Result<(PubsubMessage, Handler)>> {
163        loop {
164            // Serve a message if we have one ready.
165            if let Some(item) = self.pool.pop_front() {
166                return Some(Ok(item));
167            }
168            // Otherwise, read more messages from the stream.
169            if let Err(e) = self.stream_next().await? {
170                return Some(Err(e));
171            }
172        }
173    }
174
175    /// Returns a mutable reference to the underlying stream.
176    ///
177    /// If a stream is not yet open, this method opens the stream.
178    async fn mut_stream(&mut self) -> Result<&mut <Transport as Stub>::Stream> {
179        if self.stream.is_none() {
180            let stream = open_stream(
181                self.inner.clone(),
182                self.initial_req.clone(),
183                self.shutdown.clone(),
184            )
185            .await?;
186            self.stream = Some(stream);
187        }
188        Ok(self
189            .stream
190            .as_mut()
191            .expect("`self.stream.is_some()` must be true"))
192    }
193
194    async fn stream_next(&mut self) -> Option<Result<()>> {
195        let resp = {
196            let stream = match self.mut_stream().await {
197                Ok(s) => s,
198                // TODO(#4097) - support stream retries / resumes.
199                Err(e) => return Some(Err(e)),
200            };
201
202            match stream.next_message().await.transpose()? {
203                Ok(resp) => resp,
204                Err(e) => return Some(Err(to_gax_error(e))),
205            }
206        };
207
208        for rm in resp.received_messages {
209            let Some(message) = rm.message else {
210                // The message field should always be present. If not, the proto
211                // message was corrupted while in transit, or there is a bug in
212                // the service.
213                //
214                // The client can just ignore an ack ID without an associated
215                // message.
216                continue;
217            };
218            let _ = self.message_tx.send(rm.ack_id.clone());
219            let message = match message.cnv().map_err(Error::deser) {
220                Ok(message) => message,
221                Err(e) => return Some(Err(e)),
222            };
223            self.pool.push_back((
224                message,
225                Handler::AtLeastOnce(AtLeastOnce {
226                    ack_id: rm.ack_id,
227                    ack_tx: self.ack_tx.clone(),
228                }),
229            ));
230        }
231        Some(Ok(()))
232    }
233
234    #[cfg(test)]
235    /// Close the session, awaiting all pending acks and nacks.
236    ///
237    /// This is a useful method for setting clean test expectations.
238    async fn close(self) -> anyhow::Result<()> {
239        // Signal a shutdown to the keepalive task.
240        drop(self._keepalive_guard);
241
242        // Signal a shutdown to the lease management background task.
243        drop(self.message_tx);
244
245        // Wait for the lease management task to complete.
246        self._lease_loop.await?;
247
248        Ok(())
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::super::client::Subscriber;
255    use super::super::keepalive::KEEPALIVE_PERIOD;
256    use super::super::lease_state::tests::{test_id, test_ids};
257    use super::*;
258    use auth::credentials::anonymous::Builder as Anonymous;
259    use pubsub_grpc_mock::google::pubsub::v1;
260    use pubsub_grpc_mock::{MockSubscriber, start};
261    use tokio::sync::mpsc::{channel, unbounded_channel};
262    use tokio::task::JoinHandle;
263    use tokio::time::Duration;
264
265    fn sorted(mut v: Vec<String>) -> Vec<String> {
266        v.sort();
267        v
268    }
269
270    fn test_data(v: i32) -> bytes::Bytes {
271        bytes::Bytes::from(format!("data-{}", test_id(v)))
272    }
273
274    fn test_response(range: std::ops::Range<i32>) -> v1::StreamingPullResponse {
275        v1::StreamingPullResponse {
276            received_messages: range
277                .into_iter()
278                .map(|i| v1::ReceivedMessage {
279                    ack_id: test_id(i),
280                    message: Some(v1::PubsubMessage {
281                        data: test_data(i).to_vec(),
282                        ..Default::default()
283                    }),
284                    ..Default::default()
285                })
286                .collect(),
287            ..Default::default()
288        }
289    }
290
291    async fn test_client(endpoint: String) -> anyhow::Result<Subscriber> {
292        Ok(Subscriber::builder()
293            .with_endpoint(endpoint)
294            .with_credentials(Anonymous::new().build())
295            .build()
296            .await?)
297    }
298
299    #[tokio::test]
300    async fn error_starting_stream() -> anyhow::Result<()> {
301        let mut mock = MockSubscriber::new();
302        mock.expect_streaming_pull()
303            .return_once(|_| Err(tonic::Status::internal("fail")));
304        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
305        let client = test_client(endpoint).await?;
306        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
307        let err = session
308            .next()
309            .await
310            .expect("stream should not be empty")
311            .expect_err("the first streamed item should be an error");
312        assert!(err.status().is_some(), "{err:?}");
313        let status = err.status().unwrap();
314        assert_eq!(status.code, gax::error::rpc::Code::Internal);
315        assert_eq!(status.message, "fail");
316
317        Ok(())
318    }
319
320    #[tokio::test]
321    async fn initial_request() -> anyhow::Result<()> {
322        const MIB: i64 = 1024 * 1024;
323
324        // We use this channel to surface writes (requests) from outside our
325        // mock expectation.
326        let (recover_writes_tx, mut recover_writes_rx) = channel(1);
327
328        let mut mock = MockSubscriber::new();
329        mock.expect_streaming_pull().return_once(move |request| {
330            tokio::spawn(async move {
331                // Note that this task stays alive as long as we hold
332                // `recover_writes_rx`.
333                let mut request_rx = request.into_inner();
334                while let Some(request) = request_rx.recv().await {
335                    recover_writes_tx
336                        .send(request)
337                        .await
338                        .expect("forwarding writes always succeeds");
339                }
340            });
341            Err(tonic::Status::internal("fail"))
342        });
343
344        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
345        let client = test_client(endpoint).await?;
346        let _ = client
347            .streaming_pull("projects/p/subscriptions/s")
348            .set_ack_deadline_seconds(20)
349            .set_max_outstanding_messages(2000)
350            .set_max_outstanding_bytes(200 * MIB)
351            .start()
352            .next()
353            .await;
354
355        let initial_req = recover_writes_rx
356            .recv()
357            .await
358            .expect("should receive a request")?;
359        assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
360        assert_eq!(initial_req.stream_ack_deadline_seconds, 20);
361        assert_eq!(initial_req.max_outstanding_messages, 2000);
362        assert_eq!(initial_req.max_outstanding_bytes, 200 * MIB);
363        assert!(!initial_req.client_id.is_empty());
364
365        Ok(())
366    }
367
368    #[tokio::test(start_paused = true)]
369    async fn basic_success() -> anyhow::Result<()> {
370        let (response_tx, response_rx) = channel(10);
371        let (ack_tx, mut ack_rx) = unbounded_channel();
372
373        let mut mock = MockSubscriber::new();
374        mock.expect_streaming_pull()
375            .return_once(|_| Ok(tonic::Response::from(response_rx)));
376        mock.expect_acknowledge().returning(move |r| {
377            ack_tx
378                .send(r.into_inner())
379                .expect("sending on channel always succeeds");
380            Ok(tonic::Response::from(()))
381        });
382        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
383        let client = test_client(endpoint).await?;
384        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
385
386        response_tx.send(Ok(test_response(1..2))).await?;
387        response_tx.send(Ok(test_response(2..4))).await?;
388        response_tx.send(Ok(test_response(4..7))).await?;
389        drop(response_tx);
390
391        for i in 1..7 {
392            let (m, Handler::AtLeastOnce(h)) =
393                session.next().await.transpose()?.expect("message {i}/6");
394            assert_eq!(m.data, test_data(i));
395            assert_eq!(h.ack_id, test_id(i));
396            h.ack();
397        }
398        let end = session.next().await.transpose()?;
399        assert!(end.is_none(), "Received extra message: {end:?}");
400
401        // Wait for the session to join its background tasks.
402        session.close().await?;
403
404        // Verify the acks went through.
405        let ack_req = ack_rx.try_recv()?;
406        assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
407        assert_eq!(sorted(ack_req.ack_ids), test_ids(1..7));
408
409        Ok(())
410    }
411
412    #[tokio::test(start_paused = true)]
413    async fn basic_lease_management() -> anyhow::Result<()> {
414        let (response_tx, response_rx) = channel(10);
415        let (ack_tx, mut ack_rx) = unbounded_channel();
416        let (nack_tx, mut nack_rx) = unbounded_channel();
417        let (extend_tx, mut extend_rx) = unbounded_channel();
418
419        let mut mock = MockSubscriber::new();
420        mock.expect_streaming_pull()
421            .return_once(|_| Ok(tonic::Response::from(response_rx)));
422        mock.expect_acknowledge().returning(move |r| {
423            ack_tx
424                .send(r.into_inner())
425                .expect("sending on channel always succeeds");
426            Ok(tonic::Response::from(()))
427        });
428        mock.expect_modify_ack_deadline().returning(move |r| {
429            let r = r.into_inner();
430            if r.ack_deadline_seconds == 0 {
431                nack_tx.send(r).expect("sending on channel always succeeds");
432            } else {
433                extend_tx
434                    .send(r)
435                    .expect("sending on channel always succeeds");
436            }
437            Ok(tonic::Response::from(()))
438        });
439        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
440        let client = test_client(endpoint).await?;
441        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
442
443        response_tx.send(Ok(test_response(0..30))).await?;
444        drop(response_tx);
445
446        // Ack some messages
447        for i in 0..10 {
448            let Some((_, Handler::AtLeastOnce(h))) = session.next().await.transpose()? else {
449                anyhow::bail!("expected message {i}")
450            };
451            h.ack();
452        }
453        // Nack some messages
454        for i in 10..20 {
455            let Some((_, Handler::AtLeastOnce(h))) = session.next().await.transpose()? else {
456                anyhow::bail!("expected message {i}")
457            };
458            h.nack();
459        }
460        // Take a long time to process some messages
461        let mut hold = Vec::new();
462        for i in 20..30 {
463            let Some((_, Handler::AtLeastOnce(h))) = session.next().await.transpose()? else {
464                anyhow::bail!("expected message {i}")
465            };
466            hold.push(h);
467        }
468
469        // Advance the clock 10s, which is the default stream ack deadline. In
470        // this time, we should attempt at least one lease extension RPC.
471        tokio::time::advance(Duration::from_secs(10)).await;
472
473        // Close the session, to make sure pending operations complete.
474        session.close().await?;
475
476        // Verify the acks went through.
477        let ack_req = ack_rx.try_recv()?;
478        assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
479        assert_eq!(sorted(ack_req.ack_ids), test_ids(0..10));
480        assert!(ack_rx.is_empty());
481
482        // Verify the initial nacks went through.
483        let nack_req = nack_rx.try_recv()?;
484        assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
485        assert_eq!(nack_req.ack_deadline_seconds, 0);
486        assert_eq!(sorted(nack_req.ack_ids), test_ids(10..20));
487
488        // Verify that we nack the leftover messages when the stream shuts down.
489        let nack_req = nack_rx.try_recv()?;
490        assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
491        assert_eq!(nack_req.ack_deadline_seconds, 0);
492        assert_eq!(sorted(nack_req.ack_ids), test_ids(20..30));
493        assert!(nack_rx.is_empty());
494
495        // Verify at least one lease extension attempt was made.
496        let extend_req = extend_rx.try_recv()?;
497        assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
498        assert_eq!(extend_req.ack_deadline_seconds, 10);
499        assert_eq!(sorted(extend_req.ack_ids), test_ids(20..30));
500
501        Ok(())
502    }
503
504    #[tokio::test(start_paused = true)]
505    async fn delayed_responses() -> anyhow::Result<()> {
506        // In this test, we verify the case where an application asks for a
507        // message, but a response is not immediately available on the stream.
508
509        let (response_tx, response_rx) = channel(10);
510        let handle: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
511            tokio::time::sleep(Duration::from_millis(20)).await;
512            response_tx.send(Ok(test_response(1..2))).await?;
513            Ok(())
514        });
515
516        let mut mock = MockSubscriber::new();
517        mock.expect_streaming_pull()
518            .return_once(|_| Ok(tonic::Response::from(response_rx)));
519        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
520        let client = test_client(endpoint).await?;
521        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
522        let (m, Handler::AtLeastOnce(h)) = session
523            .next()
524            .await
525            .transpose()?
526            .expect("stream should wait for a message");
527        assert_eq!(m.data, test_data(1));
528        assert_eq!(h.ack_id, test_id(1));
529
530        handle.await??;
531
532        Ok(())
533    }
534
535    #[tokio::test]
536    async fn serves_messages_immediately() -> anyhow::Result<()> {
537        // This test verifies we do not do something crazy like draining the
538        // stream (which would never end) before serving messages to the
539        // application.
540
541        let (response_tx, response_rx) = channel(10);
542
543        let mut mock = MockSubscriber::new();
544        mock.expect_streaming_pull()
545            .return_once(|_| Ok(tonic::Response::from(response_rx)));
546        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
547        let client = test_client(endpoint).await?;
548        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
549
550        for i in 1..7 {
551            response_tx.send(Ok(test_response(i..i + 1))).await?;
552
553            let (m, Handler::AtLeastOnce(h)) =
554                session.next().await.transpose()?.expect("message {i}/6");
555            assert_eq!(m.data, test_data(i));
556            assert_eq!(h.ack_id, test_id(i));
557        }
558        drop(response_tx);
559        let end = session.next().await.transpose()?;
560        assert!(end.is_none(), "Received extra message: {end:?}");
561
562        Ok(())
563    }
564
565    #[tokio::test]
566    async fn handles_empty_response() -> anyhow::Result<()> {
567        let (response_tx, response_rx) = channel(10);
568
569        let mut mock = MockSubscriber::new();
570        mock.expect_streaming_pull()
571            .return_once(|_| Ok(tonic::Response::from(response_rx)));
572        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
573        let client = test_client(endpoint).await?;
574        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
575
576        response_tx.send(Ok(test_response(1..2))).await?;
577        // See if we can handle an empty range
578        response_tx.send(Ok(test_response(2..2))).await?;
579        response_tx.send(Ok(test_response(2..3))).await?;
580        drop(response_tx);
581
582        for i in 1..3 {
583            let (m, Handler::AtLeastOnce(h)) =
584                session.next().await.transpose()?.expect("message {i}/2");
585            assert_eq!(m.data, test_data(i));
586            assert_eq!(h.ack_id, test_id(i));
587        }
588        let end = session.next().await.transpose()?;
589        assert!(end.is_none(), "Received extra message: {end:?}");
590
591        Ok(())
592    }
593
594    #[tokio::test(start_paused = true)]
595    async fn handles_missing_message_field() -> anyhow::Result<()> {
596        let (response_tx, response_rx) = channel(10);
597        let (extend_tx, mut extend_rx) = unbounded_channel();
598
599        let bad = v1::StreamingPullResponse {
600            received_messages: vec![v1::ReceivedMessage {
601                ack_id: "ignored-ack-id".to_string(),
602                message: None,
603                ..Default::default()
604            }],
605            ..Default::default()
606        };
607
608        let mut mock = MockSubscriber::new();
609        mock.expect_streaming_pull()
610            .return_once(|_| Ok(tonic::Response::from(response_rx)));
611        mock.expect_acknowledge()
612            .returning(|_| Ok(tonic::Response::from(())));
613        mock.expect_modify_ack_deadline().returning(move |r| {
614            extend_tx
615                .send(r.into_inner())
616                .expect("sending on channel always succeeds");
617            Ok(tonic::Response::from(()))
618        });
619        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
620        let client = test_client(endpoint).await?;
621        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
622
623        response_tx.send(Ok(test_response(1..4))).await?;
624        // See if we can handle an empty range
625        response_tx.send(Ok(bad)).await?;
626        response_tx.send(Ok(test_response(4..7))).await?;
627        drop(response_tx);
628
629        for i in 1..7 {
630            let (m, Handler::AtLeastOnce(h)) =
631                session.next().await.transpose()?.expect("message {i}/6");
632            assert_eq!(m.data, test_data(i));
633            assert_eq!(h.ack_id, test_id(i));
634        }
635        let end = session.next().await.transpose()?;
636        assert!(end.is_none(), "Received extra message: {end:?}");
637
638        // Advance the clock 10s, which is the default stream ack deadline. In
639        // this time, we should attempt at least one lease extension RPC.
640        tokio::time::advance(Duration::from_secs(10)).await;
641
642        // Close the session, to make sure pending operations complete.
643        session.close().await?;
644
645        // Verify at least one lease extension attempt was made.
646        let extend_req = extend_rx.try_recv()?;
647        assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
648        assert_eq!(extend_req.ack_deadline_seconds, 10);
649        // Note that we do not expect to see "ignored-ack-id".
650        assert_eq!(sorted(extend_req.ack_ids), test_ids(1..7));
651
652        Ok(())
653    }
654
655    #[tokio::test]
656    async fn permanent_error_midstream() -> anyhow::Result<()> {
657        let (response_tx, response_rx) = channel(10);
658
659        let mut mock = MockSubscriber::new();
660        mock.expect_streaming_pull()
661            .return_once(|_| Ok(tonic::Response::from(response_rx)));
662        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
663        let client = test_client(endpoint).await?;
664        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
665
666        response_tx.send(Ok(test_response(1..4))).await?;
667        response_tx
668            .send(Err(tonic::Status::internal("fail")))
669            .await?;
670        drop(response_tx);
671
672        for i in 1..4 {
673            let (m, Handler::AtLeastOnce(h)) =
674                session.next().await.transpose()?.expect("message {i}/3");
675            assert_eq!(m.data, test_data(i));
676            assert_eq!(h.ack_id, test_id(i));
677        }
678        let err = session
679            .next()
680            .await
681            .transpose()
682            .expect_err("expected an error from stream");
683        assert!(err.status().is_some(), "{err:?}");
684        let status = err.status().unwrap();
685        assert_eq!(status.code, gax::error::rpc::Code::Internal);
686        assert_eq!(status.message, "fail");
687
688        Ok(())
689    }
690
691    #[tokio::test(start_paused = true)]
692    async fn keepalives() -> anyhow::Result<()> {
693        // We use this channel to surface writes (requests) from outside our
694        // mock expectation.
695        let (recover_writes_tx, mut recover_writes_rx) = channel(1);
696        let (response_tx, response_rx) = channel(10);
697
698        let mut mock = MockSubscriber::new();
699        mock.expect_streaming_pull().return_once(move |request| {
700            tokio::spawn(async move {
701                // Note that this task stays alive as long as we hold
702                // `recover_writes_rx`.
703                let mut request_rx = request.into_inner();
704                while let Some(request) = request_rx.recv().await {
705                    recover_writes_tx
706                        .send(request)
707                        .await
708                        .expect("forwarding writes always succeeds");
709                }
710            });
711            Ok(tonic::Response::from(response_rx))
712        });
713
714        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
715        let client = test_client(endpoint).await?;
716        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
717        response_tx.send(Ok(test_response(1..4))).await?;
718        let _ = session.next().await;
719
720        let initial_req = recover_writes_rx
721            .recv()
722            .await
723            .expect("should receive an initial request")?;
724        assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
725
726        // Verify that we receive at least one keepalive request on the stream.
727        tokio::time::advance(KEEPALIVE_PERIOD).await;
728        let keepalive_req = recover_writes_rx
729            .recv()
730            .await
731            .expect("should receive a keepalive request")?;
732        assert_eq!(keepalive_req, v1::StreamingPullRequest::default());
733
734        // Drop the session, which should signal a shutdown of the keepalive
735        // task.
736        drop(session);
737
738        // Advance the time far enough to expect a keepalive ping, if the
739        // keepalive task was still running.
740        tokio::time::advance(4 * KEEPALIVE_PERIOD).await;
741        assert!(recover_writes_rx.is_empty());
742
743        Ok(())
744    }
745
746    #[tokio::test]
747    async fn client_id() -> anyhow::Result<()> {
748        // We use this channel to surface writes (requests) from outside our
749        // mock expectation.
750        let (recover_writes_tx, mut recover_writes_rx) = channel(10);
751        let recover_writes_tx = std::sync::Arc::new(tokio::sync::Mutex::new(recover_writes_tx));
752
753        let mut mock = MockSubscriber::new();
754        mock.expect_streaming_pull()
755            .times(3)
756            .returning(move |request| {
757                let tx = recover_writes_tx.clone();
758                tokio::spawn(async move {
759                    // Note that this task stays alive as long as we hold
760                    // `recover_writes_rx`.
761                    let mut request_rx = request.into_inner();
762                    while let Some(request) = request_rx.recv().await {
763                        tx.lock()
764                            .await
765                            .send(request)
766                            .await
767                            .expect("forwarding writes always succeeds");
768                    }
769                });
770                Err(tonic::Status::internal("fail"))
771            });
772
773        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
774
775        // Make two requests with the same client. The requests should have the
776        // same client ID.
777        let c1 = test_client(endpoint.clone()).await?;
778        let _ = c1
779            .streaming_pull("projects/p/subscriptions/s")
780            .start()
781            .next()
782            .await;
783        let req1 = recover_writes_rx
784            .recv()
785            .await
786            .expect("should receive a request")?;
787        let _ = c1
788            .streaming_pull("projects/p/subscriptions/s")
789            .start()
790            .next()
791            .await;
792        let req2 = recover_writes_rx
793            .recv()
794            .await
795            .expect("should receive a request")?;
796        assert_eq!(req1.client_id, req2.client_id);
797
798        // Make a third request with a different client. This request should
799        // have a different client ID.
800        let c2 = test_client(endpoint).await?;
801        let _ = c2
802            .streaming_pull("projects/p/subscriptions/s")
803            .start()
804            .next()
805            .await;
806        let req3 = recover_writes_rx
807            .recv()
808            .await
809            .expect("should receive a request")?;
810        assert_ne!(req1.client_id, req3.client_id);
811
812        Ok(())
813    }
814
815    #[tokio::test(start_paused = true)]
816    async fn no_immediate_message() -> anyhow::Result<()> {
817        const TEST_TIMEOUT: Duration = Duration::from_secs(42);
818
819        let (_response_tx, response_rx) = channel(10);
820
821        let mut mock = MockSubscriber::new();
822        mock.expect_streaming_pull()
823            .return_once(move |_| Ok(tonic::Response::from(response_rx)));
824
825        let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
826        let client = test_client(endpoint).await?;
827        let mut session = client.streaming_pull("projects/p/subscriptions/s").start();
828
829        let _ = tokio::time::timeout(TEST_TIMEOUT, session.next())
830            .await
831            .expect_err("next() should never yield.");
832
833        Ok(())
834    }
835}