use std::io::{self, Cursor, ErrorKind, Read, Write};
use crate::{Error, WireError};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use bytes::{BufMut, Bytes, BytesMut};
const FRAME_MAX_LEN: u64 = 1024 * 1024;
#[derive(Debug)]
pub struct Frame {
pub id: u64,
data: Bytes,
}
impl Frame {
pub fn decode_from<R: Read>(r: &mut R, buf: &mut BytesMut) -> io::Result<Self> {
let id = r.read_u64::<BigEndian>()?;
info!("decode id = {:?}", id);
let len = r.read_u64::<BigEndian>()? + 16;
info!("decode len = {:?}", len);
if len > FRAME_MAX_LEN {
let s = format!("decode too big frame length. len={len}");
error!("{s}");
return Err(io::Error::new(ErrorKind::InvalidInput, s));
}
let buf_len = len as usize;
buf.reserve(len as usize);
let data: &mut [u8] = unsafe { std::mem::transmute(buf.chunk_mut()) };
r.read_exact(&mut data[16..buf_len])?;
unsafe { buf.advance_mut(buf_len) };
let mut data = buf.split_to(buf_len);
unsafe { data.set_len(0) };
data.put_u64(id);
data.put_u64(len - 16);
unsafe { data.set_len(buf_len) };
let data = data.freeze();
Ok(Frame { id, data })
}
pub fn decode_req(&self) -> &[u8] {
&self.data[16..]
}
pub fn decode_rsp(&self) -> Result<&[u8], Error> {
use Error::*;
let mut r = Cursor::new(&self.data[..]);
r.set_position(16);
let ty = r.read_u8()?;
let len = r.read_u64::<BigEndian>()? as usize;
let buf = r.into_inner();
let data = &buf[25..len + 25];
match ty {
0 => Ok(data),
1 => Err(ServerDeserialize(String::from_utf8(data.into()).unwrap())),
2 => Err(ServerSerialize(String::from_utf8(data.into()).unwrap())),
3 => Err(Status(String::from_utf8(data.into()).unwrap())),
_ => {
let s = format!("invalid response type. ty={ty}");
error!("{s}");
Err(ClientDeserialize(s))
}
}
}
}
pub struct ReqBuf(Cursor<Vec<u8>>);
impl Default for ReqBuf {
fn default() -> Self {
ReqBuf::new()
}
}
impl ReqBuf {
pub fn new() -> Self {
let mut buf = Vec::with_capacity(128);
buf.resize(16, 0);
let mut cursor = Cursor::new(buf);
cursor.set_position(16);
ReqBuf(cursor)
}
pub fn finish(self, id: u64) -> Vec<u8> {
let mut cursor = self.0;
let len = cursor.get_ref().len() as u64;
assert!(len <= FRAME_MAX_LEN);
cursor.set_position(0);
cursor.write_u64::<BigEndian>(id).unwrap();
info!("encode id = {:?}", id);
cursor.write_u64::<BigEndian>(len - 16).unwrap();
info!("encode len = {:?}", len);
cursor.into_inner()
}
}
impl Write for ReqBuf {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
pub struct RspBuf(Cursor<Vec<u8>>);
impl Default for RspBuf {
fn default() -> Self {
RspBuf::new()
}
}
pub const SERVER_POLL_ENCODE: u8 = 200;
impl RspBuf {
pub fn new() -> Self {
let mut buf = Vec::with_capacity(128);
buf.resize(25, 0);
let mut cursor = Cursor::new(buf);
cursor.set_position(25);
RspBuf(cursor)
}
pub fn finish(self, id: u64, ret: Result<(), WireError>) -> Vec<u8> {
let mut cursor = self.0;
let dummy = Vec::new();
let (ty, len, data) = match ret {
Ok(_) => (0, cursor.get_ref().len() - 25, dummy.as_slice()),
Err(ref e) => match *e {
WireError::ServerDeserialize(ref s) => (1, s.len(), s.as_bytes()),
WireError::ServerSerialize(ref s) => (2, s.len(), s.as_bytes()),
WireError::Status(ref s) => (3, s.len(), s.as_bytes()),
WireError::Polling => (SERVER_POLL_ENCODE, 0, dummy.as_slice()),
},
};
let len = len as u64;
assert!(len < FRAME_MAX_LEN);
cursor.set_position(0);
cursor.write_u64::<BigEndian>(id).unwrap();
info!("encode id = {:?}", id);
cursor.write_u64::<BigEndian>(len + 9).unwrap();
info!("encode len = {:?}", len);
cursor.write_u8(ty).unwrap();
cursor.write_u64::<BigEndian>(len).unwrap();
match ty {
0 => {} SERVER_POLL_ENCODE => {
}
1..=3 => {
cursor.get_mut().resize(len as usize + 25, 0);
cursor.write_all(data).unwrap();
}
_ => unreachable!("unknown rsp type"),
}
cursor.into_inner()
}
}
impl Write for RspBuf {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}