use std::error::Error;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::process::Stdio;
use std::task::{Context, Poll};
use futures::{Future, FutureExt, StreamExt};
use http::uri::{Authority, Scheme};
use http::{header, Request, Response, StatusCode};
use hyper::Body;
use tokio::io::{self, AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio_util::io::{ReaderStream, StreamReader};
use tower::Service;
pub struct Cgi {
path: PathBuf,
env_clear: bool,
}
impl Cgi {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
Cgi {
path: path.as_ref().to_path_buf(),
env_clear: true,
}
}
pub fn env_clear(mut self, clear: bool) -> Self {
self.env_clear = clear;
self
}
}
type BoxedError = Box<dyn Error + Sync + Send>;
impl Service<Request<Body>> for Cgi {
type Response = Response<Body>;
type Error = BoxedError;
#[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
#[inline]
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let script_path = self.path.clone();
let env_clear = self.env_clear;
async move {
let script_path = std::fs::canonicalize(script_path)?;
let mut cmd = Command::new(&script_path);
let cmd = if env_clear { cmd.env_clear() } else { &mut cmd };
let mut child = cmd
.env("GATEWAY_INTERFACE", "CGI/1.1")
.env("QUERY_STRING", req.uri().query().unwrap_or_default())
.env("PATH_INFO", req.uri().path())
.env("PATH_TRANSLATED", &script_path)
.env("REQUEST_METHOD", req.method().as_str().to_ascii_uppercase())
.env("SCRIPT_NAME", req.uri().path())
.env(
"SERVER_NAME",
req.headers()
.get(header::HOST)
.and_then(|val| val.to_str().ok())
.and_then(|host| host.parse::<Authority>().ok())
.map(|authority| authority.host().to_owned())
.unwrap_or_default(),
)
.env(
"SERVER_PORT",
req.uri()
.port()
.map(|port| port.to_string())
.or_else(|| {
req.headers().get("x-forwarded-proto").and_then(|val| {
match val.to_str() {
Ok("http") => Some("80".to_string()),
Ok("https") => Some("443".to_string()),
_ => None,
}
})
})
.or_else(|| match req.uri().scheme() {
Some(scheme) if *scheme == Scheme::HTTP => Some("80".to_string()),
Some(scheme) if *scheme == Scheme::HTTPS => Some("443".to_string()),
_ => None,
})
.unwrap_or_else(|| "80".to_string()),
)
.env("SERVER_PROTOCOL", format!("{:?}", req.version()))
.env("SERVER_SOFTWARE", "tower-cgi/0.0.1")
.env(
"CONTENT_TYPE",
req.headers()
.get(header::CONTENT_TYPE)
.and_then(|val| val.to_str().ok())
.unwrap_or_default(),
)
.env(
"CONTENT_LENGTH",
req.headers()
.get(header::CONTENT_LENGTH)
.and_then(|val| val.to_str().ok())
.unwrap_or_default(),
)
.envs(
req.headers()
.into_iter()
.map(|(name, value)| {
let name = format!("HTTP_{}", name)
.replace("-", "_")
.to_ascii_uppercase();
Ok((name, value.to_str()?))
})
.collect::<Result<Vec<_>, Self::Error>>()?,
)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.spawn()?;
let mut stdin = child.stdin.take().ok_or("Failed to get process STDIN")?;
let stdout = child.stdout.take().ok_or("Failed to get process STDOUT")?;
tokio::spawn(async move { child.wait().await.unwrap() });
let write_request_body = async move {
let request_body = req
.into_body()
.map(|chunk| chunk.map_err(|err| io::Error::new(io::ErrorKind::Other, err)));
let mut request_body_reader = StreamReader::new(request_body);
io::copy(&mut request_body_reader, &mut stdin).await?;
Ok::<_, Self::Error>(io::copy(&mut request_body_reader, &mut stdin).await?)
};
let read_response = async move {
let mut stdout_reader = BufReader::new(stdout);
let mut headers = Vec::new();
loop {
stdout_reader.read_until(b'\n', &mut headers).await?;
match headers.as_slice() {
[.., b'\r', b'\n', b'\r', b'\n'] => break,
[.., b'\n', b'\n'] => break,
_ => continue,
}
}
let mut parsed_headers = [httparse::EMPTY_HEADER; 64];
httparse::parse_headers(&headers, &mut parsed_headers)?;
let response = parsed_headers
.into_iter()
.filter(|header| *header != httparse::EMPTY_HEADER)
.map(|header| (header.name, header.value))
.try_fold(
Response::builder().status(200),
|response, (name, value)| {
if name.to_ascii_lowercase() == "status" {
Ok::<_, Self::Error>(
response.status(StatusCode::from_bytes(&value[0..3])?),
)
} else {
Ok(response.header(name, value))
}
},
)?;
let body_reader = ReaderStream::new(stdout_reader);
let response = response.body(Body::wrap_stream(body_reader))?;
Ok::<_, Self::Error>(response)
};
let (_, response) = tokio::try_join!(write_request_body, read_response)?;
Ok(response)
}
.boxed()
}
}
#[cfg(test)]
mod tests {
use std::fs::Permissions;
use std::io;
use std::io::Write;
use std::os::unix::fs::PermissionsExt;
use http::Request;
use hyper::Body;
use indoc::indoc;
use tempfile::{NamedTempFile, TempPath};
use tower::ServiceExt;
use crate::Cgi;
async fn temp_cgi_script(program: &str) -> io::Result<TempPath> {
let mut file = NamedTempFile::new()?;
file.as_file_mut()
.set_permissions(Permissions::from_mode(0o755))?;
writeln!(file, "{}", program)?;
let path = file.into_temp_path();
std::thread::sleep(std::time::Duration::from_secs(1));
Ok(path)
}
#[tokio::test]
async fn test_status_code() {
let script = temp_cgi_script(indoc! {r#"
#!/bin/sh
echo "Status: 201 Created"
echo ""
"#})
.await
.unwrap();
let svc = Cgi::new(&script);
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.status(), 201);
}
#[tokio::test]
async fn test_response_headers() {
let script = temp_cgi_script(indoc! {r#"
#!/bin/sh
echo "Status: 200"
echo "x-some-header: hello"
echo "x-other-header: bye"
echo ""
"#})
.await
.unwrap();
let svc = Cgi::new(&script);
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.headers()["x-some-header"], "hello");
assert_eq!(res.headers()["x-other-header"], "bye");
}
#[tokio::test]
async fn test_request_headers() {
let script = temp_cgi_script(indoc! {r#"
#!/bin/sh
echo "Status: 200"
echo "x-req-header: ${HTTP_SOME_REQUEST_HEADER}"
echo ""
"#})
.await
.unwrap();
let svc = Cgi::new(&script);
let req = Request::builder()
.header("some-request-header", "hello")
.body(Body::empty())
.unwrap();
let res = svc.oneshot(req).await.unwrap();
assert_eq!(res.headers()["x-req-header"], "hello");
}
#[tokio::test]
async fn test_response_body() {
let script = temp_cgi_script(indoc! {r#"
#!/bin/sh
echo "Status: 200"
echo ""
printf "Hello"
"#})
.await
.unwrap();
let svc = Cgi::new(&script);
let req = Request::builder().body(Body::empty()).unwrap();
let res = svc.oneshot(req).await.unwrap();
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(&body[..], b"Hello");
}
#[tokio::test]
async fn test_request_body() {
let script = temp_cgi_script(indoc! {r#"
#!/bin/sh
echo "Status: 200"
echo ""
cat -
"#})
.await
.unwrap();
let svc = Cgi::new(&script);
let req = Request::builder().body(Body::from(&b"input"[..])).unwrap();
let res = svc.oneshot(req).await.unwrap();
let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
assert_eq!(&body[..], b"input");
}
}