#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::needless_continue)]
use std::net::SocketAddr;
use crate::core::types::MsgId;
use crate::io::mbuf::{Mbuf, MbufQueue};
use crate::msg::message::Msg;
use crate::msg::message::MsgParseResult;
pub const MAGIC: &[u8] = b"$2014$";
pub const VERSION_10: u8 = 1;
pub const CRLF: &[u8] = b"\r\n";
pub const HANDSHAKE_PLACEHOLDER_DATA: u8 = b'd';
pub const GOSSIP_PLACEHOLDER_DATA: u8 = b'a';
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub enum DynParseState {
#[default]
Start,
MagicString,
MsgId,
TypeId,
BitField,
Version,
SameDc,
Star,
DataLen,
Data,
SpacesBeforePayloadLen,
PayloadLen,
CrlfBeforeDone,
Done,
PostDone,
Unknown,
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
#[repr(u8)]
pub enum DmsgType {
#[default]
Unknown = 0,
Debug = 1,
ParseError = 2,
Req = 3,
ReqForward = 4,
Res = 5,
CryptoHandshake = 6,
GossipSyn = 7,
GossipSynReply = 8,
GossipAck = 9,
GossipDigestSyn = 10,
GossipDigestAck = 11,
GossipDigestAck2 = 12,
GossipShutdown = 13,
HandoffChunk = 14,
FtSearchReq = 15,
FtSearchRep = 16,
}
impl DmsgType {
#[must_use]
pub fn from_u8(v: u8) -> Option<Self> {
Some(match v {
0 => DmsgType::Unknown,
1 => DmsgType::Debug,
2 => DmsgType::ParseError,
3 => DmsgType::Req,
4 => DmsgType::ReqForward,
5 => DmsgType::Res,
6 => DmsgType::CryptoHandshake,
7 => DmsgType::GossipSyn,
8 => DmsgType::GossipSynReply,
9 => DmsgType::GossipAck,
10 => DmsgType::GossipDigestSyn,
11 => DmsgType::GossipDigestAck,
12 => DmsgType::GossipDigestAck2,
13 => DmsgType::GossipShutdown,
14 => DmsgType::HandoffChunk,
15 => DmsgType::FtSearchReq,
16 => DmsgType::FtSearchRep,
_ => return None,
})
}
#[must_use]
pub const fn as_u8(self) -> u8 {
self as u8
}
}
pub const DMSG_FLAG_ENCRYPTED: u8 = 0x1;
pub const DMSG_FLAG_COMPRESSED: u8 = 0x2;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct Dmsg {
pub id: MsgId,
pub ty: DmsgType,
pub flags: u8,
pub version: u8,
pub same_dc: bool,
pub source_address: Option<SocketAddr>,
pub mlen: u32,
pub data: Vec<u8>,
pub plen: u32,
pub payload: Vec<u8>,
}
impl Dmsg {
#[must_use]
pub fn new() -> Self {
Self {
id: 0,
ty: DmsgType::Unknown,
flags: 0,
version: VERSION_10,
same_dc: true,
source_address: None,
mlen: 0,
data: Vec::new(),
plen: 0,
payload: Vec::new(),
}
}
#[must_use]
pub fn is_encrypted(&self) -> bool {
self.flags & DMSG_FLAG_ENCRYPTED != 0
}
#[must_use]
pub fn is_compressed(&self) -> bool {
self.flags & DMSG_FLAG_COMPRESSED != 0
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum ParseStep {
NeedMore {
consumed: usize,
},
HeaderDone {
consumed: usize,
},
Error {
consumed: usize,
},
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum DnodeError {
OutOfSpace,
BadMagic,
BadNumber,
MissingCrlf,
BadType,
TruncatedData,
}
#[derive(Debug)]
pub struct DnodeParser {
state: DynParseState,
num: u64,
dmsg: Dmsg,
data_remaining: u32,
magic_progress: u8,
prev_was_digit: bool,
}
impl DnodeParser {
#[must_use]
pub fn new() -> Self {
Self {
state: DynParseState::Start,
num: 0,
dmsg: Dmsg::new(),
data_remaining: 0,
magic_progress: 0,
prev_was_digit: false,
}
}
pub fn reset(&mut self) {
*self = Self::new();
}
#[must_use]
pub fn state(&self) -> DynParseState {
self.state
}
#[must_use]
pub fn dmsg(&self) -> &Dmsg {
&self.dmsg
}
pub fn take_dmsg(&mut self) -> Dmsg {
let mut out = Dmsg::new();
std::mem::swap(&mut out, &mut self.dmsg);
self.state = DynParseState::Start;
self.num = 0;
self.data_remaining = 0;
self.magic_progress = 0;
self.prev_was_digit = false;
out
}
#[allow(clippy::too_many_lines)]
pub fn step(&mut self, input: &[u8]) -> ParseStep {
let mut idx = 0usize;
while idx < input.len() {
let ch = input[idx];
match self.state {
DynParseState::Start => {
if self.magic_progress == 0 {
if ch == b' ' {
idx += 1;
continue;
}
if ch != b'$' {
return ParseStep::Error { consumed: idx };
}
}
let want = MAGIC[usize::from(self.magic_progress)];
if ch != want {
return ParseStep::Error { consumed: idx };
}
self.magic_progress += 1;
idx += 1;
if usize::from(self.magic_progress) == MAGIC.len() {
self.state = DynParseState::MagicString;
self.magic_progress = 0;
}
continue;
}
DynParseState::MagicString => {
if ch == b' ' {
self.state = DynParseState::MsgId;
self.num = 0;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::MsgId => {
if ch.is_ascii_digit() {
self.num = self.num.wrapping_mul(10) + u64::from(ch - b'0');
self.prev_was_digit = true;
idx += 1;
continue;
}
if ch == b' ' && self.prev_was_digit {
self.dmsg.id = self.num;
self.state = DynParseState::TypeId;
self.num = 0;
self.prev_was_digit = false;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::TypeId => {
if ch.is_ascii_digit() {
self.num = self.num.wrapping_mul(10) + u64::from(ch - b'0');
self.prev_was_digit = true;
idx += 1;
continue;
}
if ch == b' ' && self.prev_was_digit {
self.dmsg.ty = match DmsgType::from_u8(self.num as u8) {
Some(t) => t,
None => return ParseStep::Error { consumed: idx },
};
self.state = DynParseState::BitField;
self.num = 0;
self.prev_was_digit = false;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::BitField => {
if ch.is_ascii_digit() {
self.num = self.num.wrapping_mul(10) + u64::from(ch - b'0');
self.prev_was_digit = true;
idx += 1;
continue;
}
if ch == b' ' && self.prev_was_digit {
self.dmsg.flags = (self.num as u8) & 0xF;
self.state = DynParseState::Version;
self.num = 0;
self.prev_was_digit = false;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::Version => {
if ch.is_ascii_digit() {
self.num = self.num.wrapping_mul(10) + u64::from(ch - b'0');
self.prev_was_digit = true;
idx += 1;
continue;
}
if ch == b' ' && self.prev_was_digit {
self.dmsg.version = self.num as u8;
self.state = DynParseState::SameDc;
self.num = 0;
self.prev_was_digit = false;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::SameDc => {
if ch.is_ascii_digit() {
self.dmsg.same_dc = ch != b'0';
self.prev_was_digit = true;
idx += 1;
continue;
}
if ch == b' ' && self.prev_was_digit {
self.state = DynParseState::DataLen;
self.num = 0;
self.prev_was_digit = false;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::Star | DynParseState::DataLen => {
if ch == b'*' {
idx += 1;
continue;
}
if ch.is_ascii_digit() {
self.num = self.num.wrapping_mul(10) + u64::from(ch - b'0');
idx += 1;
continue;
}
if ch == b' ' && self.state == DynParseState::DataLen {
self.dmsg.mlen = self.num as u32;
self.data_remaining = self.dmsg.mlen;
self.dmsg.data.clear();
self.dmsg.data.reserve(self.data_remaining as usize);
self.state = DynParseState::Data;
self.num = 0;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::Data => {
if self.data_remaining == 0 {
self.state = DynParseState::SpacesBeforePayloadLen;
continue;
}
let take = std::cmp::min(self.data_remaining as usize, input.len() - idx);
self.dmsg.data.extend_from_slice(&input[idx..idx + take]);
self.data_remaining -= take as u32;
idx += take;
if self.data_remaining == 0 {
self.state = DynParseState::SpacesBeforePayloadLen;
}
continue;
}
DynParseState::SpacesBeforePayloadLen => {
if ch == b' ' {
idx += 1;
continue;
}
if ch == b'*' {
self.state = DynParseState::PayloadLen;
self.num = 0;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::PayloadLen => {
if ch.is_ascii_digit() {
self.num = self.num.wrapping_mul(10) + u64::from(ch - b'0');
idx += 1;
continue;
}
if ch == b'\r' {
self.dmsg.plen = self.num as u32;
self.state = DynParseState::CrlfBeforeDone;
self.num = 0;
idx += 1;
continue;
}
return ParseStep::Error { consumed: idx };
}
DynParseState::CrlfBeforeDone => {
if ch == b'\n' {
self.state = DynParseState::Done;
idx += 1;
return ParseStep::HeaderDone { consumed: idx };
}
return ParseStep::Error { consumed: idx };
}
DynParseState::Done | DynParseState::PostDone | DynParseState::Unknown => {
return ParseStep::HeaderDone { consumed: idx };
}
}
}
ParseStep::NeedMore { consumed: idx }
}
}
impl Default for DnodeParser {
fn default() -> Self {
Self::new()
}
}
pub fn dmsg_write(
mbuf: &mut Mbuf,
msg_id: MsgId,
ty: DmsgType,
flags: u8,
same_dc: bool,
aes_key_payload: Option<&[u8]>,
plen: u32,
) -> Result<(), DnodeError> {
let header = build_header(msg_id, ty, flags, same_dc, aes_key_payload, plen, false);
write_chain(mbuf, &header)
}
pub fn dmsg_write_mbuf(
mbuf: &mut Mbuf,
msg_id: MsgId,
ty: DmsgType,
flags: u8,
same_dc: bool,
aes_key_payload: Option<&[u8]>,
plen: u32,
) -> Result<(), DnodeError> {
let header = build_header(msg_id, ty, flags, same_dc, aes_key_payload, plen, true);
write_chain(mbuf, &header)
}
fn build_header(
msg_id: MsgId,
ty: DmsgType,
flags: u8,
same_dc: bool,
aes_key_payload: Option<&[u8]>,
plen: u32,
gossip_placeholder: bool,
) -> Vec<u8> {
use std::io::Write as _;
let mut buf: Vec<u8> = Vec::with_capacity(64);
buf.extend_from_slice(b" $2014$ ");
let _ = write!(buf, "{msg_id}");
buf.push(b' ');
let _ = write!(buf, "{}", ty.as_u8());
buf.push(b' ');
let _ = write!(buf, "{}", flags & 0xF);
buf.push(b' ');
let _ = write!(buf, "{VERSION_10}");
buf.push(b' ');
buf.push(if same_dc { b'1' } else { b'0' });
buf.push(b' ');
buf.push(b'*');
if let Some(payload) = aes_key_payload {
let _ = write!(buf, "{}", payload.len());
buf.push(b' ');
buf.extend_from_slice(payload);
} else {
buf.extend_from_slice(b"1 ");
buf.push(if gossip_placeholder {
GOSSIP_PLACEHOLDER_DATA
} else {
HANDSHAKE_PLACEHOLDER_DATA
});
}
buf.push(b' ');
buf.push(b'*');
let _ = write!(buf, "{plen}");
buf.extend_from_slice(CRLF);
buf
}
fn write_chain(mbuf: &mut Mbuf, payload: &[u8]) -> Result<(), DnodeError> {
if mbuf.remaining() < payload.len() {
return Err(DnodeError::OutOfSpace);
}
let n = mbuf.recv(payload);
debug_assert_eq!(n, payload.len());
Ok(())
}
pub fn parse_req(msg: &mut Msg) -> MsgParseResult {
parse_msg(msg, false)
}
pub fn parse_rsp(msg: &mut Msg) -> MsgParseResult {
parse_msg(msg, true)
}
fn parse_msg(msg: &mut Msg, _is_response: bool) -> MsgParseResult {
let mut bytes: Vec<u8> = Vec::with_capacity(msg.mbufs().total_len());
for mbuf in msg.mbufs() {
bytes.extend_from_slice(mbuf.readable());
}
let mut parser = DnodeParser::new();
parser.state = msg.dyn_parse_state();
match parser.step(&bytes) {
ParseStep::HeaderDone { .. } => {
let dmsg = parser.take_dmsg();
msg.set_dyn_parse_state(DynParseState::Done);
msg.set_dmsg(dmsg);
msg.set_parse_result(MsgParseResult::Ok);
MsgParseResult::Ok
}
ParseStep::NeedMore { .. } => {
msg.set_dyn_parse_state(parser.state);
msg.set_parse_result(MsgParseResult::Again);
MsgParseResult::Again
}
ParseStep::Error { .. } => {
msg.set_dyn_parse_state(DynParseState::Unknown);
msg.set_parse_result(MsgParseResult::Error);
MsgParseResult::Error
}
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum DmsgDispatch {
Bypass,
Forward,
}
#[must_use]
pub fn dmsg_process(dmsg: &Dmsg) -> DmsgDispatch {
match dmsg.ty {
DmsgType::CryptoHandshake
| DmsgType::GossipSyn
| DmsgType::GossipSynReply
| DmsgType::HandoffChunk
| DmsgType::FtSearchReq
| DmsgType::FtSearchRep => DmsgDispatch::Bypass,
_ => DmsgDispatch::Forward,
}
}
pub fn flatten_chain(chain: &mut MbufQueue) -> Vec<u8> {
let mut out = Vec::with_capacity(chain.total_len());
while let Some(buf) = chain.pop_front() {
out.extend_from_slice(buf.readable());
}
out
}
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct Handshake {
capabilities: crate::cluster::capability::CapabilityAd,
}
impl Handshake {
pub const MAGIC: [u8; 4] = *b"DHS1";
#[must_use]
pub fn new(capabilities: crate::cluster::capability::CapabilityAd) -> Self {
Self { capabilities }
}
#[must_use]
pub fn capabilities(&self) -> &crate::cluster::capability::CapabilityAd {
&self.capabilities
}
#[must_use]
pub fn into_capabilities(self) -> crate::cluster::capability::CapabilityAd {
self.capabilities
}
#[must_use]
pub fn encode(&self) -> Vec<u8> {
let cap_bytes = self.capabilities.encode();
let mut out = Vec::with_capacity(Self::MAGIC.len() + 2 + cap_bytes.len());
out.extend_from_slice(&Self::MAGIC);
out.extend_from_slice(&0u16.to_le_bytes()); out.extend_from_slice(&cap_bytes);
out
}
pub fn decode(bytes: &[u8]) -> Result<Self, crate::cluster::capability::CapabilityCodecError> {
use crate::cluster::capability::CapabilityCodecError;
if bytes.len() < Self::MAGIC.len() + 2 {
return Err(CapabilityCodecError::Truncated);
}
if bytes[..Self::MAGIC.len()] != Self::MAGIC {
return Err(CapabilityCodecError::BadMagic);
}
let flags_off = Self::MAGIC.len();
let flags = u16::from_le_bytes([bytes[flags_off], bytes[flags_off + 1]]);
if flags != 0 {
return Err(CapabilityCodecError::BadMagic);
}
let cap_bytes = &bytes[flags_off + 2..];
let capabilities = crate::cluster::capability::CapabilityAd::decode(cap_bytes)?;
Ok(Self { capabilities })
}
#[must_use]
pub const fn header_len() -> usize {
Self::MAGIC.len() + 2
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::mbuf::MbufPool;
#[test]
fn parse_simple_req() {
let mut p = DnodeParser::new();
let bytes = b"$2014$ 1 3 0 1 1 *1 d *0\r\n";
match p.step(bytes) {
ParseStep::HeaderDone { consumed } => assert_eq!(consumed, bytes.len()),
other => panic!("unexpected: {other:?}"),
}
let d = p.take_dmsg();
assert_eq!(d.id, 1);
assert_eq!(d.ty, DmsgType::Req);
assert_eq!(d.flags, 0);
assert_eq!(d.version, 1);
assert!(d.same_dc);
assert_eq!(d.mlen, 1);
assert_eq!(d.data, b"d");
assert_eq!(d.plen, 0);
}
#[test]
fn parse_payload_len() {
let mut p = DnodeParser::new();
let bytes = b"$2014$ 2 3 0 1 1 *1 d *413\r\n";
match p.step(bytes) {
ParseStep::HeaderDone { consumed } => assert_eq!(consumed, bytes.len()),
other => panic!("unexpected: {other:?}"),
}
assert_eq!(p.dmsg().plen, 413);
}
#[test]
fn parse_three_back_to_back() {
let mut input: Vec<u8> = Vec::new();
input.extend_from_slice(b"$2014$ 1 3 0 1 1 *1 d *0\r\n");
input.extend_from_slice(b"some redis bytes here ignored");
input.extend_from_slice(b"$2014$ 2 3 0 1 1 *1 d *3\r\nABC");
input.extend_from_slice(b"$2014$ 3 3 0 1 1 *1 d *0\r\n");
let mut p = DnodeParser::new();
let mut idx = 0;
let mut count = 0;
while idx < input.len() {
match p.step(&input[idx..]) {
ParseStep::HeaderDone { consumed } => {
let d = p.take_dmsg();
count += 1;
let after_header = idx + consumed;
if count == 1 {
assert_eq!(d.id, 1);
idx = input[after_header..]
.iter()
.position(|&b| b == b'$')
.map_or(input.len(), |n| after_header + n);
} else if count == 2 {
assert_eq!(d.id, 2);
assert_eq!(d.plen, 3);
idx = after_header + d.plen as usize;
} else {
assert_eq!(d.id, 3);
idx = after_header;
}
p.reset();
}
ParseStep::NeedMore { .. } => {
break;
}
ParseStep::Error { consumed } => {
idx += consumed.max(1);
p.reset();
}
}
}
assert_eq!(count, 3);
}
#[test]
fn need_more_when_truncated() {
let mut p = DnodeParser::new();
let prefix = b"$2014$ 1 3 0 1 1 *1 d *";
match p.step(prefix) {
ParseStep::NeedMore { consumed } => assert_eq!(consumed, prefix.len()),
other => panic!("unexpected: {other:?}"),
}
let suffix = b"42\r\n";
match p.step(suffix) {
ParseStep::HeaderDone { consumed } => assert_eq!(consumed, suffix.len()),
other => panic!("unexpected: {other:?}"),
}
assert_eq!(p.take_dmsg().plen, 42);
}
#[test]
fn parse_error_on_garbage_prefix() {
let mut p = DnodeParser::new();
match p.step(b"!nope") {
ParseStep::Error { consumed } => assert_eq!(consumed, 0),
other => panic!("unexpected: {other:?}"),
}
}
#[test]
fn writer_round_trip_unencrypted() {
let pool = MbufPool::default();
let mut buf = pool.get();
dmsg_write(&mut buf, 42, DmsgType::Req, 0, true, None, 0).unwrap();
let bytes = buf.readable().to_vec();
let mut p = DnodeParser::new();
let step = p.step(&bytes);
assert!(matches!(step, ParseStep::HeaderDone { .. }));
let d = p.take_dmsg();
assert_eq!(d.id, 42);
assert_eq!(d.ty, DmsgType::Req);
assert_eq!(d.flags, 0);
assert!(d.same_dc);
assert_eq!(d.mlen, 1);
assert_eq!(d.data, b"d");
assert_eq!(d.plen, 0);
}
#[test]
fn writer_round_trip_with_aes_payload() {
let pool = MbufPool::default();
let mut buf = pool.get();
let payload = vec![0xAB; 128];
dmsg_write(
&mut buf,
7,
DmsgType::CryptoHandshake,
DMSG_FLAG_ENCRYPTED,
false,
Some(&payload),
512,
)
.unwrap();
let bytes = buf.readable().to_vec();
let mut p = DnodeParser::new();
match p.step(&bytes) {
ParseStep::HeaderDone { consumed } => assert_eq!(consumed, bytes.len()),
other => panic!("unexpected: {other:?}"),
}
let d = p.take_dmsg();
assert_eq!(d.id, 7);
assert_eq!(d.ty, DmsgType::CryptoHandshake);
assert!(d.is_encrypted());
assert!(!d.same_dc);
assert_eq!(d.data, payload);
assert_eq!(d.plen, 512);
}
#[test]
fn dispatcher_classifies_control_plane() {
let mut d = Dmsg::new();
for ty in [
DmsgType::CryptoHandshake,
DmsgType::GossipSyn,
DmsgType::GossipSynReply,
] {
d.ty = ty;
assert_eq!(dmsg_process(&d), DmsgDispatch::Bypass);
}
for ty in [
DmsgType::GossipAck,
DmsgType::GossipDigestSyn,
DmsgType::GossipDigestAck,
DmsgType::GossipDigestAck2,
DmsgType::GossipShutdown,
DmsgType::Req,
DmsgType::ReqForward,
DmsgType::Res,
] {
d.ty = ty;
assert_eq!(dmsg_process(&d), DmsgDispatch::Forward);
}
d.ty = DmsgType::HandoffChunk;
assert_eq!(dmsg_process(&d), DmsgDispatch::Bypass);
for ty in [DmsgType::FtSearchReq, DmsgType::FtSearchRep] {
d.ty = ty;
assert_eq!(dmsg_process(&d), DmsgDispatch::Bypass);
}
}
}