use google_cloud_gax::grpc::{Code, Status};
use google_cloud_gax::retry::RetrySetting;
use google_cloud_googleapis::pubsub::v1::{
AcknowledgeRequest, ModifyAckDeadlineRequest, PubsubMessage, StreamingPullRequest,
};
use std::ops::{Deref, DerefMut};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use crate::apiv1::default_retry_setting as base_retry_setting;
use crate::apiv1::subscriber_client::{create_empty_streaming_pull_request, SubscriberClient};
#[derive(Debug, Clone)]
pub struct ReceivedMessage {
pub message: PubsubMessage,
pub ack_id: String,
pub subscription: String,
pub subscriber_client: SubscriberClient,
pub delivery_attempt: Option<usize>,
}
impl ReceivedMessage {
pub(crate) fn new(
subscription: String,
subc: SubscriberClient,
message: PubsubMessage,
ack_id: String,
delivery_attempt: Option<usize>,
) -> Self {
Self {
message,
ack_id,
subscription,
subscriber_client: subc,
delivery_attempt,
}
}
pub fn ack_id(&self) -> &str {
self.ack_id.as_str()
}
pub async fn ack(&self) -> Result<(), Status> {
ack(
&self.subscriber_client,
self.subscription.to_string(),
vec![self.ack_id.to_string()],
)
.await
}
pub async fn nack(&self) -> Result<(), Status> {
nack(
&self.subscriber_client,
self.subscription.to_string(),
vec![self.ack_id.to_string()],
)
.await
}
pub async fn modify_ack_deadline(&self, ack_deadline_seconds: i32) -> Result<(), Status> {
modify_ack_deadline(
&self.subscriber_client,
self.subscription.to_string(),
vec![self.ack_id.to_string()],
ack_deadline_seconds,
)
.await
}
pub fn delivery_attempt(&self) -> Option<usize> {
self.delivery_attempt
}
}
fn default_retry_setting() -> RetrySetting {
let mut retry = base_retry_setting();
retry.codes.push(Code::Cancelled);
retry
}
#[derive(Debug, Clone)]
pub struct SubscriberConfig {
pub ping_interval: Duration,
pub retry_setting: Option<RetrySetting>,
pub stream_ack_deadline_seconds: i32,
pub max_outstanding_messages: i64,
pub max_outstanding_bytes: i64,
}
impl Default for SubscriberConfig {
fn default() -> Self {
Self {
ping_interval: Duration::from_secs(10),
retry_setting: Some(default_retry_setting()),
stream_ack_deadline_seconds: 60,
max_outstanding_messages: 50,
max_outstanding_bytes: 1000 * 1000 * 1000,
}
}
}
struct UnprocessedMessages {
tx: Option<oneshot::Sender<Option<Vec<String>>>>,
ack_ids: Option<Vec<String>>,
}
impl UnprocessedMessages {
fn new(tx: oneshot::Sender<Option<Vec<String>>>) -> Self {
Self {
tx: Some(tx),
ack_ids: Some(vec![]),
}
}
}
impl Deref for UnprocessedMessages {
type Target = Vec<String>;
fn deref(&self) -> &Self::Target {
self.ack_ids.as_ref().unwrap()
}
}
impl DerefMut for UnprocessedMessages {
fn deref_mut(&mut self) -> &mut Vec<String> {
self.ack_ids.as_mut().unwrap()
}
}
impl Drop for UnprocessedMessages {
fn drop(&mut self) {
if let Some(tx) = self.tx.take() {
let _ = tx.send(self.ack_ids.take());
}
}
}
#[derive(Debug)]
pub(crate) struct Subscriber {
client: SubscriberClient,
subscription: String,
task_to_receive: Option<JoinHandle<()>>,
unprocessed_messages_receiver: Option<oneshot::Receiver<Option<Vec<String>>>>,
}
impl Drop for Subscriber {
fn drop(&mut self) {
if let Some(task) = self.task_to_receive.take() {
task.abort();
}
let rx = match self.unprocessed_messages_receiver.take() {
None => return,
Some(rx) => rx,
};
let subscription = self.subscription.clone();
let client = self.client.clone();
tracing::warn!(
"Subscriber is not disposed. Call dispose() to properly clean up resources. subscription={}",
&subscription
);
let task = async move {
if let Ok(Some(messages)) = rx.await {
if messages.is_empty() {
return;
}
tracing::debug!("nack {} unprocessed messages", messages.len());
if let Err(err) = nack(&client, subscription, messages).await {
tracing::error!("failed to nack message: {:?}", err);
}
}
};
let _forget = tokio::spawn(task);
}
}
impl Subscriber {
pub fn spawn(
subscription: String,
client: SubscriberClient,
queue: mpsc::Sender<ReceivedMessage>,
config: SubscriberConfig,
) -> Self {
let subscription_clone = subscription.clone();
let client_clone = client.clone();
let (tx, rx) = oneshot::channel();
let task_to_receive = async move {
tracing::debug!("start subscriber: {}", subscription);
let retryable_codes = match &config.retry_setting {
Some(v) => v.codes.clone(),
None => default_retry_setting().codes,
};
let mut unprocessed_messages = UnprocessedMessages::new(tx);
loop {
let mut request = create_empty_streaming_pull_request();
request.subscription = subscription.to_string();
request.stream_ack_deadline_seconds = config.stream_ack_deadline_seconds;
request.max_outstanding_messages = config.max_outstanding_messages;
request.max_outstanding_bytes = config.max_outstanding_bytes;
let response = Self::receive(
client.clone(),
request,
config.clone(),
queue.clone(),
&mut unprocessed_messages,
)
.await;
if let Err(e) = response {
if retryable_codes.contains(&e.code()) {
tracing::trace!("refresh connection: subscriber will reconnect {:?} : {}", e, subscription);
continue;
} else {
tracing::error!("failed to receive message: subscriber will stop {:?} : {}", e, subscription);
break;
}
} else {
tracing::debug!("stopped to receive message: {}", subscription);
break;
}
}
tracing::debug!("stop subscriber: {}", subscription);
};
Self {
client: client_clone,
subscription: subscription_clone,
task_to_receive: Some(tokio::spawn(task_to_receive)),
unprocessed_messages_receiver: Some(rx),
}
}
async fn receive(
client: SubscriberClient,
request: StreamingPullRequest,
config: SubscriberConfig,
queue: mpsc::Sender<ReceivedMessage>,
unprocessed_messages: &mut Vec<String>,
) -> Result<(), Status> {
let subscription = request.subscription.to_string();
let response = client
.streaming_pull(request, config.ping_interval, config.retry_setting.clone())
.await?;
let mut stream = response.into_inner();
loop {
let message = stream.message().await?;
let messages = match message {
Some(m) => m.received_messages,
None => return Ok(()),
};
let mut msgs = Vec::with_capacity(messages.len());
for received_message in messages {
if let Some(message) = received_message.message {
let id = message.message_id.clone();
tracing::trace!("message received: msg_id={id}");
let msg = ReceivedMessage::new(
subscription.clone(),
client.clone(),
message,
received_message.ack_id.clone(),
(received_message.delivery_attempt > 0).then_some(received_message.delivery_attempt as usize),
);
unprocessed_messages.push(msg.ack_id.clone());
msgs.push(msg);
}
}
for msg in msgs.drain(..) {
let ack_id = msg.ack_id.clone();
if queue.send(msg).await.is_ok() {
unprocessed_messages.retain(|e| *e != ack_id);
} else {
break;
}
}
}
}
pub async fn dispose(mut self) -> usize {
if let Some(task) = self.task_to_receive.take() {
task.abort();
}
let mut count = 0;
let rx = match self.unprocessed_messages_receiver.take() {
None => return count,
Some(rx) => rx,
};
if let Ok(Some(messages)) = rx.await {
if messages.is_empty() {
return count;
}
let size = messages.len();
tracing::debug!("nack {} unprocessed messages", size);
let result = nack(&self.client, self.subscription.clone(), messages).await;
match result {
Ok(_) => count = size,
Err(err) => tracing::error!("failed to nack message: {:?}", err),
}
}
count
}
}
async fn modify_ack_deadline(
subscriber_client: &SubscriberClient,
subscription: String,
ack_ids: Vec<String>,
ack_deadline_seconds: i32,
) -> Result<(), Status> {
if ack_ids.is_empty() {
return Ok(());
}
let req = ModifyAckDeadlineRequest {
subscription,
ack_deadline_seconds,
ack_ids,
};
subscriber_client
.modify_ack_deadline(req, None)
.await
.map(|e| e.into_inner())
}
pub(crate) async fn nack(
subscriber_client: &SubscriberClient,
subscription: String,
ack_ids: Vec<String>,
) -> Result<(), Status> {
for chunk in ack_ids.chunks(100) {
modify_ack_deadline(subscriber_client, subscription.clone(), chunk.to_vec(), 0).await?;
}
Ok(())
}
pub(crate) async fn ack(
subscriber_client: &SubscriberClient,
subscription: String,
ack_ids: Vec<String>,
) -> Result<(), Status> {
if ack_ids.is_empty() {
return Ok(());
}
let req = AcknowledgeRequest { subscription, ack_ids };
subscriber_client.acknowledge(req, None).await.map(|e| e.into_inner())
}