#[cfg(test)]
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::str::{from_utf8, Utf8Error};
use futures_lite::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt};
use headers::{encode_headers, ScgiHeaderParseError};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ScgiReadError {
#[error("Length can't be decoded to an integer")]
BadLength,
#[error("The length or the headers are not in UTF-8")]
Utf8(#[from] Utf8Error),
#[error("Netstring sanity checks fail")]
BadNetstring,
#[error("Error parsing SCGI headers")]
BadHeaders(#[from] ScgiHeaderParseError),
#[error("IO Error")]
IO(#[from] std::io::Error),
}
#[cfg(not(test))]
pub type ScgiHeaders = HashMap<String, String>;
#[cfg(test)]
pub type ScgiHeaders = BTreeMap<String, String>;
#[derive(Debug, Default, PartialEq, Eq)]
pub struct ScgiRequest {
pub headers: ScgiHeaders,
pub body: Vec<u8>,
}
impl ScgiRequest {
pub fn new() -> Self {
Self::default()
}
pub fn from_headers(headers: ScgiHeaders) -> Self {
Self {
headers,
body: Vec::new(),
}
}
pub fn encode(&self) -> Vec<u8> {
let headers = encode_headers(&self.headers, self.body.len());
let mut buf = Vec::with_capacity(headers.len() + 6);
buf.extend(headers.len().to_string().as_bytes());
buf.push(b':');
buf.extend(headers);
buf.push(b',');
buf.extend(&self.body);
buf
}
}
pub async fn read_request<S: AsyncBufRead + Unpin>(
stream: &mut S,
) -> Result<ScgiRequest, ScgiReadError> {
let mut len_part = Vec::with_capacity(10);
let read = stream.read_until(b':', &mut len_part).await?;
if len_part[read - 1] != b':' {
return Err(ScgiReadError::BadNetstring);
}
let length = from_utf8(&len_part[..read - 1])?
.parse::<usize>()
.map_err(|_| ScgiReadError::BadLength)?;
let mut headers = vec![0; length];
stream.read_exact(&mut headers).await?;
let mut end_delim = 0;
stream
.read_exact(std::slice::from_mut(&mut end_delim))
.await?;
if end_delim != b',' {
return Err(ScgiReadError::BadNetstring);
}
let (headers, content_length) = headers::header_string_map(&headers)?;
let mut body = vec![0; content_length];
stream.read_exact(&mut body).await?;
Ok(ScgiRequest { headers, body })
}
pub mod headers {
use super::*;
use memchr::memchr;
#[derive(Error, Debug)]
pub enum ScgiHeaderParseError {
#[error("The length or the headers are not in UTF-8")]
Utf8(#[from] Utf8Error),
#[error("Error parsing the null-terminated headers")]
BadHeaderVals,
#[error("CONTENT_LENGTH can't be decoded to an integer")]
BadLength,
#[error("CONTENT_LENGTH header was missing")]
NoLength,
}
pub fn encode_headers(headers: &ScgiHeaders, content_length: usize) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend(b"SCGI");
buf.push(0);
buf.extend(b"1");
buf.push(0);
buf.extend(b"CONTENT_LENGTH");
buf.push(0);
buf.extend(content_length.to_string().as_bytes());
buf.push(0);
for (name, value) in headers.iter() {
buf.extend(name.as_bytes());
buf.push(0);
buf.extend(value.as_bytes());
buf.push(0);
}
buf
}
pub fn parse_headers<'h>(
raw_headers: &'h [u8],
mut headers_fn: impl FnMut(&'h str, &'h str) -> Result<(), ScgiHeaderParseError>,
) -> Result<(), ScgiHeaderParseError> {
let mut pos = 0;
while pos < raw_headers.len() {
let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
let header_name = from_utf8(&raw_headers[pos..pos + null])?;
pos += null + 1;
let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
let header_value = from_utf8(&raw_headers[pos..pos + null])?;
headers_fn(header_name, header_value)?;
pos += null + 1;
}
Ok(())
}
pub fn header_string_map(
raw_headers: &[u8],
) -> Result<(ScgiHeaders, usize), ScgiHeaderParseError> {
let mut headers_map = ScgiHeaders::new();
let mut content_length = None;
parse_headers(raw_headers, |name, value| {
if name != "SCGI" {
if name == "CONTENT_LENGTH" {
content_length =
Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
} else {
headers_map.insert(name.to_owned(), value.to_owned());
}
}
Ok(())
})?;
Ok((
headers_map,
content_length.ok_or(ScgiHeaderParseError::NoLength)?,
))
}
pub fn header_str_map(
raw_headers: &[u8],
) -> Result<(HashMap<&str, &str>, usize), ScgiHeaderParseError> {
let mut headers_map = HashMap::new();
let mut content_length = None;
parse_headers(raw_headers, |name, value| {
if name != "SCGI" {
if name == "CONTENT_LENGTH" {
content_length =
Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
} else {
headers_map.insert(name, value);
}
}
Ok(())
})?;
Ok((
headers_map,
content_length.ok_or(ScgiHeaderParseError::NoLength)?,
))
}
}
#[cfg(test)]
mod tests {
use futures_lite::io::BufReader;
use crate::read_request;
use super::*;
const TEST_DATA: &[u8] = include_bytes!("../test_data/dump");
#[test]
fn encode_scgi_request() {
let mut headers = ScgiHeaders::new();
headers.insert("hello".to_owned(), "world".to_owned());
headers.insert("foo".to_owned(), "bar".to_owned());
let req = ScgiRequest {
headers,
body: "here is some data".into(),
};
let encoded = req.encode();
assert_eq!(TEST_DATA, &encoded);
}
#[test]
fn read_scgi_request() {
let mut headers = ScgiHeaders::new();
headers.insert("hello".to_owned(), "world".to_owned());
headers.insert("foo".to_owned(), "bar".to_owned());
let expected = ScgiRequest {
headers,
body: "here is some data".into(),
};
futures_lite::future::block_on(async {
let mut stream = BufReader::new(TEST_DATA);
let req = read_request(&mut stream).await.unwrap();
assert_eq!(expected, req);
});
}
}