wcgi_host/
cgi.rs

1use std::collections::HashMap;
2
3use http::{header::HeaderName, HeaderMap, HeaderValue, StatusCode};
4use tokio::io::{AsyncBufRead, AsyncBufReadExt};
5
6use crate::CgiError;
7
8const SERVER_SOFTWARE: &str = concat!(env!("CARGO_PKG_NAME"), " ", env!("CARGO_PKG_VERSION"));
9
10/// The RFC says certain headers that should be mapped to environment variables,
11/// if present.
12const KNOWN_META_VARIABLES: [(&str, &str); 3] = [
13    ("Content-Length", "CONTENT_LENGTH"),
14    ("Content-Type", "CONTENT_TYPE"),
15    ("Authorization", "AUTH_TYPE"),
16];
17
18pub(crate) fn prepare_environment_variables(
19    parts: http::request::Parts,
20    env: &mut HashMap<String, String>,
21) {
22    env.insert(
23        "REQUEST_METHOD".to_string(),
24        parts.method.as_str().to_string(),
25    );
26    env.insert(
27        "QUERY_STRING".to_string(),
28        parts.uri.query().unwrap_or("").to_string(),
29    );
30    env.insert("PATH_INFO".to_string(), parts.uri.path().to_string());
31    env.insert("SERVER_SOFTWARE".to_string(), SERVER_SOFTWARE.to_string());
32
33    // FIXME(Michael-F-Bryan): we hard-code the assumption that our CGI files
34    // were mounted under /app/. This should be configurable.
35    // https://github.com/wasmerio/wcgi/issues/21
36    env.insert("DOCUMENT_ROOT".to_string(), "/app/".to_string());
37    env.insert(
38        "SCRIPT_FILENAME".to_string(),
39        format!("/app{}", parts.uri.path()),
40    );
41
42    if let Some(protocol) = server_protocol(parts.version) {
43        env.insert("SERVER_PROTOCOL".to_string(), protocol.to_string());
44    }
45
46    if let Some(content_length) = parts
47        .headers
48        .get("Content-Length")
49        .and_then(|v| v.to_str().ok())
50    {
51        env.insert("CONTENT_LENGTH".to_string(), content_length.to_string());
52    }
53    for (header_name, env_variable) in KNOWN_META_VARIABLES {
54        if let Some(value) = parts.headers.get(header_name).and_then(|v| v.to_str().ok()) {
55            env.insert(env_variable.to_string(), value.to_string());
56        }
57    }
58
59    // All "protocol-specific" HTTP sure headers are passed in as $HTTP_xxx
60    for (name, value) in &parts.headers {
61        if let Ok(value) = value.to_str() {
62            let name = format!("HTTP_{}", name.to_string().to_uppercase().replace('-', "_"));
63            env.insert(name, value.to_string());
64        }
65    }
66}
67
68pub(crate) async fn extract_response_header(
69    stdout: &mut (impl AsyncBufRead + Unpin),
70) -> Result<http::response::Parts, CgiError> {
71    let mut headers = parse_cgi_headers(stdout).await?;
72
73    let mut builder = http::response::Builder::new();
74
75    let status = headers.remove("Status").and_then(|status| {
76        let status = status.to_str().ok()?;
77        parse_status_header(status)
78    });
79
80    if let Some(status) = status {
81        builder = builder.status(status);
82    }
83
84    // Note: This can only panic when the TryInto calls used in the builder's
85    // various builder methods fail. However, this should never happen because
86    // we parse the headers and status code ourselves.
87    //
88    // Don't look too closely or you'll notice that we call into_parts() just so
89    // we can get a Parts object that the caller can pass to
90    // Response::from_parts() later on.
91    let (mut parts, _) = builder
92        .body(())
93        .expect("All builder inputs should already be validated")
94        .into_parts();
95
96    parts.headers.extend(headers);
97
98    Ok(parts)
99}
100
101/// Parse the header section from the response.
102///
103/// # Implementation Notes
104///
105/// Any invalid headers will be silently discarded. This might happen if lines
106/// aren't in the form `name: value` or when header name/value contains invalid
107/// characters.
108async fn parse_cgi_headers(
109    stdout_receiver: &mut (impl AsyncBufRead + Unpin),
110) -> Result<HeaderMap, CgiError> {
111    let mut headers = HeaderMap::new();
112    let mut buffer = String::new();
113
114    loop {
115        buffer.clear();
116        stdout_receiver
117            .read_line(&mut buffer)
118            .await
119            .map_err(CgiError::StdoutRead)?;
120
121        let line = buffer.trim();
122
123        if line.is_empty() {
124            // We found the CRLF CRLF that indicates the end of the header
125            // section.
126            break;
127        }
128
129        let (key, value) = match line.split_once(':') {
130            Some((k, v)) => (k, v),
131            None => {
132                // Let's be lenient and ignore the invalid header line.
133                continue;
134            }
135        };
136
137        if let Some((key, value)) = parse_header_pair(key, value) {
138            headers.append(key, value);
139        }
140    }
141
142    Ok(headers)
143}
144
145fn parse_header_pair(key: &str, value: &str) -> Option<(HeaderName, HeaderValue)> {
146    let key = HeaderName::from_bytes(key.trim().as_bytes()).ok()?;
147    let value = value.trim().parse().ok()?;
148
149    Some((key, value))
150}
151
152fn parse_status_header(status_code: &str) -> Option<StatusCode> {
153    // Note: the status code header may contain just the number (i.e. "200") or
154    // also the reason ("200 OK").
155    let src = match status_code.split_once(' ') {
156        Some((s, _)) => s,
157        None => status_code,
158    };
159
160    src.parse().ok()
161}
162
163fn server_protocol(version: http::Version) -> Option<&'static str> {
164    match version {
165        http::Version::HTTP_09 => Some("HTTP/0.9"),
166        http::Version::HTTP_10 => Some("HTTP/1.0"),
167        http::Version::HTTP_11 => Some("HTTP/1.1"),
168        http::Version::HTTP_2 => Some("HTTP/2.0"),
169        http::Version::HTTP_3 => Some("HTTP/3.0"),
170        _ => None,
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use tokio::io::AsyncReadExt;
177
178    use super::*;
179
180    #[test]
181    fn parse_status_codes() {
182        let codes = [
183            ("200", Some(StatusCode::OK)),
184            ("200 OK", Some(StatusCode::OK)),
185            ("", None),
186        ];
187
188        for (src, expected) in codes {
189            let got = parse_status_header(src);
190            assert_eq!(got, expected);
191        }
192    }
193
194    #[tokio::test]
195    async fn parse_response_parts() {
196        let src = [
197            "Status: 503 Database Unavailable",
198            "Content-type: text/html",
199            "",
200            "<HTML>",
201            "<HEAD><TITLE>503 Database Unavailable</TITLE></HEAD>",
202            "<BODY>",
203            "  <H1>Error</H1>",
204            "  <P>Sorry, the database is currently not available. Please",
205            "    try again later.</P>",
206            "</BODY>",
207            "</HTML>",
208        ]
209        .join("\r\n");
210        let mut reader = tokio::io::BufReader::new(src.as_bytes());
211
212        let parts = extract_response_header(&mut reader).await.unwrap();
213
214        assert_eq!(parts.status, StatusCode::SERVICE_UNAVAILABLE);
215        assert_eq!(parts.headers["Content-type"].to_str().unwrap(), "text/html");
216        // Make sure we stopped reading at the \r\n\r\n
217        let mut body = String::new();
218        reader.read_to_string(&mut body).await.unwrap();
219        assert!(body.starts_with("<HTML>\r\n<HEAD><TITLE>503 Database Unavailable</TITLE></HEAD>"));
220    }
221
222    #[tokio::test]
223    async fn respect_duplicate_headers() {
224        let src = [
225            "Status: 503 Database Unavailable",
226            "Content-type: text/html",
227            "Cookie: first=x",
228            "Cookie: second=y",
229        ]
230        .join("\r\n");
231        let mut reader = tokio::io::BufReader::new(src.as_bytes());
232
233        let parts = extract_response_header(&mut reader).await.unwrap();
234
235        let cookies: Vec<_> = parts
236            .headers
237            .get_all("Cookie")
238            .iter()
239            .map(|v| v.to_str().unwrap())
240            .collect();
241        assert_eq!(cookies, ["first=x", "second=y"]);
242    }
243
244    #[tokio::test]
245    async fn parse_response_parts_with_empty_body() {
246        let src = "Location: https://google.com/\r\n";
247        let mut reader = tokio::io::BufReader::new(src.as_bytes());
248
249        let parts = extract_response_header(&mut reader).await.unwrap();
250
251        assert_eq!(
252            parts.headers["Location"].to_str().unwrap(),
253            "https://google.com/"
254        );
255        // Make sure we stopped reading at the \r\n\r\n
256        let mut body = String::new();
257        reader.read_to_string(&mut body).await.unwrap();
258        assert!(body.is_empty());
259    }
260}