use crate::{Headers, ValidatedMessage};
use async_trait::async_trait;
use futures_util::stream;
use pin_project::pin_project;
use std::{
borrow::Cow,
fmt::Display,
ops::Bound,
pin::Pin,
str::FromStr,
task::{Context, Poll},
time::{Duration, SystemTime},
};
use tracing::debug;
use uuid::Uuid;
use ya_gcp::{
grpc::{Body, BoxBody, Bytes, DefaultGrpcImpl, GrpcService, StdError},
pubsub,
};
use super::{
retry_policy, AcknowledgeError, BoxError, ModifyAcknowledgeError, PubSubError,
StreamSubscriptionConfig, TopicName,
};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct SubscriptionName<'s>(Cow<'s, str>);
impl<'s> SubscriptionName<'s> {
pub fn new(subscription: impl Into<Cow<'s, str>>) -> Self {
Self(subscription.into())
}
pub fn with_cross_project(
project: impl Into<Cow<'s, str>>,
subscription: impl Into<Cow<'s, str>>,
) -> Self {
Self(format!("{}-{}", project.into(), subscription.into()).into())
}
fn into_project_subscription_name(
self,
project_name: impl Display,
queue_name: impl Display,
) -> pubsub::ProjectSubscriptionName {
pubsub::ProjectSubscriptionName::new(
project_name,
std::format_args!(
"hedwig-{queue}-{subscription}",
queue = queue_name,
subscription = self.0
),
)
}
}
#[derive(Debug, Clone)]
pub struct ConsumerClient<S = DefaultGrpcImpl> {
client: pubsub::SubscriberClient<S>,
project: String,
queue: String,
}
impl<S> ConsumerClient<S> {
pub fn from_client(
client: pubsub::SubscriberClient<S>,
project: String,
queue: String,
) -> Self {
ConsumerClient {
client,
project,
queue,
}
}
fn project(&self) -> &str {
&self.project
}
fn queue(&self) -> &str {
&self.queue
}
pub fn format_subscription(
&self,
subscription: SubscriptionName<'_>,
) -> pubsub::ProjectSubscriptionName {
subscription.into_project_subscription_name(self.project(), self.queue())
}
pub fn format_topic(&self, topic: TopicName<'_>) -> pubsub::ProjectTopicName {
topic.into_project_topic_name(self.project())
}
pub fn inner(&self) -> &pubsub::SubscriberClient<S> {
&self.client
}
pub fn inner_mut(&mut self) -> &mut pubsub::SubscriberClient<S> {
&mut self.client
}
}
impl<S> ConsumerClient<S>
where
S: GrpcService<BoxBody>,
S::Error: Into<StdError>,
S::ResponseBody: Body<Data = Bytes> + Send + 'static,
<S::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub async fn create_subscription(
&mut self,
config: SubscriptionConfig<'_>,
) -> Result<(), PubSubError> {
let subscription = SubscriptionConfig::into_subscription(config, &*self);
self.client
.raw_api_mut()
.create_subscription(subscription)
.await?;
Ok(())
}
pub async fn delete_subscription(
&mut self,
subscription: SubscriptionName<'_>,
) -> Result<(), PubSubError> {
let subscription = self.format_subscription(subscription).into();
self.client
.raw_api_mut()
.delete_subscription({
let mut r = pubsub::api::DeleteSubscriptionRequest::default();
r.subscription = subscription;
r
})
.await?;
Ok(())
}
pub fn stream_subscription(
&mut self,
subscription: SubscriptionName<'_>,
stream_config: StreamSubscriptionConfig,
) -> PubSubStream<S>
where
S: Clone,
{
let subscription = self.format_subscription(subscription);
PubSubStream(self.client.stream_subscription(subscription, stream_config))
}
pub async fn seek(
&mut self,
subscription: SubscriptionName<'_>,
timestamp: pubsub::api::Timestamp,
) -> Result<(), PubSubError> {
let request = {
let mut r = pubsub::api::SeekRequest::default();
r.subscription = self.format_subscription(subscription).into();
r.target = Some(pubsub::api::seek_request::Target::Time(timestamp));
r
};
self.client.raw_api_mut().seek(request).await?;
Ok(())
}
}
match_fields! {
pubsub::api::Subscription =>
#[derive(Debug, Clone)]
pub struct SubscriptionConfig<'s> {
pub name: SubscriptionName<'s>,
pub topic: TopicName<'s>,
pub ack_deadline_seconds: u16,
pub retain_acked_messages: bool,
pub message_retention_duration: Option<pubsub::api::Duration>,
pub labels: std::collections::HashMap<String, String>,
pub enable_message_ordering: bool,
pub expiration_policy: Option<pubsub::api::ExpirationPolicy>,
pub filter: String,
pub dead_letter_policy: Option<pubsub::api::DeadLetterPolicy>,
pub retry_policy: Option<pubsub::api::RetryPolicy>,
@except:
push_config,
detached,
topic_message_retention_duration,
bigquery_config,
cloud_storage_config,
enable_exactly_once_delivery,
}
}
impl SubscriptionConfig<'_> {
fn into_subscription<C>(self, client: &ConsumerClient<C>) -> pubsub::api::Subscription {
let mut sub = pubsub::api::Subscription::default();
sub.name = client.format_subscription(self.name).into();
sub.topic = client.format_topic(self.topic).into();
sub.ack_deadline_seconds = self.ack_deadline_seconds.into();
sub.retain_acked_messages = self.retain_acked_messages;
sub.message_retention_duration = self.message_retention_duration;
sub.labels = self.labels;
sub.enable_message_ordering = self.enable_message_ordering;
sub.expiration_policy = self.expiration_policy;
sub.filter = self.filter;
sub.dead_letter_policy = self.dead_letter_policy;
sub.retry_policy = self.retry_policy;
sub.push_config = None; sub.detached = false; sub.topic_message_retention_duration = None;
sub
}
}
impl Default for SubscriptionConfig<'_> {
fn default() -> Self {
Self {
name: SubscriptionName::new(String::new()),
topic: TopicName::new(String::new()),
ack_deadline_seconds: 0,
retain_acked_messages: false,
message_retention_duration: None,
labels: std::collections::HashMap::default(),
enable_message_ordering: false,
expiration_policy: None,
filter: "".into(),
dead_letter_policy: None,
retry_policy: None,
}
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "google")))]
pub type PubSubMessage<T> = crate::consumer::AcknowledgeableMessage<pubsub::AcknowledgeToken, T>;
#[derive(Debug, thiserror::Error)]
#[cfg_attr(docsrs, doc(cfg(feature = "google")))]
pub enum PubSubStreamError {
#[error(transparent)]
Stream(#[from] PubSubError),
#[error("missing expected attribute: {key}")]
MissingAttribute {
key: &'static str,
},
#[error("invalid attribute value for {key}: {invalid_value}")]
InvalidAttribute {
key: &'static str,
invalid_value: String,
#[source]
source: BoxError,
},
}
#[async_trait]
impl crate::consumer::AcknowledgeToken for pubsub::AcknowledgeToken {
type AckError = AcknowledgeError;
type ModifyError = ModifyAcknowledgeError;
type NackError = AcknowledgeError;
async fn ack(self) -> Result<(), Self::AckError> {
self.ack().await
}
async fn nack(self) -> Result<(), Self::NackError> {
self.nack().await
}
async fn modify_deadline(&mut self, seconds: u32) -> Result<(), Self::ModifyError> {
self.modify_deadline(seconds).await
}
}
#[pin_project]
#[cfg_attr(docsrs, doc(cfg(feature = "google")))]
pub struct PubSubStream<
S = DefaultGrpcImpl,
R = retry_policy::ExponentialBackoff<pubsub::PubSubRetryCheck>,
>(#[pin] pubsub::StreamSubscription<S, R>);
impl<S, OldR> PubSubStream<S, OldR> {
pub fn with_retry_policy<R>(self, retry_policy: R) -> PubSubStream<S, R>
where
R: retry_policy::RetryPolicy<(), PubSubError>,
{
PubSubStream(self.0.with_retry_policy(retry_policy))
}
}
impl<S, R> stream::Stream for PubSubStream<S, R>
where
S: GrpcService<BoxBody> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<StdError>,
S::ResponseBody: Body<Data = Bytes> + Send + 'static,
<S::ResponseBody as Body>::Error: Into<StdError> + Send,
R: retry_policy::RetryPolicy<(), PubSubError> + Send + 'static,
R::RetryOp: Send + 'static,
<R::RetryOp as retry_policy::RetryOperation<(), PubSubError>>::Sleep: Send + 'static,
{
type Item = Result<PubSubMessage<ValidatedMessage>, PubSubStreamError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.project().0.poll_next(cx).map(|opt| {
opt.map(|res| {
let (ack_token, message) = res?;
Ok(PubSubMessage {
ack_token,
message: pubsub_to_hedwig(message)?,
})
})
})
}
}
impl<S, R> crate::consumer::Consumer for PubSubStream<S, R>
where
S: GrpcService<BoxBody> + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<StdError>,
S::ResponseBody: Body<Data = Bytes> + Send + 'static,
<S::ResponseBody as Body>::Error: Into<StdError> + Send,
R: retry_policy::RetryPolicy<(), PubSubError> + Send + 'static,
R::RetryOp: Send + 'static,
<R::RetryOp as retry_policy::RetryOperation<(), PubSubError>>::Sleep: Send + 'static,
{
type AckToken = pubsub::AcknowledgeToken;
type Error = PubSubStreamError;
type Stream = PubSubStream<S, R>;
fn stream(self) -> Self::Stream {
self
}
}
const HEDWIG_NAME_RANGE: (Bound<&str>, Bound<&str>) =
(Bound::Included("hedwig_"), Bound::Excluded("hedwig`"));
fn pubsub_to_hedwig(
msg: pubsub::api::PubsubMessage,
) -> Result<ValidatedMessage, PubSubStreamError> {
let mut headers = msg.attributes;
fn take_attr<F, T>(
map: &mut Headers,
key: &'static str,
parse: F,
) -> Result<T, PubSubStreamError>
where
F: FnOnce(String) -> Result<T, (String, BoxError)>,
{
let value = map
.remove(key)
.ok_or(PubSubStreamError::MissingAttribute { key })?;
parse(value).map_err(
|(invalid_value, source)| PubSubStreamError::InvalidAttribute {
key,
invalid_value,
source,
},
)
}
let id = take_attr(&mut headers, crate::HEDWIG_ID, |string| {
Uuid::from_str(&string).map_err(|e| (string, BoxError::from(e)))
})?;
let timestamp = take_attr(&mut headers, crate::HEDWIG_MESSAGE_TIMESTAMP, |string| {
let millis_since_epoch = match u64::from_str(&string) {
Err(err) => return Err((string, BoxError::from(err))),
Ok(t) => t,
};
SystemTime::UNIX_EPOCH
.checked_add(Duration::from_millis(millis_since_epoch))
.ok_or_else(|| {
(
string,
BoxError::from(format!(
"time stamp {} is too large for SystemTime",
millis_since_epoch
)),
)
})
})?;
let schema = take_attr(&mut headers, crate::HEDWIG_SCHEMA, Ok::<String, _>)?;
take_attr(&mut headers, crate::HEDWIG_PUBLISHER, |_| Ok(()))?;
take_attr(&mut headers, crate::HEDWIG_FORMAT_VERSION, |_| Ok(()))?;
headers
.range::<str, _>(HEDWIG_NAME_RANGE)
.map(|(k, _v)| k.clone()) .collect::<Vec<_>>()
.into_iter()
.for_each(|k| {
debug!(message = "removing unknown hedwig attribute", key = &k[..]);
headers.remove(&k);
});
Ok(ValidatedMessage::new(
id, timestamp, schema, headers, msg.data,
))
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
HEDWIG_FORMAT_VERSION, HEDWIG_ID, HEDWIG_MESSAGE_TIMESTAMP, HEDWIG_PUBLISHER, HEDWIG_SCHEMA,
};
use pubsub::api::PubsubMessage;
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
struct EqValidatedMessage(ValidatedMessage);
impl std::ops::Deref for EqValidatedMessage {
type Target = ValidatedMessage;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl PartialEq<ValidatedMessage> for EqValidatedMessage {
fn eq(&self, other: &ValidatedMessage) -> bool {
self.uuid() == other.uuid()
&& self.timestamp() == other.timestamp()
&& self.schema() == other.schema()
&& self.headers() == other.headers()
&& self.data() == other.data()
}
}
macro_rules! string_btree {
($($key:expr => $val:expr),* $(,)?) => {
{
#[allow(unused_mut)]
let mut map = BTreeMap::new();
$(
map.insert(($key).to_string(), ($val).to_string());
)*
map
}
}
}
#[test]
fn headers_parsed() {
let user_attrs = string_btree! {
"aaa" => "aaa_value",
"zzz" => "zzz_value",
"some_longer_string" => "the value for the longer string",
};
let hedwig_attrs = string_btree! {
HEDWIG_ID => Uuid::nil(),
HEDWIG_MESSAGE_TIMESTAMP => 1000,
HEDWIG_SCHEMA => "my-test-schema",
HEDWIG_PUBLISHER => "my-test-publisher",
HEDWIG_FORMAT_VERSION => "1",
};
let data = "foobar";
let mut attributes = user_attrs.clone();
attributes.extend(hedwig_attrs);
let message = {
let mut m = PubsubMessage::default();
m.data = data.into();
m.attributes = attributes;
m.message_id = String::from("some_unique_id");
m.publish_time = Some(pubsub::api::Timestamp {
seconds: 15,
nanos: 42,
});
m.ordering_key = String::new();
m
};
let validated_message = pubsub_to_hedwig(message).unwrap();
assert_eq!(
EqValidatedMessage(ValidatedMessage::new(
Uuid::nil(),
SystemTime::UNIX_EPOCH + Duration::from_millis(1000),
"my-test-schema",
user_attrs,
data
)),
validated_message
);
}
#[test]
fn headers_error_on_missing() {
let full_hedwig_attrs = string_btree! {
HEDWIG_ID => Uuid::nil(),
HEDWIG_MESSAGE_TIMESTAMP => 1000,
HEDWIG_SCHEMA => "my-test-schema",
HEDWIG_PUBLISHER => "my-test-publisher",
HEDWIG_FORMAT_VERSION => "1",
};
for &missing_header in [
HEDWIG_ID,
HEDWIG_MESSAGE_TIMESTAMP,
HEDWIG_SCHEMA,
HEDWIG_PUBLISHER,
HEDWIG_FORMAT_VERSION,
]
.iter()
{
let mut attributes = full_hedwig_attrs.clone();
attributes.remove(missing_header);
let res = pubsub_to_hedwig({
let mut m = PubsubMessage::default();
m.attributes = attributes;
m
});
match res {
Err(PubSubStreamError::MissingAttribute { key }) => assert_eq!(key, missing_header),
_ => panic!(
"result did not fail on missing attribute {}: {:?}",
missing_header, res
),
}
}
}
#[test]
fn forward_compat_headers_removed() {
let hedwig_attrs = string_btree! {
HEDWIG_ID => Uuid::nil(),
HEDWIG_MESSAGE_TIMESTAMP => 1000,
HEDWIG_SCHEMA => "my-test-schema",
HEDWIG_PUBLISHER => "my-test-publisher",
HEDWIG_FORMAT_VERSION => "1",
"hedwig_some_new_flag" => "boom!",
"hedwig_another_change_from_the_future" => "kablam!",
};
let user_attrs = string_btree! {
"abc" => "123",
"foo" => "bar",
"aaaaaaaaaaaaaaaaaaaaaaaaa" => "bbbbbbbbbbbbbbbbbbbb",
"hedwig-key-but-with-hyphens" => "assumes the restricted format always uses underscores",
"hedwigAsAPrefixToSomeString" => "camelCase",
};
let mut attributes = user_attrs.clone();
attributes.extend(hedwig_attrs);
let validated_message = pubsub_to_hedwig({
let mut m = PubsubMessage::default();
m.attributes = attributes;
m
})
.unwrap();
assert_eq!(&user_attrs, validated_message.headers());
}
#[test]
fn project_subscription_name() {
let subscription_name =
SubscriptionName::with_cross_project("other_project", "my_subscription");
assert_eq!(
String::from(
subscription_name.into_project_subscription_name("my_project", "some_queue")
),
"projects/my_project/subscriptions/hedwig-some_queue-other_project-my_subscription"
);
}
}