taskserver_protocol/
request.rs

1use std::error::Error;
2use std::fmt;
3use tokio::prelude::*;
4
5/// The Protocol Version of the taskserver protocol
6#[derive(Debug, PartialEq, Clone)]
7pub enum RequestProtocol {
8    /// the currently only supported version
9    V1,
10}
11
12/// The group containing the headers relevant for authentication
13#[derive(Debug, PartialEq)]
14pub struct RequestAuthHeader<'a> {
15    /// Users org
16    pub org: &'a str,
17    /// username
18    pub user: &'a str,
19    /// key (the password equivalent for task)
20    pub key: &'a str,
21}
22
23/// The Subtype. Currently only makes sense in the context of the sync request
24#[derive(Debug, PartialEq, Clone)]
25pub enum RequestSubtype {
26    Init,
27}
28
29/// The Request type. Contains all valid requests the client can send.
30#[derive(Debug, PartialEq, Clone)]
31pub enum RequestType {
32    Statistics,
33    Sync,
34}
35
36/// Type safe header struct.
37/// The relevant headers will be present as fields on the [RequestHeaders](struct.RequestHeaders.html) struct
38#[derive(Debug, PartialEq, Clone)]
39pub enum RequestHeader<'a> {
40    Client(&'a str),
41    Org(&'a str),
42    User(&'a str),
43    Key(&'a str),
44    Protocol(RequestProtocol),
45    Type(RequestType),
46    Other(&'a str),
47    Subtype(RequestSubtype),
48}
49
50/// Group containing the most necessary headers. All fields here are required
51#[derive(Debug, PartialEq)]
52pub struct RequestHeaders<'a> {
53    pub protocol: RequestProtocol,
54    pub client: &'a str,
55    pub request_type: RequestType,
56    pub request_subtype: Option<RequestSubtype>,
57    pub auth: RequestAuthHeader<'a>,
58}
59
60/// Parsed Request.
61#[derive(PartialEq)]
62pub struct Request<'a, P> {
63    pub headers: RequestHeaders<'a>,
64    pub raw_headers: Vec<RequestHeader<'a>>,
65    /// this field usually contains an iterator yielding the lines of the payload, but Request needs the type Parameter because of Limitations of impl Trait
66    pub payload: P,
67}
68
69/// Request parsing Errors
70#[derive(Debug)]
71pub enum RequestError {
72    /// Raised when a header is malformed
73    InvalidHeader(String),
74    /// Raised with the protocol name of the missing header.
75    MissingHeader(String),
76    /// Raised when reading the Request fails
77    IOError(tokio::io::Error),
78    /// Raised when the Request payload is not valid utf-8
79    EncodingError(std::str::Utf8Error),
80    MissingSyncKey,
81    /// Raised when the Request is not valid but none of the Other cases is applicable
82    InvalidRequest(String),
83}
84
85impl fmt::Display for RequestError {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        write!(f, "{:?}", &self)
88    }
89}
90impl Error for RequestError {}
91
92impl From<tokio::io::Error> for RequestError {
93    fn from(e: tokio::io::Error) -> Self {
94        RequestError::IOError(e)
95    }
96}
97impl From<std::str::Utf8Error> for RequestError {
98    fn from(e: std::str::Utf8Error) -> Self {
99        RequestError::EncodingError(e)
100    }
101}
102
103/// parses the header into the typesafe struct
104fn parse_header(raw: &str) -> Result<RequestHeader<'_>, RequestError> {
105    let mut s = raw.split(": ");
106    fn make_err(r: &str) -> RequestError {
107        RequestError::InvalidHeader(r.into())
108    }
109    let name = s.next().ok_or_else(|| make_err(raw))?;
110    let value = s.next().ok_or_else(|| make_err(raw))?;
111    if s.next().is_some() {
112        return Err(make_err(raw));
113    }
114
115    let v = match name {
116        "client" => RequestHeader::Client(value),
117        "org" => RequestHeader::Org(value),
118        "user" => RequestHeader::User(value),
119        "key" => RequestHeader::Key(value),
120        "protocol" => match value {
121            "v1" => RequestHeader::Protocol(RequestProtocol::V1),
122            _ => return Err(make_err(raw)),
123        },
124        "type" => match value {
125            "sync" => RequestHeader::Type(RequestType::Sync),
126            "statistics" => RequestHeader::Type(RequestType::Statistics),
127            _ => return Err(make_err(raw)),
128        },
129        "subtype" => match value {
130            "init" => RequestHeader::Subtype(RequestSubtype::Init),
131            _ => return Err(make_err(raw)),
132        },
133        _ => RequestHeader::Other(value),
134    };
135    Ok(v)
136}
137
138/// splits the request string into headers (everything until the first empty line)
139/// and payload (everything after)
140/// Then assembles the Request struct while validating the existence of all required headers
141pub fn parse_request(req: &str) -> Result<Request<'_, impl Iterator<Item = &str>>, RequestError> {
142    let mut lines = req.lines();
143
144    let mut protocol = None;
145    let mut request_type = None;
146    let mut request_subtype = None;
147    let mut client = None;
148    let mut auth_org = None;
149    let mut auth_user = None;
150    let mut auth_key = None;
151    let mut raw_headers = Vec::new();
152    for line in &mut lines {
153        // the empty line marks the end of the headers and the beginning of the payload
154        if line.is_empty() {
155            break;
156        }
157
158        let header = parse_header(line)?;
159        match &header {
160            RequestHeader::Protocol(p) => protocol = Some(p.clone()),
161            RequestHeader::Type(t) => request_type = Some(t.clone()),
162            RequestHeader::Subtype(t) => request_subtype = Some(t.clone()),
163            RequestHeader::Client(c) => client = Some(*c),
164            RequestHeader::Org(o) => auth_org = Some(*o),
165            RequestHeader::User(u) => auth_user = Some(*u),
166            RequestHeader::Key(k) => auth_key = Some(*k),
167            _ => {}
168        }
169        raw_headers.push(header);
170    }
171
172    let parsed_header = match (
173        protocol,
174        request_type,
175        client,
176        auth_org,
177        auth_user,
178        auth_key,
179    ) {
180        (None, _, _, _, _, _) => Err(RequestError::MissingHeader("protocol".into())),
181        (_, None, _, _, _, _) => Err(RequestError::MissingHeader("type".into())),
182        (_, _, None, _, _, _) => Err(RequestError::MissingHeader("client".into())),
183        (_, _, _, None, _, _) => Err(RequestError::MissingHeader("org".into())),
184        (_, _, _, _, None, _) => Err(RequestError::MissingHeader("user".into())),
185        (_, _, _, _, _, None) => Err(RequestError::MissingHeader("key".into())),
186        (Some(protocol), Some(request_type), Some(client), Some(org), Some(user), Some(key)) => {
187            Ok(Request {
188                headers: RequestHeaders {
189                    protocol,
190                    client,
191                    request_type,
192                    request_subtype,
193                    auth: RequestAuthHeader { org, user, key },
194                },
195                raw_headers,
196                payload: lines.filter(|a| !a.is_empty()),
197            })
198        }
199    }?;
200
201    Ok(parsed_header)
202}
203
204/// reads the entire request into the provided buffer
205pub async fn get_request_data<R>(buf: &mut Vec<u8>, mut con: R) -> Result<(), RequestError>
206where
207    R: AsyncRead + Unpin,
208{
209    let len = con.read_u32().await?;
210    let len = (len as usize) - std::mem::size_of::<u32>();
211
212    buf.clear();
213    let capacity = buf.capacity();
214    if capacity < len {
215        buf.reserve_exact(len - buf.capacity());
216    }
217
218    // this is only valid because u8 does not have a drop implementation
219    unsafe {
220        buf.set_len(len);
221    }
222    con.read_exact(buf).await?;
223
224    Ok(())
225}