1use std::io;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12
13use bytes::Bytes;
14use http::{HeaderName, HeaderValue, StatusCode};
15use http_body::{Body, Frame, SizeHint};
16use tokio::io::{AsyncRead, AsyncReadExt as _, ReadBuf};
17use tokio::time::Instant;
18
19#[derive(Debug, thiserror::Error)]
21pub enum HeaderReadError {
22 #[error("cgi child closed before producing a usable header block")]
26 Eof,
27 #[error("cgi connect timeout exceeded before header block ended")]
30 Timeout,
31}
32
33pub async fn read_until_header_end<R>(
45 mut stdout: R,
46 deadline: Instant,
47) -> Result<(Vec<u8>, Vec<u8>, R), HeaderReadError>
48where
49 R: AsyncRead + Unpin + Send,
50{
51 let mut buf = Vec::with_capacity(1024);
52 let mut tmp = [0u8; 4096];
53 loop {
54 let read = tokio::time::timeout_at(deadline, stdout.read(&mut tmp))
55 .await
56 .map_err(|_| HeaderReadError::Timeout)?;
57 match read {
58 Ok(n) if n > 0 => {
59 buf.extend_from_slice(&tmp[..n]);
60 if let Some(end) = find_header_end(&buf) {
61 let leftover = buf.split_off(end);
62 return Ok((buf, leftover, stdout));
63 }
64 }
65 Ok(_) | Err(_) => return Err(HeaderReadError::Eof),
68 }
69 }
70}
71
72fn find_header_end(buf: &[u8]) -> Option<usize> {
73 buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4)
74}
75
76#[derive(Debug, thiserror::Error)]
78pub enum ParseError {
79 #[error("non-utf8 header block: {0}")]
80 NonUtf8(String),
81 #[error("malformed header line: {0}")]
82 MalformedHeader(String),
83 #[error("invalid header name {0}")]
84 InvalidHeaderName(String),
85 #[error("invalid header value for {0}")]
86 InvalidHeaderValue(String),
87 #[error("invalid Status header: {0}")]
88 InvalidStatus(String),
89}
90
91pub fn parse_response_headers(block: &[u8]) -> Result<http::response::Builder, ParseError> {
105 let s = std::str::from_utf8(block).map_err(|e| ParseError::NonUtf8(e.to_string()))?;
106 let mut status: Option<StatusCode> = None;
107 let mut location_seen = false;
108 let mut builder = http::Response::builder();
109 for line in s.split("\r\n") {
110 if line.is_empty() {
111 continue;
112 }
113 let (name, value) =
114 line.split_once(':').ok_or_else(|| ParseError::MalformedHeader(line.to_owned()))?;
115 let name = name.trim();
116 let value = value.trim();
117 if name.eq_ignore_ascii_case("Status") {
118 let code = value
119 .split_whitespace()
120 .next()
121 .ok_or_else(|| ParseError::InvalidStatus(format!("empty value: {value:?}")))?;
122 let parsed: u16 =
123 code.parse().map_err(|e| ParseError::InvalidStatus(format!("parse {code:?}: {e}")))?;
124 status =
125 Some(StatusCode::from_u16(parsed).map_err(|e| ParseError::InvalidStatus(e.to_string()))?);
126 } else {
127 let header_name =
128 HeaderName::try_from(name).map_err(|_| ParseError::InvalidHeaderName(name.to_owned()))?;
129 let header_val = HeaderValue::try_from(value)
130 .map_err(|_| ParseError::InvalidHeaderValue(name.to_owned()))?;
131 if header_name.as_str().eq_ignore_ascii_case("location") {
132 location_seen = true;
133 }
134 builder = builder.header(header_name, header_val);
135 }
136 }
137 let final_status = match (status, location_seen) {
138 (Some(s), _) => s,
139 (None, true) => StatusCode::FOUND,
140 (None, false) => StatusCode::OK,
141 };
142 Ok(builder.status(final_status))
143}
144
145pub struct CgiResponseBody<R, G = ()> {
156 initial: Option<Bytes>,
157 stdout: R,
158 deadline: Instant,
159 _guard: G,
160}
161
162impl<R, G> CgiResponseBody<R, G> {
163 pub fn new(initial: Vec<u8>, stdout: R, deadline: Instant, guard: G) -> Self {
167 let initial = if initial.is_empty() { None } else { Some(Bytes::from(initial)) };
168 Self { initial, stdout, deadline, _guard: guard }
169 }
170}
171
172impl<R, G> Body for CgiResponseBody<R, G>
173where
174 R: AsyncRead + Unpin + Send,
175 G: Send + Unpin + 'static,
176{
177 type Data = Bytes;
178 type Error = io::Error;
179
180 fn poll_frame(
181 mut self: Pin<&mut Self>,
182 cx: &mut Context<'_>,
183 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
184 if let Some(b) = self.initial.take() {
185 return Poll::Ready(Some(Ok(Frame::data(b))));
186 }
187 if Instant::now() >= self.deadline {
188 return Poll::Ready(Some(Err(io::Error::other("cgi total_timeout exceeded mid-body"))));
189 }
190 let mut buf = [0u8; 4096];
191 let mut read_buf = ReadBuf::new(&mut buf);
192 match Pin::new(&mut self.stdout).poll_read(cx, &mut read_buf) {
193 Poll::Pending => Poll::Pending,
194 Poll::Ready(Ok(())) => {
195 let filled = read_buf.filled();
196 if filled.is_empty() {
197 Poll::Ready(None)
198 } else {
199 Poll::Ready(Some(Ok(Frame::data(Bytes::copy_from_slice(filled)))))
200 }
201 }
202 Poll::Ready(Err(e)) => {
203 Poll::Ready(Some(Err(io::Error::other(format!("cgi stdout read: {e}")))))
204 }
205 }
206 }
207
208 fn is_end_stream(&self) -> bool {
209 false
210 }
211
212 fn size_hint(&self) -> SizeHint {
213 SizeHint::default()
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::time::Duration;
220
221 use http_body_util::BodyExt as _;
222 use tokio::io::AsyncWriteExt as _;
223
224 use super::*;
225
226 #[test]
227 fn parse_status_header_picks_up_code() {
228 let block = b"Status: 201 Created\r\nContent-Type: text/plain\r\n\r\n";
229 let resp = parse_response_headers(block).expect("parse").body(()).unwrap();
230 assert_eq!(resp.status(), StatusCode::CREATED);
231 assert_eq!(resp.headers().get("content-type").unwrap(), "text/plain");
232 }
233
234 #[test]
235 fn parse_location_without_status_defaults_to_302() {
236 let block = b"Location: /elsewhere\r\n\r\n";
237 let resp = parse_response_headers(block).expect("parse").body(()).unwrap();
238 assert_eq!(resp.status(), StatusCode::FOUND);
239 assert_eq!(resp.headers().get("location").unwrap(), "/elsewhere");
240 }
241
242 #[test]
243 fn parse_no_status_no_location_defaults_to_200() {
244 let block = b"Content-Type: text/plain\r\n\r\n";
245 let resp = parse_response_headers(block).expect("parse").body(()).unwrap();
246 assert_eq!(resp.status(), StatusCode::OK);
247 }
248
249 #[test]
250 fn parse_rejects_malformed_line() {
251 let block = b"no-colon-here\r\n\r\n";
252 assert!(matches!(parse_response_headers(block), Err(ParseError::MalformedHeader(_)),));
253 }
254
255 #[tokio::test]
256 async fn read_until_header_end_returns_block_and_leftover() {
257 let (mut tx, rx) = tokio::io::duplex(64);
258 tokio::spawn(async move {
259 tx.write_all(b"Status: 200 OK\r\n\r\nbody-bytes-here").await.unwrap();
260 });
261 let deadline = Instant::now() + Duration::from_secs(2);
262 let (head, leftover, _rest) = read_until_header_end(rx, deadline).await.expect("ok");
263 assert_eq!(head, b"Status: 200 OK\r\n\r\n");
264 assert_eq!(leftover, b"body-bytes-here");
265 }
266
267 #[tokio::test]
268 async fn read_until_header_end_eof_returns_err() {
269 let (tx, rx) = tokio::io::duplex(64);
270 drop(tx); let deadline = Instant::now() + Duration::from_secs(2);
272 assert!(matches!(read_until_header_end(rx, deadline).await, Err(HeaderReadError::Eof)));
273 }
274
275 #[tokio::test(start_paused = true)]
276 async fn read_until_header_end_timeout_returns_err() {
277 let (_tx, rx) = tokio::io::duplex(64);
278 let deadline = Instant::now() + Duration::from_millis(50);
280 tokio::time::advance(Duration::from_millis(60)).await;
281 assert!(matches!(read_until_header_end(rx, deadline).await, Err(HeaderReadError::Timeout)));
282 }
283
284 #[tokio::test]
285 async fn cgi_response_body_yields_leftover_then_streams_to_eof() {
286 let (mut tx, rx) = tokio::io::duplex(64);
287 tokio::spawn(async move {
288 tx.write_all(b"-streamed").await.unwrap();
289 drop(tx);
290 });
291 let deadline = Instant::now() + Duration::from_secs(2);
292 let mut body = CgiResponseBody::new(b"leftover".to_vec(), rx, deadline, ());
293
294 let frame = std::pin::Pin::new(&mut body).frame().await.expect("first frame").expect("data");
295 assert_eq!(frame.into_data().unwrap(), &b"leftover"[..]);
296
297 let mut acc = Vec::new();
299 while let Some(f) = std::pin::Pin::new(&mut body).frame().await {
300 acc.extend_from_slice(f.unwrap().into_data().unwrap().as_ref());
301 }
302 assert_eq!(acc, b"-streamed");
303 }
304
305 #[tokio::test]
306 async fn cgi_response_body_keeps_guard_alive_until_drop() {
307 let guard = std::sync::Arc::new(());
311 let weak = std::sync::Arc::downgrade(&guard);
312 let (_tx, rx) = tokio::io::duplex(64);
313 let body = CgiResponseBody::new(Vec::new(), rx, Instant::now() + Duration::from_secs(1), guard);
314 assert!(weak.strong_count() > 0);
315 drop(body);
316 assert_eq!(weak.strong_count(), 0);
317 }
318}