use super::handler::AckResult;
use super::retry_policy::at_least_once_options;
use super::stub::Stub;
use crate::RequestOptions;
use crate::error::AckError;
use crate::model::{AcknowledgeRequest, ModifyAckDeadlineRequest};
use google_cloud_gax::exponential_backoff::ExponentialBackoff;
use google_cloud_gax::retry_loop_internal::retry_loop;
use google_cloud_gax::retry_policy::NeverRetry;
use google_cloud_gax::retry_throttler::CircuitBreaker;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::mpsc::UnboundedSender;
#[async_trait::async_trait]
pub(super) trait Leaser {
async fn ack(&self, ack_ids: Vec<String>);
async fn nack(&self, ack_ids: Vec<String>);
async fn extend(&self, ack_ids: Vec<String>);
async fn confirmed_ack(&self, ack_ids: Vec<String>);
async fn confirmed_nack(&self, ack_ids: Vec<String>);
}
pub(super) type ConfirmedAcks = HashMap<String, AckResult>;
pub(super) struct DefaultLeaser<T>
where
T: Stub + 'static,
{
inner: Arc<T>,
confirmed_tx: UnboundedSender<ConfirmedAcks>,
options: RequestOptions,
subscription: String,
ack_deadline_seconds: i32,
}
impl<T> Clone for DefaultLeaser<T>
where
T: Stub + 'static,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
confirmed_tx: self.confirmed_tx.clone(),
options: self.options.clone(),
subscription: self.subscription.clone(),
ack_deadline_seconds: self.ack_deadline_seconds,
}
}
}
impl<T> DefaultLeaser<T>
where
T: Stub + 'static,
{
pub(super) fn new(
inner: Arc<T>,
confirmed_tx: UnboundedSender<ConfirmedAcks>,
subscription: String,
ack_deadline_seconds: i32,
grpc_subchannel_count: usize,
) -> Self {
DefaultLeaser {
inner,
confirmed_tx,
options: at_least_once_options(grpc_subchannel_count),
subscription,
ack_deadline_seconds,
}
}
}
#[async_trait::async_trait]
impl<T> Leaser for DefaultLeaser<T>
where
T: Stub + 'static,
{
async fn ack(&self, ack_ids: Vec<String>) {
let req = AcknowledgeRequest::new()
.set_subscription(self.subscription.clone())
.set_ack_ids(ack_ids);
let _ = self.inner.acknowledge(req, self.options.clone()).await;
}
async fn nack(&self, ack_ids: Vec<String>) {
let req = ModifyAckDeadlineRequest::new()
.set_subscription(self.subscription.clone())
.set_ack_ids(ack_ids)
.set_ack_deadline_seconds(0);
let _ = self
.inner
.modify_ack_deadline(req, self.options.clone())
.await;
}
async fn extend(&self, ack_ids: Vec<String>) {
let req = ModifyAckDeadlineRequest::new()
.set_subscription(self.subscription.clone())
.set_ack_ids(ack_ids)
.set_ack_deadline_seconds(self.ack_deadline_seconds);
let _ = self
.inner
.modify_ack_deadline(req, self.options.clone())
.await;
}
async fn confirmed_ack(&self, ack_ids: Vec<String>) {
let leaser = self.clone();
let mut ack_ids = ack_ids;
let attempt = async move |_| {
let ids = std::mem::take(&mut ack_ids);
let ack_ids = leaser.confirmed_ack_attempt(ids).await;
if ack_ids.is_empty() {
Ok(())
} else {
Err(crate::Error::timeout("retry me"))
}
};
let sleep = async |d| tokio::time::sleep(d).await;
let _ = retry_loop(
attempt,
sleep,
true,
retry_throttler(&self.options),
retry_policy(),
backoff_policy(),
)
.await;
}
async fn confirmed_nack(&self, ack_ids: Vec<String>) {
let req = ModifyAckDeadlineRequest::new()
.set_subscription(self.subscription.clone())
.set_ack_ids(ack_ids.clone())
.set_ack_deadline_seconds(0);
let response = self
.inner
.modify_ack_deadline(req, self.options.clone())
.await;
let shared_result = response.map(|_| ()).map_err(Arc::new);
let confirmed_acks = ack_ids
.into_iter()
.map(|id| {
(
id,
shared_result
.clone()
.map_err(|source| AckError::Rpc { source }),
)
})
.collect();
let _ = self.confirmed_tx.send(confirmed_acks);
}
}
fn retry_policy() -> Arc<NeverRetry> {
Arc::new(NeverRetry)
}
fn backoff_policy() -> Arc<ExponentialBackoff> {
Arc::new(ExponentialBackoff::default())
}
fn retry_throttler(
options: &RequestOptions,
) -> google_cloud_gax::retry_throttler::SharedRetryThrottler {
options.retry_throttler().clone().unwrap_or_else(|| {
Arc::new(Mutex::new(
CircuitBreaker::new(1000, 0, 0).expect("This is a valid configuration"),
))
})
}
impl<T> DefaultLeaser<T>
where
T: Stub + 'static,
{
async fn confirmed_ack_attempt(&self, ack_ids: Vec<String>) -> Vec<String> {
let req = AcknowledgeRequest::new()
.set_subscription(self.subscription.clone())
.set_ack_ids(ack_ids.clone());
let response = self.inner.acknowledge(req, self.options.clone()).await;
let shared_result = response.map(|_| ()).map_err(Arc::new);
let confirmed_acks = ack_ids
.into_iter()
.map(|id| {
(
id,
shared_result
.clone()
.map_err(|source| AckError::Rpc { source }),
)
})
.collect();
let _ = self.confirmed_tx.send(confirmed_acks);
Vec::new()
}
}
#[cfg(test)]
pub(super) mod tests {
use super::super::lease_state::tests::{sorted, test_ids};
use super::super::retry_policy::tests::verify_policies;
use super::super::stub::tests::MockStub;
use super::*;
use crate::{Error, Response};
use google_cloud_gax::error::rpc::{Code, Status};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::mpsc::unbounded_channel;
mockall::mock! {
#[derive(Debug)]
pub(in super::super) Leaser {}
#[async_trait::async_trait]
impl Leaser for Leaser {
async fn ack(&self, ack_ids: Vec<String>);
async fn nack(&self, ack_ids: Vec<String>);
async fn extend(&self, ack_ids: Vec<String>);
async fn confirmed_ack(&self, ack_ids: Vec<String>);
async fn confirmed_nack(&self, ack_ids: Vec<String>);
}
}
#[async_trait::async_trait]
impl Leaser for Arc<MockLeaser> {
async fn ack(&self, ack_ids: Vec<String>) {
MockLeaser::ack(self, ack_ids).await
}
async fn nack(&self, ack_ids: Vec<String>) {
MockLeaser::nack(self, ack_ids).await
}
async fn extend(&self, ack_ids: Vec<String>) {
MockLeaser::extend(self, ack_ids).await
}
async fn confirmed_ack(&self, ack_ids: Vec<String>) {
MockLeaser::confirmed_ack(self, ack_ids).await
}
async fn confirmed_nack(&self, ack_ids: Vec<String>) {
MockLeaser::confirmed_nack(self, ack_ids).await
}
}
#[async_trait::async_trait]
impl Leaser for Arc<Mutex<MockLeaser>> {
async fn ack(&self, ack_ids: Vec<String>) {
self.lock().await.ack(ack_ids).await
}
async fn nack(&self, ack_ids: Vec<String>) {
self.lock().await.nack(ack_ids).await
}
async fn extend(&self, ack_ids: Vec<String>) {
self.lock().await.extend(ack_ids).await
}
async fn confirmed_ack(&self, ack_ids: Vec<String>) {
self.lock().await.confirmed_ack(ack_ids).await
}
async fn confirmed_nack(&self, ack_ids: Vec<String>) {
self.lock().await.confirmed_nack(ack_ids).await
}
}
#[test]
fn clone() {
let (confirmed_tx, _confirmed_rx) = unbounded_channel();
let leaser = DefaultLeaser::new(
Arc::new(MockStub::new()),
confirmed_tx,
"projects/my-project/subscriptions/my-subscription".to_string(),
10,
1_usize,
);
let clone = leaser.clone();
assert!(Arc::ptr_eq(&leaser.inner, &clone.inner));
assert!(leaser.confirmed_tx.same_channel(&clone.confirmed_tx));
assert_eq!(leaser.subscription, clone.subscription);
assert_eq!(leaser.ack_deadline_seconds, clone.ack_deadline_seconds);
}
#[tokio::test]
async fn ack() {
let (confirmed_tx, _confirmed_rx) = unbounded_channel();
let mut mock = MockStub::new();
mock.expect_acknowledge().times(1).return_once(|r, o| {
assert_eq!(
r.subscription,
"projects/my-project/subscriptions/my-subscription"
);
assert_eq!(r.ack_ids, test_ids(0..10));
verify_policies(o, 16);
Ok(Response::from(()))
});
let leaser = DefaultLeaser::new(
Arc::new(mock),
confirmed_tx,
"projects/my-project/subscriptions/my-subscription".to_string(),
10,
16_usize,
);
leaser.ack(test_ids(0..10)).await;
}
#[tokio::test]
async fn nack() {
let (confirmed_tx, _confirmed_rx) = unbounded_channel();
let mut mock = MockStub::new();
mock.expect_modify_ack_deadline()
.times(1)
.return_once(|r, o| {
assert_eq!(r.ack_deadline_seconds, 0);
assert_eq!(
r.subscription,
"projects/my-project/subscriptions/my-subscription"
);
assert_eq!(r.ack_ids, test_ids(0..10));
verify_policies(o, 16);
Ok(Response::from(()))
});
let leaser = DefaultLeaser::new(
Arc::new(mock),
confirmed_tx,
"projects/my-project/subscriptions/my-subscription".to_string(),
10,
16_usize,
);
leaser.nack(test_ids(0..10)).await;
}
#[tokio::test]
async fn extend() {
let (confirmed_tx, _confirmed_rx) = unbounded_channel();
let mut mock = MockStub::new();
mock.expect_modify_ack_deadline()
.times(1)
.return_once(|r, o| {
assert_eq!(r.ack_deadline_seconds, 10);
assert_eq!(
r.subscription,
"projects/my-project/subscriptions/my-subscription"
);
assert_eq!(r.ack_ids, test_ids(0..10));
verify_policies(o, 16);
Ok(Response::from(()))
});
let leaser = DefaultLeaser::new(
Arc::new(mock),
confirmed_tx,
"projects/my-project/subscriptions/my-subscription".to_string(),
10,
16_usize,
);
leaser.extend(test_ids(0..10)).await;
}
#[tokio::test]
async fn confirmed_ack_success() -> anyhow::Result<()> {
let (confirmed_tx, mut confirmed_rx) = unbounded_channel();
let mut mock = MockStub::new();
mock.expect_acknowledge().times(1).return_once(|r, o| {
assert_eq!(
r.subscription,
"projects/my-project/subscriptions/my-subscription"
);
assert_eq!(r.ack_ids, test_ids(0..10));
verify_policies(o, 16);
Ok(Response::from(()))
});
let leaser = DefaultLeaser::new(
Arc::new(mock),
confirmed_tx,
"projects/my-project/subscriptions/my-subscription".to_string(),
10,
16_usize,
);
leaser.confirmed_ack(test_ids(0..10)).await;
let confirmed_acks = confirmed_rx.recv().await.expect("results were not sent");
let ack_ids: Vec<_> = confirmed_acks.keys().cloned().collect();
assert_eq!(sorted(&ack_ids), test_ids(0..10));
for (ack_id, result) in &confirmed_acks {
assert!(
result.is_ok(),
"Expected success for {ack_id}, got {result:?}"
);
}
Ok(())
}
#[tokio::test]
async fn confirmed_ack_failure() -> anyhow::Result<()> {
let (confirmed_tx, mut confirmed_rx) = unbounded_channel();
let mut mock = MockStub::new();
mock.expect_acknowledge().times(1).return_once(|r, o| {
assert_eq!(
r.subscription,
"projects/my-project/subscriptions/my-subscription"
);
assert_eq!(r.ack_ids, test_ids(0..10));
verify_policies(o, 16);
Err(Error::service(
Status::default()
.set_code(Code::FailedPrecondition)
.set_message("fail"),
))
});
let leaser = DefaultLeaser::new(
Arc::new(mock),
confirmed_tx,
"projects/my-project/subscriptions/my-subscription".to_string(),
10,
16_usize,
);
leaser.confirmed_ack(test_ids(0..10)).await;
let confirmed_acks = confirmed_rx.recv().await.expect("results were not sent");
let ack_ids: Vec<_> = confirmed_acks.keys().cloned().collect();
assert_eq!(sorted(&ack_ids), test_ids(0..10));
for (ack_id, result) in &confirmed_acks {
match result {
Err(AckError::Rpc { source, .. }) => {
let status = source.status().expect("RPC source should have a status");
assert_eq!(status.code, Code::FailedPrecondition);
assert_eq!(status.message, "fail");
}
_ => panic!("Expected RPC error for {ack_id}, got {result:?}"),
}
}
Ok(())
}
}