use bytes::Bytes;
use redis_protocol::resp2::types::BytesFrame as Resp2Frame;
use redis_protocol::resp3::types::BytesFrame as Resp3Frame;
#[derive(Debug, PartialEq, Eq)]
pub enum Message {
SubConfirmation(usize),
UnSubConfirmation(usize),
Publish(Bytes, Bytes),
Unknown,
}
#[derive(Debug, PartialEq, Eq)]
pub enum DecodeError {
ProtocolViolation,
IntegerOverflow,
}
pub trait ToPushMessage {
fn decode_push(self) -> Result<Message, DecodeError>
where
Self: Sized,
{
Decoder::new(self).decode()
}
fn as_array(&self) -> Option<&[Self]>
where
Self: Sized;
fn clone_byte_string(&self, frame: &Self) -> Result<Bytes, DecodeError>;
fn get_number(&self, frame: &Self) -> Result<i64, DecodeError>;
}
impl ToPushMessage for Resp3Frame {
fn as_array(&self) -> Option<&[Self]> {
if let Resp3Frame::Push { data, attributes: _ } = self {
return Some(data);
}
None
}
fn clone_byte_string(&self, frame: &Self) -> Result<Bytes, DecodeError> {
match frame {
Resp3Frame::BlobString { data, attributes: _ }
| Resp3Frame::SimpleString { data, attributes: _ } => Ok(data.clone()),
_ => Err(DecodeError::ProtocolViolation),
}
}
fn get_number(&self, frame: &Self) -> Result<i64, DecodeError> {
if let Resp3Frame::Number { data, attributes: _ } = frame {
return Ok(*data);
}
Err(DecodeError::ProtocolViolation)
}
}
impl ToPushMessage for Resp2Frame {
fn as_array(&self) -> Option<&[Self]>
where
Self: Sized,
{
if let Resp2Frame::Array(data) = self {
return Some(data);
}
None
}
fn clone_byte_string(&self, frame: &Self) -> Result<Bytes, DecodeError> {
match frame {
Resp2Frame::SimpleString(string) | Resp2Frame::BulkString(string) => Ok(string.clone()),
_ => Err(DecodeError::ProtocolViolation),
}
}
fn get_number(&self, frame: &Self) -> Result<i64, DecodeError> {
if let Resp2Frame::Integer(number) = frame {
return Ok(*number);
}
Err(DecodeError::ProtocolViolation)
}
}
struct Decoder<F: ToPushMessage> {
frame: F,
}
impl<F: ToPushMessage> Decoder<F> {
pub fn new(frame: F) -> Self {
Self { frame }
}
pub fn decode(self) -> Result<Message, DecodeError> {
let data = match self.frame.as_array() {
None => return Ok(Message::Unknown),
Some(data) => data,
};
if data.len() < 3 {
return Err(DecodeError::ProtocolViolation);
}
match &self.frame.clone_byte_string(&data[0])?[..] {
b"message" => self.decode_message(data),
b"subscribe" => self.decode_subscribe(data),
b"unsubscribe" => self.decode_unsubscribe(data),
&_ => Ok(Message::Unknown),
}
}
fn decode_subscribe(&self, data: &[F]) -> Result<Message, DecodeError> {
let channel_count = self.frame.get_number(&data[2])?;
Ok(Message::SubConfirmation(self.cast_channel_count(channel_count)?))
}
fn decode_unsubscribe(&self, data: &[F]) -> Result<Message, DecodeError> {
let channel_count = self.frame.get_number(&data[2])?;
Ok(Message::UnSubConfirmation(
self.cast_channel_count(channel_count)?,
))
}
fn decode_message(&self, data: &[F]) -> Result<Message, DecodeError> {
Ok(Message::Publish(
self.frame.clone_byte_string(&data[1])?,
self.frame.clone_byte_string(&data[2])?,
))
}
fn cast_channel_count(&self, count: i64) -> Result<usize, DecodeError> {
if count.is_negative() {
return Err(DecodeError::ProtocolViolation);
}
usize::try_from(count).map_err(|_| DecodeError::IntegerOverflow)
}
}