use std::io::{Read, Write, stdin};
use std::collections::HashMap;
use std::convert::TryFrom;
pub extern crate http;
pub type Request = http::Request<Vec<u8>>;
pub type Response = http::Response<Vec<u8>>;
pub fn handle<F>(func: F)
where F: FnOnce(Request) -> Response
{
let env_vars: HashMap<String, String> = std::env::vars().collect();
let content_length: usize = env_vars.get("CONTENT_LENGTH")
.and_then(|cl| cl.parse::<usize>().ok()).unwrap_or(0);
let mut stdin_contents = vec![0; content_length];
stdin().read_exact(&mut stdin_contents).unwrap();
let request = parse_request(env_vars, stdin_contents);
let response = func(request);
let output = serialize_response(response);
std::io::stdout().write_all(&output).unwrap();
}
#[macro_export]
macro_rules! cgi_main {
( $func:expr ) => {
fn main() {
cgi::handle( $func );
}
};
}
#[macro_export]
macro_rules! cgi_try_main {
( $func:expr ) => {
fn main() {
cgi::handle(|request: cgi::Request| {
match $func(request) {
Ok(resp) => resp,
Err(err) => {
eprintln!("{:?}", err);
cgi::empty_response(500)
}
}
})
}
};
}
pub fn err_to_500<E>(res: Result<Response, E>) -> Response {
res.unwrap_or(empty_response(500))
}
pub fn empty_response<T>(status_code: T) -> Response
where http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>
{
http::response::Builder::new().status(status_code).body(vec![]).unwrap()
}
pub fn html_response<T, S>(status_code: T, body: S) -> Response
where http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
S: Into<String>
{
let body: Vec<u8> = body.into().into_bytes();
http::response::Builder::new()
.status(status_code)
.header(http::header::CONTENT_TYPE, "text/html; charset=utf-8")
.header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str())
.body(body)
.unwrap()
}
pub fn string_response<T, S>(status_code: T, body: S) -> Response
where http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
S: Into<String>
{
let body: Vec<u8> = body.into().into_bytes();
http::response::Builder::new()
.status(status_code)
.header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str())
.body(body)
.unwrap()
}
pub fn text_response<T, S>(status_code: T, body: S) -> Response
where http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
S: Into<String>
{
let body: Vec<u8> = body.into().into_bytes();
http::response::Builder::new()
.status(status_code)
.header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str())
.header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(body)
.unwrap()
}
pub fn binary_response<'a, T>(status_code: T, content_type: impl Into<Option<&'a str>>, body: Vec<u8>) -> Response
where http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
{
let content_type: Option<&str> = content_type.into();
let mut response = http::response::Builder::new()
.status(status_code)
.header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str());
if let Some(ct) = content_type {
response = response.header(http::header::CONTENT_TYPE, ct);
}
response.body(body).unwrap()
}
fn parse_request(env_vars: HashMap<String, String>, stdin: Vec<u8>) -> Request {
let mut req = http::Request::builder();
req = req.method(env_vars["REQUEST_METHOD"].as_str());
let uri = if env_vars.get("QUERY_STRING").unwrap_or(&"".to_owned()) != "" {
format!("{}?{}", env_vars["SCRIPT_NAME"], env_vars["QUERY_STRING"])
} else {
env_vars["SCRIPT_NAME"].to_owned()
};
req = req.uri(uri.as_str());
if let Some(v) = env_vars.get("SERVER_PROTOCOL") {
if v == "HTTP/0.9" {
req = req.version(http::version::Version::HTTP_09);
} else if v == "HTTP/1.0" {
req = req.version(http::version::Version::HTTP_10);
} else if v == "HTTP/1.1" {
req = req.version(http::version::Version::HTTP_11);
} else if v == "HTTP/2.0" {
req = req.version(http::version::Version::HTTP_2);
} else {
unimplemented!("Unsupport HTTP SERVER_PROTOCOL {:?}", v);
}
}
for key in env_vars.keys().filter(|k| k.starts_with("HTTP_")) {
let header: String = key.chars().skip(5).map(|c| if c == '_' { '-' } else { c }).collect();
req = req.header(header.as_str(), env_vars[key].as_str().trim());
}
req = add_header(req, &env_vars, "AUTH_TYPE", "X-CGI-Auth-Type");
req = add_header(req, &env_vars, "CONTENT_LENGTH", "X-CGI-Content-Length");
req = add_header(req, &env_vars, "CONTENT_TYPE", "X-CGI-Content-Type");
req = add_header(req, &env_vars, "GATEWAY_INTERFACE", "X-CGI-Gateway-Interface");
req = add_header(req, &env_vars, "PATH_INFO", "X-CGI-Path-Info");
req = add_header(req, &env_vars, "PATH_TRANSLATED", "X-CGI-Path-Translated");
req = add_header(req, &env_vars, "QUERY_STRING", "X-CGI-Query-String");
req = add_header(req, &env_vars, "REMOTE_ADDR", "X-CGI-Remote-Addr");
req = add_header(req, &env_vars, "REMOTE_HOST", "X-CGI-Remote-Host");
req = add_header(req, &env_vars, "REMOTE_IDENT", "X-CGI-Remote-Ident");
req = add_header(req, &env_vars, "REMOTE_USER", "X-CGI-Remote-User");
req = add_header(req, &env_vars, "REQUEST_METHOD", "X-CGI-Request-Method");
req = add_header(req, &env_vars, "SCRIPT_NAME", "X-CGI-Script-Name");
req = add_header(req, &env_vars, "SERVER_PORT", "X-CGI-Server-Port");
req = add_header(req, &env_vars, "SERVER_PROTOCOL", "X-CGI-Server-Protocol");
req = add_header(req, &env_vars, "SERVER_SOFTWARE", "X-CGI-Server-Software");
req.body(stdin).unwrap()
}
fn add_header(req: http::request::Builder, env_vars: &HashMap<String, String>, meta_var: &str, target_header: &str) -> http::request::Builder {
if let Some(var) = env_vars.get(meta_var) {
req.header(target_header, var.as_str())
} else {
req
}
}
fn serialize_response(response: Response) -> Vec<u8> {
let mut output = String::new();
output.push_str("Status: ");
output.push_str(response.status().as_str());
if let Some(reason) = response.status().canonical_reason() {
output.push_str(" ");
output.push_str(reason);
}
output.push_str("\n");
{
let headers = response.headers();
let mut keys: Vec<&http::header::HeaderName> = headers.keys().collect();
keys.sort_by_key(|h| h.as_str());
for key in keys {
output.push_str(key.as_str());
output.push_str(": ");
output.push_str(headers.get(key).unwrap().to_str().unwrap());
output.push_str("\n");
}
}
output.push_str("\n");
let mut output = output.into_bytes();
let (_, mut body) = response.into_parts();
output.append(&mut body);
output
}
#[cfg(test)]
mod tests {
use super::*;
fn env(input: Vec<(&str, &str)>) -> HashMap<String, String> {
input.into_iter().map(|(a, b)| (a.to_owned(), b.to_owned())).collect()
}
#[test]
fn test_parse_request() {
let env_vars = env(vec![
("REQUEST_METHOD", "GET"), ("SCRIPT_NAME", "/my/path/script"),
("SERVER_PROTOCOL", "HTTP/1.0"), ("HTTP_USER_AGENT", "MyBrowser/1.0"),
("QUERY_STRING", "foo=bar&baz=bop"),
]);
let stdin = Vec::new();
let req = parse_request(env_vars, stdin);
assert_eq!(req.method(), &http::method::Method::GET);
assert_eq!(req.uri(), "/my/path/script?foo=bar&baz=bop");
assert_eq!(req.uri().path(), "/my/path/script");
assert_eq!(req.uri().query(), Some("foo=bar&baz=bop"));
assert_eq!(req.version(), http::version::Version::HTTP_10);
assert_eq!(req.headers()[http::header::USER_AGENT], "MyBrowser/1.0");
assert_eq!(req.body(), &vec![] as &Vec<u8>);
}
fn test_serialized_response(resp: http::response::Builder, body: &str, expected_output: &str) {
let resp: Response = resp.body(String::from(body).into_bytes()).unwrap();
let output = serialize_response(resp);
let expected_output = String::from(expected_output).into_bytes();
if output != expected_output {
println!("output: {}\nexptected: {}", std::str::from_utf8(&output).unwrap(), std::str::from_utf8(&expected_output).unwrap());
}
assert_eq!(output, expected_output);
}
#[test]
fn test_serialized_response1() {
test_serialized_response(
http::Response::builder().status(200),
"Hello World",
"Status: 200 OK\n\nHello World"
);
test_serialized_response(
http::Response::builder().status(200)
.header("Content-Type", "text/html")
.header("Content-Language", "en")
.header("Cache-Control", "max-age=3600"),
"<html><body><h1>Hello</h1></body></html>",
"Status: 200 OK\ncache-control: max-age=3600\ncontent-language: en\ncontent-type: text/html\n\n<html><body><h1>Hello</h1></body></html>"
);
}
#[test]
fn test_shortcuts1() {
assert_eq!(std::str::from_utf8(&serialize_response(html_response(200, "<html><body><h1>Hello World</h1></body></html>"))).unwrap(),
"Status: 200 OK\ncontent-length: 46\ncontent-type: text/html; charset=utf-8\n\n<html><body><h1>Hello World</h1></body></html>"
);
}
#[test]
fn test_shortcuts2() {
assert_eq!(std::str::from_utf8(&serialize_response(binary_response(200, None, vec![65, 66, 67]))).unwrap(),
"Status: 200 OK\ncontent-length: 3\n\nABC"
);
assert_eq!(std::str::from_utf8(&serialize_response(binary_response(200, "application/octet-stream", vec![65, 66, 67]))).unwrap(),
"Status: 200 OK\ncontent-length: 3\ncontent-type: application/octet-stream\n\nABC"
);
let ct: String = "image/png".to_string();
assert_eq!(std::str::from_utf8(&serialize_response(binary_response(200, ct.as_str(), vec![65, 66, 67]))).unwrap(),
"Status: 200 OK\ncontent-length: 3\ncontent-type: image/png\n\nABC"
);
}
}