1use std::collections::HashMap;
16use std::convert::Infallible;
17use std::future::Future;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use bytes::Bytes;
22use http::{Request, Response, StatusCode};
23use http_body::{Body, Frame};
24use http_body_util::{BodyExt, Full, combinators::BoxBody};
25use pin_project_lite::pin_project;
26use thiserror::Error;
27use tokio::io::{AsyncReadExt, AsyncWriteExt};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::mpsc;
30use tower_service::Service;
31
32#[derive(Error, Debug)]
33pub enum CgiError {
34 #[error("Failed to spawn CGI process: {0}")]
35 SpawnError(#[from] std::io::Error),
36
37 #[error("CGI process returned non-zero exit code: {0}")]
38 ProcessError(i32),
39
40 #[error("Failed to parse CGI response: {0}")]
41 ParseError(String),
42
43 #[error("Invalid CGI response header: {0}")]
44 InvalidHeader(String),
45
46 #[error("CGI process was killed by signal")]
47 ProcessKilled,
48
49 #[error("CGI process timed out after {0}ms")]
50 Timeout(u64),
51}
52
53impl CgiError {
54 fn into_response(self) -> Response<BoxBody<Bytes, Infallible>> {
55 Response::builder()
56 .status(StatusCode::INTERNAL_SERVER_ERROR)
57 .body(BoxBody::new(
58 Full::new(Bytes::from(self.to_string())).map_err(|_: Infallible| unreachable!()),
59 ))
60 .unwrap()
61 }
62}
63
64const MAX_HEADER_SIZE: usize = 64 * 1024;
66const BODY_CHUNK_SIZE: usize = 8 * 1024;
68const CHANNEL_BUFFER_SIZE: usize = 16;
70
71pin_project! {
72 struct CgiStreamBody {
74 rx: mpsc::Receiver<Result<Bytes, std::io::Error>>,
75 done: bool,
76 }
77}
78
79impl CgiStreamBody {
80 fn new(rx: mpsc::Receiver<Result<Bytes, std::io::Error>>) -> Self {
81 Self { rx, done: false }
82 }
83}
84
85impl Body for CgiStreamBody {
86 type Data = Bytes;
87 type Error = Infallible;
88
89 fn poll_frame(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
93 let this = self.project();
94
95 if *this.done {
96 return Poll::Ready(None);
97 }
98
99 match this.rx.poll_recv(cx) {
100 Poll::Ready(Some(Ok(bytes))) if !bytes.is_empty() => {
101 Poll::Ready(Some(Ok(Frame::data(bytes))))
102 }
103 Poll::Pending => Poll::Pending,
104 _ => {
105 *this.done = true;
106 Poll::Ready(None)
107 }
108 }
109 }
110
111 fn is_end_stream(&self) -> bool {
112 self.done
113 }
114}
115
116#[derive(Debug, Clone)]
120pub struct CgiConfig {
121 command: String,
122 args: Vec<String>,
123 working_dir: Option<String>,
124 extra_env: HashMap<String, String>,
125 server_software: String,
126 server_name: String,
127 server_port: u16,
128 timeout_ms: Option<u64>,
129 inherit_env: Vec<String>,
130 script_name: Option<String>,
131 document_root: Option<String>,
132 pass_through_stderr: bool,
133 https: bool,
134}
135
136impl CgiConfig {
137 pub fn new(command: impl Into<String>) -> Self {
139 Self {
140 command: command.into(),
141 args: Vec::new(),
142 working_dir: None,
143 extra_env: HashMap::new(),
144 server_software: "rust-cgi-server/0.1".to_string(),
145 server_name: "localhost".to_string(),
146 server_port: 80,
147 timeout_ms: None,
148 inherit_env: Vec::new(),
149 script_name: None,
150 document_root: None,
151 pass_through_stderr: true,
152 https: false,
153 }
154 }
155
156 pub fn args(mut self, args: Vec<String>) -> Self {
158 self.args = args;
159 self
160 }
161
162 pub fn working_dir(mut self, dir: impl Into<String>) -> Self {
164 self.working_dir = Some(dir.into());
165 self
166 }
167
168 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
170 self.extra_env.insert(key.into(), value.into());
171 self
172 }
173
174 pub fn server_software(mut self, software: impl Into<String>) -> Self {
176 self.server_software = software.into();
177 self
178 }
179
180 pub fn server_name(mut self, name: impl Into<String>) -> Self {
182 self.server_name = name.into();
183 self
184 }
185
186 pub fn server_port(mut self, port: u16) -> Self {
188 self.server_port = port;
189 self
190 }
191
192 pub fn https(mut self, https: bool) -> Self {
195 self.https = https;
196 self
197 }
198
199 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
201 self.timeout_ms = Some(timeout.as_millis() as u64);
202 self
203 }
204
205 pub fn inherit_env(mut self, vars: Vec<String>) -> Self {
207 self.inherit_env = vars;
208 self
209 }
210
211 pub fn inherit(mut self, var: impl Into<String>) -> Self {
213 self.inherit_env.push(var.into());
214 self
215 }
216
217 pub fn script_name(mut self, name: impl Into<String>) -> Self {
219 self.script_name = Some(name.into());
220 self
221 }
222
223 pub fn document_root(mut self, root: impl Into<String>) -> Self {
226 self.document_root = Some(root.into());
227 self
228 }
229
230 pub fn pass_through_stderr(mut self, pass_through: bool) -> Self {
232 self.pass_through_stderr = pass_through;
233 self
234 }
235}
236
237fn build_cgi_env<B>(
238 request: &Request<B>,
239 config: &CgiConfig,
240 remote_addr: Option<&str>,
241) -> HashMap<String, String> {
242 let mut env = HashMap::new();
243
244 env.insert("GATEWAY_INTERFACE".to_string(), "CGI/1.1".to_string());
246 env.insert("REQUEST_METHOD".to_string(), request.method().to_string());
247 env.insert("SERVER_NAME".to_string(), config.server_name.clone());
248 env.insert("SERVER_PORT".to_string(), config.server_port.to_string());
249 env.insert(
250 "SERVER_PROTOCOL".to_string(),
251 format!("{:?}", request.version()),
252 );
253 env.insert(
254 "SERVER_SOFTWARE".to_string(),
255 config.server_software.clone(),
256 );
257
258 let path = request.uri().path();
259 let (script_name, path_info) = if let Some(ref sn) = config.script_name {
260 let path_info = if path.starts_with(sn) {
261 path[sn.len()..].to_string()
262 } else {
263 String::new()
264 };
265 (sn.clone(), path_info)
266 } else {
267 (path.to_string(), String::new())
268 };
269
270 env.insert("SCRIPT_NAME".to_string(), script_name);
271 env.insert("PATH_INFO".to_string(), path_info.clone());
272
273 if !path_info.is_empty()
275 && let Some(ref doc_root) = config.document_root
276 {
277 let path_translated = format!("{}{}", doc_root, path_info);
278 env.insert("PATH_TRANSLATED".to_string(), path_translated);
279 }
280
281 env.insert(
283 "QUERY_STRING".to_string(),
284 request.uri().query().unwrap_or("").to_string(),
285 );
286
287 env.insert(
288 "REQUEST_URI".to_string(),
289 request
290 .uri()
291 .path_and_query()
292 .map(|pq| pq.to_string())
293 .unwrap_or_else(|| path.to_string()),
294 );
295
296 env.insert(
299 "REMOTE_ADDR".to_string(),
300 remote_addr.unwrap_or("127.0.0.1").to_string(),
301 );
302
303 if config.https {
305 env.insert("HTTPS".to_string(), "on".to_string());
306 }
307
308 if let Some(content_type) = request.headers().get(http::header::CONTENT_TYPE)
309 && let Ok(ct) = content_type.to_str()
310 {
311 env.insert("CONTENT_TYPE".to_string(), ct.to_string());
312 }
313
314 if let Some(content_length) = request.headers().get(http::header::CONTENT_LENGTH)
315 && let Ok(cl) = content_length.to_str()
316 {
317 env.insert("CONTENT_LENGTH".to_string(), cl.to_string());
318 }
319
320 for var_name in &config.inherit_env {
321 if let Ok(value) = std::env::var(var_name) {
322 env.insert(var_name.clone(), value);
323 }
324 }
325
326 for (name, value) in request.headers() {
329 if name == http::header::CONTENT_TYPE || name == http::header::CONTENT_LENGTH {
331 continue;
332 }
333
334 if name == http::header::AUTHORIZATION || name == http::header::PROXY_AUTHORIZATION {
336 continue;
337 }
338
339 if name == http::header::CONNECTION
341 || name.as_str().eq_ignore_ascii_case("keep-alive")
342 || name == http::header::TRANSFER_ENCODING
343 || name == http::header::TE
344 || name == http::header::TRAILER
345 || name == http::header::UPGRADE
346 {
347 continue;
348 }
349
350 let env_name = format!("HTTP_{}", name.as_str().to_uppercase().replace('-', "_"));
351
352 if let Ok(v) = value.to_str() {
353 env.insert(env_name, v.to_string());
354 }
355 }
356
357 for (key, value) in &config.extra_env {
358 env.insert(key.clone(), value.clone());
359 }
360
361 env
362}
363
364fn find_header_body_separator(data: &[u8]) -> Option<usize> {
365 for i in 0..data.len() {
366 if i + 3 < data.len() && &data[i..i + 4] == b"\r\n\r\n" {
367 return Some(i);
368 }
369 if i + 1 < data.len() && &data[i..i + 2] == b"\n\n" {
370 return Some(i);
371 }
372 }
373 None
374}
375
376fn separator_length_at(data: &[u8], pos: usize) -> usize {
378 if data[pos..].starts_with(b"\r\n\r\n") {
379 4
380 } else {
381 2 }
383}
384
385struct ParsedHeaders {
387 status: StatusCode,
388 headers: http::HeaderMap,
389}
390
391fn parse_cgi_headers(header_bytes: &[u8]) -> Result<ParsedHeaders, CgiError> {
393 let header_str = std::str::from_utf8(header_bytes)
394 .map_err(|e| CgiError::ParseError(format!("Invalid UTF-8 in headers: {}", e)))?;
395
396 let mut status = StatusCode::OK;
397 let mut headers = http::HeaderMap::new();
398
399 for line in header_str.lines() {
400 let line = line.trim();
401 if line.is_empty() {
402 continue;
403 }
404
405 if let Some((name, value)) = line.split_once(':') {
406 let name = name.trim();
407 let value = value.trim();
408
409 if name.eq_ignore_ascii_case("Status") {
410 let code_str = value.split_once(' ').map_or(value, |(code, _)| code);
411 status = code_str
412 .parse::<u16>()
413 .map_err(|_| {
414 CgiError::InvalidHeader(format!("Invalid status code: {}", code_str))
415 })
416 .and_then(|code| {
417 StatusCode::from_u16(code).map_err(|_| {
418 CgiError::InvalidHeader(format!("Invalid status code: {}", code))
419 })
420 })?;
421 } else {
422 let header_name = http::header::HeaderName::from_bytes(name.as_bytes())
423 .map_err(|e| CgiError::InvalidHeader(format!("Invalid header name: {}", e)))?;
424 let header_value = http::header::HeaderValue::from_str(value)
425 .map_err(|e| CgiError::InvalidHeader(format!("Invalid header value: {}", e)))?;
426 headers.append(header_name, header_value);
427 }
428 }
429 }
430
431 Ok(ParsedHeaders { status, headers })
432}
433
434async fn write_request_body<B>(mut stdin: ChildStdin, body: B, timeout: Option<std::time::Duration>)
436where
437 B: Body + Send,
438 B::Data: Send,
439 B::Error: std::fmt::Debug,
440{
441 use bytes::Buf;
442
443 let write_future = async {
444 let mut body = std::pin::pin!(body);
445 loop {
446 match std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
447 Some(Ok(frame)) => {
448 if let Ok(data) = frame.into_data()
449 && let Err(e) = stdin.write_all(data.chunk()).await
450 {
451 if e.kind() != std::io::ErrorKind::BrokenPipe {
454 eprintln!("CGI stdin write error: {}", e);
455 }
456 break;
457 }
458 }
459 Some(Err(e)) => {
460 eprintln!("Error reading request body: {:?}", e);
461 break;
462 }
463 None => break,
464 }
465 }
466 drop(stdin);
468 };
469
470 if let Some(timeout_duration) = timeout {
471 if tokio::time::timeout(timeout_duration, write_future)
472 .await
473 .is_err()
474 {
475 eprintln!("Timeout writing request body to CGI stdin");
476 }
477 } else {
478 write_future.await;
479 }
480}
481
482async fn read_and_parse_headers(
485 stdout: &mut ChildStdout,
486 child: &mut Child,
487 timeout: Option<std::time::Duration>,
488) -> Result<(ParsedHeaders, Vec<u8>), CgiError> {
489 let read_future = async {
490 let mut header_buf = Vec::with_capacity(4096);
491 let mut temp_buf = [0u8; 1024];
492
493 loop {
494 let n = stdout
495 .read(&mut temp_buf)
496 .await
497 .map_err(CgiError::SpawnError)?;
498
499 if n == 0 {
500 if header_buf.is_empty() {
503 if let Ok(Some(status)) = child.try_wait()
505 && !status.success()
506 {
507 return Err(status
508 .code()
509 .map_or(CgiError::ProcessKilled, CgiError::ProcessError));
510 }
511 return Err(CgiError::ParseError(
513 "CGI script produced no output".to_string(),
514 ));
515 }
516 let parsed = parse_cgi_headers(&header_buf)?;
518 return Ok((parsed, Vec::new()));
519 }
520
521 header_buf.extend_from_slice(&temp_buf[..n]);
522
523 if let Some(pos) = find_header_body_separator(&header_buf) {
524 let sep_len = separator_length_at(&header_buf, pos);
525 let body_start = pos + sep_len;
526
527 let parsed = parse_cgi_headers(&header_buf[..pos])?;
528 let initial_body = header_buf[body_start..].to_vec();
529
530 return Ok((parsed, initial_body));
531 }
532
533 if header_buf.len() > MAX_HEADER_SIZE {
535 return Err(CgiError::ParseError(format!(
536 "Headers exceed maximum size of {} bytes",
537 MAX_HEADER_SIZE
538 )));
539 }
540 }
541 };
542
543 if let Some(timeout_duration) = timeout {
544 tokio::time::timeout(timeout_duration, read_future)
545 .await
546 .map_err(|_| CgiError::Timeout(timeout_duration.as_millis() as u64))?
547 } else {
548 read_future.await
549 }
550}
551
552async fn stream_stdout(
554 mut stdout: ChildStdout,
555 tx: mpsc::Sender<Result<Bytes, std::io::Error>>,
556 initial_body: Vec<u8>,
557 mut child: Child,
558) {
559 if !initial_body.is_empty() && tx.send(Ok(Bytes::from(initial_body))).await.is_err() {
560 return;
561 }
562
563 let mut buf = vec![0u8; BODY_CHUNK_SIZE];
564 loop {
565 match stdout.read(&mut buf).await {
566 Ok(0) => break,
567 Ok(n) => {
568 let chunk = Bytes::copy_from_slice(&buf[..n]);
569 if tx.send(Ok(chunk)).await.is_err() {
570 break;
571 }
572 }
573 Err(e) => {
574 let _ = tx.send(Err(e)).await;
575 break;
576 }
577 }
578 }
579
580 drop(stdout);
581 drop(tx);
582
583 match child.wait().await {
584 Ok(status) if !status.success() => match status.code() {
585 Some(code) => eprintln!("CGI process exited with code: {code}"),
586 None => eprintln!("CGI process was killed by signal"),
587 },
588 Err(e) => eprintln!("Error waiting for CGI process: {e}"),
589 _ => {}
590 }
591}
592
593async fn execute_cgi<B>(
595 config: &CgiConfig,
596 env: HashMap<String, String>,
597 body: B,
598) -> Result<Response<BoxBody<Bytes, Infallible>>, CgiError>
599where
600 B: Body + Send + 'static,
601 B::Data: Send,
602 B::Error: std::fmt::Debug + Send,
603{
604 use std::time::Duration;
605
606 let mut cmd = Command::new(&config.command);
607
608 if !config.args.is_empty() {
609 cmd.args(&config.args);
610 }
611
612 if let Some(ref dir) = config.working_dir {
613 cmd.current_dir(dir);
614 }
615
616 cmd.env_clear();
617 for (key, value) in &env {
618 cmd.env(key, value);
619 }
620
621 cmd.stdin(std::process::Stdio::piped());
622 cmd.stdout(std::process::Stdio::piped());
623 cmd.stderr(std::process::Stdio::inherit()); let mut child = cmd.spawn()?;
626 let timeout = config.timeout_ms.map(Duration::from_millis);
627
628 let stdin = child.stdin.take().unwrap();
629 tokio::spawn(write_request_body(stdin, body, timeout));
630
631 let mut stdout = child.stdout.take().unwrap();
632
633 let (parsed_headers, initial_body) =
635 read_and_parse_headers(&mut stdout, &mut child, timeout).await?;
636
637 let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
638 tokio::spawn(stream_stdout(stdout, tx, initial_body, child));
639
640 let mut response_builder = Response::builder().status(parsed_headers.status);
641 for (name, value) in parsed_headers.headers.iter() {
642 response_builder = response_builder.header(name, value);
643 }
644
645 response_builder
646 .body(BoxBody::new(CgiStreamBody::new(rx)))
647 .map_err(|e| CgiError::ParseError(format!("Failed to build response: {}", e)))
648}
649
650#[derive(Clone)]
655pub struct CgiService {
656 config: CgiConfig,
657 remote_addr: Option<String>,
658}
659
660impl CgiService {
661 pub fn new(command: impl Into<String>) -> Self {
663 Self {
664 config: CgiConfig::new(command),
665 remote_addr: None,
666 }
667 }
668
669 pub fn with_config(config: CgiConfig) -> Self {
671 Self {
672 config,
673 remote_addr: None,
674 }
675 }
676
677 pub fn remote_addr(mut self, addr: impl Into<String>) -> Self {
679 self.remote_addr = Some(addr.into());
680 self
681 }
682}
683
684impl<B> Service<Request<B>> for CgiService
685where
686 B: Body + Send + 'static,
687 B::Data: Send,
688 B::Error: std::fmt::Debug + Send,
689{
690 type Response = Response<BoxBody<Bytes, Infallible>>;
691 type Error = Infallible;
692 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
693
694 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
695 Poll::Ready(Ok(()))
696 }
697
698 fn call(&mut self, request: Request<B>) -> Self::Future {
699 let config = self.config.clone();
700 let remote_addr = self.remote_addr.clone();
701
702 Box::pin(async move {
703 let env = build_cgi_env(&request, &config, remote_addr.as_deref());
704 let (_parts, body) = request.into_parts();
705 Ok(execute_cgi(&config, env, body)
706 .await
707 .unwrap_or_else(|e| e.into_response()))
708 })
709 }
710}