1use std::error::Error;
2use std::fmt;
3use tokio::prelude::*;
4
5#[derive(Debug, PartialEq, Clone)]
7pub enum RequestProtocol {
8 V1,
10}
11
12#[derive(Debug, PartialEq)]
14pub struct RequestAuthHeader<'a> {
15 pub org: &'a str,
17 pub user: &'a str,
19 pub key: &'a str,
21}
22
23#[derive(Debug, PartialEq, Clone)]
25pub enum RequestSubtype {
26 Init,
27}
28
29#[derive(Debug, PartialEq, Clone)]
31pub enum RequestType {
32 Statistics,
33 Sync,
34}
35
36#[derive(Debug, PartialEq, Clone)]
39pub enum RequestHeader<'a> {
40 Client(&'a str),
41 Org(&'a str),
42 User(&'a str),
43 Key(&'a str),
44 Protocol(RequestProtocol),
45 Type(RequestType),
46 Other(&'a str),
47 Subtype(RequestSubtype),
48}
49
50#[derive(Debug, PartialEq)]
52pub struct RequestHeaders<'a> {
53 pub protocol: RequestProtocol,
54 pub client: &'a str,
55 pub request_type: RequestType,
56 pub request_subtype: Option<RequestSubtype>,
57 pub auth: RequestAuthHeader<'a>,
58}
59
60#[derive(PartialEq)]
62pub struct Request<'a, P> {
63 pub headers: RequestHeaders<'a>,
64 pub raw_headers: Vec<RequestHeader<'a>>,
65 pub payload: P,
67}
68
69#[derive(Debug)]
71pub enum RequestError {
72 InvalidHeader(String),
74 MissingHeader(String),
76 IOError(tokio::io::Error),
78 EncodingError(std::str::Utf8Error),
80 MissingSyncKey,
81 InvalidRequest(String),
83}
84
85impl fmt::Display for RequestError {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 write!(f, "{:?}", &self)
88 }
89}
90impl Error for RequestError {}
91
92impl From<tokio::io::Error> for RequestError {
93 fn from(e: tokio::io::Error) -> Self {
94 RequestError::IOError(e)
95 }
96}
97impl From<std::str::Utf8Error> for RequestError {
98 fn from(e: std::str::Utf8Error) -> Self {
99 RequestError::EncodingError(e)
100 }
101}
102
103fn parse_header(raw: &str) -> Result<RequestHeader<'_>, RequestError> {
105 let mut s = raw.split(": ");
106 fn make_err(r: &str) -> RequestError {
107 RequestError::InvalidHeader(r.into())
108 }
109 let name = s.next().ok_or_else(|| make_err(raw))?;
110 let value = s.next().ok_or_else(|| make_err(raw))?;
111 if s.next().is_some() {
112 return Err(make_err(raw));
113 }
114
115 let v = match name {
116 "client" => RequestHeader::Client(value),
117 "org" => RequestHeader::Org(value),
118 "user" => RequestHeader::User(value),
119 "key" => RequestHeader::Key(value),
120 "protocol" => match value {
121 "v1" => RequestHeader::Protocol(RequestProtocol::V1),
122 _ => return Err(make_err(raw)),
123 },
124 "type" => match value {
125 "sync" => RequestHeader::Type(RequestType::Sync),
126 "statistics" => RequestHeader::Type(RequestType::Statistics),
127 _ => return Err(make_err(raw)),
128 },
129 "subtype" => match value {
130 "init" => RequestHeader::Subtype(RequestSubtype::Init),
131 _ => return Err(make_err(raw)),
132 },
133 _ => RequestHeader::Other(value),
134 };
135 Ok(v)
136}
137
138pub fn parse_request(req: &str) -> Result<Request<'_, impl Iterator<Item = &str>>, RequestError> {
142 let mut lines = req.lines();
143
144 let mut protocol = None;
145 let mut request_type = None;
146 let mut request_subtype = None;
147 let mut client = None;
148 let mut auth_org = None;
149 let mut auth_user = None;
150 let mut auth_key = None;
151 let mut raw_headers = Vec::new();
152 for line in &mut lines {
153 if line.is_empty() {
155 break;
156 }
157
158 let header = parse_header(line)?;
159 match &header {
160 RequestHeader::Protocol(p) => protocol = Some(p.clone()),
161 RequestHeader::Type(t) => request_type = Some(t.clone()),
162 RequestHeader::Subtype(t) => request_subtype = Some(t.clone()),
163 RequestHeader::Client(c) => client = Some(*c),
164 RequestHeader::Org(o) => auth_org = Some(*o),
165 RequestHeader::User(u) => auth_user = Some(*u),
166 RequestHeader::Key(k) => auth_key = Some(*k),
167 _ => {}
168 }
169 raw_headers.push(header);
170 }
171
172 let parsed_header = match (
173 protocol,
174 request_type,
175 client,
176 auth_org,
177 auth_user,
178 auth_key,
179 ) {
180 (None, _, _, _, _, _) => Err(RequestError::MissingHeader("protocol".into())),
181 (_, None, _, _, _, _) => Err(RequestError::MissingHeader("type".into())),
182 (_, _, None, _, _, _) => Err(RequestError::MissingHeader("client".into())),
183 (_, _, _, None, _, _) => Err(RequestError::MissingHeader("org".into())),
184 (_, _, _, _, None, _) => Err(RequestError::MissingHeader("user".into())),
185 (_, _, _, _, _, None) => Err(RequestError::MissingHeader("key".into())),
186 (Some(protocol), Some(request_type), Some(client), Some(org), Some(user), Some(key)) => {
187 Ok(Request {
188 headers: RequestHeaders {
189 protocol,
190 client,
191 request_type,
192 request_subtype,
193 auth: RequestAuthHeader { org, user, key },
194 },
195 raw_headers,
196 payload: lines.filter(|a| !a.is_empty()),
197 })
198 }
199 }?;
200
201 Ok(parsed_header)
202}
203
204pub async fn get_request_data<R>(buf: &mut Vec<u8>, mut con: R) -> Result<(), RequestError>
206where
207 R: AsyncRead + Unpin,
208{
209 let len = con.read_u32().await?;
210 let len = (len as usize) - std::mem::size_of::<u32>();
211
212 buf.clear();
213 let capacity = buf.capacity();
214 if capacity < len {
215 buf.reserve_exact(len - buf.capacity());
216 }
217
218 unsafe {
220 buf.set_len(len);
221 }
222 con.read_exact(buf).await?;
223
224 Ok(())
225}