cgi_service/
lib.rs

1//! A Tower service that implements the CGI protocol (RFC 3875).
2//!
3//! # Example
4//!
5//! ```no_run
6//! use axum::{Router, routing::any_service};
7//! use cgi_service::CgiService;
8//!
9//! let app: Router = Router::new().route(
10//!     "/",
11//!     any_service(CgiService::new("/usr/lib/cgi-bin/script")),
12//! );
13//! ```
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::future::Future;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use bytes::Bytes;
22use http::{Request, Response, StatusCode};
23use http_body::{Body, Frame};
24use http_body_util::{BodyExt, Full, combinators::BoxBody};
25use pin_project_lite::pin_project;
26use thiserror::Error;
27use tokio::io::{AsyncReadExt, AsyncWriteExt};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::mpsc;
30use tower_service::Service;
31
32#[derive(Error, Debug)]
33pub enum CgiError {
34    #[error("Failed to spawn CGI process: {0}")]
35    SpawnError(#[from] std::io::Error),
36
37    #[error("CGI process returned non-zero exit code: {0}")]
38    ProcessError(i32),
39
40    #[error("Failed to parse CGI response: {0}")]
41    ParseError(String),
42
43    #[error("Invalid CGI response header: {0}")]
44    InvalidHeader(String),
45
46    #[error("CGI process was killed by signal")]
47    ProcessKilled,
48
49    #[error("CGI process timed out after {0}ms")]
50    Timeout(u64),
51}
52
53impl CgiError {
54    fn into_response(self) -> Response<BoxBody<Bytes, Infallible>> {
55        Response::builder()
56            .status(StatusCode::INTERNAL_SERVER_ERROR)
57            .body(BoxBody::new(
58                Full::new(Bytes::from(self.to_string())).map_err(|_: Infallible| unreachable!()),
59            ))
60            .unwrap()
61    }
62}
63
64/// Maximum size of CGI response headers before we give up
65const MAX_HEADER_SIZE: usize = 64 * 1024;
66/// Chunk size for reading response body
67const BODY_CHUNK_SIZE: usize = 8 * 1024;
68/// Buffer size for the response body channel
69const CHANNEL_BUFFER_SIZE: usize = 16;
70
71pin_project! {
72    /// A streaming body that reads from a CGI process stdout via a channel.
73    pub struct CgiStreamBody {
74        rx: mpsc::Receiver<Result<Bytes, std::io::Error>>,
75        done: bool,
76    }
77}
78
79impl CgiStreamBody {
80    fn new(rx: mpsc::Receiver<Result<Bytes, std::io::Error>>) -> Self {
81        Self { rx, done: false }
82    }
83}
84
85impl Body for CgiStreamBody {
86    type Data = Bytes;
87    type Error = Infallible;
88
89    fn poll_frame(
90        self: Pin<&mut Self>,
91        cx: &mut Context<'_>,
92    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
93        let this = self.project();
94
95        if *this.done {
96            return Poll::Ready(None);
97        }
98
99        match this.rx.poll_recv(cx) {
100            Poll::Ready(Some(Ok(bytes))) => {
101                if bytes.is_empty() {
102                    *this.done = true;
103                    Poll::Ready(None)
104                } else {
105                    Poll::Ready(Some(Ok(Frame::data(bytes))))
106                }
107            }
108            Poll::Ready(Some(Err(_))) => {
109                *this.done = true;
110                Poll::Ready(None)
111            }
112            Poll::Ready(None) => {
113                *this.done = true;
114                Poll::Ready(None)
115            }
116            Poll::Pending => Poll::Pending,
117        }
118    }
119
120    fn is_end_stream(&self) -> bool {
121        self.done
122    }
123}
124
125#[derive(Debug, Clone)]
126pub struct CgiConfig {
127    command: String,
128    args: Vec<String>,
129    working_dir: Option<String>,
130    extra_env: HashMap<String, String>,
131    server_software: String,
132    server_name: String,
133    server_port: u16,
134    timeout_ms: Option<u64>,
135    inherit_env: Vec<String>,
136    script_name: Option<String>,
137    document_root: Option<String>,
138    pass_through_stderr: bool,
139    https: bool,
140}
141
142impl CgiConfig {
143    pub fn new(command: impl Into<String>) -> Self {
144        Self {
145            command: command.into(),
146            args: Vec::new(),
147            working_dir: None,
148            extra_env: HashMap::new(),
149            server_software: "rust-cgi-server/0.1".to_string(),
150            server_name: "localhost".to_string(),
151            server_port: 80,
152            timeout_ms: None,
153            inherit_env: Vec::new(),
154            script_name: None,
155            document_root: None,
156            pass_through_stderr: true,
157            https: false,
158        }
159    }
160
161    pub fn args(mut self, args: Vec<String>) -> Self {
162        self.args = args;
163        self
164    }
165
166    pub fn working_dir(mut self, dir: impl Into<String>) -> Self {
167        self.working_dir = Some(dir.into());
168        self
169    }
170
171    pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
172        self.extra_env.insert(key.into(), value.into());
173        self
174    }
175
176    pub fn server_software(mut self, software: impl Into<String>) -> Self {
177        self.server_software = software.into();
178        self
179    }
180
181    pub fn server_name(mut self, name: impl Into<String>) -> Self {
182        self.server_name = name.into();
183        self
184    }
185
186    pub fn server_port(mut self, port: u16) -> Self {
187        self.server_port = port;
188        self
189    }
190
191    /// Set whether the request was received over HTTPS.
192    /// When true, sets the HTTPS environment variable to "on".
193    pub fn https(mut self, https: bool) -> Self {
194        self.https = https;
195        self
196    }
197
198    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
199        self.timeout_ms = Some(timeout.as_millis() as u64);
200        self
201    }
202
203    pub fn inherit_env(mut self, vars: Vec<String>) -> Self {
204        self.inherit_env = vars;
205        self
206    }
207
208    pub fn inherit(mut self, var: impl Into<String>) -> Self {
209        self.inherit_env.push(var.into());
210        self
211    }
212
213    /// When set, PATH_INFO will be computed as the request path minus this script name.
214    pub fn script_name(mut self, name: impl Into<String>) -> Self {
215        self.script_name = Some(name.into());
216        self
217    }
218
219    /// Set the document root for computing PATH_TRANSLATED.
220    /// When PATH_INFO is non-empty, PATH_TRANSLATED = document_root + PATH_INFO.
221    pub fn document_root(mut self, root: impl Into<String>) -> Self {
222        self.document_root = Some(root.into());
223        self
224    }
225
226    pub fn pass_through_stderr(mut self, pass_through: bool) -> Self {
227        self.pass_through_stderr = pass_through;
228        self
229    }
230}
231
232fn build_cgi_env<B>(
233    request: &Request<B>,
234    config: &CgiConfig,
235    remote_addr: Option<&str>,
236) -> HashMap<String, String> {
237    let mut env = HashMap::new();
238
239    // Required CGI meta-variables (RFC 3875 Section 4.1)
240    env.insert("GATEWAY_INTERFACE".to_string(), "CGI/1.1".to_string());
241    env.insert("REQUEST_METHOD".to_string(), request.method().to_string());
242    env.insert("SERVER_NAME".to_string(), config.server_name.clone());
243    env.insert("SERVER_PORT".to_string(), config.server_port.to_string());
244    env.insert(
245        "SERVER_PROTOCOL".to_string(),
246        format!("{:?}", request.version()),
247    );
248    env.insert(
249        "SERVER_SOFTWARE".to_string(),
250        config.server_software.clone(),
251    );
252
253    let path = request.uri().path();
254    let (script_name, path_info) = if let Some(ref sn) = config.script_name {
255        let path_info = if path.starts_with(sn) {
256            path[sn.len()..].to_string()
257        } else {
258            String::new()
259        };
260        (sn.clone(), path_info)
261    } else {
262        (path.to_string(), String::new())
263    };
264
265    env.insert("SCRIPT_NAME".to_string(), script_name);
266    env.insert("PATH_INFO".to_string(), path_info.clone());
267
268    // PATH_TRANSLATED: physical path corresponding to PATH_INFO (RFC 3875 Section 4.1.6)
269    if !path_info.is_empty() {
270        if let Some(ref doc_root) = config.document_root {
271            let path_translated = format!("{}{}", doc_root, path_info);
272            env.insert("PATH_TRANSLATED".to_string(), path_translated);
273        }
274    }
275
276    // Query string (empty string if none, per spec)
277    env.insert(
278        "QUERY_STRING".to_string(),
279        request.uri().query().unwrap_or("").to_string(),
280    );
281
282    env.insert(
283        "REQUEST_URI".to_string(),
284        request
285            .uri()
286            .path_and_query()
287            .map(|pq| pq.to_string())
288            .unwrap_or_else(|| path.to_string()),
289    );
290
291    // REMOTE_ADDR is required per RFC 3875 Section 4.1.8
292    // Default to 127.0.0.1 if not explicitly set
293    env.insert(
294        "REMOTE_ADDR".to_string(),
295        remote_addr.unwrap_or("127.0.0.1").to_string(),
296    );
297
298    // HTTPS meta-variable indicates secure connection
299    if config.https {
300        env.insert("HTTPS".to_string(), "on".to_string());
301    }
302
303    if let Some(content_type) = request.headers().get(http::header::CONTENT_TYPE) {
304        if let Ok(ct) = content_type.to_str() {
305            env.insert("CONTENT_TYPE".to_string(), ct.to_string());
306        }
307    }
308
309    if let Some(content_length) = request.headers().get(http::header::CONTENT_LENGTH) {
310        if let Ok(cl) = content_length.to_str() {
311            env.insert("CONTENT_LENGTH".to_string(), cl.to_string());
312        }
313    }
314
315    for var_name in &config.inherit_env {
316        if let Ok(value) = std::env::var(var_name) {
317            env.insert(var_name.clone(), value);
318        }
319    }
320
321    // Convert HTTP headers to HTTP_* environment variables
322    // RFC 3875 Section 4.1.18: Exclude certain headers for security/protocol reasons
323    for (name, value) in request.headers() {
324        // Skip Content-Type and Content-Length (handled separately as CGI vars)
325        if name == http::header::CONTENT_TYPE || name == http::header::CONTENT_LENGTH {
326            continue;
327        }
328
329        // Skip Authorization headers for security (RFC 3875 recommendation)
330        if name == http::header::AUTHORIZATION || name == http::header::PROXY_AUTHORIZATION {
331            continue;
332        }
333
334        // Skip hop-by-hop headers that shouldn't be passed to CGI
335        if name == http::header::CONNECTION
336            || name.as_str().eq_ignore_ascii_case("keep-alive")
337            || name == http::header::TRANSFER_ENCODING
338            || name == http::header::TE
339            || name == http::header::TRAILER
340            || name == http::header::UPGRADE
341        {
342            continue;
343        }
344
345        let env_name = format!("HTTP_{}", name.as_str().to_uppercase().replace('-', "_"));
346
347        if let Ok(v) = value.to_str() {
348            env.insert(env_name, v.to_string());
349        }
350    }
351
352    for (key, value) in &config.extra_env {
353        env.insert(key.clone(), value.clone());
354    }
355
356    env
357}
358
359fn find_header_body_separator(data: &[u8]) -> Option<usize> {
360    for i in 0..data.len() {
361        if i + 3 < data.len() && &data[i..i + 4] == b"\r\n\r\n" {
362            return Some(i);
363        }
364        if i + 1 < data.len() && &data[i..i + 2] == b"\n\n" {
365            return Some(i);
366        }
367    }
368    None
369}
370
371/// Detect the length of the header/body separator at the given position
372fn separator_length_at(data: &[u8], pos: usize) -> usize {
373    if data[pos..].starts_with(b"\r\n\r\n") {
374        4
375    } else if data[pos..].starts_with(b"\n\n") {
376        2
377    } else {
378        2 // fallback
379    }
380}
381
382/// Parsed CGI headers
383struct ParsedHeaders {
384    status: StatusCode,
385    headers: http::HeaderMap,
386}
387
388/// Parse CGI headers from raw bytes (without body)
389fn parse_cgi_headers(header_bytes: &[u8]) -> Result<ParsedHeaders, CgiError> {
390    let header_str = std::str::from_utf8(header_bytes)
391        .map_err(|e| CgiError::ParseError(format!("Invalid UTF-8 in headers: {}", e)))?;
392
393    let mut status = StatusCode::OK;
394    let mut headers = http::HeaderMap::new();
395
396    for line in header_str.lines() {
397        let line = line.trim();
398        if line.is_empty() {
399            continue;
400        }
401
402        if let Some((name, value)) = line.split_once(':') {
403            let name = name.trim();
404            let value = value.trim();
405
406            if name.eq_ignore_ascii_case("Status") {
407                let status_parts: Vec<&str> = value.splitn(2, ' ').collect();
408                if let Some(code_str) = status_parts.first() {
409                    status = code_str
410                        .parse::<u16>()
411                        .map_err(|_| {
412                            CgiError::InvalidHeader(format!("Invalid status code: {}", code_str))
413                        })
414                        .and_then(|code| {
415                            StatusCode::from_u16(code).map_err(|_| {
416                                CgiError::InvalidHeader(format!("Invalid status code: {}", code))
417                            })
418                        })?;
419                }
420            } else {
421                let header_name = http::header::HeaderName::from_bytes(name.as_bytes())
422                    .map_err(|e| CgiError::InvalidHeader(format!("Invalid header name: {}", e)))?;
423                let header_value = http::header::HeaderValue::from_str(value)
424                    .map_err(|e| CgiError::InvalidHeader(format!("Invalid header value: {}", e)))?;
425                headers.append(header_name, header_value);
426            }
427        }
428    }
429
430    Ok(ParsedHeaders { status, headers })
431}
432
433/// Write request body chunks to CGI process stdin
434async fn write_request_body<B>(mut stdin: ChildStdin, body: B, timeout: Option<std::time::Duration>)
435where
436    B: Body + Send,
437    B::Data: Send,
438    B::Error: std::fmt::Debug,
439{
440    use bytes::Buf;
441
442    let write_future = async {
443        let mut body = std::pin::pin!(body);
444        loop {
445            match std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
446                Some(Ok(frame)) => {
447                    if let Ok(data) = frame.into_data() {
448                        if let Err(e) = stdin.write_all(data.chunk()).await {
449                            // Process may have closed stdin early - this is OK
450                            // Some CGI scripts don't read the full body
451                            if e.kind() != std::io::ErrorKind::BrokenPipe {
452                                eprintln!("CGI stdin write error: {}", e);
453                            }
454                            break;
455                        }
456                    }
457                }
458                Some(Err(e)) => {
459                    eprintln!("Error reading request body: {:?}", e);
460                    break;
461                }
462                None => break, // End of body
463            }
464        }
465        // Drop stdin to signal EOF to the CGI process
466        drop(stdin);
467    };
468
469    if let Some(timeout_duration) = timeout {
470        if tokio::time::timeout(timeout_duration, write_future)
471            .await
472            .is_err()
473        {
474            eprintln!("Timeout writing request body to CGI stdin");
475        }
476    } else {
477        write_future.await;
478    }
479}
480
481/// Read stdout until headers are complete, then return parsed headers and any body bytes read.
482/// Also takes the child process to check exit status if no output is received.
483async fn read_and_parse_headers(
484    stdout: &mut ChildStdout,
485    child: &mut Child,
486    timeout: Option<std::time::Duration>,
487) -> Result<(ParsedHeaders, Vec<u8>), CgiError> {
488    let read_future = async {
489        let mut header_buf = Vec::with_capacity(4096);
490        let mut temp_buf = [0u8; 1024];
491
492        loop {
493            let n = stdout
494                .read(&mut temp_buf)
495                .await
496                .map_err(CgiError::SpawnError)?;
497
498            if n == 0 {
499                // EOF before headers complete
500                // If no data was received at all, check if the process failed
501                if header_buf.is_empty() {
502                    // Try to get process exit status
503                    if let Ok(Some(status)) = child.try_wait() {
504                        if !status.success() {
505                            if let Some(code) = status.code() {
506                                return Err(CgiError::ProcessError(code));
507                            } else {
508                                return Err(CgiError::ProcessKilled);
509                            }
510                        }
511                    }
512                    // Process exited successfully but produced no output
513                    return Err(CgiError::ParseError(
514                        "CGI script produced no output".to_string(),
515                    ));
516                }
517                // Some data was received but no header separator - treat as headers only
518                let parsed = parse_cgi_headers(&header_buf)?;
519                return Ok((parsed, Vec::new()));
520            }
521
522            header_buf.extend_from_slice(&temp_buf[..n]);
523
524            // Check for header/body separator
525            if let Some(pos) = find_header_body_separator(&header_buf) {
526                let sep_len = separator_length_at(&header_buf, pos);
527                let body_start = pos + sep_len;
528
529                let parsed = parse_cgi_headers(&header_buf[..pos])?;
530                let initial_body = header_buf[body_start..].to_vec();
531
532                return Ok((parsed, initial_body));
533            }
534
535            // Prevent unbounded header growth
536            if header_buf.len() > MAX_HEADER_SIZE {
537                return Err(CgiError::ParseError(format!(
538                    "Headers exceed maximum size of {} bytes",
539                    MAX_HEADER_SIZE
540                )));
541            }
542        }
543    };
544
545    if let Some(timeout_duration) = timeout {
546        tokio::time::timeout(timeout_duration, read_future)
547            .await
548            .map_err(|_| CgiError::Timeout(timeout_duration.as_millis() as u64))?
549    } else {
550        read_future.await
551    }
552}
553
554/// Stream stdout to the response body channel
555async fn stream_stdout(
556    mut stdout: ChildStdout,
557    tx: mpsc::Sender<Result<Bytes, std::io::Error>>,
558    initial_body: Vec<u8>,
559    mut child: Child,
560    pass_through_stderr: bool,
561) {
562    // First, send any body data that was read during header parsing
563    if !initial_body.is_empty() {
564        if tx.send(Ok(Bytes::from(initial_body))).await.is_err() {
565            return; // Receiver dropped, stop streaming
566        }
567    }
568
569    // Continue reading stdout in chunks
570    let mut buf = vec![0u8; BODY_CHUNK_SIZE];
571    loop {
572        match stdout.read(&mut buf).await {
573            Ok(0) => break, // EOF
574            Ok(n) => {
575                let chunk = Bytes::copy_from_slice(&buf[..n]);
576                if tx.send(Ok(chunk)).await.is_err() {
577                    break; // Receiver dropped
578                }
579            }
580            Err(e) => {
581                let _ = tx.send(Err(e)).await;
582                break;
583            }
584        }
585    }
586
587    // Close our stdout handle before waiting
588    drop(stdout);
589    // Drop the sender to signal end of stream
590    drop(tx);
591
592    // Wait for process to exit and handle stderr
593    match child.wait().await {
594        Ok(status) => {
595            if !status.success() {
596                if let Some(code) = status.code() {
597                    eprintln!("CGI process exited with code: {}", code);
598                } else {
599                    eprintln!("CGI process was killed by signal");
600                }
601            }
602        }
603        Err(e) => {
604            eprintln!("Error waiting for CGI process: {}", e);
605        }
606    }
607
608    // Note: stderr is handled separately via pass_through_stderr in config
609    // In streaming mode, we can't easily capture stderr without risking deadlocks
610    let _ = pass_through_stderr;
611}
612
613/// Execute CGI script with streaming request and response bodies
614async fn execute_cgi<B>(
615    config: &CgiConfig,
616    env: HashMap<String, String>,
617    body: B,
618) -> Result<Response<BoxBody<Bytes, Infallible>>, CgiError>
619where
620    B: Body + Send + 'static,
621    B::Data: Send,
622    B::Error: std::fmt::Debug + Send,
623{
624    use std::time::Duration;
625
626    let mut cmd = Command::new(&config.command);
627
628    if !config.args.is_empty() {
629        cmd.args(&config.args);
630    }
631
632    if let Some(ref dir) = config.working_dir {
633        cmd.current_dir(dir);
634    }
635
636    cmd.env_clear();
637    for (key, value) in &env {
638        cmd.env(key, value);
639    }
640
641    cmd.stdin(std::process::Stdio::piped());
642    cmd.stdout(std::process::Stdio::piped());
643    cmd.stderr(std::process::Stdio::inherit()); // Pass through stderr directly
644
645    let mut child = cmd.spawn()?;
646
647    // Get timeout duration
648    let timeout = config.timeout_ms.map(Duration::from_millis);
649
650    // Take stdin and spawn writer task to stream request body
651    let stdin = child.stdin.take().expect("stdin was configured as piped");
652    tokio::spawn(write_request_body(stdin, body, timeout));
653
654    // Take stdout for reading
655    let mut stdout = child.stdout.take().expect("stdout was configured as piped");
656
657    // Read and parse headers (with timeout)
658    // Pass child reference to check exit status if no output is received
659    let (parsed_headers, initial_body) =
660        read_and_parse_headers(&mut stdout, &mut child, timeout).await?;
661
662    // Create channel for streaming body
663    let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
664
665    // Spawn stdout reader task
666    let pass_through_stderr = config.pass_through_stderr;
667    tokio::spawn(stream_stdout(
668        stdout,
669        tx,
670        initial_body,
671        child,
672        pass_through_stderr,
673    ));
674
675    // Build response with streaming body
676    let mut response_builder = Response::builder().status(parsed_headers.status);
677    for (name, value) in parsed_headers.headers.iter() {
678        response_builder = response_builder.header(name, value);
679    }
680
681    response_builder
682        .body(BoxBody::new(CgiStreamBody::new(rx)))
683        .map_err(|e| CgiError::ParseError(format!("Failed to build response: {}", e)))
684}
685
686#[derive(Clone)]
687pub struct CgiService {
688    config: CgiConfig,
689    remote_addr: Option<String>,
690}
691
692impl CgiService {
693    pub fn new(command: impl Into<String>) -> Self {
694        Self {
695            config: CgiConfig::new(command),
696            remote_addr: None,
697        }
698    }
699
700    pub fn with_config(config: CgiConfig) -> Self {
701        Self {
702            config,
703            remote_addr: None,
704        }
705    }
706
707    pub fn remote_addr(mut self, addr: impl Into<String>) -> Self {
708        self.remote_addr = Some(addr.into());
709        self
710    }
711}
712
713impl<B> Service<Request<B>> for CgiService
714where
715    B: Body + Send + 'static,
716    B::Data: Send,
717    B::Error: std::fmt::Debug + Send,
718{
719    type Response = Response<BoxBody<Bytes, Infallible>>;
720    type Error = Infallible;
721    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
722
723    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
724        Poll::Ready(Ok(()))
725    }
726
727    fn call(&mut self, request: Request<B>) -> Self::Future {
728        let config = self.config.clone();
729        let remote_addr = self.remote_addr.clone();
730
731        Box::pin(async move {
732            // Build environment from request headers (without consuming body)
733            let env = build_cgi_env(&request, &config, remote_addr.as_deref());
734
735            // Extract body for streaming to CGI process
736            let (_parts, body) = request.into_parts();
737
738            // Execute CGI with streaming
739            Ok(execute_cgi(&config, env, body)
740                .await
741                .unwrap_or_else(|e| e.into_response()))
742        })
743    }
744}