use anyhow::Context;
use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine};
use bytes::{Buf, BufMut, BytesMut};
use rmp::Marker;
use tokio_util::codec::{Decoder, Encoder};
use tracing::trace;
use self::{request::EncodedRequest, response::Response};
use crate::{
errors::{CodecDecodeError, CodecEncodeError, DecodingError},
Error,
};
pub mod consts;
pub mod request;
pub mod response;
pub mod utils;
enum LengthDecoder {
NoMarker,
Marker(Marker),
Value(usize),
}
impl Default for LengthDecoder {
fn default() -> Self {
Self::NoMarker
}
}
impl LengthDecoder {
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<usize>, DecodingError> {
if src.is_empty() {
return Ok(None);
}
let marker = match self {
LengthDecoder::NoMarker => {
let marker = Marker::from_u8(src.get_u8());
*self = Self::Marker(marker);
trace!("decoded length marker: {:?}", marker);
marker
}
LengthDecoder::Marker(x) => *x,
LengthDecoder::Value(x) => return Ok(Some(*x)),
};
let length = match marker {
Marker::FixPos(x) => x as usize,
Marker::U8 => {
if src.len() > 1 {
src.get_u8() as usize
} else {
return Ok(None);
}
}
Marker::U16 => {
if src.len() > 2 {
src.get_u16() as usize
} else {
return Ok(None);
}
}
Marker::U32 => {
if src.len() > 4 {
src.get_u32() as usize
} else {
return Ok(None);
}
}
Marker::U64 => {
if src.len() > 8 {
src.get_u64() as usize
} else {
return Ok(None);
}
}
rest => {
return Err(DecodingError::type_mismatch(
"unsigned integer",
format!("{:?}", rest),
))
}
};
trace!("decoded frame length: {}", length);
*self = LengthDecoder::Value(length);
Ok(Some(length))
}
fn reset(&mut self) {
*self = LengthDecoder::NoMarker
}
}
#[derive(Default)]
pub(crate) struct ClientCodec {
length_decoder: LengthDecoder,
}
impl Decoder for ClientCodec {
type Item = Response;
type Error = CodecDecodeError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let Some(next_frame_length) = self
.length_decoder
.decode(src)
.map_err(CodecDecodeError::Decode)?
else {
return Ok(None);
};
if src.len() >= next_frame_length {
self.length_decoder.reset();
let frame_bytes = src.split_to(next_frame_length);
Response::decode(frame_bytes.reader())
.map(Some)
.map_err(CodecDecodeError::Decode)
} else {
src.reserve(next_frame_length - src.len());
Ok(None)
}
}
}
impl Encoder<EncodedRequest> for ClientCodec {
type Error = CodecEncodeError;
fn encode(&mut self, item: EncodedRequest, dst: &mut BytesMut) -> Result<(), Self::Error> {
let begin_idx = dst.len();
let mut writer = dst.writer();
rmp::encode::write_u64(&mut writer, 0)
.map_err(|err| CodecEncodeError::Encode(err.into()))?;
item.encode(&mut writer).map_err(CodecEncodeError::Encode)?;
let dst = writer.into_inner();
let data_len = dst.len() - begin_idx - 9;
let mut len_writer = dst[begin_idx..].writer();
rmp::encode::write_u64(&mut len_writer, data_len as u64)
.map_err(|err| CodecEncodeError::Encode(err.into()))?;
Ok(())
}
}
#[derive(Debug)]
pub struct Greeting {
pub server: String,
pub salt: Vec<u8>,
}
impl Greeting {
pub const SIZE: usize = 128;
pub fn decode(buffer: [u8; Self::SIZE]) -> Result<Self, Error> {
let line1 = &buffer[0..62];
let line2 = &buffer[64..126];
let salt_b64 = line2
.iter()
.enumerate()
.rev()
.find(|x| *x.1 != b' ')
.map_or(&b""[..], |(idx, _)| &line2[0..idx]);
let salt = STANDARD_NO_PAD
.decode(salt_b64)
.context("Failed to decode salt from base64")
.map_err(Error::Other)?;
Ok(Self {
server: String::from_utf8_lossy(line1).into_owned(),
salt,
})
}
}