use super::cgi_util::{
InFlightGuard, build_cgi_env, collect_body, parse_cgi_response,
socket_roundtrip,
};
use crate::error::{HttpResponse, response_502};
use crate::error::ReqBody;
use crate::handler::Handler;
use crate::headers::RequestContext;
use crate::metrics::Metrics;
use async_trait::async_trait;
use hyper::Request;
use std::sync::Arc;
use std::sync::atomic::Ordering;
pub(crate) struct ScgiHandler {
socket: String,
root: String,
index: Option<String>,
metrics: Arc<Metrics>,
}
#[async_trait]
impl Handler for ScgiHandler {
async fn handle(
&self,
req: Request<ReqBody>,
matched_prefix: &str,
_ctx: &RequestContext<'_>,
) -> HttpResponse {
self.metrics
.scgi_requests_total
.fetch_add(1, Ordering::Relaxed);
let _guard = InFlightGuard::new(
self.metrics.clone(),
|m| &m.scgi_in_flight,
);
let (parts, body) = req.into_parts();
let body_bytes = match collect_body(
body,
&self.metrics.scgi_errors_total,
)
.await
{
Ok(b) => b,
Err(resp) => return resp,
};
let env = build_cgi_env(
&parts,
&self.root,
matched_prefix,
&self.index,
&body_bytes,
);
let request_bytes = build_scgi_request(&env, &body_bytes);
match socket_roundtrip(
&self.socket, &request_bytes, "scgi",
)
.await
{
Ok(raw) => match parse_cgi_response(&raw) {
Ok(resp) => resp,
Err(e) => {
self.metrics
.scgi_errors_total
.fetch_add(1, Ordering::Relaxed);
tracing::error!(
socket = %self.socket,
"scgi: malformed CGI response: {e}"
);
response_502()
}
},
Err(e) => {
self.metrics
.scgi_errors_total
.fetch_add(1, Ordering::Relaxed);
tracing::error!(
socket = %self.socket,
"scgi: connection error: {e}"
);
response_502()
}
}
}
}
impl ScgiHandler {
pub(crate) fn new(
socket: &str,
root: &str,
index: Option<String>,
metrics: Arc<Metrics>,
) -> Self {
Self {
socket: socket.to_owned(),
root: root.to_owned(),
index,
metrics,
}
}
}
pub(crate) fn build_scgi_request(env: &[(String, String)], body: &[u8]) -> Vec<u8> {
let content_length = body.len().to_string();
let mut header_block = Vec::new();
append_pair(&mut header_block, "CONTENT_LENGTH", &content_length);
for (key, value) in env {
if key != "CONTENT_LENGTH" {
append_pair(&mut header_block, key, value);
}
}
let mut out = Vec::new();
out.extend_from_slice(header_block.len().to_string().as_bytes());
out.push(b':');
out.extend_from_slice(&header_block);
out.push(b',');
out.extend_from_slice(body);
out
}
fn append_pair(buf: &mut Vec<u8>, key: &str, value: &str) {
buf.extend_from_slice(key.as_bytes());
buf.push(b'\0');
buf.extend_from_slice(value.as_bytes());
buf.push(b'\0');
}
#[cfg(test)]
mod tests {
use super::*;
fn env(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn build_scgi_request_netstring_format() {
let e = env(&[("REQUEST_METHOD", "GET"), ("QUERY_STRING", "")]);
let req = build_scgi_request(&e, b"");
let colon = req.iter().position(|&b| b == b':').unwrap();
let declared_len: usize =
std::str::from_utf8(&req[..colon]).unwrap().parse().unwrap();
let data_end = colon + 1 + declared_len;
assert_eq!(req[data_end], b',', "netstring must end with comma");
assert_eq!(declared_len, data_end - colon - 1);
}
#[test]
fn build_scgi_request_content_length_first() {
let e = env(&[
("REQUEST_METHOD", "POST"),
("CONTENT_LENGTH", "5"), ]);
let req = build_scgi_request(&e, b"hello");
let colon = req.iter().position(|&b| b == b':').unwrap();
let data = &req[colon + 1..];
assert!(
data.starts_with(b"CONTENT_LENGTH\x00"),
"CONTENT_LENGTH must be first in SCGI header block"
);
}
#[test]
fn build_scgi_request_body_appended() {
let e = env(&[("REQUEST_METHOD", "POST")]);
let body = b"name=Alice";
let req = build_scgi_request(&e, body);
assert!(req.ends_with(body));
}
#[test]
fn build_scgi_request_content_length_matches_body() {
let body = b"hello world";
let e = env(&[("REQUEST_METHOD", "POST")]);
let req = build_scgi_request(&e, body);
let colon = req.iter().position(|&b| b == b':').unwrap();
let declared_len: usize =
std::str::from_utf8(&req[..colon]).unwrap().parse().unwrap();
let header_block = &req[colon + 1..colon + 1 + declared_len];
let key_end = header_block.iter().position(|&b| b == 0).unwrap();
let val_end = header_block[key_end + 1..]
.iter()
.position(|&b| b == 0)
.unwrap();
let value = std::str::from_utf8(
&header_block[key_end + 1..key_end + 1 + val_end],
)
.unwrap();
assert_eq!(value, body.len().to_string());
}
#[test]
fn build_scgi_request_empty_body() {
let req = build_scgi_request(&[], b"");
let colon = req.iter().position(|&b| b == b':').unwrap();
let declared_len: usize =
std::str::from_utf8(&req[..colon]).unwrap().parse().unwrap();
let header_block = &req[colon + 1..colon + 1 + declared_len];
assert!(header_block.starts_with(b"CONTENT_LENGTH\0"));
assert_eq!(*req.last().unwrap(), b',');
}
#[test]
fn build_scgi_request_preserves_pair_order() {
let e = env(&[
("AAA", "1"),
("BBB", "2"),
("CCC", "3"),
]);
let req = build_scgi_request(&e, b"");
let s = String::from_utf8_lossy(&req);
let a = s.find("AAA").expect("AAA present");
let b = s.find("BBB").expect("BBB present");
let c = s.find("CCC").expect("CCC present");
assert!(a < b && b < c, "pair order not preserved");
}
}