use super::keepalive;
use super::retry_policy::StreamRetryPolicy;
use super::stub::{Stub, TonicStreaming};
use crate::RequestOptions;
use crate::google::pubsub::v1::{StreamingPullRequest, StreamingPullResponse};
use crate::{Error, Result};
use gaxi::grpc::tonic::Result as TonicResult;
use google_cloud_gax::backoff_policy::BackoffPolicy;
use google_cloud_gax::exponential_backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
use google_cloud_gax::retry_loop_internal::retry_loop;
use google_cloud_gax::retry_throttler::CircuitBreaker;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_util::sync::{CancellationToken, DropGuard};
pub(super) const INITIAL_DELAY: Duration = Duration::from_millis(100);
pub(super) const MAXIMUM_DELAY: Duration = Duration::from_secs(60);
#[derive(Debug)]
pub(super) struct Stream<T>
where
T: Stub,
{
_keepalive_guard: DropGuard,
pub(super) stream: <T as Stub>::Stream,
}
impl<T> TonicStreaming for Stream<T>
where
T: Stub + 'static,
<T as Stub>::Stream: TonicStreaming,
{
async fn next_message(&mut self) -> TonicResult<Option<StreamingPullResponse>> {
self.stream.next_message().await
}
}
impl<T> Stream<T>
where
T: Stub,
{
pub(super) async fn new(inner: Arc<T>, initial_req: StreamingPullRequest) -> Result<Self> {
Self::new_with_backoff(inner, initial_req, default_backoff_policy()).await
}
async fn new_with_backoff(
inner: Arc<T>,
initial_req: StreamingPullRequest,
backoff: Arc<dyn BackoffPolicy>,
) -> Result<Self> {
let sleep = async |d| tokio::time::sleep(d).await;
let attempt = move |_| {
let inner = inner.clone();
let initial_req = initial_req.clone();
async move { open_stream(inner, initial_req).await }
};
retry_loop(
attempt,
sleep,
true,
default_retry_throttler(),
default_retry_policy(),
backoff,
)
.await
}
}
async fn open_stream<T>(inner: Arc<T>, initial_req: StreamingPullRequest) -> Result<Stream<T>>
where
T: Stub,
{
let (request_tx, request_rx) = mpsc::channel(1);
let request_params = format!("subscription={}", initial_req.subscription);
request_tx.send(initial_req).await.map_err(Error::io)?;
let shutdown = CancellationToken::new();
keepalive::spawn(request_tx, shutdown.clone());
let stream = inner
.streaming_pull(&request_params, request_rx, RequestOptions::default())
.await?
.into_inner();
Ok(Stream {
_keepalive_guard: shutdown.drop_guard(),
stream,
})
}
fn default_retry_policy() -> Arc<StreamRetryPolicy> {
Arc::new(StreamRetryPolicy)
}
fn default_retry_throttler() -> Arc<Mutex<CircuitBreaker>> {
Arc::new(Mutex::new(
CircuitBreaker::new(1000, 0, 0).expect("This is a valid configuration"),
))
}
fn default_backoff_policy() -> Arc<ExponentialBackoff> {
Arc::new(
ExponentialBackoffBuilder::new()
.with_initial_delay(INITIAL_DELAY)
.with_maximum_delay(MAXIMUM_DELAY)
.with_scaling(4)
.build()
.expect("This is a valid configuration"),
)
}
#[cfg(test)]
mod tests {
use super::super::keepalive::KEEPALIVE_PERIOD;
use super::super::lease_state::tests::test_ids;
use super::super::stub::tests::MockStub;
use super::*;
use crate::google::pubsub::v1::{ReceivedMessage, StreamingPullResponse};
use gaxi::grpc::tonic::Response as TonicResponse;
use google_cloud_gax::backoff_policy::BackoffPolicy;
use google_cloud_gax::error::rpc::{Code, Status};
use google_cloud_gax::retry_state::RetryState;
use google_cloud_test_macros::tokio_test_no_panics;
mockall::mock! {
#[derive(Debug)]
BackoffPolicy {}
impl BackoffPolicy for BackoffPolicy {
fn on_failure(&self, state: &RetryState) -> Duration;
}
}
fn transient_error() -> Error {
Error::service(
Status::default()
.set_code(Code::Unavailable)
.set_message("try again"),
)
}
fn permanent_error() -> Error {
Error::service(
Status::default()
.set_code(Code::FailedPrecondition)
.set_message("fail"),
)
}
fn test_response(range: std::ops::Range<i32>) -> StreamingPullResponse {
StreamingPullResponse {
received_messages: test_ids(range)
.into_iter()
.map(|ack_id| ReceivedMessage {
ack_id,
..Default::default()
})
.collect(),
..Default::default()
}
}
fn initial_request() -> StreamingPullRequest {
StreamingPullRequest {
subscription: "projects/my-project/subscriptions/my-subscription".to_string(),
stream_ack_deadline_seconds: 10,
..Default::default()
}
}
fn keepalive_request() -> StreamingPullRequest {
StreamingPullRequest::default()
}
#[tokio_test_no_panics(start_paused = true)]
async fn success() -> anyhow::Result<()> {
let (response_tx, response_rx) = mpsc::channel(10);
let mut mock = MockStub::new();
mock.expect_streaming_pull()
.withf(|s, _, _| s == "subscription=projects/my-project/subscriptions/my-subscription")
.times(1)
.return_once(move |_s, _r, _o| Ok(TonicResponse::from(response_rx)));
response_tx.send(Ok(test_response(1..10))).await?;
response_tx.send(Ok(test_response(11..20))).await?;
response_tx.send(Ok(test_response(21..30))).await?;
drop(response_tx);
let mut stream = open_stream(Arc::new(mock), initial_request()).await?;
assert_eq!(stream.next_message().await?, Some(test_response(1..10)));
assert_eq!(stream.next_message().await?, Some(test_response(11..20)));
assert_eq!(stream.next_message().await?, Some(test_response(21..30)));
assert_eq!(stream.next_message().await?, None);
Ok(())
}
#[tokio_test_no_panics(start_paused = true)]
async fn keepalives() -> anyhow::Result<()> {
let (response_tx, response_rx) = mpsc::channel(10);
let (recover_writes_tx, mut recover_writes_rx) = mpsc::channel(1);
let mut mock = MockStub::new();
mock.expect_streaming_pull()
.withf(|s, _, _| s == "subscription=projects/my-project/subscriptions/my-subscription")
.times(1)
.return_once(move |_s, mut request_rx, _o| {
tokio::spawn(async move {
while let Some(request) = request_rx.recv().await {
recover_writes_tx
.send(request)
.await
.expect("forwarding writes always succeeds");
}
});
Ok(TonicResponse::from(response_rx))
});
let mut stream = open_stream(Arc::new(mock), initial_request()).await?;
assert_eq!(recover_writes_rx.recv().await, Some(initial_request()));
tokio::time::advance(KEEPALIVE_PERIOD).await;
assert_eq!(recover_writes_rx.recv().await, Some(keepalive_request()));
response_tx.send(Ok(test_response(1..10))).await?;
assert_eq!(stream.next_message().await?, Some(test_response(1..10)));
drop(stream);
assert_eq!(recover_writes_rx.recv().await, None);
Ok(())
}
#[tokio_test_no_panics]
async fn error() -> anyhow::Result<()> {
let mut mock = MockStub::new();
mock.expect_streaming_pull()
.withf(|s, _, _| s == "subscription=projects/my-project/subscriptions/my-subscription")
.times(1)
.return_once(|_, _, _| Err(Error::io("fail")));
let err = open_stream(Arc::new(mock), initial_request())
.await
.expect_err("open_stream should fail");
assert!(err.is_io(), "{err:?}");
Ok(())
}
#[tokio_test_no_panics]
async fn retry_then_success() -> anyhow::Result<()> {
let mut seq = mockall::Sequence::new();
let mut mock_stub = MockStub::new();
let mut mock_backoff = MockBackoffPolicy::new();
for attempt in 1..20 {
mock_stub
.expect_streaming_pull()
.withf(|s, _, _| {
s == "subscription=projects/my-project/subscriptions/my-subscription"
})
.times(1)
.in_sequence(&mut seq)
.return_once(|_, _, _| Err(transient_error()));
mock_backoff
.expect_on_failure()
.times(1)
.withf(move |s| s.attempt_count == attempt)
.in_sequence(&mut seq)
.return_const(Duration::ZERO);
}
let (response_tx, response_rx) = mpsc::channel(10);
response_tx.send(Ok(test_response(1..10))).await?;
drop(response_tx);
mock_stub
.expect_streaming_pull()
.withf(|s, _, _| s == "subscription=projects/my-project/subscriptions/my-subscription")
.times(1)
.in_sequence(&mut seq)
.return_once(move |_s, _r, _o| Ok(TonicResponse::from(response_rx)));
let mut stream = Stream::new_with_backoff(
Arc::new(mock_stub),
initial_request(),
Arc::new(mock_backoff),
)
.await?;
assert_eq!(stream.next_message().await?, Some(test_response(1..10)));
assert_eq!(stream.next_message().await?, None);
Ok(())
}
#[tokio_test_no_panics]
async fn retry_then_permanent_failure() -> anyhow::Result<()> {
let mut seq = mockall::Sequence::new();
let mut mock_stub = MockStub::new();
let mut mock_backoff = MockBackoffPolicy::new();
for attempt in 1..20 {
mock_stub
.expect_streaming_pull()
.withf(|s, _, _| {
s == "subscription=projects/my-project/subscriptions/my-subscription"
})
.times(1)
.in_sequence(&mut seq)
.return_once(|_, _, _| Err(transient_error()));
mock_backoff
.expect_on_failure()
.times(1)
.withf(move |s| s.attempt_count == attempt)
.in_sequence(&mut seq)
.return_const(Duration::ZERO);
}
mock_stub
.expect_streaming_pull()
.withf(|s, _, _| s == "subscription=projects/my-project/subscriptions/my-subscription")
.times(1)
.in_sequence(&mut seq)
.return_once(|_, _, _| Err(permanent_error()));
mock_backoff
.expect_on_failure()
.times(1)
.in_sequence(&mut seq)
.return_const(Duration::ZERO);
let err = Stream::new_with_backoff(
Arc::new(mock_stub),
initial_request(),
Arc::new(mock_backoff),
)
.await
.expect_err("opening stream should fail");
assert!(err.status().is_some(), "{err:?}");
let status = err.status().unwrap();
assert_eq!(
status.code,
google_cloud_gax::error::rpc::Code::FailedPrecondition
);
assert_eq!(status.message, "fail");
Ok(())
}
}