use bytes::{Bytes, BytesMut, BufMut, Buf};
use crate::bufvec::BufList;
use log::{debug};
use std::iter::{FromIterator, Extend, IntoIterator};
#[derive(Debug, Clone)]
struct Header
{
version: u8,
rtype: u8,
request_id: u16,
content_length: u16,
padding_length: u8,
}
pub struct BeginRequestBody
{
role: u16,
flags: u8,
}
pub struct EndRequestBody
{
pub app_status: u32,
pub protocol_status: u8,
}
pub struct UnknownTypeBody
{
rtype: u8,
}
pub struct NameValuePair
{
pub name_data: Bytes,
pub value_data: Bytes,
}
pub struct NVBody
{
pairs: BufList<Bytes>,
len: u16
}
pub struct NVBodyList
{
bodies: Vec<NVBody>
}
pub struct STDINBody{
}
pub enum Body
{
BeginRequest(BeginRequestBody),
EndRequest(EndRequestBody),
UnknownType(UnknownTypeBody),
Params(NVBody),
StdIn(Bytes),
StdErr(Bytes),
StdOut(Bytes),
Abort,
GetValues(NVBody),
GetValuesResult(NVBody),
}
pub struct Record
{
header: Header,
pub body: Body
}
pub const LISTENSOCK_FILENO: u8 = 0;
impl Body {
const MAX_LENGTH: usize = 0xffff;
}
impl Header {
const HEADER_LEN: usize = 8;
const VERSION_1: u8 = 1;
}
impl Record {
pub const MGMT_REQUEST_ID: u16 = 0;
const BEGIN_REQUEST: u8 = 1;
const ABORT_REQUEST: u8 = 2;
const END_REQUEST: u8 = 3;
pub const PARAMS: u8 = 4;
const STDIN: u8 = 5;
const STDOUT: u8 = 6;
const STDERR: u8 = 7;
const DATA: u8 = 8;
pub const GET_VALUES: u8 = 9;
pub const GET_VALUES_RESULT: u8 = 10;
const UNKNOWN_TYPE: u8 = 11;
}
impl BeginRequestBody {
pub const KEEP_CONN: u8 = 1;
pub const RESPONDER: u16 = 1;
pub const AUTHORIZER: u16 = 2;
pub const FILTER: u16 = 3;
}
impl EndRequestBody {
pub const REQUEST_COMPLETE: u8 = 0;
pub const CANT_MPX_CONN: u8 = 1;
pub const OVERLOADED: u8 = 2;
pub const UNKNOWN_ROLE: u8 = 3;
}
pub const MAX_CONNS: &'static [u8] = b"MAX_CONNS";
pub const MAX_REQS: &'static [u8] = b"MAX_REQS";
pub const MPXS_CONNS: &'static [u8] = b"MPXS_CONNS";
impl NameValuePair
{
pub fn parse(data: &mut Bytes) -> NameValuePair
{
let mut pos: usize = 0;
let key_length = NameValuePair::param_length(data, &mut pos);
let value_length = NameValuePair::param_length(data, &mut pos);
let key = data.slice(pos..pos + key_length);
pos += key_length;
let value = data.slice(pos..pos + value_length);
pos += value_length;
data.advance(pos);
NameValuePair {
name_data: key,
value_data: value
}
}
pub fn new(name_data: Bytes, value_data: Bytes) -> NameValuePair
{
NameValuePair {
name_data: name_data,
value_data: value_data
}
}
fn param_length(data: &Bytes, pos: &mut usize) -> usize
{
let mut length: usize = data[*pos] as usize;
if (length >> 7) == 1 {
length = (data.slice(*pos+1..(*pos + 4)).get_u32() & 0x7FFFFFFF) as usize;
*pos += 4;
} else {
*pos += 1;
}
length
}
fn len(&self) -> usize {
let ln = self.name_data.len();
let lv = self.value_data.len();
let mut lf: usize = ln+lv+2;
if ln > 0x7f {
lf +=3;
}
if lv > 0x7f {
lf +=3;
}
lf
}
}
impl STDINBody
{
pub fn new(request_id: u16, b: &mut Bytes) -> Record
{
let mut size = b.remaining();
if size > Body::MAX_LENGTH {
size = Body::MAX_LENGTH;
}
Record {
header: Header::new(Record::STDIN, request_id, size as u16),
body: Body::StdIn(b.slice(..size))
}
}
}
impl Header {
pub fn new(rtype:u8,request_id:u16,len:u16) -> Header {
let mut pad: u8 = (len%8) as u8;
if pad !=0 {
pad = 8 - pad;
}
Header {
version: Header::VERSION_1,
rtype,
request_id,
content_length: len,
padding_length: pad,
}
}
}
impl Header
{
fn write_into(self, data: &mut BytesMut)
{
data.put_u8(self.version);
data.put_u8(self.rtype);
data.put_u16(self.request_id);
data.put_u16(self.content_length);
data.put_u8(self.padding_length);
data.put_u8(0);
debug!("h {} {} -> {:?}",self.request_id, self.content_length, &data);
}
fn parse(data: &mut Bytes) -> Header
{
let h = Header {
version: data.get_u8(),
rtype: data.get_u8(),
request_id: data.get_u16(),
content_length: data.get_u16(),
padding_length: data.get_u8()
};
data.advance(1);
h
}
}
impl BeginRequestBody
{
pub fn new(role: u16, flags: u8, request_id: u16) -> Record
{
Record {
header: Header {
version: Header::VERSION_1,
rtype: Record::BEGIN_REQUEST,
request_id,
content_length: 8,
padding_length: 0,
},
body: Body::BeginRequest(BeginRequestBody {
role,
flags
})
}
}
}
pub(crate) struct RecordReader {
current: Option<Header>
}
impl RecordReader
{
pub(crate) fn new() -> RecordReader {
RecordReader {
current: None
}
}
pub(crate) fn read(&mut self, data: &mut Bytes) -> Option<Record>
{
let mut full_header = match self.current.take() {
Some(h) => h,
None => {
if data.remaining() < 8 {
return None;
}
debug!("new header");
Header::parse(data)
}
};
let mut body_len = full_header.content_length as usize;
let header = if data.remaining() < body_len {
let mut nh = full_header.clone();
body_len = data.remaining();
nh.content_length = body_len as u16;
nh.padding_length = 0;
full_header.content_length -= body_len as u16;
self.current = Some(full_header);
if body_len < 1 {
return None;
}
debug!("more later, now:");
nh
}else{
full_header
};
debug!("read type {:?} payload: {:?}", header.rtype, &data.slice(..body_len));
let body = data.slice(0..body_len);
data.advance(body_len);
if data.remaining() < header.padding_length as usize {
if body_len < 1 {
self.current = Some(header);
return None;
}
let mut nh = header.clone();
nh.content_length = 0;
self.current = Some(nh);
debug!("padding {} is still missing", header.padding_length);
}else{
data.advance(header.padding_length as usize);
}
let body = Record::parse_body(body, header.rtype);
Some(Record {
header,
body
})
}
}
impl Record
{
pub(crate) fn get_request_id(&self) -> u16 {
self.header.request_id
}
pub(crate) fn read(data: &mut Bytes) -> Option<Record>
{
if data.remaining() < 8 {
return None;
}
let header = Header::parse(&mut data.slice(..));
let len = header.content_length as usize+header.padding_length as usize;
if data.remaining() < len+8 {
return None;
}
data.advance(8);
debug!("read type {:?} payload: {:?}", header.rtype, &data.slice(..len));
let body = data.slice(0..header.content_length as usize);
data.advance(len);
let body = Record::parse_body(body, header.rtype);
Some(Record {
header,
body
})
}
fn parse_body(mut payload: Bytes, ptype: u8) -> Body {
match ptype {
Record::STDOUT => Body::StdOut(payload),
Record::STDERR => Body::StdErr(payload),
Record::END_REQUEST => Body::EndRequest(EndRequestBody::parse(payload)),
Record::UNKNOWN_TYPE => {
let rtype = payload.get_u8();
payload.advance(7);
Body::UnknownType(UnknownTypeBody {
rtype
})
},
Record::GET_VALUES_RESULT => Body::GetValuesResult(NVBody::from_bytes(payload)),
Record::GET_VALUES => Body::GetValues(NVBody::from_bytes(payload)),
Record::PARAMS => Body::Params(NVBody::from_bytes(payload)),
_ => panic!("not impl"),
}
}
pub(crate) fn abort(request_id: u16) -> Record {
Record {
header: Header {
version: Header::VERSION_1,
rtype: Record::ABORT_REQUEST,
request_id,
content_length: 0,
padding_length: 0,
},
body: Body::Abort
}
}
pub(crate) fn append(self, buf: &mut BufList<Bytes>)
{
match self.body
{
Body::BeginRequest(brb) => {
let mut data = BytesMut::with_capacity(Header::HEADER_LEN+8);
self.header.write_into(&mut data);
brb.write_into(&mut data);
buf.push(data.into());
},
Body::Params(nvs) | Body::GetValues(nvs) | Body::GetValuesResult(nvs) => {
let mut data = BytesMut::with_capacity(Header::HEADER_LEN);
let pad = self.header.padding_length as usize;
self.header.write_into(&mut data);
let header = data.freeze();
buf.push(header.clone());
buf.append(nvs.pairs);
if pad>0 {
buf.push(header.slice(0..pad));
}
},
Body::StdIn(b) => {
let mut data = BytesMut::with_capacity(Header::HEADER_LEN);
let pad = self.header.padding_length as usize;
self.header.write_into(&mut data);
let header = data.freeze();
buf.push(header.clone());
if !b.has_remaining()
{
debug_assert!(pad==0);
return;
}
buf.push(b);
if pad>0 {
buf.push(header.slice(0..pad));
}
},
Body::Abort => {
let mut data = BytesMut::with_capacity(Header::HEADER_LEN);
self.header.write_into(&mut data);
let header = data.freeze();
buf.push(header);
},
_ => panic!("not impl"),
}
}
}
impl NVBody
{
pub fn new() -> NVBody {
NVBody {
pairs: BufList::new(),
len: 0
}
}
pub fn to_record(self, rtype: u8, request_id: u16) -> Record {
let mut pad: u8 = (self.len%8) as u8;
if pad !=0 {
pad = 8 - pad;
}
Record {
header: Header {
version: Header::VERSION_1,
rtype,
request_id,
content_length: self.len,
padding_length: pad,
},
body: match rtype {
Record::PARAMS => Body::Params(self),
Record::GET_VALUES => Body::GetValues(self),
Record::GET_VALUES_RESULT => Body::GetValuesResult(self),
_ => panic!("No valid type"),
}
}
}
pub fn fits(&self, pair: &NameValuePair) -> bool {
let l = pair.len()+self.len as usize;
l <= Body::MAX_LENGTH
}
pub fn add(&mut self, pair: NameValuePair) -> Result<(),()> {
let l = pair.len()+self.len as usize;
if l > Body::MAX_LENGTH {
return Err(());
}
self.len = l as u16;
let mut ln = pair.name_data.len();
let mut lv = pair.value_data.len();
if ln+lv > Body::MAX_LENGTH {
return Err(());
}
let mut lf: usize = 2;
if ln > 0x7f {
if ln > 0x7fffffff {
return Err(());
}
lf +=3;
ln |= 0x8000;
}
if lv > 0x7f {
if lv > 0x7fffffff {
return Err(());
}
lf +=3;
lv |= 0x8000;
}
let mut data: BytesMut = BytesMut::with_capacity(lf);
if ln > 0x7f {
data.put_u32(ln as u32);
}else{
data.put_u8(ln as u8);
}
if lv > 0x7f {
data.put_u32(lv as u32);
}else{
data.put_u8(lv as u8);
}
self.pairs.push(data.freeze());
self.pairs.push(pair.name_data);
if lv > 0 {
self.pairs.push(pair.value_data);
}
Ok(())
}
pub(crate) fn from_bytes(buf: Bytes) -> NVBody {
let mut b = NVBody::new();
b.len = buf.remaining() as u16;
if b.len > 0 {
b.pairs.push(buf);
}
b
}
}
impl Iterator for NVBody {
type Item = NameValuePair;
fn next(&mut self) -> Option<NameValuePair> {
if !self.pairs.has_remaining() {
return None;
}
Some(NameValuePair::parse(&mut self.pairs.to_bytes()))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(1, None)
}
}
impl NVBodyList{
pub fn new() -> NVBodyList {
NVBodyList{
bodies: vec![NVBody::new()]
}
}
pub fn add(&mut self, pair: NameValuePair) {
let mut nv = self.bodies.last_mut().unwrap();
if !nv.fits(&pair) {
let new = NVBody::new();
self.bodies.push(new);
nv = self.bodies.last_mut().unwrap();
}
nv.add(pair).expect("KVPair bigger that 0xFFFF");
}
pub(crate) fn append_records(self, rtype: u8, request_id: u16, wbuf: &mut BufList<Bytes>) {
for nvbod in self.bodies {
nvbod.to_record(rtype, request_id).append(wbuf);
}
}
}
impl FromIterator<(Bytes, Bytes)> for NVBodyList
{
fn from_iter<T: IntoIterator<Item = (Bytes, Bytes)>>(iter: T) -> NVBodyList {
let mut nv = NVBodyList::new();
nv.extend(iter);
nv
}
}
impl Extend<(Bytes, Bytes)> for NVBodyList
{
#[inline]
fn extend<T: IntoIterator<Item = (Bytes, Bytes)>>(&mut self, iter: T) {
for (k,v) in iter {
self.add(NameValuePair::new(k,v));
}
}
}
impl BeginRequestBody
{
fn write_into(self, data: &mut BytesMut)
{
data.put_u16(self.role);
data.put_u8(self.flags);
data.put_slice(&[0;5]);
debug!("br {} -> {:?}",self.role, &data);
}
}
impl EndRequestBody
{
fn parse(mut data: Bytes) -> EndRequestBody
{
let b = EndRequestBody {
app_status: data.get_u32(),
protocol_status: data.get_u8(),
};
data.advance(3);
b
}
}
#[test]
fn encode_simple_get() {
let mut b = BufList::new();
BeginRequestBody::new(BeginRequestBody::RESPONDER,0,1).append(&mut b);
let mut nv = NVBody::new();
nv.add(NameValuePair::new(Bytes::from(&b"SCRIPT_FILENAME"[..]),Bytes::from(&b"/home/daniel/Public/test.php"[..]))).expect("record full");
nv.to_record(Record::PARAMS, 1).append(&mut b);
NVBody::new().to_record(Record::PARAMS, 1).append(&mut b);
assert_eq!(b.to_bytes(),
&b"\x01\x01\0\x01\0\x08\0\0\0\x01\0\0\0\0\0\0\x01\x04\0\x01\0-\x03\0\x0f\x1cSCRIPT_FILENAME/home/daniel/Public/test.php\x01\x04\0\x01\x04\0\x01\0\0\0\0"[..]
);
}
#[test]
fn encode_post() {
let mut b = BufList::new();
BeginRequestBody::new(BeginRequestBody::RESPONDER,0,1).append(&mut b);
let mut nv = NVBody::new();
nv.add(NameValuePair::new(Bytes::from(&b"SCRIPT_FILENAME"[..]),Bytes::from(&b"/home/daniel/Public/test.php"[..]))).expect("record full");
nv.to_record(Record::PARAMS, 1).append(&mut b);
NVBody::new().to_record(Record::PARAMS, 1).append(&mut b);
STDINBody::new(1, &mut Bytes::from(&b"a=b"[..])).append(&mut b);
STDINBody::new(1, &mut Bytes::new()).append(&mut b);
assert_eq!(b.to_bytes(),
&b"\x01\x01\0\x01\0\x08\0\0\0\x01\0\0\0\0\0\0\x01\x04\0\x01\0-\x03\0\x0f\x1cSCRIPT_FILENAME/home/daniel/Public/test.php\x01\x04\0\x01\x04\0\x01\0\0\0\0\x01\x05\0\x01\0\x03\x05\0a=b\x01\x05\0\x01\0\x01\x05\0\x01\0\0\0\0"[..]);
}