use crate::codec::CodecState;
use crate::frame::{Frame, MessageType};
#[cfg(feature = "valkey")]
use crate::frame::{ValkeyFrame, valkey::valkey_query_type};
#[cfg(feature = "cassandra")]
use crate::frame::{cassandra, cassandra::CassandraMetadata};
use anyhow::{Context, Result, anyhow};
use bytes::Bytes;
use derivative::Derivative;
use fnv::FnvBuildHasher;
use nonzero_ext::nonzero;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::num::NonZeroU32;
use std::time::Instant;
pub type MessageIdMap<T> = HashMap<MessageId, T, FnvBuildHasher>;
pub type MessageIdSet = HashSet<MessageId, FnvBuildHasher>;
pub enum Metadata {
#[cfg(feature = "cassandra")]
Cassandra(CassandraMetadata),
#[cfg(feature = "valkey")]
Valkey,
#[cfg(feature = "kafka")]
Kafka,
#[cfg(feature = "opensearch")]
OpenSearch,
}
impl Metadata {
pub fn to_error_response(&self, error: String) -> Result<Message> {
#[allow(unreachable_code)]
Ok(Message::from_frame(match self {
#[cfg(feature = "valkey")]
Metadata::Valkey => {
let message = format!("ERR {error}")
.replace("\r\n", " ")
.replace('\n', " ");
Frame::Valkey(ValkeyFrame::Error(message.into()))
}
#[cfg(feature = "cassandra")]
Metadata::Cassandra(meta) => Frame::Cassandra(meta.to_error_response(error)),
#[cfg(feature = "kafka")]
Metadata::Kafka => return Err(anyhow!(error).context(
"A generic error cannot be formed because the kafka protocol does not support it",
)),
#[cfg(feature = "opensearch")]
Metadata::OpenSearch => unimplemented!(),
}))
}
}
pub type Messages = Vec<Message>;
pub type MessageId = u128;
#[derive(Derivative, Debug, Clone)]
#[derivative(PartialEq)]
pub struct Message {
inner: Option<MessageInner>,
#[derivative(PartialEq = "ignore")]
pub(crate) received_from_source_or_sink_at: Option<Instant>,
pub(crate) codec_state: CodecState,
#[derivative(PartialEq = "ignore")]
pub(crate) id: MessageId,
#[derivative(PartialEq = "ignore")]
pub(crate) request_id: Option<MessageId>,
}
impl Message {
pub fn from_bytes_at_instant(
bytes: Bytes,
codec_state: CodecState,
received_from_source_or_sink_at: Option<Instant>,
) -> Self {
Message {
inner: Some(MessageInner::RawBytes {
bytes,
message_type: MessageType::from(&codec_state),
}),
codec_state,
received_from_source_or_sink_at,
id: rand::random(),
request_id: None,
}
}
pub fn from_bytes_and_frame_at_instant(
bytes: Bytes,
frame: Frame,
received_from_source_or_sink_at: Option<Instant>,
) -> Self {
Message {
codec_state: frame.as_codec_state(),
inner: Some(MessageInner::Parsed {
bytes,
frame: Box::new(frame),
}),
received_from_source_or_sink_at,
id: rand::random(),
request_id: None,
}
}
pub fn from_frame_at_instant(
frame: Frame,
received_from_source_or_sink_at: Option<Instant>,
) -> Self {
Message {
codec_state: frame.as_codec_state(),
inner: Some(MessageInner::Modified {
frame: Box::new(frame),
}),
received_from_source_or_sink_at,
id: rand::random(),
request_id: None,
}
}
pub fn from_frame_diverged(frame: Frame, diverged_from: &Message) -> Self {
Message {
codec_state: frame.as_codec_state(),
inner: Some(MessageInner::Modified {
frame: Box::new(frame),
}),
received_from_source_or_sink_at: diverged_from.received_from_source_or_sink_at,
id: diverged_from.id(),
request_id: None,
}
}
pub fn from_bytes(bytes: Bytes, codec_state: CodecState) -> Self {
Self::from_bytes_at_instant(bytes, codec_state, None)
}
pub fn from_frame(frame: Frame) -> Self {
Self::from_frame_at_instant(frame, None)
}
}
impl Message {
pub fn frame(&mut self) -> Option<&mut Frame> {
let (inner, result) = self.inner.take().unwrap().ensure_parsed(self.codec_state);
self.inner = Some(inner);
if let Err(err) = result {
tracing::error!("{:?}", err.context("Failed to parse frame"));
return None;
}
match self.inner.as_mut().unwrap() {
MessageInner::RawBytes { .. } => {
unreachable!("Cannot be RawBytes because ensure_parsed was called")
}
MessageInner::Parsed { frame, .. } => Some(frame),
MessageInner::Modified { frame } => Some(frame),
}
}
pub fn into_frame(mut self) -> Option<Box<Frame>> {
let (inner, result) = self.inner.take().unwrap().ensure_parsed(self.codec_state);
if let Err(err) = result {
tracing::error!("{:?}", err.context("Failed to parse frame"));
return None;
}
match inner {
MessageInner::RawBytes { .. } => {
unreachable!("Cannot be RawBytes because ensure_parsed was called")
}
MessageInner::Parsed { frame, .. } => Some(frame),
MessageInner::Modified { frame } => Some(frame),
}
}
pub fn id(&self) -> MessageId {
self.id
}
pub fn request_id(&self) -> Option<MessageId> {
self.request_id
}
pub fn set_request_id(&mut self, request_id: MessageId) {
self.request_id = Some(request_id);
}
pub fn clone_with_new_id(&self) -> Self {
Message {
inner: self.inner.clone(),
received_from_source_or_sink_at: None,
codec_state: self.codec_state,
id: rand::random(),
request_id: self.request_id,
}
}
pub fn message_type(&self) -> MessageType {
match self.inner.as_ref().unwrap() {
MessageInner::RawBytes { message_type, .. } => *message_type,
MessageInner::Parsed { frame, .. } | MessageInner::Modified { frame } => {
frame.get_type()
}
}
}
pub fn ensure_message_type(&self, expected_message_type: MessageType) -> Result<()> {
match self.inner.as_ref().unwrap() {
MessageInner::RawBytes { message_type, .. } => {
if *message_type == expected_message_type || *message_type == MessageType::Dummy {
Ok(())
} else {
Err(anyhow!(
"Expected message of type {:?} but was of type {:?}",
expected_message_type,
message_type
))
}
}
MessageInner::Parsed { frame, .. } => {
let message_type = frame.get_type();
if message_type == expected_message_type || message_type == MessageType::Dummy {
Ok(())
} else {
Err(anyhow!(
"Expected message of type {:?} but was of type {:?}",
expected_message_type,
frame.name()
))
}
}
MessageInner::Modified { frame } => {
let message_type = frame.get_type();
if message_type == expected_message_type || message_type == MessageType::Dummy {
Ok(())
} else {
Err(anyhow!(
"Expected message of type {:?} but was of type {:?}",
expected_message_type,
frame.name()
))
}
}
}
}
pub fn into_encodable(self) -> Encodable {
match self.inner.unwrap() {
MessageInner::RawBytes { bytes, .. } => Encodable::Bytes(bytes),
MessageInner::Parsed { bytes, .. } => Encodable::Bytes(bytes),
MessageInner::Modified { frame } => match frame.as_ref() {
Frame::Dummy => Encodable::Bytes(Bytes::new()),
_ => Encodable::Frame(frame),
},
}
}
pub fn cell_count(&self) -> Result<NonZeroU32> {
Ok(match self.inner.as_ref().unwrap() {
MessageInner::RawBytes {
#[cfg(feature = "cassandra")]
bytes,
message_type,
..
} => match message_type {
#[cfg(feature = "valkey")]
MessageType::Valkey => nonzero!(1u32),
#[cfg(feature = "cassandra")]
MessageType::Cassandra => cassandra::raw_frame::cell_count(bytes)?,
#[cfg(feature = "kafka")]
MessageType::Kafka => todo!(),
MessageType::Dummy => nonzero!(1u32),
#[cfg(feature = "opensearch")]
MessageType::OpenSearch => todo!(),
},
MessageInner::Modified { frame } | MessageInner::Parsed { frame, .. } => {
match frame.as_ref() {
#[cfg(feature = "cassandra")]
Frame::Cassandra(frame) => frame.cell_count()?,
#[cfg(feature = "valkey")]
Frame::Valkey(_) => nonzero!(1u32),
#[cfg(feature = "kafka")]
Frame::Kafka(_) => todo!(),
Frame::Dummy => nonzero!(1u32),
#[cfg(feature = "opensearch")]
Frame::OpenSearch(_) => todo!(),
}
}
})
}
pub fn invalidate_cache(&mut self) {
self.inner = self.inner.take().map(|x| x.invalidate_cache());
}
pub fn get_query_type(&mut self) -> QueryType {
match self.frame() {
#[cfg(feature = "cassandra")]
Some(Frame::Cassandra(cassandra)) => cassandra.get_query_type(),
#[cfg(feature = "valkey")]
Some(Frame::Valkey(valkey)) => valkey_query_type(valkey), #[cfg(feature = "kafka")]
Some(Frame::Kafka(_)) => todo!(),
Some(Frame::Dummy) => todo!(),
#[cfg(feature = "opensearch")]
Some(Frame::OpenSearch(_)) => todo!(),
None => QueryType::ReadWrite,
}
}
pub fn from_response_to_error_response(&self, error: String) -> Result<Message> {
let mut response = self
.metadata()
.context("Failed to parse metadata of request or response when producing an error")?
.to_error_response(error)?;
if let Some(request_id) = self.request_id() {
response.set_request_id(request_id)
}
Ok(response)
}
pub fn from_request_to_error_response(&self, error: String) -> Result<Message> {
let mut request = self
.metadata()
.context("Failed to parse metadata of request or response when producing an error")?
.to_error_response(error)?;
request.set_request_id(self.id());
Ok(request)
}
pub fn metadata(&self) -> Result<Metadata> {
match self.inner.as_ref().unwrap() {
MessageInner::RawBytes {
#[cfg(feature = "cassandra")]
bytes,
message_type,
..
} => match message_type {
#[cfg(feature = "cassandra")]
MessageType::Cassandra => {
Ok(Metadata::Cassandra(cassandra::raw_frame::metadata(bytes)?))
}
#[cfg(feature = "valkey")]
MessageType::Valkey => Ok(Metadata::Valkey),
#[cfg(feature = "kafka")]
MessageType::Kafka => Ok(Metadata::Kafka),
MessageType::Dummy => Err(anyhow!("Dummy has no metadata")),
#[cfg(feature = "opensearch")]
MessageType::OpenSearch => Err(anyhow!("OpenSearch has no metadata")),
},
MessageInner::Parsed { frame, .. } | MessageInner::Modified { frame } => {
match frame.as_ref() {
#[cfg(feature = "cassandra")]
Frame::Cassandra(frame) => Ok(Metadata::Cassandra(frame.metadata())),
#[cfg(feature = "kafka")]
Frame::Kafka(_) => Ok(Metadata::Kafka),
#[cfg(feature = "valkey")]
Frame::Valkey(_) => Ok(Metadata::Valkey),
Frame::Dummy => Err(anyhow!("dummy has no metadata")),
#[cfg(feature = "opensearch")]
Frame::OpenSearch(_) => Err(anyhow!("OpenSearch has no metadata")),
}
}
}
}
pub fn replace_with_dummy(&mut self) {
self.inner = Some(MessageInner::Modified {
frame: Box::new(Frame::Dummy),
});
}
pub(crate) fn response_is_dummy(&mut self) -> bool {
match self.message_type() {
#[cfg(feature = "valkey")]
MessageType::Valkey => false,
#[cfg(feature = "cassandra")]
MessageType::Cassandra => false,
#[cfg(feature = "kafka")]
MessageType::Kafka => match self.frame() {
Some(Frame::Kafka(crate::frame::kafka::KafkaFrame::Request {
body: crate::frame::kafka::RequestBody::Produce(produce),
..
})) => produce.acks == 0,
_ => false,
},
#[cfg(feature = "opensearch")]
MessageType::OpenSearch => false,
MessageType::Dummy => true,
}
}
pub fn is_dummy(&self) -> bool {
if let Some(MessageInner::Modified { frame }) = &self.inner {
if let Frame::Dummy = frame.as_ref() {
return true;
}
}
false
}
pub fn to_backpressure(&mut self) -> Result<Message> {
let metadata = self.metadata()?;
Ok(Message::from_frame_at_instant(
match metadata {
#[cfg(feature = "cassandra")]
Metadata::Cassandra(metadata) => Frame::Cassandra(metadata.backpressure_response()),
#[cfg(feature = "valkey")]
Metadata::Valkey => unimplemented!(),
#[cfg(feature = "kafka")]
Metadata::Kafka => unimplemented!(),
#[cfg(feature = "opensearch")]
Metadata::OpenSearch => unimplemented!(),
},
#[allow(unreachable_code)]
self.received_from_source_or_sink_at,
))
}
pub(crate) fn stream_id(&self) -> Option<i16> {
match &self.inner {
#[cfg(feature = "cassandra")]
Some(MessageInner::RawBytes {
bytes,
message_type: MessageType::Cassandra,
}) => {
use bytes::Buf;
const HEADER_LEN: usize = 9;
if bytes.len() >= HEADER_LEN {
Some((&bytes[2..4]).get_i16())
} else {
None
}
}
Some(MessageInner::RawBytes { .. }) => None,
Some(MessageInner::Parsed { frame, .. } | MessageInner::Modified { frame }) => {
match frame.as_ref() {
#[cfg(feature = "cassandra")]
Frame::Cassandra(cassandra) => Some(cassandra.stream_id),
#[cfg(feature = "valkey")]
Frame::Valkey(_) => None,
#[cfg(feature = "kafka")]
Frame::Kafka(_) => None,
Frame::Dummy => None,
#[cfg(feature = "opensearch")]
Frame::OpenSearch(_) => None,
}
}
None => None,
}
}
pub fn to_high_level_string(&mut self) -> String {
if let Some(response) = self.frame() {
format!("{}", response)
} else if let Some(MessageInner::RawBytes {
bytes,
message_type,
}) = &self.inner
{
format!("Unparseable {message_type:?} message {bytes:?}")
} else {
unreachable!("self.frame() failed so MessageInner must still be RawBytes")
}
}
}
#[derive(PartialEq, Debug, Clone)]
enum MessageInner {
RawBytes {
bytes: Bytes,
message_type: MessageType,
},
Parsed {
bytes: Bytes,
frame: Box<Frame>,
},
Modified {
frame: Box<Frame>,
},
}
impl MessageInner {
fn ensure_parsed(self, codec_state: CodecState) -> (Self, Result<()>) {
match self {
MessageInner::RawBytes {
bytes,
message_type,
} => match Frame::from_bytes(bytes.clone(), message_type, codec_state) {
Ok(frame) => (MessageInner::Parsed { bytes, frame }, Ok(())),
Err(err) => (
MessageInner::RawBytes {
bytes,
message_type,
},
Err(err),
),
},
MessageInner::Parsed { .. } => (self, Ok(())),
MessageInner::Modified { .. } => (self, Ok(())),
}
}
fn invalidate_cache(self) -> Self {
match self {
MessageInner::RawBytes { .. } => {
tracing::error!("Invalidated cache but the frame was not parsed");
self
}
MessageInner::Parsed { frame, .. } => MessageInner::Modified { frame },
MessageInner::Modified { .. } => self,
}
}
}
#[derive(Debug)]
pub enum Encodable {
Bytes(Bytes),
Frame(Box<Frame>),
}
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub enum QueryType {
Read,
Write,
ReadWrite,
SchemaChange,
PubSubMessage,
}