use super::builder::Subscribe;
use super::handler::{AckResult, Action, AtLeastOnce, ExactlyOnce, Handler};
use super::lease_loop::LeaseLoop;
use super::lease_state::{AtLeastOnceInfo, ExactlyOnceInfo, LeaseInfo, LeaseOptions, NewMessage};
use super::leaser::DefaultLeaser;
use super::retry_policy::StreamRetryPolicy;
use super::shutdown_token::ShutdownToken;
use super::stream::Stream;
use super::stub::TonicStreaming as _;
use super::transport::Transport;
use crate::google::pubsub::v1::{StreamingPullRequest, StreamingPullResponse};
use crate::model::Message;
use crate::{Error, Result};
use futures::FutureExt;
use futures::future::{BoxFuture, Shared};
use gaxi::grpc::from_status::to_gax_error;
use gaxi::prost::FromProto as _;
use google_cloud_gax::retry_result::RetryResult;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::mpsc::{UnboundedSender, WeakUnboundedSender, unbounded_channel};
use tokio::sync::oneshot::Receiver;
use tokio::time::Duration;
use tokio_util::sync::{CancellationToken, DropGuard};
#[derive(Debug)]
pub struct MessageStream {
inner: MessageStreamImpl,
lease_loop: Shared<BoxFuture<'static, ()>>,
shutdown: CancellationToken,
_shutdown_guard: DropGuard,
}
#[derive(Debug)]
pub struct MessageStreamImpl {
stub: Arc<Transport>,
initial_req: StreamingPullRequest,
stream: Option<StreamState>,
pool: VecDeque<(Message, HandlerInfo)>,
message_tx: WeakUnboundedSender<NewMessage>,
ack_tx: WeakUnboundedSender<Action>,
shutdown: CancellationToken,
}
#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
enum StreamState {
Closed,
Active(Stream<Transport>),
}
impl MessageStream {
pub(super) fn new(builder: Subscribe) -> Self {
let stub = builder.inner;
let subscription = builder.subscription;
let (confirmed_tx, confirmed_rx) = unbounded_channel();
let leaser = DefaultLeaser::new(
stub.clone(),
confirmed_tx,
subscription.clone(),
builder.ack_deadline_seconds,
builder.grpc_subchannel_count,
);
let options = LeaseOptions {
max_lease: builder.max_lease,
max_lease_extension: Duration::from_secs(builder.ack_deadline_seconds as u64),
shutdown_behavior: builder.shutdown_behavior,
..Default::default()
};
let LeaseLoop {
handle,
message_tx,
ack_tx,
} = LeaseLoop::new(leaser, confirmed_rx, options);
let lease_loop = handle.map(|_| ()).boxed().shared();
let weak_message_tx = message_tx.downgrade();
let weak_ack_tx = ack_tx.downgrade();
let shutdown = CancellationToken::new();
let shutdown_clone = shutdown.clone();
let _shutdown_guard = shutdown.clone().drop_guard();
tokio::spawn(async move {
shutdown_clone.cancelled().await;
drop(message_tx);
drop(ack_tx);
});
let initial_req = StreamingPullRequest {
subscription,
stream_ack_deadline_seconds: builder.ack_deadline_seconds,
max_outstanding_messages: builder.max_outstanding_messages,
max_outstanding_bytes: builder.max_outstanding_bytes,
client_id: builder.client_id,
protocol_version: 1,
..Default::default()
};
let inner = MessageStreamImpl {
stub,
initial_req,
stream: None,
pool: VecDeque::new(),
message_tx: weak_message_tx,
ack_tx: weak_ack_tx,
shutdown: shutdown.clone(),
};
Self {
inner,
lease_loop,
shutdown,
_shutdown_guard,
}
}
pub async fn next(&mut self) -> Option<Result<(Message, Handler)>> {
let next = tokio::select! {
biased;
_ = self.shutdown.cancelled() => {
self.inner.close();
None
},
n = self.inner.next() => n,
};
next
}
#[cfg(feature = "unstable-stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "unstable-stream")))]
pub fn into_stream(self) -> impl futures::Stream<Item = Result<(Message, Handler)>> + Unpin {
use futures::stream::unfold;
Box::pin(unfold(self, |mut stream| async move {
stream.next().await.map(|item| (item, stream))
}))
}
pub fn shutdown_token(&self) -> ShutdownToken {
ShutdownToken {
inner: self.shutdown.clone(),
fut: self.lease_loop.clone(),
}
}
}
impl MessageStreamImpl {
async fn next(&mut self) -> Option<Result<(Message, Handler)>> {
loop {
if let Some((m, hi)) = self.pool.pop_front() {
return Some(Ok((m, hi.into_handler(self.ack_tx.upgrade()?))));
}
if let Err(e) = self.populate_pool().await? {
match StreamRetryPolicy::on_midstream_error(e) {
RetryResult::Continue(_) => {
self.stream = None;
continue;
}
RetryResult::Permanent(e) | RetryResult::Exhausted(e) => {
self.close();
return Some(Err(e));
}
}
}
}
}
async fn open_stream(&mut self) -> Result<()> {
let stream = Stream::<Transport>::new(self.stub.clone(), self.initial_req.clone()).await?;
self.stream = Some(StreamState::Active(stream));
Ok(())
}
async fn next_response(&mut self) -> Option<Result<StreamingPullResponse>> {
if self.stream.is_none() {
if let Err(e) = self.open_stream().await {
return Some(Err(e));
}
}
let stream = match self.stream.as_mut()? {
StreamState::Closed => return None,
StreamState::Active(s) => s,
};
stream
.next_message()
.await
.map_err(to_gax_error)
.transpose()
}
async fn populate_pool(&mut self) -> Option<Result<()>> {
let resp = match self.next_response().await? {
Ok(resp) => resp,
Err(e) => return Some(Err(e)),
};
let exactly_once = resp
.subscription_properties
.is_some_and(|m| m.exactly_once_delivery_enabled);
for rm in resp.received_messages {
let Some(message) = rm.message else {
continue;
};
let (lease_info, handler_info) = if exactly_once {
let (result_tx, result_rx) = tokio::sync::oneshot::channel();
(
LeaseInfo::ExactlyOnce(ExactlyOnceInfo::new(result_tx)),
HandlerInfo::ExactlyOnce {
ack_id: rm.ack_id.clone(),
result_rx,
},
)
} else {
(
LeaseInfo::AtLeastOnce(AtLeastOnceInfo::new()),
HandlerInfo::AtLeastOnce {
ack_id: rm.ack_id.clone(),
},
)
};
let _ = self.message_tx.upgrade()?.send(NewMessage {
ack_id: rm.ack_id,
lease_info,
});
let message = match message.cnv().map_err(Error::deser) {
Ok(message) => message,
Err(e) => return Some(Err(e)),
};
self.pool.push_back((message, handler_info));
}
Some(Ok(()))
}
fn close(&mut self) {
self.stream = Some(StreamState::Closed);
self.pool.clear();
self.shutdown.cancel();
}
}
#[derive(Debug)]
enum HandlerInfo {
AtLeastOnce {
ack_id: String,
},
ExactlyOnce {
ack_id: String,
result_rx: Receiver<AckResult>,
},
}
impl HandlerInfo {
fn into_handler(self, ack_tx: UnboundedSender<Action>) -> Handler {
match self {
HandlerInfo::AtLeastOnce { ack_id } => {
Handler::AtLeastOnce(AtLeastOnce::new(ack_id, ack_tx))
}
HandlerInfo::ExactlyOnce { ack_id, result_rx } => {
Handler::ExactlyOnce(ExactlyOnce::new(ack_id, ack_tx, result_rx))
}
}
}
}
#[cfg(test)]
mod tests {
use super::super::ShutdownBehavior;
use super::super::client::Subscriber;
use super::super::keepalive::KEEPALIVE_PERIOD;
use super::super::lease_state::tests::{test_id, test_ids};
use super::super::stream::{INITIAL_DELAY, MAXIMUM_DELAY};
use super::*;
use gaxi::grpc::tonic::{Response as TonicResponse, Status as TonicStatus};
use google_cloud_auth::credentials::anonymous::Builder as Anonymous;
use google_cloud_test_macros::tokio_test_no_panics;
use pubsub_grpc_mock::google::pubsub::v1;
use pubsub_grpc_mock::{MockSubscriber, start};
use test_case::test_case;
use tokio::sync::mpsc::{channel, unbounded_channel};
use tokio::task::{JoinHandle, JoinSet};
use tokio::time::{Duration, Instant};
fn sorted(mut v: Vec<String>) -> Vec<String> {
v.sort();
v
}
fn test_data(v: i32) -> bytes::Bytes {
bytes::Bytes::from(format!("data-{}", test_id(v)))
}
fn test_response(range: std::ops::Range<i32>) -> v1::StreamingPullResponse {
v1::StreamingPullResponse {
received_messages: range
.into_iter()
.map(|i| v1::ReceivedMessage {
ack_id: test_id(i),
message: Some(v1::PubsubMessage {
data: test_data(i).to_vec(),
..Default::default()
}),
..Default::default()
})
.collect(),
..Default::default()
}
}
fn test_exactly_once_response(range: std::ops::Range<i32>) -> v1::StreamingPullResponse {
v1::StreamingPullResponse {
subscription_properties: Some(v1::streaming_pull_response::SubscriptionProperties {
exactly_once_delivery_enabled: true,
..Default::default()
}),
received_messages: range
.into_iter()
.map(|i| v1::ReceivedMessage {
ack_id: test_id(i),
message: Some(v1::PubsubMessage {
data: test_data(i).to_vec(),
..Default::default()
}),
..Default::default()
})
.collect(),
..Default::default()
}
}
async fn test_client(endpoint: String) -> anyhow::Result<Subscriber> {
Ok(Subscriber::builder()
.with_endpoint(endpoint)
.with_credentials(Anonymous::new().build())
.build()
.await?)
}
#[tokio_test_no_panics]
async fn error_starting_stream() -> anyhow::Result<()> {
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Err(TonicStatus::failed_precondition("fail")));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
let err = stream
.next()
.await
.expect("stream should not be empty")
.expect_err("the first streamed item should be an error");
assert!(err.status().is_some(), "{err:?}");
let status = err.status().unwrap();
assert_eq!(
status.code,
google_cloud_gax::error::rpc::Code::FailedPrecondition
);
assert_eq!(status.message, "fail");
Ok(())
}
#[tokio_test_no_panics]
async fn permanent_error_ends_stream() -> anyhow::Result<()> {
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.returning(|_| Err(TonicStatus::failed_precondition("fail")));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
let next = stream.next().await;
assert!(
matches!(next, Some(Err(_))),
"expected permanent error, got {next:?}"
);
let next = stream.next().await;
assert!(next.is_none(), "expected end of stream, got {next:?}");
Ok(())
}
#[tokio_test_no_panics]
async fn initial_request() -> anyhow::Result<()> {
const MIB: i64 = 1024 * 1024;
let (recover_writes_tx, mut recover_writes_rx) = channel(1);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull().return_once(move |request| {
tokio::spawn(async move {
let mut request_rx = request.into_inner();
while let Some(request) = request_rx.recv().await {
recover_writes_tx
.send(request)
.await
.expect("forwarding writes always succeeds");
}
});
Err(TonicStatus::failed_precondition("fail"))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let _ = client
.subscribe("projects/p/subscriptions/s")
.set_max_lease_extension(Duration::from_secs(20))
.set_max_outstanding_messages(2000)
.set_max_outstanding_bytes(200 * MIB)
.build()
.next()
.await;
let initial_req = recover_writes_rx
.recv()
.await
.expect("should receive a request")?;
assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
assert_eq!(initial_req.stream_ack_deadline_seconds, 20);
assert_eq!(initial_req.max_outstanding_messages, 2000);
assert_eq!(initial_req.max_outstanding_bytes, 200 * MIB);
assert!(
!initial_req.client_id.is_empty(),
"initial request has empty client id: {initial_req:?}"
);
assert!(
initial_req.protocol_version >= 1,
"protocol_version={}",
initial_req.protocol_version
);
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn basic_success() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
response_tx.send(Ok(test_response(1..2))).await?;
response_tx.send(Ok(test_response(2..4))).await?;
response_tx.send(Ok(test_response(4..7))).await?;
drop(response_tx);
for i in 1..7 {
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/6")
};
assert_eq!(m.data, test_data(i));
assert_eq!(h.ack_id(), test_id(i));
h.ack();
}
let end = stream.next().await.transpose()?;
assert!(end.is_none(), "Received extra message: {end:?}");
stream.shutdown_token().shutdown().await;
let ack_req = ack_rx.try_recv()?;
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(sorted(ack_req.ack_ids), test_ids(1..7));
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn basic_success_exactly_once() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
mock.expect_modify_ack_deadline()
.returning(|_| Ok(TonicResponse::from(())));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_shutdown_behavior(ShutdownBehavior::WaitForProcessing)
.build();
response_tx
.send(Ok(test_exactly_once_response(1..2)))
.await?;
response_tx
.send(Ok(test_exactly_once_response(2..4)))
.await?;
response_tx
.send(Ok(test_exactly_once_response(4..7)))
.await?;
drop(response_tx);
let mut acks = JoinSet::new();
for i in 1..7 {
let Some((m, Handler::ExactlyOnce(h))) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/6")
};
assert_eq!(m.data, test_data(i));
assert_eq!(h.ack_id(), test_id(i));
acks.spawn(h.confirmed_ack());
}
let end = stream.next().await.transpose()?;
assert!(end.is_none(), "Received extra message: {end:?}");
stream.shutdown_token().shutdown().await;
while let Some(r) = acks.join_next().await {
r??;
}
let ack_req = ack_rx.try_recv()?;
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(sorted(ack_req.ack_ids), test_ids(1..7));
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn basic_lease_management() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let (nack_tx, mut nack_rx) = unbounded_channel();
let (extend_tx, mut extend_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
mock.expect_modify_ack_deadline().returning(move |r| {
let r = r.into_inner();
if r.ack_deadline_seconds == 0 {
nack_tx.send(r).expect("sending on channel always succeeds");
} else {
extend_tx
.send(r)
.expect("sending on channel always succeeds");
}
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_max_lease_extension(Duration::from_secs(10))
.set_shutdown_behavior(ShutdownBehavior::NackImmediately)
.build();
response_tx.send(Ok(test_response(0..30))).await?;
drop(response_tx);
for i in 0..10 {
let Some((_, Handler::AtLeastOnce(h))) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}")
};
h.ack();
}
for i in 10..20 {
let Some((_, Handler::AtLeastOnce(h))) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}")
};
h.nack();
}
let mut hold = Vec::new();
for i in 20..30 {
let Some((_, Handler::AtLeastOnce(h))) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}")
};
hold.push(h);
}
tokio::time::advance(Duration::from_secs(10)).await;
stream.shutdown_token().shutdown().await;
let ack_req = ack_rx.try_recv()?;
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(sorted(ack_req.ack_ids), test_ids(0..10));
assert!(ack_rx.is_empty(), "{ack_rx:?}");
let nack_req = nack_rx.try_recv()?;
assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(nack_req.ack_deadline_seconds, 0);
assert_eq!(sorted(nack_req.ack_ids), test_ids(10..20));
let nack_req = nack_rx.try_recv()?;
assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(nack_req.ack_deadline_seconds, 0);
assert_eq!(sorted(nack_req.ack_ids), test_ids(20..30));
assert!(nack_rx.is_empty(), "{nack_rx:?}");
let extend_req = extend_rx.try_recv()?;
assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
assert_eq!(extend_req.ack_deadline_seconds, 10);
assert_eq!(sorted(extend_req.ack_ids), test_ids(20..30));
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn delayed_responses() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let handle: JoinHandle<anyhow::Result<()>> = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(20)).await;
response_tx.send(Ok(test_response(1..2))).await?;
Ok(())
});
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_modify_ack_deadline()
.returning(|_| Ok(TonicResponse::from(())));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
let (m, h) = stream
.next()
.await
.transpose()?
.expect("stream should wait for a message");
assert_eq!(m.data, test_data(1));
assert_eq!(h.ack_id(), test_id(1));
handle.await??;
Ok(())
}
#[tokio_test_no_panics]
async fn serves_messages_immediately() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_modify_ack_deadline()
.returning(|_| Ok(TonicResponse::from(())));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
for i in 1..7 {
response_tx.send(Ok(test_response(i..i + 1))).await?;
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/6")
};
assert_eq!(m.data, test_data(i));
assert_eq!(h.ack_id(), test_id(i));
}
drop(response_tx);
let end = stream.next().await.transpose()?;
assert!(end.is_none(), "Received extra message: {end:?}");
Ok(())
}
#[tokio_test_no_panics]
async fn handles_empty_response() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_modify_ack_deadline()
.returning(|_| Ok(TonicResponse::from(())));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
response_tx.send(Ok(test_response(1..2))).await?;
response_tx.send(Ok(test_response(2..2))).await?;
response_tx.send(Ok(test_response(2..3))).await?;
drop(response_tx);
for i in 1..3 {
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/2")
};
assert_eq!(m.data, test_data(i));
assert_eq!(h.ack_id(), test_id(i));
}
let end = stream.next().await.transpose()?;
assert!(end.is_none(), "Received extra message: {end:?}");
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn handles_missing_message_field() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (extend_tx, mut extend_rx) = unbounded_channel();
let bad = v1::StreamingPullResponse {
received_messages: vec![v1::ReceivedMessage {
ack_id: "ignored-ack-id".to_string(),
message: None,
..Default::default()
}],
..Default::default()
};
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_modify_ack_deadline().returning(move |r| {
let r = r.into_inner();
if r.ack_deadline_seconds != 0 {
extend_tx
.send(r)
.expect("sending on channel always succeeds");
}
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_max_lease_extension(Duration::from_secs(10))
.set_shutdown_behavior(ShutdownBehavior::NackImmediately)
.build();
response_tx.send(Ok(test_response(1..4))).await?;
response_tx.send(Ok(bad)).await?;
response_tx.send(Ok(test_response(4..7))).await?;
drop(response_tx);
let mut handlers = Vec::new();
for i in 1..7 {
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/6")
};
assert_eq!(m.data, test_data(i));
assert_eq!(h.ack_id(), test_id(i));
handlers.push(h);
}
tokio::time::advance(Duration::from_secs(10)).await;
stream.shutdown_token().shutdown().await;
let extend_req = extend_rx.try_recv()?;
assert_eq!(extend_req.subscription, "projects/p/subscriptions/s");
assert_eq!(extend_req.ack_deadline_seconds, 10);
assert_eq!(sorted(extend_req.ack_ids), test_ids(1..7));
Ok(())
}
#[tokio_test_no_panics]
async fn permanent_error_midstream() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
response_tx.send(Ok(test_response(1..4))).await?;
response_tx
.send(Err(TonicStatus::failed_precondition("fail")))
.await?;
drop(response_tx);
for i in 1..4 {
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/3")
};
assert_eq!(m.data, test_data(i));
assert_eq!(h.ack_id(), test_id(i));
}
let err = stream
.next()
.await
.transpose()
.expect_err("expected an error from stream");
assert!(err.status().is_some(), "{err:?}");
let status = err.status().unwrap();
assert_eq!(
status.code,
google_cloud_gax::error::rpc::Code::FailedPrecondition
);
assert_eq!(status.message, "fail");
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn keepalives() -> anyhow::Result<()> {
let (recover_writes_tx, mut recover_writes_rx) = channel(1);
let (response_tx, response_rx) = channel(10);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull().return_once(move |request| {
tokio::spawn(async move {
let mut request_rx = request.into_inner();
while let Some(request) = request_rx.recv().await {
recover_writes_tx
.send(request)
.await
.expect("forwarding writes always succeeds");
}
});
Ok(TonicResponse::from(response_rx))
});
mock.expect_modify_ack_deadline()
.returning(|_| Ok(TonicResponse::from(())));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
response_tx.send(Ok(test_response(1..4))).await?;
let _ = stream.next().await;
let initial_req = recover_writes_rx
.recv()
.await
.expect("should receive an initial request")?;
assert_eq!(initial_req.subscription, "projects/p/subscriptions/s");
tokio::time::advance(KEEPALIVE_PERIOD).await;
let keepalive_req = recover_writes_rx
.recv()
.await
.expect("should receive a keepalive request")?;
assert_eq!(keepalive_req, v1::StreamingPullRequest::default());
drop(stream);
tokio::time::advance(4 * KEEPALIVE_PERIOD).await;
assert!(recover_writes_rx.is_empty(), "{recover_writes_rx:?}");
Ok(())
}
#[tokio_test_no_panics]
async fn client_id() -> anyhow::Result<()> {
let (recover_writes_tx, mut recover_writes_rx) = channel(10);
let recover_writes_tx = std::sync::Arc::new(tokio::sync::Mutex::new(recover_writes_tx));
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.times(3)
.returning(move |request| {
let tx = recover_writes_tx.clone();
tokio::spawn(async move {
let mut request_rx = request.into_inner();
while let Some(request) = request_rx.recv().await {
tx.lock()
.await
.send(request)
.await
.expect("forwarding writes always succeeds");
}
});
Err(TonicStatus::failed_precondition("fail"))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let c1 = test_client(endpoint.clone()).await?;
let _ = c1
.subscribe("projects/p/subscriptions/s")
.build()
.next()
.await;
let req1 = recover_writes_rx
.recv()
.await
.expect("should receive a request")?;
let _ = c1
.subscribe("projects/p/subscriptions/s")
.build()
.next()
.await;
let req2 = recover_writes_rx
.recv()
.await
.expect("should receive a request")?;
assert_eq!(req1.client_id, req2.client_id);
let c2 = test_client(endpoint).await?;
let _ = c2
.subscribe("projects/p/subscriptions/s")
.build()
.next()
.await;
let req3 = recover_writes_rx
.recv()
.await
.expect("should receive a request")?;
assert_ne!(req1.client_id, req3.client_id);
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn no_immediate_message() -> anyhow::Result<()> {
const TEST_TIMEOUT: Duration = Duration::from_secs(42);
let (_response_tx, response_rx) = channel(10);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(move |_| Ok(TonicResponse::from(response_rx)));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
let _ = tokio::time::timeout(TEST_TIMEOUT, stream.next())
.await
.expect_err("next() should never yield.");
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn retry_transient_when_starting_stream() -> anyhow::Result<()> {
const NUM_RETRIES: u32 = 20;
let start_time = Instant::now();
let mut seq = mockall::Sequence::new();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.times(NUM_RETRIES as usize)
.in_sequence(&mut seq)
.returning(|_| Err(TonicStatus::unavailable("try again")));
mock.expect_streaming_pull()
.times(1)
.in_sequence(&mut seq)
.return_once(|_| Err(TonicStatus::failed_precondition("fail")));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
let err = stream
.next()
.await
.expect("stream should not be empty")
.expect_err("the first streamed item should be an error");
assert!(err.status().is_some(), "{err:?}");
let status = err.status().unwrap();
assert_eq!(
status.code,
google_cloud_gax::error::rpc::Code::FailedPrecondition
);
assert_eq!(status.message, "fail");
let elapsed = start_time.elapsed();
assert!(
elapsed <= MAXIMUM_DELAY * NUM_RETRIES,
"elapsed={elapsed:?}"
);
assert!(
elapsed >= INITIAL_DELAY * NUM_RETRIES,
"elapsed={elapsed:?}"
);
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn resume_midstream_success() -> anyhow::Result<()> {
let (response_tx_1, response_rx_1) = channel(10);
let (response_tx_2, response_rx_2) = channel(10);
let (response_tx_3, response_rx_3) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let mut seq = mockall::Sequence::new();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.times(1)
.in_sequence(&mut seq)
.return_once(|_| Ok(TonicResponse::from(response_rx_1)));
mock.expect_streaming_pull()
.times(1)
.in_sequence(&mut seq)
.return_once(move |_| Ok(TonicResponse::from(response_rx_2)));
mock.expect_streaming_pull()
.times(1)
.in_sequence(&mut seq)
.return_once(|_| Ok(TonicResponse::from(response_rx_3)));
mock.expect_acknowledge().times(1..).returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
response_tx_1.send(Ok(test_response(0..10))).await?;
response_tx_1.send(Ok(test_response(10..20))).await?;
response_tx_1
.send(Err(TonicStatus::unavailable("GFE disconnect. try again")))
.await?;
drop(response_tx_1);
response_tx_2.send(Ok(test_response(20..30))).await?;
response_tx_2.send(Ok(test_response(30..40))).await?;
response_tx_2
.send(Err(TonicStatus::unavailable("GFE disconnect. try again")))
.await?;
drop(response_tx_2);
response_tx_3.send(Ok(test_response(40..50))).await?;
drop(response_tx_3);
for i in 0..50 {
let (m, h) = stream
.next()
.await
.unwrap_or_else(|| panic!("expected message {}/50", i + 1))?;
assert_eq!(m.data, test_data(i));
h.ack();
}
let end = stream.next().await.transpose()?;
assert!(end.is_none(), "Received extra message: {end:?}");
stream.shutdown_token().shutdown().await;
let mut got = Vec::new();
while let Ok(ack_req) = ack_rx.try_recv() {
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
got.extend(ack_req.ack_ids);
}
assert_eq!(sorted(got), test_ids(0..50));
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn resume_midstream_hits_permanent_error() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let mut seq = mockall::Sequence::new();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.times(1)
.in_sequence(&mut seq)
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_streaming_pull()
.times(3)
.in_sequence(&mut seq)
.returning(|_| Err(TonicStatus::unavailable("try again")));
mock.expect_streaming_pull()
.times(1)
.in_sequence(&mut seq)
.return_once(|_| Err(TonicStatus::failed_precondition("fail")));
mock.expect_acknowledge().times(1..).returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
response_tx.send(Ok(test_response(0..10))).await?;
response_tx.send(Ok(test_response(10..20))).await?;
response_tx
.send(Err(TonicStatus::unavailable("GFE disconnect. try again")))
.await?;
drop(response_tx);
for i in 0..20 {
let (m, h) = stream
.next()
.await
.unwrap_or_else(|| panic!("expected message {}/20", i + 1))?;
assert_eq!(m.data, test_data(i));
h.ack();
}
let err = stream
.next()
.await
.transpose()
.expect_err("expected an error from stream");
assert!(err.status().is_some(), "{err:?}");
let status = err.status().unwrap();
assert_eq!(
status.code,
google_cloud_gax::error::rpc::Code::FailedPrecondition
);
assert_eq!(status.message, "fail");
stream.shutdown_token().shutdown().await;
let mut got = Vec::new();
while let Ok(ack_req) = ack_rx.try_recv() {
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
got.extend(ack_req.ack_ids);
}
assert_eq!(sorted(got), test_ids(0..20));
Ok(())
}
#[tokio_test_no_panics]
async fn routing_header() -> anyhow::Result<()> {
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull().return_once(move |request| {
let metadata = request.metadata();
assert_eq!(
metadata
.get("x-goog-request-params")
.expect("routing header missing"),
"subscription=projects/p/subscriptions/s"
);
Err(TonicStatus::failed_precondition("ignored"))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let _ = client
.subscribe("projects/p/subscriptions/s")
.build()
.next()
.await;
Ok(())
}
#[cfg(feature = "unstable-stream")]
#[tokio_test_no_panics(start_paused = true)]
async fn into_stream() -> anyhow::Result<()> {
use futures::TryStreamExt;
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let stream = client
.subscribe("projects/p/subscriptions/s")
.build()
.into_stream();
response_tx.send(Ok(test_response(1..3))).await?;
drop(response_tx);
let got: Vec<_> = stream
.map_ok(|(m, h)| {
h.ack();
m.data
})
.try_collect()
.await?;
assert_eq!(got, vec![test_data(1), test_data(2)]);
let ack_req = ack_rx
.recv()
.await
.expect("should receive acknowledgements");
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(sorted(ack_req.ack_ids), test_ids(1..3));
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn basic_lease_expiration() -> anyhow::Result<()> {
const MAX_LEASE_EXTENSION: Duration = Duration::from_secs(10);
const MAX_LEASE: Duration = Duration::from_secs(30);
let start_time = Instant::now();
let (response_tx, response_rx) = channel(10);
let (extend_tx, mut extend_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_modify_ack_deadline().returning(move |r| {
extend_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_max_lease(MAX_LEASE)
.set_max_lease_extension(MAX_LEASE_EXTENSION)
.set_shutdown_behavior(ShutdownBehavior::NackImmediately)
.build();
response_tx.send(Ok(test_response(0..1))).await?;
drop(response_tx);
let (_m, _h) = stream
.next()
.await
.expect("stream should yield a message")?;
let mut latest = None;
for _ in 0..MAX_LEASE.as_secs() * 2 {
while let Ok(r) = extend_rx.try_recv() {
assert_ne!(r.ack_deadline_seconds, 0, "unexpectedly received a nack");
latest = Some(start_time.elapsed());
}
tokio::time::advance(Duration::from_secs(1)).await;
tokio::task::yield_now().await;
}
let expected_range = (MAX_LEASE - MAX_LEASE_EXTENSION)..=MAX_LEASE;
assert!(
latest.is_some_and(|t| expected_range.contains(&t)),
"{latest:?}"
);
stream.shutdown_token().shutdown().await;
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn shutdown_wait_for_processing() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge()
.times(1)
.returning(|_| Ok(TonicResponse::from(())));
mock.expect_modify_ack_deadline()
.returning(|_| Ok(TonicResponse::from(())));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_shutdown_behavior(ShutdownBehavior::WaitForProcessing)
.build();
response_tx.send(Ok(test_response(0..1))).await?;
drop(response_tx);
let (_m, h) = stream
.next()
.await
.expect("stream should yield a message")?;
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(5)).await;
h.ack();
});
stream.shutdown_token().shutdown().await;
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn at_least_once_and_exactly_once() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_modify_ack_deadline()
.returning(|_| Ok(TonicResponse::from(())));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_shutdown_behavior(ShutdownBehavior::NackImmediately)
.build();
response_tx.send(Ok(test_response(0..1))).await?;
response_tx
.send(Ok(test_exactly_once_response(1..2)))
.await?;
response_tx.send(Ok(test_response(2..3))).await?;
response_tx
.send(Ok(test_exactly_once_response(3..4)))
.await?;
drop(response_tx);
let (m, h) = stream.next().await.expect("should yield a message")?;
assert_eq!(m.data, test_data(0));
assert_eq!(h.ack_id(), test_id(0));
assert!(matches!(h, Handler::AtLeastOnce(_)), "{h:?}");
let (m, h) = stream.next().await.expect("should yield a message")?;
assert_eq!(m.data, test_data(1));
assert_eq!(h.ack_id(), test_id(1));
assert!(matches!(h, Handler::ExactlyOnce(_)), "{h:?}");
let (m, h) = stream.next().await.expect("should yield a message")?;
assert_eq!(m.data, test_data(2));
assert_eq!(h.ack_id(), test_id(2));
assert!(matches!(h, Handler::AtLeastOnce(_)), "{h:?}");
let (m, h) = stream.next().await.expect("should yield a message")?;
assert_eq!(m.data, test_data(3));
assert_eq!(h.ack_id(), test_id(3));
assert!(matches!(h, Handler::ExactlyOnce(_)), "{h:?}");
let end = stream.next().await.transpose()?;
assert!(end.is_none(), "Received extra message: {end:?}");
stream.shutdown_token().shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn cancel_before_open() -> anyhow::Result<()> {
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.returning(|_| Err(TonicStatus::unavailable("try again")));
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client.subscribe("projects/p/subscriptions/s").build();
let shutdown_token = stream.shutdown_token();
let next = tokio::spawn(async move { stream.next().await });
shutdown_token.shutdown().await;
let end = next.await?;
assert!(end.is_none(), "Shutdown should end the stream, got {end:?}");
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn cancel_midstream() -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let (nack_tx, mut nack_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().times(1).returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
mock.expect_modify_ack_deadline()
.times(1)
.returning(move |r| {
nack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_shutdown_behavior(ShutdownBehavior::WaitForProcessing)
.build();
let shutdown_token = stream.shutdown_token();
response_tx.send(Ok(test_response(1..10))).await?;
for i in 1..6 {
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/5")
};
assert_eq!(m.data, test_data(i));
h.ack();
}
let shutdown = tokio::spawn(async move {
shutdown_token.shutdown().await;
});
tokio::task::yield_now().await;
let end = stream.next().await.transpose()?;
assert!(end.is_none(), "Shutdown should end the stream, got {end:?}");
shutdown.await?;
let ack_req = ack_rx.try_recv()?;
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(sorted(ack_req.ack_ids), test_ids(1..6));
let nack_req = nack_rx.try_recv()?;
assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(nack_req.ack_deadline_seconds, 0);
assert_eq!(sorted(nack_req.ack_ids), test_ids(6..10));
Ok(())
}
#[test_case(ShutdownBehavior::NackImmediately)]
#[test_case(ShutdownBehavior::WaitForProcessing)]
#[tokio_test_no_panics(start_paused = true)]
async fn shutdown_without_next(shutdown_behavior: ShutdownBehavior) -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let (nack_tx, mut nack_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().times(1).returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
mock.expect_modify_ack_deadline()
.times(1)
.returning(move |r| {
nack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_shutdown_behavior(shutdown_behavior)
.build();
let shutdown_token = stream.shutdown_token();
response_tx.send(Ok(test_response(1..10))).await?;
for i in 1..6 {
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/5")
};
assert_eq!(m.data, test_data(i));
h.ack();
}
shutdown_token.shutdown().await;
let ack_req = ack_rx.try_recv()?;
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(sorted(ack_req.ack_ids), test_ids(1..6));
let nack_req = nack_rx.try_recv()?;
assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(nack_req.ack_deadline_seconds, 0);
assert_eq!(sorted(nack_req.ack_ids), test_ids(6..10));
Ok(())
}
#[test_case(ShutdownBehavior::NackImmediately)]
#[test_case(ShutdownBehavior::WaitForProcessing)]
#[tokio_test_no_panics(start_paused = true)]
async fn stream_error_initiates_shutdown(
shutdown_behavior: ShutdownBehavior,
) -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().times(1).returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_shutdown_behavior(shutdown_behavior)
.build();
let shutdown_token = stream.shutdown_token();
response_tx.send(Ok(test_response(0..1))).await?;
response_tx
.send(Err(TonicStatus::failed_precondition("fail")))
.await?;
drop(response_tx);
let (m, h) = stream.next().await.expect("should yield a message")?;
assert_eq!(m.data, test_data(0));
h.ack();
let err = stream.next().await.expect("should yield an error");
assert!(err.is_err(), "{err:?}");
shutdown_token.wait_for_shutdown().await;
let ack_req = ack_rx.try_recv()?;
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(ack_req.ack_ids, test_ids(0..1));
Ok(())
}
#[test_case(ShutdownBehavior::NackImmediately)]
#[test_case(ShutdownBehavior::WaitForProcessing)]
#[tokio_test_no_panics(start_paused = true)]
async fn drop_cancels(shutdown_behavior: ShutdownBehavior) -> anyhow::Result<()> {
let (response_tx, response_rx) = channel(10);
let (ack_tx, mut ack_rx) = unbounded_channel();
let (nack_tx, mut nack_rx) = unbounded_channel();
let mut mock = MockSubscriber::new();
mock.expect_streaming_pull()
.return_once(|_| Ok(TonicResponse::from(response_rx)));
mock.expect_acknowledge().times(1).returning(move |r| {
ack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
mock.expect_modify_ack_deadline()
.times(1)
.returning(move |r| {
nack_tx
.send(r.into_inner())
.expect("sending on channel always succeeds");
Ok(TonicResponse::from(()))
});
let (endpoint, _server) = start("0.0.0.0:0", mock).await?;
let client = test_client(endpoint).await?;
let mut stream = client
.subscribe("projects/p/subscriptions/s")
.set_shutdown_behavior(shutdown_behavior)
.build();
let shutdown_token = stream.shutdown_token();
response_tx.send(Ok(test_response(1..10))).await?;
for i in 1..6 {
let Some((m, h)) = stream.next().await.transpose()? else {
anyhow::bail!("expected message {i}/5")
};
assert_eq!(m.data, test_data(i));
h.ack();
}
drop(stream); shutdown_token.wait_for_shutdown().await;
let ack_req = ack_rx.try_recv()?;
assert_eq!(ack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(sorted(ack_req.ack_ids), test_ids(1..6));
let nack_req = nack_rx.try_recv()?;
assert_eq!(nack_req.subscription, "projects/p/subscriptions/s");
assert_eq!(nack_req.ack_deadline_seconds, 0);
assert_eq!(sorted(nack_req.ack_ids), test_ids(6..10));
Ok(())
}
}