use std::time::SystemTime;
use uuid::Uuid;
use crate::{Headers, ValidatedMessage};
#[derive(thiserror::Error, Debug)]
#[error("unable to encode the protobuf payload")]
#[cfg_attr(docsrs, doc(cfg(feature = "prost")))]
pub struct ProstValidatorError(#[source] prost::EncodeError);
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[cfg_attr(docsrs, doc(cfg(feature = "prost")))]
pub enum ProstDecodeError<E: std::error::Error + 'static> {
#[error("invalid schema for decoded message type")]
InvalidSchema(#[source] E),
#[error(transparent)]
Decode(#[from] prost::DecodeError),
}
#[derive(Default)]
struct UseNewToConstruct;
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "prost")))]
pub struct ProstValidator(UseNewToConstruct);
impl ProstValidator {
pub fn new() -> Self {
ProstValidator(UseNewToConstruct)
}
pub fn validate<M, S>(
&self,
id: Uuid,
timestamp: SystemTime,
schema: S,
headers: Headers,
data: &M,
) -> Result<ValidatedMessage, ProstValidatorError>
where
M: prost::Message,
S: Into<std::borrow::Cow<'static, str>>,
{
let mut bytes = bytes::BytesMut::new();
data.encode(&mut bytes).map_err(ProstValidatorError)?;
Ok(ValidatedMessage::new(id, timestamp, schema, headers, bytes))
}
}
pub struct ProstDecoder<S> {
schema_matcher: S,
}
impl<S> ProstDecoder<S> {
pub fn new(schema_matcher: S) -> Self {
Self { schema_matcher }
}
pub fn decode<M>(
&self,
msg: ValidatedMessage,
) -> Result<M, ProstDecodeError<S::InvalidSchemaError>>
where
S: SchemaMatcher<M>,
S::InvalidSchemaError: std::error::Error + 'static,
M: prost::Message + Default,
{
self.schema_matcher
.try_match_schema(msg.schema())
.map_err(ProstDecodeError::InvalidSchema)?;
Ok(M::decode(msg.into_data())?)
}
}
pub trait SchemaMatcher<MessageType> {
type InvalidSchemaError;
fn try_match_schema(&self, schema: &str) -> Result<(), Self::InvalidSchemaError>;
}
impl<T, F, E> SchemaMatcher<T> for F
where
F: Fn(&str) -> Result<(), E>,
{
type InvalidSchemaError = E;
fn try_match_schema(&self, schema: &str) -> Result<(), Self::InvalidSchemaError> {
(self)(schema)
}
}
#[derive(Debug, Clone, Eq, PartialEq, thiserror::Error)]
#[error("deserialized schema {encountered} does not match expected schema {expected} for type {message_type}")]
pub struct SchemaMismatchError {
expected: &'static str,
encountered: String,
message_type: &'static str,
}
impl SchemaMismatchError {
pub fn new<MessageType>(expected: &'static str, encountered: String) -> Self {
SchemaMismatchError {
expected,
encountered,
message_type: std::any::type_name::<MessageType>(),
}
}
}
pub struct ExactSchemaMatcher<T> {
expected_schema: &'static str,
_message_type: std::marker::PhantomData<fn(T)>, }
impl<T> ExactSchemaMatcher<T> {
pub fn new(expected_schema: &'static str) -> Self {
Self {
expected_schema,
_message_type: std::marker::PhantomData,
}
}
}
impl<T> SchemaMatcher<T> for ExactSchemaMatcher<T> {
type InvalidSchemaError = SchemaMismatchError;
fn try_match_schema(&self, schema: &str) -> Result<(), Self::InvalidSchemaError> {
if self.expected_schema == schema {
Ok(())
} else {
Err(SchemaMismatchError::new::<T>(
self.expected_schema,
schema.to_owned(),
))
}
}
}