1#[cfg(test)]
71use std::collections::BTreeMap;
72use std::collections::HashMap;
73use std::str::{from_utf8, Utf8Error};
74
75use futures_lite::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt};
76use headers::{encode_headers, ScgiHeaderParseError};
77use thiserror::Error;
78
79#[derive(Error, Debug)]
81pub enum ScgiReadError {
82 #[error("Length can't be decoded to an integer")]
84 BadLength,
85 #[error("The length or the headers are not in UTF-8")]
87 Utf8(#[from] Utf8Error),
88 #[error("Netstring sanity checks fail")]
90 BadNetstring,
91 #[error("Error parsing SCGI headers")]
93 BadHeaders(#[from] ScgiHeaderParseError),
94 #[error("IO Error")]
96 IO(#[from] std::io::Error),
97}
98
99#[cfg(not(test))]
101pub type ScgiHeaders = HashMap<String, String>;
102#[cfg(test)]
103pub type ScgiHeaders = BTreeMap<String, String>;
104
105#[derive(Debug, Default, PartialEq, Eq)]
110pub struct ScgiRequest {
111 pub headers: ScgiHeaders,
113 pub body: Vec<u8>,
115}
116
117impl ScgiRequest {
118 pub fn new() -> Self {
120 Self::default()
121 }
122
123 pub fn from_headers(headers: ScgiHeaders) -> Self {
125 Self {
126 headers,
127 body: Vec::new(),
128 }
129 }
130
131 pub fn encode(&self) -> Vec<u8> {
133 let headers = encode_headers(&self.headers, self.body.len());
134 let mut buf = Vec::with_capacity(headers.len() + 6);
135 buf.extend(headers.len().to_string().as_bytes());
136 buf.push(b':');
137 buf.extend(headers);
138 buf.push(b',');
139 buf.extend(&self.body);
140 buf
141 }
142}
143
144pub async fn read_request<S: AsyncBufRead + Unpin>(
146 stream: &mut S,
147) -> Result<ScgiRequest, ScgiReadError> {
148 let mut len_part = Vec::with_capacity(10);
149 let read = stream.read_until(b':', &mut len_part).await?;
150 if len_part[read - 1] != b':' {
151 return Err(ScgiReadError::BadNetstring);
152 }
153 let length = from_utf8(&len_part[..read - 1])?
154 .parse::<usize>()
155 .map_err(|_| ScgiReadError::BadLength)?;
156 let mut headers = vec![0; length];
157 stream.read_exact(&mut headers).await?;
158 let mut end_delim = 0;
159 stream
160 .read_exact(std::slice::from_mut(&mut end_delim))
161 .await?;
162 if end_delim != b',' {
163 return Err(ScgiReadError::BadNetstring);
164 }
165 let (headers, content_length) = headers::header_string_map(&headers)?;
166 let mut body = vec![0; content_length];
167 stream.read_exact(&mut body).await?;
168 Ok(ScgiRequest { headers, body })
169}
170
171pub mod headers {
173 use super::*;
174 use memchr::memchr;
175
176 #[derive(Error, Debug)]
178 pub enum ScgiHeaderParseError {
179 #[error("The length or the headers are not in UTF-8")]
181 Utf8(#[from] Utf8Error),
182 #[error("Error parsing the null-terminated headers")]
184 BadHeaderVals,
185 #[error("CONTENT_LENGTH can't be decoded to an integer")]
187 BadLength,
188 #[error("CONTENT_LENGTH header was missing")]
190 NoLength,
191 }
192
193 pub fn encode_headers(headers: &ScgiHeaders, content_length: usize) -> Vec<u8> {
197 let mut buf = Vec::new();
198
199 buf.extend(b"SCGI");
201 buf.push(0);
202 buf.extend(b"1");
203 buf.push(0);
204
205 buf.extend(b"CONTENT_LENGTH");
207 buf.push(0);
208 buf.extend(content_length.to_string().as_bytes());
209 buf.push(0);
210
211 for (name, value) in headers.iter() {
212 buf.extend(name.as_bytes());
213 buf.push(0);
214 buf.extend(value.as_bytes());
215 buf.push(0);
216 }
217 buf
218 }
219
220 pub fn parse_headers<'h>(
222 raw_headers: &'h [u8],
223 mut headers_fn: impl FnMut(&'h str, &'h str) -> Result<(), ScgiHeaderParseError>,
224 ) -> Result<(), ScgiHeaderParseError> {
225 let mut pos = 0;
226 while pos < raw_headers.len() {
227 let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
228 let header_name = from_utf8(&raw_headers[pos..pos + null])?;
229 pos += null + 1;
230 let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
231 let header_value = from_utf8(&raw_headers[pos..pos + null])?;
232 headers_fn(header_name, header_value)?;
233 pos += null + 1;
234 }
235 Ok(())
236 }
237
238 pub fn header_string_map(
244 raw_headers: &[u8],
245 ) -> Result<(ScgiHeaders, usize), ScgiHeaderParseError> {
246 let mut headers_map = ScgiHeaders::new();
247 let mut content_length = None;
248 parse_headers(raw_headers, |name, value| {
249 if name != "SCGI" {
251 if name == "CONTENT_LENGTH" {
252 content_length =
253 Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
254 } else {
255 headers_map.insert(name.to_owned(), value.to_owned());
256 }
257 }
258 Ok(())
259 })?;
260 Ok((
261 headers_map,
262 content_length.ok_or(ScgiHeaderParseError::NoLength)?,
263 ))
264 }
265
266 pub fn header_str_map(
272 raw_headers: &[u8],
273 ) -> Result<(HashMap<&str, &str>, usize), ScgiHeaderParseError> {
274 let mut headers_map = HashMap::new();
275 let mut content_length = None;
276 parse_headers(raw_headers, |name, value| {
277 if name != "SCGI" {
279 if name == "CONTENT_LENGTH" {
280 content_length =
281 Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
282 } else {
283 headers_map.insert(name, value);
284 }
285 }
286 Ok(())
287 })?;
288 Ok((
289 headers_map,
290 content_length.ok_or(ScgiHeaderParseError::NoLength)?,
291 ))
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use futures_lite::io::BufReader;
298
299 use crate::read_request;
300
301 use super::*;
302
303 const TEST_DATA: &[u8] = include_bytes!("../test_data/dump");
304
305 #[test]
306 fn encode_scgi_request() {
307 let mut headers = ScgiHeaders::new();
308 headers.insert("hello".to_owned(), "world".to_owned());
309 headers.insert("foo".to_owned(), "bar".to_owned());
310 let req = ScgiRequest {
311 headers,
312 body: "here is some data".into(),
313 };
314 let encoded = req.encode();
315 assert_eq!(TEST_DATA, &encoded);
316 }
317
318 #[test]
319 fn read_scgi_request() {
320 let mut headers = ScgiHeaders::new();
321 headers.insert("hello".to_owned(), "world".to_owned());
322 headers.insert("foo".to_owned(), "bar".to_owned());
323 let expected = ScgiRequest {
324 headers,
325 body: "here is some data".into(),
326 };
327 futures_lite::future::block_on(async {
328 let mut stream = BufReader::new(TEST_DATA);
329 let req = read_request(&mut stream).await.unwrap();
330 assert_eq!(expected, req);
331 });
332 }
333}