async_scgi/
lib.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4//
5// This Source Code Form is "Incompatible With Secondary Licenses", as
6// defined by the Mozilla Public License, v. 2.0.
7
8//! Async SCGI Client & Server
9//!
10//! This library will work with any async runtime that uses the [`futures-io`](https://crates.io/crates/futures-io)
11//! library I/O traits.
12//!
13//! This crate provides two main tools:
14//! - The [`ScgiRequest`] type to read & write SCGI requests.
15//! - The [`read_request`] function to read an SCGI request from a socket.
16//!
17//! ## Client Example
18//!
19//! ```no_run
20//! # use std::str::from_utf8;
21//! # use futures_lite::{AsyncReadExt, AsyncWriteExt};
22//! # use smol::net::TcpStream;
23//! use async_scgi::{ScgiHeaders, ScgiRequest};
24//!
25//! # fn main() -> anyhow::Result<()> {
26//! # smol::block_on(async {
27//! let mut stream = TcpStream::connect("127.0.0.1:12345").await?;
28//! let mut headers = ScgiHeaders::new();
29//! headers.insert("PATH_INFO".to_owned(), "/".to_owned());
30//! headers.insert("SERVER_NAME".to_owned(), "example.com".to_owned());
31//! let body = b"Hello world!";
32//! let req = ScgiRequest {
33//!     headers,
34//!     body: body.to_vec(),
35//! };
36//! stream.write_all(&req.encode()).await?;
37//! let mut resp = vec![];
38//! stream.read_to_end(&mut resp).await?;
39//! let resp_str = from_utf8(&resp)?;
40//! println!("{}", resp_str);
41//! # Ok(())
42//! # })
43//! # }
44//! ```
45//!
46//! ## Server Example
47//!
48//! ```no_run
49//! # use futures_lite::{AsyncWriteExt, StreamExt};
50//! # use smol::io::BufReader;
51//! # use smol::net::TcpListener;
52//! # use std::str::from_utf8;
53//! #
54//! # fn main() -> anyhow::Result<()> {
55//! # smol::block_on(async {
56//! let listener = TcpListener::bind("127.0.0.1:12345").await?;
57//! let mut incoming = listener.incoming();
58//! while let Some(stream) = incoming.next().await {
59//!     let mut stream = BufReader::new(stream?);
60//!     let req = async_scgi::read_request(&mut stream).await?;
61//!     println!("Headers: {:?}", req.headers);
62//!     println!("Body: {}", from_utf8(&req.body).unwrap());
63//!     stream.write_all(b"Hello Client!").await?;
64//! }
65//! # Ok(())
66//! # })
67//! # }
68//! ```
69
70#[cfg(test)]
71use std::collections::BTreeMap;
72use std::collections::HashMap;
73use std::str::{from_utf8, Utf8Error};
74
75use futures_lite::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt};
76use headers::{encode_headers, ScgiHeaderParseError};
77use thiserror::Error;
78
79/// An error that occurred while reading an SCGI request.
80#[derive(Error, Debug)]
81pub enum ScgiReadError {
82    /// Length can't be decoded to an integer.
83    #[error("Length can't be decoded to an integer")]
84    BadLength,
85    /// The length or the headers are not in UTF-8.
86    #[error("The length or the headers are not in UTF-8")]
87    Utf8(#[from] Utf8Error),
88    /// Netstring sanity checks fail.
89    #[error("Netstring sanity checks fail")]
90    BadNetstring,
91    /// Error parsing SCGI headers.
92    #[error("Error parsing SCGI headers")]
93    BadHeaders(#[from] ScgiHeaderParseError),
94    /// IO Error.
95    #[error("IO Error")]
96    IO(#[from] std::io::Error),
97}
98
99/// An ScgiRequest header map.
100#[cfg(not(test))]
101pub type ScgiHeaders = HashMap<String, String>;
102#[cfg(test)]
103pub type ScgiHeaders = BTreeMap<String, String>;
104
105/// An SCGI request.
106///
107/// The `SCGI` and `CONTENT_LENGTH` length headers are added automatically when
108/// [`ScgiRequest::encode`] is called and removed when requests are read.
109#[derive(Debug, Default, PartialEq, Eq)]
110pub struct ScgiRequest {
111    /// The request header name, value pairs.
112    pub headers: ScgiHeaders,
113    /// The request body.
114    pub body: Vec<u8>,
115}
116
117impl ScgiRequest {
118    /// Create an empty ScgiRequest.
119    pub fn new() -> Self {
120        Self::default()
121    }
122
123    /// Create an ScgiRequest with a set of headers.
124    pub fn from_headers(headers: ScgiHeaders) -> Self {
125        Self {
126            headers,
127            body: Vec::new(),
128        }
129    }
130
131    /// Encode an ScgiRequest to be sent over the wire.
132    pub fn encode(&self) -> Vec<u8> {
133        let headers = encode_headers(&self.headers, self.body.len());
134        let mut buf = Vec::with_capacity(headers.len() + 6);
135        buf.extend(headers.len().to_string().as_bytes());
136        buf.push(b':');
137        buf.extend(headers);
138        buf.push(b',');
139        buf.extend(&self.body);
140        buf
141    }
142}
143
144/// Read an SCGI request.
145pub async fn read_request<S: AsyncBufRead + Unpin>(
146    stream: &mut S,
147) -> Result<ScgiRequest, ScgiReadError> {
148    let mut len_part = Vec::with_capacity(10);
149    let read = stream.read_until(b':', &mut len_part).await?;
150    if len_part[read - 1] != b':' {
151        return Err(ScgiReadError::BadNetstring);
152    }
153    let length = from_utf8(&len_part[..read - 1])?
154        .parse::<usize>()
155        .map_err(|_| ScgiReadError::BadLength)?;
156    let mut headers = vec![0; length];
157    stream.read_exact(&mut headers).await?;
158    let mut end_delim = 0;
159    stream
160        .read_exact(std::slice::from_mut(&mut end_delim))
161        .await?;
162    if end_delim != b',' {
163        return Err(ScgiReadError::BadNetstring);
164    }
165    let (headers, content_length) = headers::header_string_map(&headers)?;
166    let mut body = vec![0; content_length];
167    stream.read_exact(&mut body).await?;
168    Ok(ScgiRequest { headers, body })
169}
170
171/// Functions for working with SCGI headers directly.
172pub mod headers {
173    use super::*;
174    use memchr::memchr;
175
176    /// An error that occurred while parsing SCGI headers.
177    #[derive(Error, Debug)]
178    pub enum ScgiHeaderParseError {
179        /// The length or the headers are not in UTF-8.
180        #[error("The length or the headers are not in UTF-8")]
181        Utf8(#[from] Utf8Error),
182        /// Error parsing the null-terminated headers.
183        #[error("Error parsing the null-terminated headers")]
184        BadHeaderVals,
185        /// CONTENT_LENGTH can't be decoded to an integer.
186        #[error("CONTENT_LENGTH can't be decoded to an integer")]
187        BadLength,
188        /// CONTENT_LENGTH header was missing.
189        #[error("CONTENT_LENGTH header was missing")]
190        NoLength,
191    }
192
193    /// Encode headers to be sent over the wire.
194    ///
195    /// The `SCGI` and `CONTENT_LENGTH` headers are added automatically.
196    pub fn encode_headers(headers: &ScgiHeaders, content_length: usize) -> Vec<u8> {
197        let mut buf = Vec::new();
198
199        // Add required SCGI version header
200        buf.extend(b"SCGI");
201        buf.push(0);
202        buf.extend(b"1");
203        buf.push(0);
204
205        // Add required CONTENT_LENGTH version header
206        buf.extend(b"CONTENT_LENGTH");
207        buf.push(0);
208        buf.extend(content_length.to_string().as_bytes());
209        buf.push(0);
210
211        for (name, value) in headers.iter() {
212            buf.extend(name.as_bytes());
213            buf.push(0);
214            buf.extend(value.as_bytes());
215            buf.push(0);
216        }
217        buf
218    }
219
220    /// Parse the headers, invoking the `header` closure for every header parsed.
221    pub fn parse_headers<'h>(
222        raw_headers: &'h [u8],
223        mut headers_fn: impl FnMut(&'h str, &'h str) -> Result<(), ScgiHeaderParseError>,
224    ) -> Result<(), ScgiHeaderParseError> {
225        let mut pos = 0;
226        while pos < raw_headers.len() {
227            let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
228            let header_name = from_utf8(&raw_headers[pos..pos + null])?;
229            pos += null + 1;
230            let null = memchr(0, &raw_headers[pos..]).ok_or(ScgiHeaderParseError::BadHeaderVals)?;
231            let header_value = from_utf8(&raw_headers[pos..pos + null])?;
232            headers_fn(header_name, header_value)?;
233            pos += null + 1;
234        }
235        Ok(())
236    }
237
238    /// Parse the headers and pack them as strings into a map.
239    ///
240    /// The value of the `CONTENT_LENGTH` header is returned in adition to the
241    /// header map. The `SCGI` and `CONTENT_LENGTH` headers are not included
242    /// in the header map.
243    pub fn header_string_map(
244        raw_headers: &[u8],
245    ) -> Result<(ScgiHeaders, usize), ScgiHeaderParseError> {
246        let mut headers_map = ScgiHeaders::new();
247        let mut content_length = None;
248        parse_headers(raw_headers, |name, value| {
249            // Ignore SCGI version header
250            if name != "SCGI" {
251                if name == "CONTENT_LENGTH" {
252                    content_length =
253                        Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
254                } else {
255                    headers_map.insert(name.to_owned(), value.to_owned());
256                }
257            }
258            Ok(())
259        })?;
260        Ok((
261            headers_map,
262            content_length.ok_or(ScgiHeaderParseError::NoLength)?,
263        ))
264    }
265
266    /// Parse the headers and pack them as slices into a map.
267    ///
268    /// The value of the `CONTENT_LENGTH` header is returned in adition to the
269    /// header map. The `SCGI` and `CONTENT_LENGTH` headers are not included
270    /// in the header map.
271    pub fn header_str_map(
272        raw_headers: &[u8],
273    ) -> Result<(HashMap<&str, &str>, usize), ScgiHeaderParseError> {
274        let mut headers_map = HashMap::new();
275        let mut content_length = None;
276        parse_headers(raw_headers, |name, value| {
277            // Ignore SCGI version header
278            if name != "SCGI" {
279                if name == "CONTENT_LENGTH" {
280                    content_length =
281                        Some(value.parse().map_err(|_| ScgiHeaderParseError::BadLength)?);
282                } else {
283                    headers_map.insert(name, value);
284                }
285            }
286            Ok(())
287        })?;
288        Ok((
289            headers_map,
290            content_length.ok_or(ScgiHeaderParseError::NoLength)?,
291        ))
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use futures_lite::io::BufReader;
298
299    use crate::read_request;
300
301    use super::*;
302
303    const TEST_DATA: &[u8] = include_bytes!("../test_data/dump");
304
305    #[test]
306    fn encode_scgi_request() {
307        let mut headers = ScgiHeaders::new();
308        headers.insert("hello".to_owned(), "world".to_owned());
309        headers.insert("foo".to_owned(), "bar".to_owned());
310        let req = ScgiRequest {
311            headers,
312            body: "here is some data".into(),
313        };
314        let encoded = req.encode();
315        assert_eq!(TEST_DATA, &encoded);
316    }
317
318    #[test]
319    fn read_scgi_request() {
320        let mut headers = ScgiHeaders::new();
321        headers.insert("hello".to_owned(), "world".to_owned());
322        headers.insert("foo".to_owned(), "bar".to_owned());
323        let expected = ScgiRequest {
324            headers,
325            body: "here is some data".into(),
326        };
327        futures_lite::future::block_on(async {
328            let mut stream = BufReader::new(TEST_DATA);
329            let req = read_request(&mut stream).await.unwrap();
330            assert_eq!(expected, req);
331        });
332    }
333}