use std::error::Error;
use std::fmt;
use tokio::prelude::*;
#[derive(Debug, PartialEq, Clone)]
pub enum RequestProtocol {
V1,
}
#[derive(Debug, PartialEq)]
pub struct RequestAuthHeader<'a> {
pub org: &'a str,
pub user: &'a str,
pub key: &'a str,
}
#[derive(Debug, PartialEq, Clone)]
pub enum RequestSubtype {
Init,
}
#[derive(Debug, PartialEq, Clone)]
pub enum RequestType {
Statistics,
Sync,
}
#[derive(Debug, PartialEq, Clone)]
pub enum RequestHeader<'a> {
Client(&'a str),
Org(&'a str),
User(&'a str),
Key(&'a str),
Protocol(RequestProtocol),
Type(RequestType),
Other(&'a str),
Subtype(RequestSubtype),
}
#[derive(Debug, PartialEq)]
pub struct RequestHeaders<'a> {
pub protocol: RequestProtocol,
pub client: &'a str,
pub request_type: RequestType,
pub request_subtype: Option<RequestSubtype>,
pub auth: RequestAuthHeader<'a>,
}
#[derive(PartialEq)]
pub struct Request<'a, P> {
pub headers: RequestHeaders<'a>,
pub raw_headers: Vec<RequestHeader<'a>>,
pub payload: P,
}
#[derive(Debug)]
pub enum RequestError {
InvalidHeader(String),
MissingHeader(String),
IOError(tokio::io::Error),
EncodingError(std::str::Utf8Error),
MissingSyncKey,
InvalidRequest(String),
}
impl fmt::Display for RequestError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", &self)
}
}
impl Error for RequestError {}
impl From<tokio::io::Error> for RequestError {
fn from(e: tokio::io::Error) -> Self {
RequestError::IOError(e)
}
}
impl From<std::str::Utf8Error> for RequestError {
fn from(e: std::str::Utf8Error) -> Self {
RequestError::EncodingError(e)
}
}
fn parse_header(raw: &str) -> Result<RequestHeader<'_>, RequestError> {
let mut s = raw.split(": ");
fn make_err(r: &str) -> RequestError {
RequestError::InvalidHeader(r.into())
}
let name = s.next().ok_or_else(|| make_err(raw))?;
let value = s.next().ok_or_else(|| make_err(raw))?;
if s.next().is_some() {
return Err(make_err(raw));
}
let v = match name {
"client" => RequestHeader::Client(value),
"org" => RequestHeader::Org(value),
"user" => RequestHeader::User(value),
"key" => RequestHeader::Key(value),
"protocol" => match value {
"v1" => RequestHeader::Protocol(RequestProtocol::V1),
_ => return Err(make_err(raw)),
},
"type" => match value {
"sync" => RequestHeader::Type(RequestType::Sync),
"statistics" => RequestHeader::Type(RequestType::Statistics),
_ => return Err(make_err(raw)),
},
"subtype" => match value {
"init" => RequestHeader::Subtype(RequestSubtype::Init),
_ => return Err(make_err(raw)),
},
_ => RequestHeader::Other(value),
};
Ok(v)
}
pub fn parse_request(req: &str) -> Result<Request<'_, impl Iterator<Item = &str>>, RequestError> {
let mut lines = req.lines();
let mut protocol = None;
let mut request_type = None;
let mut request_subtype = None;
let mut client = None;
let mut auth_org = None;
let mut auth_user = None;
let mut auth_key = None;
let mut raw_headers = Vec::new();
for line in &mut lines {
if line.is_empty() {
break;
}
let header = parse_header(line)?;
match &header {
RequestHeader::Protocol(p) => protocol = Some(p.clone()),
RequestHeader::Type(t) => request_type = Some(t.clone()),
RequestHeader::Subtype(t) => request_subtype = Some(t.clone()),
RequestHeader::Client(c) => client = Some(*c),
RequestHeader::Org(o) => auth_org = Some(*o),
RequestHeader::User(u) => auth_user = Some(*u),
RequestHeader::Key(k) => auth_key = Some(*k),
_ => {}
}
raw_headers.push(header);
}
let parsed_header = match (
protocol,
request_type,
client,
auth_org,
auth_user,
auth_key,
) {
(None, _, _, _, _, _) => Err(RequestError::MissingHeader("protocol".into())),
(_, None, _, _, _, _) => Err(RequestError::MissingHeader("type".into())),
(_, _, None, _, _, _) => Err(RequestError::MissingHeader("client".into())),
(_, _, _, None, _, _) => Err(RequestError::MissingHeader("org".into())),
(_, _, _, _, None, _) => Err(RequestError::MissingHeader("user".into())),
(_, _, _, _, _, None) => Err(RequestError::MissingHeader("key".into())),
(Some(protocol), Some(request_type), Some(client), Some(org), Some(user), Some(key)) => {
Ok(Request {
headers: RequestHeaders {
protocol,
client,
request_type,
request_subtype,
auth: RequestAuthHeader { org, user, key },
},
raw_headers,
payload: lines.filter(|a| !a.is_empty()),
})
}
}?;
Ok(parsed_header)
}
pub async fn get_request_data<R>(buf: &mut Vec<u8>, mut con: R) -> Result<(), RequestError>
where
R: AsyncRead + Unpin,
{
let len = con.read_u32().await?;
let len = (len as usize) - std::mem::size_of::<u32>();
buf.clear();
let capacity = buf.capacity();
if capacity < len {
buf.reserve_exact(len - buf.capacity());
}
unsafe {
buf.set_len(len);
}
con.read_exact(buf).await?;
Ok(())
}