1use std::collections::HashMap;
16use std::convert::Infallible;
17use std::future::Future;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use bytes::Bytes;
22use http::{Request, Response, StatusCode};
23use http_body::{Body, Frame};
24use http_body_util::{BodyExt, Full, combinators::BoxBody};
25use pin_project_lite::pin_project;
26use thiserror::Error;
27use tokio::io::{AsyncReadExt, AsyncWriteExt};
28use tokio::process::{Child, ChildStdin, ChildStdout, Command};
29use tokio::sync::mpsc;
30use tower_service::Service;
31
32#[derive(Error, Debug)]
33pub enum CgiError {
34 #[error("Failed to spawn CGI process: {0}")]
35 SpawnError(#[from] std::io::Error),
36
37 #[error("CGI process returned non-zero exit code: {0}")]
38 ProcessError(i32),
39
40 #[error("Failed to parse CGI response: {0}")]
41 ParseError(String),
42
43 #[error("Invalid CGI response header: {0}")]
44 InvalidHeader(String),
45
46 #[error("CGI process was killed by signal")]
47 ProcessKilled,
48
49 #[error("CGI process timed out after {0}ms")]
50 Timeout(u64),
51}
52
53impl CgiError {
54 fn into_response(self) -> Response<BoxBody<Bytes, Infallible>> {
55 Response::builder()
56 .status(StatusCode::INTERNAL_SERVER_ERROR)
57 .body(BoxBody::new(
58 Full::new(Bytes::from(self.to_string())).map_err(|_: Infallible| unreachable!()),
59 ))
60 .unwrap()
61 }
62}
63
64const MAX_HEADER_SIZE: usize = 64 * 1024;
66const BODY_CHUNK_SIZE: usize = 8 * 1024;
68const CHANNEL_BUFFER_SIZE: usize = 16;
70
71pin_project! {
72 pub struct CgiStreamBody {
74 rx: mpsc::Receiver<Result<Bytes, std::io::Error>>,
75 done: bool,
76 }
77}
78
79impl CgiStreamBody {
80 fn new(rx: mpsc::Receiver<Result<Bytes, std::io::Error>>) -> Self {
81 Self { rx, done: false }
82 }
83}
84
85impl Body for CgiStreamBody {
86 type Data = Bytes;
87 type Error = Infallible;
88
89 fn poll_frame(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
93 let this = self.project();
94
95 if *this.done {
96 return Poll::Ready(None);
97 }
98
99 match this.rx.poll_recv(cx) {
100 Poll::Ready(Some(Ok(bytes))) => {
101 if bytes.is_empty() {
102 *this.done = true;
103 Poll::Ready(None)
104 } else {
105 Poll::Ready(Some(Ok(Frame::data(bytes))))
106 }
107 }
108 Poll::Ready(Some(Err(_))) => {
109 *this.done = true;
110 Poll::Ready(None)
111 }
112 Poll::Ready(None) => {
113 *this.done = true;
114 Poll::Ready(None)
115 }
116 Poll::Pending => Poll::Pending,
117 }
118 }
119
120 fn is_end_stream(&self) -> bool {
121 self.done
122 }
123}
124
125#[derive(Debug, Clone)]
126pub struct CgiConfig {
127 command: String,
128 args: Vec<String>,
129 working_dir: Option<String>,
130 extra_env: HashMap<String, String>,
131 server_software: String,
132 server_name: String,
133 server_port: u16,
134 timeout_ms: Option<u64>,
135 inherit_env: Vec<String>,
136 script_name: Option<String>,
137 document_root: Option<String>,
138 pass_through_stderr: bool,
139 https: bool,
140}
141
142impl CgiConfig {
143 pub fn new(command: impl Into<String>) -> Self {
144 Self {
145 command: command.into(),
146 args: Vec::new(),
147 working_dir: None,
148 extra_env: HashMap::new(),
149 server_software: "rust-cgi-server/0.1".to_string(),
150 server_name: "localhost".to_string(),
151 server_port: 80,
152 timeout_ms: None,
153 inherit_env: Vec::new(),
154 script_name: None,
155 document_root: None,
156 pass_through_stderr: true,
157 https: false,
158 }
159 }
160
161 pub fn args(mut self, args: Vec<String>) -> Self {
162 self.args = args;
163 self
164 }
165
166 pub fn working_dir(mut self, dir: impl Into<String>) -> Self {
167 self.working_dir = Some(dir.into());
168 self
169 }
170
171 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
172 self.extra_env.insert(key.into(), value.into());
173 self
174 }
175
176 pub fn server_software(mut self, software: impl Into<String>) -> Self {
177 self.server_software = software.into();
178 self
179 }
180
181 pub fn server_name(mut self, name: impl Into<String>) -> Self {
182 self.server_name = name.into();
183 self
184 }
185
186 pub fn server_port(mut self, port: u16) -> Self {
187 self.server_port = port;
188 self
189 }
190
191 pub fn https(mut self, https: bool) -> Self {
194 self.https = https;
195 self
196 }
197
198 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
199 self.timeout_ms = Some(timeout.as_millis() as u64);
200 self
201 }
202
203 pub fn inherit_env(mut self, vars: Vec<String>) -> Self {
204 self.inherit_env = vars;
205 self
206 }
207
208 pub fn inherit(mut self, var: impl Into<String>) -> Self {
209 self.inherit_env.push(var.into());
210 self
211 }
212
213 pub fn script_name(mut self, name: impl Into<String>) -> Self {
215 self.script_name = Some(name.into());
216 self
217 }
218
219 pub fn document_root(mut self, root: impl Into<String>) -> Self {
222 self.document_root = Some(root.into());
223 self
224 }
225
226 pub fn pass_through_stderr(mut self, pass_through: bool) -> Self {
227 self.pass_through_stderr = pass_through;
228 self
229 }
230}
231
232fn build_cgi_env<B>(
233 request: &Request<B>,
234 config: &CgiConfig,
235 remote_addr: Option<&str>,
236) -> HashMap<String, String> {
237 let mut env = HashMap::new();
238
239 env.insert("GATEWAY_INTERFACE".to_string(), "CGI/1.1".to_string());
241 env.insert("REQUEST_METHOD".to_string(), request.method().to_string());
242 env.insert("SERVER_NAME".to_string(), config.server_name.clone());
243 env.insert("SERVER_PORT".to_string(), config.server_port.to_string());
244 env.insert(
245 "SERVER_PROTOCOL".to_string(),
246 format!("{:?}", request.version()),
247 );
248 env.insert(
249 "SERVER_SOFTWARE".to_string(),
250 config.server_software.clone(),
251 );
252
253 let path = request.uri().path();
254 let (script_name, path_info) = if let Some(ref sn) = config.script_name {
255 let path_info = if path.starts_with(sn) {
256 path[sn.len()..].to_string()
257 } else {
258 String::new()
259 };
260 (sn.clone(), path_info)
261 } else {
262 (path.to_string(), String::new())
263 };
264
265 env.insert("SCRIPT_NAME".to_string(), script_name);
266 env.insert("PATH_INFO".to_string(), path_info.clone());
267
268 if !path_info.is_empty() {
270 if let Some(ref doc_root) = config.document_root {
271 let path_translated = format!("{}{}", doc_root, path_info);
272 env.insert("PATH_TRANSLATED".to_string(), path_translated);
273 }
274 }
275
276 env.insert(
278 "QUERY_STRING".to_string(),
279 request.uri().query().unwrap_or("").to_string(),
280 );
281
282 env.insert(
283 "REQUEST_URI".to_string(),
284 request
285 .uri()
286 .path_and_query()
287 .map(|pq| pq.to_string())
288 .unwrap_or_else(|| path.to_string()),
289 );
290
291 env.insert(
294 "REMOTE_ADDR".to_string(),
295 remote_addr.unwrap_or("127.0.0.1").to_string(),
296 );
297
298 if config.https {
300 env.insert("HTTPS".to_string(), "on".to_string());
301 }
302
303 if let Some(content_type) = request.headers().get(http::header::CONTENT_TYPE) {
304 if let Ok(ct) = content_type.to_str() {
305 env.insert("CONTENT_TYPE".to_string(), ct.to_string());
306 }
307 }
308
309 if let Some(content_length) = request.headers().get(http::header::CONTENT_LENGTH) {
310 if let Ok(cl) = content_length.to_str() {
311 env.insert("CONTENT_LENGTH".to_string(), cl.to_string());
312 }
313 }
314
315 for var_name in &config.inherit_env {
316 if let Ok(value) = std::env::var(var_name) {
317 env.insert(var_name.clone(), value);
318 }
319 }
320
321 for (name, value) in request.headers() {
324 if name == http::header::CONTENT_TYPE || name == http::header::CONTENT_LENGTH {
326 continue;
327 }
328
329 if name == http::header::AUTHORIZATION || name == http::header::PROXY_AUTHORIZATION {
331 continue;
332 }
333
334 if name == http::header::CONNECTION
336 || name.as_str().eq_ignore_ascii_case("keep-alive")
337 || name == http::header::TRANSFER_ENCODING
338 || name == http::header::TE
339 || name == http::header::TRAILER
340 || name == http::header::UPGRADE
341 {
342 continue;
343 }
344
345 let env_name = format!("HTTP_{}", name.as_str().to_uppercase().replace('-', "_"));
346
347 if let Ok(v) = value.to_str() {
348 env.insert(env_name, v.to_string());
349 }
350 }
351
352 for (key, value) in &config.extra_env {
353 env.insert(key.clone(), value.clone());
354 }
355
356 env
357}
358
359fn find_header_body_separator(data: &[u8]) -> Option<usize> {
360 for i in 0..data.len() {
361 if i + 3 < data.len() && &data[i..i + 4] == b"\r\n\r\n" {
362 return Some(i);
363 }
364 if i + 1 < data.len() && &data[i..i + 2] == b"\n\n" {
365 return Some(i);
366 }
367 }
368 None
369}
370
371fn separator_length_at(data: &[u8], pos: usize) -> usize {
373 if data[pos..].starts_with(b"\r\n\r\n") {
374 4
375 } else if data[pos..].starts_with(b"\n\n") {
376 2
377 } else {
378 2 }
380}
381
382struct ParsedHeaders {
384 status: StatusCode,
385 headers: http::HeaderMap,
386}
387
388fn parse_cgi_headers(header_bytes: &[u8]) -> Result<ParsedHeaders, CgiError> {
390 let header_str = std::str::from_utf8(header_bytes)
391 .map_err(|e| CgiError::ParseError(format!("Invalid UTF-8 in headers: {}", e)))?;
392
393 let mut status = StatusCode::OK;
394 let mut headers = http::HeaderMap::new();
395
396 for line in header_str.lines() {
397 let line = line.trim();
398 if line.is_empty() {
399 continue;
400 }
401
402 if let Some((name, value)) = line.split_once(':') {
403 let name = name.trim();
404 let value = value.trim();
405
406 if name.eq_ignore_ascii_case("Status") {
407 let status_parts: Vec<&str> = value.splitn(2, ' ').collect();
408 if let Some(code_str) = status_parts.first() {
409 status = code_str
410 .parse::<u16>()
411 .map_err(|_| {
412 CgiError::InvalidHeader(format!("Invalid status code: {}", code_str))
413 })
414 .and_then(|code| {
415 StatusCode::from_u16(code).map_err(|_| {
416 CgiError::InvalidHeader(format!("Invalid status code: {}", code))
417 })
418 })?;
419 }
420 } else {
421 let header_name = http::header::HeaderName::from_bytes(name.as_bytes())
422 .map_err(|e| CgiError::InvalidHeader(format!("Invalid header name: {}", e)))?;
423 let header_value = http::header::HeaderValue::from_str(value)
424 .map_err(|e| CgiError::InvalidHeader(format!("Invalid header value: {}", e)))?;
425 headers.append(header_name, header_value);
426 }
427 }
428 }
429
430 Ok(ParsedHeaders { status, headers })
431}
432
433async fn write_request_body<B>(mut stdin: ChildStdin, body: B, timeout: Option<std::time::Duration>)
435where
436 B: Body + Send,
437 B::Data: Send,
438 B::Error: std::fmt::Debug,
439{
440 use bytes::Buf;
441
442 let write_future = async {
443 let mut body = std::pin::pin!(body);
444 loop {
445 match std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
446 Some(Ok(frame)) => {
447 if let Ok(data) = frame.into_data() {
448 if let Err(e) = stdin.write_all(data.chunk()).await {
449 if e.kind() != std::io::ErrorKind::BrokenPipe {
452 eprintln!("CGI stdin write error: {}", e);
453 }
454 break;
455 }
456 }
457 }
458 Some(Err(e)) => {
459 eprintln!("Error reading request body: {:?}", e);
460 break;
461 }
462 None => break, }
464 }
465 drop(stdin);
467 };
468
469 if let Some(timeout_duration) = timeout {
470 if tokio::time::timeout(timeout_duration, write_future)
471 .await
472 .is_err()
473 {
474 eprintln!("Timeout writing request body to CGI stdin");
475 }
476 } else {
477 write_future.await;
478 }
479}
480
481async fn read_and_parse_headers(
484 stdout: &mut ChildStdout,
485 child: &mut Child,
486 timeout: Option<std::time::Duration>,
487) -> Result<(ParsedHeaders, Vec<u8>), CgiError> {
488 let read_future = async {
489 let mut header_buf = Vec::with_capacity(4096);
490 let mut temp_buf = [0u8; 1024];
491
492 loop {
493 let n = stdout
494 .read(&mut temp_buf)
495 .await
496 .map_err(CgiError::SpawnError)?;
497
498 if n == 0 {
499 if header_buf.is_empty() {
502 if let Ok(Some(status)) = child.try_wait() {
504 if !status.success() {
505 if let Some(code) = status.code() {
506 return Err(CgiError::ProcessError(code));
507 } else {
508 return Err(CgiError::ProcessKilled);
509 }
510 }
511 }
512 return Err(CgiError::ParseError(
514 "CGI script produced no output".to_string(),
515 ));
516 }
517 let parsed = parse_cgi_headers(&header_buf)?;
519 return Ok((parsed, Vec::new()));
520 }
521
522 header_buf.extend_from_slice(&temp_buf[..n]);
523
524 if let Some(pos) = find_header_body_separator(&header_buf) {
526 let sep_len = separator_length_at(&header_buf, pos);
527 let body_start = pos + sep_len;
528
529 let parsed = parse_cgi_headers(&header_buf[..pos])?;
530 let initial_body = header_buf[body_start..].to_vec();
531
532 return Ok((parsed, initial_body));
533 }
534
535 if header_buf.len() > MAX_HEADER_SIZE {
537 return Err(CgiError::ParseError(format!(
538 "Headers exceed maximum size of {} bytes",
539 MAX_HEADER_SIZE
540 )));
541 }
542 }
543 };
544
545 if let Some(timeout_duration) = timeout {
546 tokio::time::timeout(timeout_duration, read_future)
547 .await
548 .map_err(|_| CgiError::Timeout(timeout_duration.as_millis() as u64))?
549 } else {
550 read_future.await
551 }
552}
553
554async fn stream_stdout(
556 mut stdout: ChildStdout,
557 tx: mpsc::Sender<Result<Bytes, std::io::Error>>,
558 initial_body: Vec<u8>,
559 mut child: Child,
560 pass_through_stderr: bool,
561) {
562 if !initial_body.is_empty() {
564 if tx.send(Ok(Bytes::from(initial_body))).await.is_err() {
565 return; }
567 }
568
569 let mut buf = vec![0u8; BODY_CHUNK_SIZE];
571 loop {
572 match stdout.read(&mut buf).await {
573 Ok(0) => break, Ok(n) => {
575 let chunk = Bytes::copy_from_slice(&buf[..n]);
576 if tx.send(Ok(chunk)).await.is_err() {
577 break; }
579 }
580 Err(e) => {
581 let _ = tx.send(Err(e)).await;
582 break;
583 }
584 }
585 }
586
587 drop(stdout);
589 drop(tx);
591
592 match child.wait().await {
594 Ok(status) => {
595 if !status.success() {
596 if let Some(code) = status.code() {
597 eprintln!("CGI process exited with code: {}", code);
598 } else {
599 eprintln!("CGI process was killed by signal");
600 }
601 }
602 }
603 Err(e) => {
604 eprintln!("Error waiting for CGI process: {}", e);
605 }
606 }
607
608 let _ = pass_through_stderr;
611}
612
613async fn execute_cgi<B>(
615 config: &CgiConfig,
616 env: HashMap<String, String>,
617 body: B,
618) -> Result<Response<BoxBody<Bytes, Infallible>>, CgiError>
619where
620 B: Body + Send + 'static,
621 B::Data: Send,
622 B::Error: std::fmt::Debug + Send,
623{
624 use std::time::Duration;
625
626 let mut cmd = Command::new(&config.command);
627
628 if !config.args.is_empty() {
629 cmd.args(&config.args);
630 }
631
632 if let Some(ref dir) = config.working_dir {
633 cmd.current_dir(dir);
634 }
635
636 cmd.env_clear();
637 for (key, value) in &env {
638 cmd.env(key, value);
639 }
640
641 cmd.stdin(std::process::Stdio::piped());
642 cmd.stdout(std::process::Stdio::piped());
643 cmd.stderr(std::process::Stdio::inherit()); let mut child = cmd.spawn()?;
646
647 let timeout = config.timeout_ms.map(Duration::from_millis);
649
650 let stdin = child.stdin.take().expect("stdin was configured as piped");
652 tokio::spawn(write_request_body(stdin, body, timeout));
653
654 let mut stdout = child.stdout.take().expect("stdout was configured as piped");
656
657 let (parsed_headers, initial_body) =
660 read_and_parse_headers(&mut stdout, &mut child, timeout).await?;
661
662 let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
664
665 let pass_through_stderr = config.pass_through_stderr;
667 tokio::spawn(stream_stdout(
668 stdout,
669 tx,
670 initial_body,
671 child,
672 pass_through_stderr,
673 ));
674
675 let mut response_builder = Response::builder().status(parsed_headers.status);
677 for (name, value) in parsed_headers.headers.iter() {
678 response_builder = response_builder.header(name, value);
679 }
680
681 response_builder
682 .body(BoxBody::new(CgiStreamBody::new(rx)))
683 .map_err(|e| CgiError::ParseError(format!("Failed to build response: {}", e)))
684}
685
686#[derive(Clone)]
687pub struct CgiService {
688 config: CgiConfig,
689 remote_addr: Option<String>,
690}
691
692impl CgiService {
693 pub fn new(command: impl Into<String>) -> Self {
694 Self {
695 config: CgiConfig::new(command),
696 remote_addr: None,
697 }
698 }
699
700 pub fn with_config(config: CgiConfig) -> Self {
701 Self {
702 config,
703 remote_addr: None,
704 }
705 }
706
707 pub fn remote_addr(mut self, addr: impl Into<String>) -> Self {
708 self.remote_addr = Some(addr.into());
709 self
710 }
711}
712
713impl<B> Service<Request<B>> for CgiService
714where
715 B: Body + Send + 'static,
716 B::Data: Send,
717 B::Error: std::fmt::Debug + Send,
718{
719 type Response = Response<BoxBody<Bytes, Infallible>>;
720 type Error = Infallible;
721 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
722
723 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
724 Poll::Ready(Ok(()))
725 }
726
727 fn call(&mut self, request: Request<B>) -> Self::Future {
728 let config = self.config.clone();
729 let remote_addr = self.remote_addr.clone();
730
731 Box::pin(async move {
732 let env = build_cgi_env(&request, &config, remote_addr.as_deref());
734
735 let (_parts, body) = request.into_parts();
737
738 Ok(execute_cgi(&config, env, body)
740 .await
741 .unwrap_or_else(|e| e.into_response()))
742 })
743 }
744}