use std::convert::TryFrom;
use std::io;
use std::io::Cursor;
use std::net::SocketAddr;
use tokio::io::AsyncReadExt;
use cassandra_protocol::compression::Compression;
use cassandra_protocol::error;
use cassandra_protocol::frame::message_response::ResponseBody;
use cassandra_protocol::frame::{
Direction, Envelope, Flags, Opcode, Version, LENGTH_LEN, STREAM_LEN,
};
use cassandra_protocol::types::data_serialization_types::decode_timeuuid;
use cassandra_protocol::types::{
from_cursor_string_list, try_i16_from_bytes, try_i32_from_bytes, UUID_LEN,
};
const MAX_ENVELOPE_BODY_SIZE: usize = 256 * 1024 * 1024;
async fn parse_raw_envelope<T: AsyncReadExt + Unpin>(
cursor: &mut T,
compressor: Compression,
) -> error::Result<Envelope> {
let mut version_bytes = [0; Version::BYTE_LENGTH];
let mut flag_bytes = [0; Flags::BYTE_LENGTH];
let mut opcode_bytes = [0; Opcode::BYTE_LENGTH];
let mut stream_bytes = [0; STREAM_LEN];
let mut length_bytes = [0; LENGTH_LEN];
cursor.read_exact(&mut version_bytes).await?;
cursor.read_exact(&mut flag_bytes).await?;
cursor.read_exact(&mut stream_bytes).await?;
cursor.read_exact(&mut opcode_bytes).await?;
cursor.read_exact(&mut length_bytes).await?;
let version = Version::try_from(version_bytes[0])?;
let direction = Direction::from(version_bytes[0]);
let flags = Flags::from_bits_truncate(flag_bytes[0]);
let stream_id = try_i16_from_bytes(&stream_bytes)?;
let opcode = Opcode::try_from(opcode_bytes[0])?;
let length_signed = try_i32_from_bytes(&length_bytes)?;
if length_signed < 0 {
return Err(error::Error::Io(io::Error::new(
io::ErrorKind::InvalidData,
format!("negative envelope body length {length_signed}"),
)));
}
let length = length_signed as usize;
if length > MAX_ENVELOPE_BODY_SIZE {
return Err(error::Error::Io(io::Error::new(
io::ErrorKind::InvalidData,
format!("envelope body length {length} exceeds maximum {MAX_ENVELOPE_BODY_SIZE}"),
)));
}
let mut body_bytes = vec![0; length];
cursor.read_exact(&mut body_bytes).await?;
let full_body = if flags.contains(Flags::COMPRESSION) {
compressor.decode(body_bytes)?
} else {
Compression::None.decode(body_bytes)?
};
let body_len = full_body.len();
let mut body_cursor = Cursor::new(full_body.as_slice());
let tracing_id = if flags.contains(Flags::TRACING) && direction == Direction::Response {
let mut tracing_bytes = [0; UUID_LEN];
std::io::Read::read_exact(&mut body_cursor, &mut tracing_bytes)?;
decode_timeuuid(&tracing_bytes).ok()
} else {
None
};
let warnings = if flags.contains(Flags::WARNING) {
from_cursor_string_list(&mut body_cursor)?
} else {
vec![]
};
let mut body = Vec::with_capacity(body_len - body_cursor.position() as usize);
std::io::Read::read_to_end(&mut body_cursor, &mut body)?;
let envelope = Envelope {
version,
direction,
flags,
opcode,
stream_id,
body,
tracing_id,
warnings,
};
Ok(envelope)
}
pub async fn parse_envelope<T: AsyncReadExt + Unpin>(
cursor: &mut T,
compressor: Compression,
addr: SocketAddr,
) -> error::Result<Envelope> {
let envelope = parse_raw_envelope(cursor, compressor).await?;
convert_envelope_into_result(envelope, addr)
}
pub(crate) fn convert_envelope_into_result(
envelope: Envelope,
addr: SocketAddr,
) -> error::Result<Envelope> {
match envelope.opcode {
Opcode::Error => envelope.response_body().and_then(|err| match err {
ResponseBody::Error(err) => Err(error::Error::Server { body: err, addr }),
other => Err(error::Error::General(format!(
"Expected ResponseBody::Error for Opcode::Error envelope, got {other:?}"
))),
}),
_ => Ok(envelope),
}
}
#[cfg(test)]
mod tests {
use super::*;
use cassandra_protocol::frame::Version;
fn header_with_length(length_bytes: [u8; 4]) -> Vec<u8> {
let mut buf = vec![
u8::from(Version::V4), 0, 0,
0, u8::from(Opcode::Ready), ];
buf.extend_from_slice(&length_bytes);
buf
}
#[tokio::test]
async fn parse_envelope_rejects_negative_body_length() {
let mut payload = header_with_length([0xff, 0xff, 0xff, 0xff]);
let mut cursor = std::io::Cursor::new(&mut payload);
let mut bytes = vec![];
tokio::io::AsyncReadExt::read_to_end(&mut cursor, &mut bytes)
.await
.unwrap();
let mut reader = bytes.as_slice();
assert!(parse_raw_envelope(&mut reader, Compression::None)
.await
.is_err());
}
#[tokio::test]
async fn parse_envelope_rejects_oversized_body_length() {
let payload = header_with_length(i32::MAX.to_be_bytes());
let mut reader = payload.as_slice();
assert!(parse_raw_envelope(&mut reader, Compression::None)
.await
.is_err());
}
#[tokio::test]
async fn parse_envelope_does_not_read_tracing_id_for_request_direction() {
let body: Vec<u8> = (0..16u8).collect();
let mut wire = vec![
u8::from(Version::V4),
Flags::TRACING.bits(),
0,
0,
u8::from(Opcode::Query),
];
wire.extend_from_slice(&(body.len() as i32).to_be_bytes());
wire.extend_from_slice(&body);
let mut reader = wire.as_slice();
let envelope = parse_raw_envelope(&mut reader, Compression::None)
.await
.expect("a request envelope with TRACING flag must still parse");
assert_eq!(envelope.direction, Direction::Request);
assert!(
envelope.tracing_id.is_none(),
"request envelopes must not carry a tracing UUID"
);
assert_eq!(
envelope.body, body,
"request body should be preserved verbatim, got {:?}",
envelope.body
);
}
}