cgi_service/
lib.rs

1//! A Tower service that implements the CGI protocol (RFC 3875).
2//!
3//! # Example
4//!
5//! ```no_run
6//! use axum::{Router, routing::any_service};
7//! use cgi_service::CgiService;
8//!
9//! let app: Router = Router::new().route(
10//!     "/",
11//!     any_service(CgiService::new("/usr/lib/cgi-bin/script")),
12//! );
13//! ```
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::future::Future;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use bytes::Bytes;
22use http::{Request, Response, StatusCode};
23use http_body::{Body, Frame};
24use http_body_util::{BodyExt, Full, combinators::BoxBody};
25use pin_project_lite::pin_project;
26use thiserror::Error;
27use tokio::io::{AsyncReadExt, AsyncWriteExt};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::mpsc;
30use tower_service::Service;
31
32#[derive(Error, Debug)]
33pub enum CgiError {
34    #[error("Failed to spawn CGI process: {0}")]
35    SpawnError(#[from] std::io::Error),
36
37    #[error("CGI process returned non-zero exit code: {0}")]
38    ProcessError(i32),
39
40    #[error("Failed to parse CGI response: {0}")]
41    ParseError(String),
42
43    #[error("Invalid CGI response header: {0}")]
44    InvalidHeader(String),
45
46    #[error("CGI process was killed by signal")]
47    ProcessKilled,
48
49    #[error("CGI process timed out after {0}ms")]
50    Timeout(u64),
51}
52
53impl CgiError {
54    fn into_response(self) -> Response<BoxBody<Bytes, Infallible>> {
55        Response::builder()
56            .status(StatusCode::INTERNAL_SERVER_ERROR)
57            .body(BoxBody::new(
58                Full::new(Bytes::from(self.to_string())).map_err(|_: Infallible| unreachable!()),
59            ))
60            .unwrap()
61    }
62}
63
64/// Maximum size of CGI response headers before we give up
65const MAX_HEADER_SIZE: usize = 64 * 1024;
66/// Chunk size for reading response body
67const BODY_CHUNK_SIZE: usize = 8 * 1024;
68/// Buffer size for the response body channel
69const CHANNEL_BUFFER_SIZE: usize = 16;
70
71pin_project! {
72    /// A streaming body that reads from a CGI process stdout via a channel.
73    struct CgiStreamBody {
74        rx: mpsc::Receiver<Result<Bytes, std::io::Error>>,
75        done: bool,
76    }
77}
78
79impl CgiStreamBody {
80    fn new(rx: mpsc::Receiver<Result<Bytes, std::io::Error>>) -> Self {
81        Self { rx, done: false }
82    }
83}
84
85impl Body for CgiStreamBody {
86    type Data = Bytes;
87    type Error = Infallible;
88
89    fn poll_frame(
90        self: Pin<&mut Self>,
91        cx: &mut Context<'_>,
92    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
93        let this = self.project();
94
95        if *this.done {
96            return Poll::Ready(None);
97        }
98
99        match this.rx.poll_recv(cx) {
100            Poll::Ready(Some(Ok(bytes))) if !bytes.is_empty() => {
101                Poll::Ready(Some(Ok(Frame::data(bytes))))
102            }
103            Poll::Pending => Poll::Pending,
104            _ => {
105                *this.done = true;
106                Poll::Ready(None)
107            }
108        }
109    }
110
111    fn is_end_stream(&self) -> bool {
112        self.done
113    }
114}
115
116/// Configuration for a CGI service.
117///
118/// Use the builder methods to customize the CGI execution environment.
119#[derive(Debug, Clone)]
120pub struct CgiConfig {
121    command: String,
122    args: Vec<String>,
123    working_dir: Option<String>,
124    extra_env: HashMap<String, String>,
125    server_software: String,
126    server_name: String,
127    server_port: u16,
128    timeout_ms: Option<u64>,
129    inherit_env: Vec<String>,
130    script_name: Option<String>,
131    document_root: Option<String>,
132    pass_through_stderr: bool,
133    https: bool,
134}
135
136impl CgiConfig {
137    /// Creates a new configuration with the given command path.
138    pub fn new(command: impl Into<String>) -> Self {
139        Self {
140            command: command.into(),
141            args: Vec::new(),
142            working_dir: None,
143            extra_env: HashMap::new(),
144            server_software: "rust-cgi-server/0.1".to_string(),
145            server_name: "localhost".to_string(),
146            server_port: 80,
147            timeout_ms: None,
148            inherit_env: Vec::new(),
149            script_name: None,
150            document_root: None,
151            pass_through_stderr: true,
152            https: false,
153        }
154    }
155
156    /// Sets command-line arguments to pass to the CGI script.
157    pub fn args(mut self, args: Vec<String>) -> Self {
158        self.args = args;
159        self
160    }
161
162    /// Sets the working directory for the CGI process.
163    pub fn working_dir(mut self, dir: impl Into<String>) -> Self {
164        self.working_dir = Some(dir.into());
165        self
166    }
167
168    /// Adds an environment variable to pass to the CGI process.
169    pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
170        self.extra_env.insert(key.into(), value.into());
171        self
172    }
173
174    /// Sets the SERVER_SOFTWARE meta-variable.
175    pub fn server_software(mut self, software: impl Into<String>) -> Self {
176        self.server_software = software.into();
177        self
178    }
179
180    /// Sets the SERVER_NAME meta-variable.
181    pub fn server_name(mut self, name: impl Into<String>) -> Self {
182        self.server_name = name.into();
183        self
184    }
185
186    /// Sets the SERVER_PORT meta-variable.
187    pub fn server_port(mut self, port: u16) -> Self {
188        self.server_port = port;
189        self
190    }
191
192    /// Set whether the request was received over HTTPS.
193    /// When true, sets the HTTPS environment variable to "on".
194    pub fn https(mut self, https: bool) -> Self {
195        self.https = https;
196        self
197    }
198
199    /// Sets a timeout for CGI script execution.
200    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
201        self.timeout_ms = Some(timeout.as_millis() as u64);
202        self
203    }
204
205    /// Sets environment variables to inherit from the parent process.
206    pub fn inherit_env(mut self, vars: Vec<String>) -> Self {
207        self.inherit_env = vars;
208        self
209    }
210
211    /// Adds a single environment variable to inherit from the parent process.
212    pub fn inherit(mut self, var: impl Into<String>) -> Self {
213        self.inherit_env.push(var.into());
214        self
215    }
216
217    /// When set, PATH_INFO will be computed as the request path minus this script name.
218    pub fn script_name(mut self, name: impl Into<String>) -> Self {
219        self.script_name = Some(name.into());
220        self
221    }
222
223    /// Set the document root for computing PATH_TRANSLATED.
224    /// When PATH_INFO is non-empty, PATH_TRANSLATED = document_root + PATH_INFO.
225    pub fn document_root(mut self, root: impl Into<String>) -> Self {
226        self.document_root = Some(root.into());
227        self
228    }
229
230    /// Sets whether to pass through CGI process stderr to the parent process.
231    pub fn pass_through_stderr(mut self, pass_through: bool) -> Self {
232        self.pass_through_stderr = pass_through;
233        self
234    }
235}
236
237fn build_cgi_env<B>(
238    request: &Request<B>,
239    config: &CgiConfig,
240    remote_addr: Option<&str>,
241) -> HashMap<String, String> {
242    let mut env = HashMap::new();
243
244    // Required CGI meta-variables (RFC 3875 Section 4.1)
245    env.insert("GATEWAY_INTERFACE".to_string(), "CGI/1.1".to_string());
246    env.insert("REQUEST_METHOD".to_string(), request.method().to_string());
247    env.insert("SERVER_NAME".to_string(), config.server_name.clone());
248    env.insert("SERVER_PORT".to_string(), config.server_port.to_string());
249    env.insert(
250        "SERVER_PROTOCOL".to_string(),
251        format!("{:?}", request.version()),
252    );
253    env.insert(
254        "SERVER_SOFTWARE".to_string(),
255        config.server_software.clone(),
256    );
257
258    let path = request.uri().path();
259    let (script_name, path_info) = if let Some(ref sn) = config.script_name {
260        let path_info = if path.starts_with(sn) {
261            path[sn.len()..].to_string()
262        } else {
263            String::new()
264        };
265        (sn.clone(), path_info)
266    } else {
267        (path.to_string(), String::new())
268    };
269
270    env.insert("SCRIPT_NAME".to_string(), script_name);
271    env.insert("PATH_INFO".to_string(), path_info.clone());
272
273    // PATH_TRANSLATED: physical path corresponding to PATH_INFO (RFC 3875 Section 4.1.6)
274    if !path_info.is_empty()
275        && let Some(ref doc_root) = config.document_root
276    {
277        let path_translated = format!("{}{}", doc_root, path_info);
278        env.insert("PATH_TRANSLATED".to_string(), path_translated);
279    }
280
281    // Query string (empty string if none, per spec)
282    env.insert(
283        "QUERY_STRING".to_string(),
284        request.uri().query().unwrap_or("").to_string(),
285    );
286
287    env.insert(
288        "REQUEST_URI".to_string(),
289        request
290            .uri()
291            .path_and_query()
292            .map(|pq| pq.to_string())
293            .unwrap_or_else(|| path.to_string()),
294    );
295
296    // REMOTE_ADDR is required per RFC 3875 Section 4.1.8
297    // Default to 127.0.0.1 if not explicitly set
298    env.insert(
299        "REMOTE_ADDR".to_string(),
300        remote_addr.unwrap_or("127.0.0.1").to_string(),
301    );
302
303    // HTTPS meta-variable indicates secure connection
304    if config.https {
305        env.insert("HTTPS".to_string(), "on".to_string());
306    }
307
308    if let Some(content_type) = request.headers().get(http::header::CONTENT_TYPE)
309        && let Ok(ct) = content_type.to_str()
310    {
311        env.insert("CONTENT_TYPE".to_string(), ct.to_string());
312    }
313
314    if let Some(content_length) = request.headers().get(http::header::CONTENT_LENGTH)
315        && let Ok(cl) = content_length.to_str()
316    {
317        env.insert("CONTENT_LENGTH".to_string(), cl.to_string());
318    }
319
320    for var_name in &config.inherit_env {
321        if let Ok(value) = std::env::var(var_name) {
322            env.insert(var_name.clone(), value);
323        }
324    }
325
326    // Convert HTTP headers to HTTP_* environment variables
327    // RFC 3875 Section 4.1.18: Exclude certain headers for security/protocol reasons
328    for (name, value) in request.headers() {
329        // Skip Content-Type and Content-Length (handled separately as CGI vars)
330        if name == http::header::CONTENT_TYPE || name == http::header::CONTENT_LENGTH {
331            continue;
332        }
333
334        // Skip Authorization headers for security (RFC 3875 recommendation)
335        if name == http::header::AUTHORIZATION || name == http::header::PROXY_AUTHORIZATION {
336            continue;
337        }
338
339        // Skip hop-by-hop headers that shouldn't be passed to CGI
340        if name == http::header::CONNECTION
341            || name.as_str().eq_ignore_ascii_case("keep-alive")
342            || name == http::header::TRANSFER_ENCODING
343            || name == http::header::TE
344            || name == http::header::TRAILER
345            || name == http::header::UPGRADE
346        {
347            continue;
348        }
349
350        let env_name = format!("HTTP_{}", name.as_str().to_uppercase().replace('-', "_"));
351
352        if let Ok(v) = value.to_str() {
353            env.insert(env_name, v.to_string());
354        }
355    }
356
357    for (key, value) in &config.extra_env {
358        env.insert(key.clone(), value.clone());
359    }
360
361    env
362}
363
364fn find_header_body_separator(data: &[u8]) -> Option<usize> {
365    for i in 0..data.len() {
366        if i + 3 < data.len() && &data[i..i + 4] == b"\r\n\r\n" {
367            return Some(i);
368        }
369        if i + 1 < data.len() && &data[i..i + 2] == b"\n\n" {
370            return Some(i);
371        }
372    }
373    None
374}
375
376/// Detect the length of the header/body separator at the given position
377fn separator_length_at(data: &[u8], pos: usize) -> usize {
378    if data[pos..].starts_with(b"\r\n\r\n") {
379        4
380    } else {
381        2 // \n\n or fallback
382    }
383}
384
385/// Parsed CGI headers
386struct ParsedHeaders {
387    status: StatusCode,
388    headers: http::HeaderMap,
389}
390
391/// Parse CGI headers from raw bytes (without body)
392fn parse_cgi_headers(header_bytes: &[u8]) -> Result<ParsedHeaders, CgiError> {
393    let header_str = std::str::from_utf8(header_bytes)
394        .map_err(|e| CgiError::ParseError(format!("Invalid UTF-8 in headers: {}", e)))?;
395
396    let mut status = StatusCode::OK;
397    let mut headers = http::HeaderMap::new();
398
399    for line in header_str.lines() {
400        let line = line.trim();
401        if line.is_empty() {
402            continue;
403        }
404
405        if let Some((name, value)) = line.split_once(':') {
406            let name = name.trim();
407            let value = value.trim();
408
409            if name.eq_ignore_ascii_case("Status") {
410                let code_str = value.split_once(' ').map_or(value, |(code, _)| code);
411                status = code_str
412                    .parse::<u16>()
413                    .map_err(|_| {
414                        CgiError::InvalidHeader(format!("Invalid status code: {}", code_str))
415                    })
416                    .and_then(|code| {
417                        StatusCode::from_u16(code).map_err(|_| {
418                            CgiError::InvalidHeader(format!("Invalid status code: {}", code))
419                        })
420                    })?;
421            } else {
422                let header_name = http::header::HeaderName::from_bytes(name.as_bytes())
423                    .map_err(|e| CgiError::InvalidHeader(format!("Invalid header name: {}", e)))?;
424                let header_value = http::header::HeaderValue::from_str(value)
425                    .map_err(|e| CgiError::InvalidHeader(format!("Invalid header value: {}", e)))?;
426                headers.append(header_name, header_value);
427            }
428        }
429    }
430
431    Ok(ParsedHeaders { status, headers })
432}
433
434/// Write request body chunks to CGI process stdin
435async fn write_request_body<B>(mut stdin: ChildStdin, body: B, timeout: Option<std::time::Duration>)
436where
437    B: Body + Send,
438    B::Data: Send,
439    B::Error: std::fmt::Debug,
440{
441    use bytes::Buf;
442
443    let write_future = async {
444        let mut body = std::pin::pin!(body);
445        loop {
446            match std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
447                Some(Ok(frame)) => {
448                    if let Ok(data) = frame.into_data()
449                        && let Err(e) = stdin.write_all(data.chunk()).await
450                    {
451                        // Process may have closed stdin early - this is OK
452                        // Some CGI scripts don't read the full body
453                        if e.kind() != std::io::ErrorKind::BrokenPipe {
454                            eprintln!("CGI stdin write error: {}", e);
455                        }
456                        break;
457                    }
458                }
459                Some(Err(e)) => {
460                    eprintln!("Error reading request body: {:?}", e);
461                    break;
462                }
463                None => break,
464            }
465        }
466        // Drop stdin to signal EOF to the CGI process
467        drop(stdin);
468    };
469
470    if let Some(timeout_duration) = timeout {
471        if tokio::time::timeout(timeout_duration, write_future)
472            .await
473            .is_err()
474        {
475            eprintln!("Timeout writing request body to CGI stdin");
476        }
477    } else {
478        write_future.await;
479    }
480}
481
482/// Read stdout until headers are complete, then return parsed headers and any body bytes read.
483/// Also takes the child process to check exit status if no output is received.
484async fn read_and_parse_headers(
485    stdout: &mut ChildStdout,
486    child: &mut Child,
487    timeout: Option<std::time::Duration>,
488) -> Result<(ParsedHeaders, Vec<u8>), CgiError> {
489    let read_future = async {
490        let mut header_buf = Vec::with_capacity(4096);
491        let mut temp_buf = [0u8; 1024];
492
493        loop {
494            let n = stdout
495                .read(&mut temp_buf)
496                .await
497                .map_err(CgiError::SpawnError)?;
498
499            if n == 0 {
500                // EOF before headers complete
501                // If no data was received at all, check if the process failed
502                if header_buf.is_empty() {
503                    // Try to get process exit status
504                    if let Ok(Some(status)) = child.try_wait()
505                        && !status.success()
506                    {
507                        return Err(status
508                            .code()
509                            .map_or(CgiError::ProcessKilled, CgiError::ProcessError));
510                    }
511                    // Process exited successfully but produced no output
512                    return Err(CgiError::ParseError(
513                        "CGI script produced no output".to_string(),
514                    ));
515                }
516                // Some data was received but no header separator - treat as headers only
517                let parsed = parse_cgi_headers(&header_buf)?;
518                return Ok((parsed, Vec::new()));
519            }
520
521            header_buf.extend_from_slice(&temp_buf[..n]);
522
523            if let Some(pos) = find_header_body_separator(&header_buf) {
524                let sep_len = separator_length_at(&header_buf, pos);
525                let body_start = pos + sep_len;
526
527                let parsed = parse_cgi_headers(&header_buf[..pos])?;
528                let initial_body = header_buf[body_start..].to_vec();
529
530                return Ok((parsed, initial_body));
531            }
532
533            // Prevent unbounded header growth
534            if header_buf.len() > MAX_HEADER_SIZE {
535                return Err(CgiError::ParseError(format!(
536                    "Headers exceed maximum size of {} bytes",
537                    MAX_HEADER_SIZE
538                )));
539            }
540        }
541    };
542
543    if let Some(timeout_duration) = timeout {
544        tokio::time::timeout(timeout_duration, read_future)
545            .await
546            .map_err(|_| CgiError::Timeout(timeout_duration.as_millis() as u64))?
547    } else {
548        read_future.await
549    }
550}
551
552/// Stream stdout to the response body channel
553async fn stream_stdout(
554    mut stdout: ChildStdout,
555    tx: mpsc::Sender<Result<Bytes, std::io::Error>>,
556    initial_body: Vec<u8>,
557    mut child: Child,
558) {
559    if !initial_body.is_empty() && tx.send(Ok(Bytes::from(initial_body))).await.is_err() {
560        return;
561    }
562
563    let mut buf = vec![0u8; BODY_CHUNK_SIZE];
564    loop {
565        match stdout.read(&mut buf).await {
566            Ok(0) => break,
567            Ok(n) => {
568                let chunk = Bytes::copy_from_slice(&buf[..n]);
569                if tx.send(Ok(chunk)).await.is_err() {
570                    break;
571                }
572            }
573            Err(e) => {
574                let _ = tx.send(Err(e)).await;
575                break;
576            }
577        }
578    }
579
580    drop(stdout);
581    drop(tx);
582
583    match child.wait().await {
584        Ok(status) if !status.success() => match status.code() {
585            Some(code) => eprintln!("CGI process exited with code: {code}"),
586            None => eprintln!("CGI process was killed by signal"),
587        },
588        Err(e) => eprintln!("Error waiting for CGI process: {e}"),
589        _ => {}
590    }
591}
592
593/// Execute CGI script with streaming request and response bodies
594async fn execute_cgi<B>(
595    config: &CgiConfig,
596    env: HashMap<String, String>,
597    body: B,
598) -> Result<Response<BoxBody<Bytes, Infallible>>, CgiError>
599where
600    B: Body + Send + 'static,
601    B::Data: Send,
602    B::Error: std::fmt::Debug + Send,
603{
604    use std::time::Duration;
605
606    let mut cmd = Command::new(&config.command);
607
608    if !config.args.is_empty() {
609        cmd.args(&config.args);
610    }
611
612    if let Some(ref dir) = config.working_dir {
613        cmd.current_dir(dir);
614    }
615
616    cmd.env_clear();
617    for (key, value) in &env {
618        cmd.env(key, value);
619    }
620
621    cmd.stdin(std::process::Stdio::piped());
622    cmd.stdout(std::process::Stdio::piped());
623    cmd.stderr(std::process::Stdio::inherit()); // Pass through stderr directly
624
625    let mut child = cmd.spawn()?;
626    let timeout = config.timeout_ms.map(Duration::from_millis);
627
628    let stdin = child.stdin.take().unwrap();
629    tokio::spawn(write_request_body(stdin, body, timeout));
630
631    let mut stdout = child.stdout.take().unwrap();
632
633    // Pass child reference to check exit status if no output is received
634    let (parsed_headers, initial_body) =
635        read_and_parse_headers(&mut stdout, &mut child, timeout).await?;
636
637    let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
638    tokio::spawn(stream_stdout(stdout, tx, initial_body, child));
639
640    let mut response_builder = Response::builder().status(parsed_headers.status);
641    for (name, value) in parsed_headers.headers.iter() {
642        response_builder = response_builder.header(name, value);
643    }
644
645    response_builder
646        .body(BoxBody::new(CgiStreamBody::new(rx)))
647        .map_err(|e| CgiError::ParseError(format!("Failed to build response: {}", e)))
648}
649
650/// A Tower service that executes CGI scripts.
651///
652/// This service implements [RFC 3875](https://www.rfc-editor.org/rfc/rfc3875) and can be used
653/// with any Tower-compatible HTTP server such as Axum or Hyper.
654#[derive(Clone)]
655pub struct CgiService {
656    config: CgiConfig,
657    remote_addr: Option<String>,
658}
659
660impl CgiService {
661    /// Creates a new CGI service with the given command path.
662    pub fn new(command: impl Into<String>) -> Self {
663        Self {
664            config: CgiConfig::new(command),
665            remote_addr: None,
666        }
667    }
668
669    /// Creates a new CGI service with the given configuration.
670    pub fn with_config(config: CgiConfig) -> Self {
671        Self {
672            config,
673            remote_addr: None,
674        }
675    }
676
677    /// Sets the REMOTE_ADDR meta-variable for requests handled by this service.
678    pub fn remote_addr(mut self, addr: impl Into<String>) -> Self {
679        self.remote_addr = Some(addr.into());
680        self
681    }
682}
683
684impl<B> Service<Request<B>> for CgiService
685where
686    B: Body + Send + 'static,
687    B::Data: Send,
688    B::Error: std::fmt::Debug + Send,
689{
690    type Response = Response<BoxBody<Bytes, Infallible>>;
691    type Error = Infallible;
692    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
693
694    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
695        Poll::Ready(Ok(()))
696    }
697
698    fn call(&mut self, request: Request<B>) -> Self::Future {
699        let config = self.config.clone();
700        let remote_addr = self.remote_addr.clone();
701
702        Box::pin(async move {
703            let env = build_cgi_env(&request, &config, remote_addr.as_deref());
704            let (_parts, body) = request.into_parts();
705            Ok(execute_cgi(&config, env, body)
706                .await
707                .unwrap_or_else(|e| e.into_response()))
708        })
709    }
710}