Skip to main content

cgi_response/
lib.rs

1//! Parse a CGI child's stdout into an `http::Response`, then stream
2//! the body through `http_body::Body`. The other half of a CGI
3//! gateway — building the RFC 3875 environment for the child —
4//! lives in the `cgi-request` crate; pair them when you need both
5//! directions.
6//!
7//! See the crate-level README for what is and isn't covered here.
8
9use std::io;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use bytes::Bytes;
14use http::{HeaderName, HeaderValue, StatusCode};
15use http_body::{Body, Frame, SizeHint};
16use tokio::io::{AsyncRead, AsyncReadExt as _, ReadBuf};
17use tokio::time::Instant;
18
19/// Errors from [`read_until_header_end`].
20#[derive(Debug, thiserror::Error)]
21pub enum HeaderReadError {
22	/// The reader closed (`read` returned 0) or errored before the
23	/// `\r\n\r\n` separator was seen. Hosts typically map this to
24	/// `502 Bad Gateway`.
25	#[error("cgi child closed before producing a usable header block")]
26	Eof,
27	/// The deadline expired before the header block completed. Hosts
28	/// typically map this to `504 Gateway Timeout`.
29	#[error("cgi connect timeout exceeded before header block ended")]
30	Timeout,
31}
32
33/// Read from `stdout` until the RFC 3875 header / body separator
34/// (`\r\n\r\n`), or until `deadline`. Returns the header block (up
35/// to and including the separator), the leftover bytes that
36/// arrived in the same `read()` past the separator, and the
37/// still-open reader for downstream body streaming.
38///
39/// # Errors
40///
41/// - [`HeaderReadError::Eof`] when the reader closes / errors
42///   before the separator.
43/// - [`HeaderReadError::Timeout`] when `deadline` expires first.
44pub async fn read_until_header_end<R>(
45	mut stdout: R,
46	deadline: Instant,
47) -> Result<(Vec<u8>, Vec<u8>, R), HeaderReadError>
48where
49	R: AsyncRead + Unpin + Send,
50{
51	let mut buf = Vec::with_capacity(1024);
52	let mut tmp = [0u8; 4096];
53	loop {
54		let read = tokio::time::timeout_at(deadline, stdout.read(&mut tmp))
55			.await
56			.map_err(|_| HeaderReadError::Timeout)?;
57		match read {
58			Ok(n) if n > 0 => {
59				buf.extend_from_slice(&tmp[..n]);
60				if let Some(end) = find_header_end(&buf) {
61					let leftover = buf.split_off(end);
62					return Ok((buf, leftover, stdout));
63				}
64			}
65			// EOF (n == 0) or read error — both mean the child won't
66			// produce any more bytes; map to "no usable header block".
67			Ok(_) | Err(_) => return Err(HeaderReadError::Eof),
68		}
69	}
70}
71
72fn find_header_end(buf: &[u8]) -> Option<usize> {
73	buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4)
74}
75
76/// Errors from [`parse_response_headers`].
77#[derive(Debug, thiserror::Error)]
78pub enum ParseError {
79	#[error("non-utf8 header block: {0}")]
80	NonUtf8(String),
81	#[error("malformed header line: {0}")]
82	MalformedHeader(String),
83	#[error("invalid header name {0}")]
84	InvalidHeaderName(String),
85	#[error("invalid header value for {0}")]
86	InvalidHeaderValue(String),
87	#[error("invalid Status header: {0}")]
88	InvalidStatus(String),
89}
90
91/// Build an `http::response::Builder` from an RFC 3875 header
92/// block. Status resolution:
93///
94/// * `Status: 200 OK` → status code (CGI-specific header, not an
95///   HTTP/1.1 status line).
96/// * `Location: /...` without a `Status:` → 302 Found.
97/// * No `Status:`, no `Location:` → 200 OK.
98///
99/// Other headers pass through untouched.
100///
101/// # Errors
102///
103/// As [`ParseError`].
104pub fn parse_response_headers(block: &[u8]) -> Result<http::response::Builder, ParseError> {
105	let s = std::str::from_utf8(block).map_err(|e| ParseError::NonUtf8(e.to_string()))?;
106	let mut status: Option<StatusCode> = None;
107	let mut location_seen = false;
108	let mut builder = http::Response::builder();
109	for line in s.split("\r\n") {
110		if line.is_empty() {
111			continue;
112		}
113		let (name, value) =
114			line.split_once(':').ok_or_else(|| ParseError::MalformedHeader(line.to_owned()))?;
115		let name = name.trim();
116		let value = value.trim();
117		if name.eq_ignore_ascii_case("Status") {
118			let code = value
119				.split_whitespace()
120				.next()
121				.ok_or_else(|| ParseError::InvalidStatus(format!("empty value: {value:?}")))?;
122			let parsed: u16 =
123				code.parse().map_err(|e| ParseError::InvalidStatus(format!("parse {code:?}: {e}")))?;
124			status =
125				Some(StatusCode::from_u16(parsed).map_err(|e| ParseError::InvalidStatus(e.to_string()))?);
126		} else {
127			let header_name =
128				HeaderName::try_from(name).map_err(|_| ParseError::InvalidHeaderName(name.to_owned()))?;
129			let header_val = HeaderValue::try_from(value)
130				.map_err(|_| ParseError::InvalidHeaderValue(name.to_owned()))?;
131			if header_name.as_str().eq_ignore_ascii_case("location") {
132				location_seen = true;
133			}
134			builder = builder.header(header_name, header_val);
135		}
136	}
137	let final_status = match (status, location_seen) {
138		(Some(s), _) => s,
139		(None, true) => StatusCode::FOUND,
140		(None, false) => StatusCode::OK,
141	};
142	Ok(builder.status(final_status))
143}
144
145/// Streaming body for a CGI response: yields the leftover bytes
146/// (from the post-header read) first, then reads the rest from the
147/// child's stdout to EOF. A `total_deadline` caps the total
148/// streaming time; mid-body the next `poll_frame` past the deadline
149/// returns an `io::Error`.
150///
151/// `G` is a generic drop guard that the body owns for its
152/// lifetime — typically a permit, an `Arc`, or a cancellation
153/// guard the host wants to keep alive while bytes are still
154/// flowing. Use `()` when you don't need one.
155pub struct CgiResponseBody<R, G = ()> {
156	initial: Option<Bytes>,
157	stdout: R,
158	deadline: Instant,
159	_guard: G,
160}
161
162impl<R, G> CgiResponseBody<R, G> {
163	/// Build a body from a leftover-bytes prefix, an open reader,
164	/// the wall-clock deadline for stream completion, and a
165	/// caller-supplied drop guard.
166	pub fn new(initial: Vec<u8>, stdout: R, deadline: Instant, guard: G) -> Self {
167		let initial = if initial.is_empty() { None } else { Some(Bytes::from(initial)) };
168		Self { initial, stdout, deadline, _guard: guard }
169	}
170}
171
172impl<R, G> Body for CgiResponseBody<R, G>
173where
174	R: AsyncRead + Unpin + Send,
175	G: Send + Unpin + 'static,
176{
177	type Data = Bytes;
178	type Error = io::Error;
179
180	fn poll_frame(
181		mut self: Pin<&mut Self>,
182		cx: &mut Context<'_>,
183	) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
184		if let Some(b) = self.initial.take() {
185			return Poll::Ready(Some(Ok(Frame::data(b))));
186		}
187		if Instant::now() >= self.deadline {
188			return Poll::Ready(Some(Err(io::Error::other("cgi total_timeout exceeded mid-body"))));
189		}
190		let mut buf = [0u8; 4096];
191		let mut read_buf = ReadBuf::new(&mut buf);
192		match Pin::new(&mut self.stdout).poll_read(cx, &mut read_buf) {
193			Poll::Pending => Poll::Pending,
194			Poll::Ready(Ok(())) => {
195				let filled = read_buf.filled();
196				if filled.is_empty() {
197					Poll::Ready(None)
198				} else {
199					Poll::Ready(Some(Ok(Frame::data(Bytes::copy_from_slice(filled)))))
200				}
201			}
202			Poll::Ready(Err(e)) => {
203				Poll::Ready(Some(Err(io::Error::other(format!("cgi stdout read: {e}")))))
204			}
205		}
206	}
207
208	fn is_end_stream(&self) -> bool {
209		false
210	}
211
212	fn size_hint(&self) -> SizeHint {
213		SizeHint::default()
214	}
215}
216
217#[cfg(test)]
218mod tests {
219	use std::time::Duration;
220
221	use http_body_util::BodyExt as _;
222	use tokio::io::AsyncWriteExt as _;
223
224	use super::*;
225
226	#[test]
227	fn parse_status_header_picks_up_code() {
228		let block = b"Status: 201 Created\r\nContent-Type: text/plain\r\n\r\n";
229		let resp = parse_response_headers(block).expect("parse").body(()).unwrap();
230		assert_eq!(resp.status(), StatusCode::CREATED);
231		assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
232	}
233
234	#[test]
235	fn parse_location_without_status_defaults_to_302() {
236		let block = b"Location: /elsewhere\r\n\r\n";
237		let resp = parse_response_headers(block).expect("parse").body(()).unwrap();
238		assert_eq!(resp.status(), StatusCode::FOUND);
239		assert_eq!(resp.headers().get("location").unwrap(), "/elsewhere");
240	}
241
242	#[test]
243	fn parse_no_status_no_location_defaults_to_200() {
244		let block = b"Content-Type: text/plain\r\n\r\n";
245		let resp = parse_response_headers(block).expect("parse").body(()).unwrap();
246		assert_eq!(resp.status(), StatusCode::OK);
247	}
248
249	#[test]
250	fn parse_rejects_malformed_line() {
251		let block = b"no-colon-here\r\n\r\n";
252		assert!(matches!(parse_response_headers(block), Err(ParseError::MalformedHeader(_)),));
253	}
254
255	#[tokio::test]
256	async fn read_until_header_end_returns_block_and_leftover() {
257		let (mut tx, rx) = tokio::io::duplex(64);
258		tokio::spawn(async move {
259			tx.write_all(b"Status: 200 OK\r\n\r\nbody-bytes-here").await.unwrap();
260		});
261		let deadline = Instant::now() + Duration::from_secs(2);
262		let (head, leftover, _rest) = read_until_header_end(rx, deadline).await.expect("ok");
263		assert_eq!(head, b"Status: 200 OK\r\n\r\n");
264		assert_eq!(leftover, b"body-bytes-here");
265	}
266
267	#[tokio::test]
268	async fn read_until_header_end_eof_returns_err() {
269		let (tx, rx) = tokio::io::duplex(64);
270		drop(tx); // immediate EOF
271		let deadline = Instant::now() + Duration::from_secs(2);
272		assert!(matches!(read_until_header_end(rx, deadline).await, Err(HeaderReadError::Eof)));
273	}
274
275	#[tokio::test(start_paused = true)]
276	async fn read_until_header_end_timeout_returns_err() {
277		let (_tx, rx) = tokio::io::duplex(64);
278		// Hold _tx open without writing so the read just waits.
279		let deadline = Instant::now() + Duration::from_millis(50);
280		tokio::time::advance(Duration::from_millis(60)).await;
281		assert!(matches!(read_until_header_end(rx, deadline).await, Err(HeaderReadError::Timeout)));
282	}
283
284	#[tokio::test]
285	async fn cgi_response_body_yields_leftover_then_streams_to_eof() {
286		let (mut tx, rx) = tokio::io::duplex(64);
287		tokio::spawn(async move {
288			tx.write_all(b"-streamed").await.unwrap();
289			drop(tx);
290		});
291		let deadline = Instant::now() + Duration::from_secs(2);
292		let mut body = CgiResponseBody::new(b"leftover".to_vec(), rx, deadline, ());
293
294		let frame = std::pin::Pin::new(&mut body).frame().await.expect("first frame").expect("data");
295		assert_eq!(frame.into_data().unwrap(), &b"leftover"[..]);
296
297		// Next: at least one stdout chunk before EOF.
298		let mut acc = Vec::new();
299		while let Some(f) = std::pin::Pin::new(&mut body).frame().await {
300			acc.extend_from_slice(f.unwrap().into_data().unwrap().as_ref());
301		}
302		assert_eq!(acc, b"-streamed");
303	}
304
305	#[tokio::test]
306	async fn cgi_response_body_keeps_guard_alive_until_drop() {
307		// Use an Arc<()> as the drop guard; we hold a weak ref and
308		// confirm strong-count > 0 while the body lives, drops to 0
309		// after.
310		let guard = std::sync::Arc::new(());
311		let weak = std::sync::Arc::downgrade(&guard);
312		let (_tx, rx) = tokio::io::duplex(64);
313		let body = CgiResponseBody::new(Vec::new(), rx, Instant::now() + Duration::from_secs(1), guard);
314		assert!(weak.strong_count() > 0);
315		drop(body);
316		assert_eq!(weak.strong_count(), 0);
317	}
318}