1use std::io::{Read, Write, stdin};
54use std::collections::HashMap;
55use std::convert::TryFrom;
56
57pub extern crate http;
58
59pub type Request = http::Request<Vec<u8>>;
61
62pub type Response = http::Response<Vec<u8>>;
64
65pub fn handle<F>(func: F)
73 where F: FnOnce(Request) -> Response
74{
75 let env_vars: HashMap<String, String> = std::env::vars().collect();
76
77 let content_length: usize = env_vars.get("CONTENT_LENGTH")
80 .and_then(|cl| cl.parse::<usize>().ok()).unwrap_or(0);
81
82 let mut stdin_contents = vec![0; content_length];
83 stdin().read_exact(&mut stdin_contents).unwrap();
84
85 let request = parse_request(env_vars, stdin_contents);
86
87 let response = func(request);
88
89 let output = serialize_response(response);
90
91 std::io::stdout().write_all(&output).unwrap();
92}
93
94#[doc(inline)]
95pub use cgi_attributes::main;
96
97pub fn err_to_500<E>(res: Result<Response, E>) -> Response {
98 res.unwrap_or(empty_response(500))
99}
100
101pub fn empty_response<T>(status_code: T) -> Response
104 where http::StatusCode: TryFrom<T>,
105 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>
106{
107 http::response::Builder::new().status(status_code).body(vec![]).unwrap()
108}
109
110pub fn html_response<T, S>(status_code: T, body: S) -> Response
113 where http::StatusCode: TryFrom<T>,
114 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
115 S: Into<String>
116{
117 let body: Vec<u8> = body.into().into_bytes();
118 http::response::Builder::new()
119 .status(status_code)
120 .header(http::header::CONTENT_TYPE, "text/html; charset=utf-8")
121 .header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str())
122 .body(body)
123 .unwrap()
124}
125
126pub fn string_response<T, S>(status_code: T, body: S) -> Response
128 where http::StatusCode: TryFrom<T>,
129 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
130 S: Into<String>
131{
132 let body: Vec<u8> = body.into().into_bytes();
133 http::response::Builder::new()
134 .status(status_code)
135 .header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str())
136 .body(body)
137 .unwrap()
138}
139
140
141pub fn text_response<T, S>(status_code: T, body: S) -> Response
151 where http::StatusCode: TryFrom<T>,
152 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
153 S: Into<String>
154{
155 let body: Vec<u8> = body.into().into_bytes();
156 http::response::Builder::new()
157 .status(status_code)
158 .header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str())
159 .header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
160 .body(body)
161 .unwrap()
162}
163
164
165pub fn binary_response<'a, T>(status_code: T, content_type: impl Into<Option<&'a str>>, body: Vec<u8>) -> Response
186 where http::StatusCode: TryFrom<T>,
187 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
188{
189 let content_type: Option<&str> = content_type.into();
190
191 let mut response = http::response::Builder::new()
192 .status(status_code)
193 .header(http::header::CONTENT_LENGTH, format!("{}", body.len()).as_str());
194
195 if let Some(ct) = content_type {
196 response = response.header(http::header::CONTENT_TYPE, ct);
197 }
198
199 response.body(body).unwrap()
200}
201
202
203fn parse_request(env_vars: HashMap<String, String>, stdin: Vec<u8>) -> Request {
204 let mut req = http::Request::builder();
205
206 let method = env_vars.get("REQUEST_METHOD").expect("no REQUEST_METHOD set");
207 req = req.method(method.as_str());
208 let uri = if env_vars.get("QUERY_STRING").unwrap_or(&"".to_owned()) != "" {
209 format!("{}?{}", env_vars["SCRIPT_NAME"], env_vars["QUERY_STRING"])
210 } else {
211 env_vars["SCRIPT_NAME"].to_owned()
212 };
213 req = req.uri(uri.as_str());
214
215 if let Some(v) = env_vars.get("SERVER_PROTOCOL") {
216 if v == "HTTP/0.9" {
217 req = req.version(http::version::Version::HTTP_09);
218 } else if v == "HTTP/1.0" {
219 req = req.version(http::version::Version::HTTP_10);
220 } else if v == "HTTP/1.1" {
221 req = req.version(http::version::Version::HTTP_11);
222 } else if v == "HTTP/2.0" {
223 req = req.version(http::version::Version::HTTP_2);
224 } else {
225 unimplemented!("Unsupport HTTP SERVER_PROTOCOL {:?}", v);
226 }
227 }
228
229 for key in env_vars.keys().filter(|k| k.starts_with("HTTP_")) {
230 let header: String = key.chars().skip(5).map(|c| if c == '_' { '-' } else { c }).collect();
231 req = req.header(header.as_str(), env_vars[key].as_str().trim());
232 }
233
234
235 req = add_header(req, &env_vars, "AUTH_TYPE", "X-CGI-Auth-Type");
236 req = add_header(req, &env_vars, "CONTENT_LENGTH", "X-CGI-Content-Length");
237 req = add_header(req, &env_vars, "CONTENT_TYPE", "X-CGI-Content-Type");
238 req = add_header(req, &env_vars, "GATEWAY_INTERFACE", "X-CGI-Gateway-Interface");
239 req = add_header(req, &env_vars, "PATH_INFO", "X-CGI-Path-Info");
240 req = add_header(req, &env_vars, "PATH_TRANSLATED", "X-CGI-Path-Translated");
241 req = add_header(req, &env_vars, "QUERY_STRING", "X-CGI-Query-String");
242 req = add_header(req, &env_vars, "REMOTE_ADDR", "X-CGI-Remote-Addr");
243 req = add_header(req, &env_vars, "REMOTE_HOST", "X-CGI-Remote-Host");
244 req = add_header(req, &env_vars, "REMOTE_IDENT", "X-CGI-Remote-Ident");
245 req = add_header(req, &env_vars, "REMOTE_USER", "X-CGI-Remote-User");
246 req = add_header(req, &env_vars, "REQUEST_METHOD", "X-CGI-Request-Method");
247 req = add_header(req, &env_vars, "SCRIPT_NAME", "X-CGI-Script-Name");
248 req = add_header(req, &env_vars, "SERVER_PORT", "X-CGI-Server-Port");
249 req = add_header(req, &env_vars, "SERVER_PROTOCOL", "X-CGI-Server-Protocol");
250 req = add_header(req, &env_vars, "SERVER_SOFTWARE", "X-CGI-Server-Software");
251
252 req.body(stdin).unwrap()
253
254}
255
256fn add_header(req: http::request::Builder, env_vars: &HashMap<String, String>, meta_var: &str, target_header: &str) -> http::request::Builder {
258 if let Some(var) = env_vars.get(meta_var) {
259 req.header(target_header, var.as_str())
260 } else {
261 req
262 }
263}
264
265fn serialize_response(response: Response) -> Vec<u8> {
267 let mut output = String::new();
268 output.push_str("Status: ");
269 output.push_str(response.status().as_str());
270 if let Some(reason) = response.status().canonical_reason() {
271 output.push_str(" ");
272 output.push_str(reason);
273 }
274 output.push_str("\n");
275
276 {
277 let headers = response.headers();
278 let mut keys: Vec<&http::header::HeaderName> = headers.keys().collect();
279 keys.sort_by_key(|h| h.as_str());
280 for key in keys {
281 output.push_str(key.as_str());
282 output.push_str(": ");
283 output.push_str(headers.get(key).unwrap().to_str().unwrap());
284 output.push_str("\n");
285 }
286 }
287
288 output.push_str("\n");
289
290 let mut output = output.into_bytes();
291
292 let (_, mut body) = response.into_parts();
293
294 output.append(&mut body);
295
296 output
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 fn env(input: Vec<(&str, &str)>) -> HashMap<String, String> {
304 input.into_iter().map(|(a, b)| (a.to_owned(), b.to_owned())).collect()
305 }
306
307 #[test]
308 fn test_parse_request() {
309 let env_vars = env(vec![
310 ("REQUEST_METHOD", "GET"), ("SCRIPT_NAME", "/my/path/script"),
311 ("SERVER_PROTOCOL", "HTTP/1.0"), ("HTTP_USER_AGENT", "MyBrowser/1.0"),
312 ("QUERY_STRING", "foo=bar&baz=bop"),
313 ]);
314 let stdin = Vec::new();
315 let req = parse_request(env_vars, stdin);
316 assert_eq!(req.method(), &http::method::Method::GET);
317 assert_eq!(req.uri(), "/my/path/script?foo=bar&baz=bop");
318 assert_eq!(req.uri().path(), "/my/path/script");
319 assert_eq!(req.uri().query(), Some("foo=bar&baz=bop"));
320 assert_eq!(req.version(), http::version::Version::HTTP_10);
321 assert_eq!(req.headers()[http::header::USER_AGENT], "MyBrowser/1.0");
322 assert_eq!(req.body(), &vec![] as &Vec<u8>);
323 }
324
325 fn test_serialized_response(resp: http::response::Builder, body: &str, expected_output: &str) {
326 let resp: Response = resp.body(String::from(body).into_bytes()).unwrap();
327 let output = serialize_response(resp);
328 let expected_output = String::from(expected_output).into_bytes();
329
330 if output != expected_output {
331 println!("output: {}\nexptected: {}", std::str::from_utf8(&output).unwrap(), std::str::from_utf8(&expected_output).unwrap());
332 }
333
334 assert_eq!(output, expected_output);
335 }
336
337 #[test]
338 fn test_serialized_response1() {
339 test_serialized_response(
340 http::Response::builder().status(200),
341 "Hello World",
342 "Status: 200 OK\n\nHello World"
343 );
344
345 test_serialized_response(
346 http::Response::builder().status(200)
347 .header("Content-Type", "text/html")
348 .header("Content-Language", "en")
349 .header("Cache-Control", "max-age=3600"),
350 "<html><body><h1>Hello</h1></body></html>",
351 "Status: 200 OK\ncache-control: max-age=3600\ncontent-language: en\ncontent-type: text/html\n\n<html><body><h1>Hello</h1></body></html>"
352 );
353 }
354
355 #[test]
356 fn test_shortcuts1() {
357 assert_eq!(std::str::from_utf8(&serialize_response(html_response(200, "<html><body><h1>Hello World</h1></body></html>"))).unwrap(),
358 "Status: 200 OK\ncontent-length: 46\ncontent-type: text/html; charset=utf-8\n\n<html><body><h1>Hello World</h1></body></html>"
359 );
360 }
361
362 #[test]
363 fn test_shortcuts2() {
364 assert_eq!(std::str::from_utf8(&serialize_response(binary_response(200, None, vec![65, 66, 67]))).unwrap(),
365 "Status: 200 OK\ncontent-length: 3\n\nABC"
366 );
367
368 assert_eq!(std::str::from_utf8(&serialize_response(binary_response(200, "application/octet-stream", vec![65, 66, 67]))).unwrap(),
369 "Status: 200 OK\ncontent-length: 3\ncontent-type: application/octet-stream\n\nABC"
370 );
371
372 let ct: String = "image/png".to_string();
373 assert_eq!(std::str::from_utf8(&serialize_response(binary_response(200, ct.as_str(), vec![65, 66, 67]))).unwrap(),
374 "Status: 200 OK\ncontent-length: 3\ncontent-type: image/png\n\nABC"
375 );
376 }
377
378}