#![no_std]
use heapless::String;
use thiserror::Error;
#[derive(Error, Debug, PartialEq)]
pub enum ParseError {
#[error("query string did not contain '='")]
QueryNoEquals,
#[error("buffer was not large enough")]
BufferTooSmall,
#[error("only ascii characters are supported for decoding (<=128)")]
EncodedNonAscii,
#[error("malformed http request")]
BadRequest,
#[error("only GET and POST are supported")]
UnsupportedMethod,
#[error("only http/1.1 is supported")]
UnsupportedProtocol,
}
impl From<heapless::CapacityError> for ParseError {
fn from(_value: heapless::CapacityError) -> Self {
Self::BufferTooSmall
}
}
pub struct QueryParams<'a, const KN: usize, const VN: usize> {
rest: &'a str,
}
impl<'a, const KN: usize, const VN: usize> Iterator for QueryParams<'a, KN, VN> {
type Item = Result<QueryParam<KN, VN>, ParseError>;
fn next(&mut self) -> Option<Self::Item> {
if self.rest.is_empty() {
return None;
}
let (segment, tail) = self.rest.split_once('&').unwrap_or((self.rest, ""));
self.rest = tail;
if segment.is_empty() {
return self.next(); }
Some(segment.parse())
}
}
pub struct QueryParam<const KN: usize, const VN: usize> {
pub k: String<KN>,
pub v: String<VN>,
}
impl<const KN: usize, const VN: usize> QueryParam<KN, VN> {
pub fn entry(&self) -> (&String<KN>, &String<VN>) {
(&self.k, &self.v)
}
}
impl<const KN: usize, const VN: usize> core::str::FromStr for QueryParam<KN, VN> {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (k, v) = s.split_once('=').ok_or(ParseError::QueryNoEquals)?;
Ok(QueryParam {
k: unescape::<KN>(k)?,
v: unescape::<VN>(v)?,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RequestLine<const N: usize> {
pub method: Method,
pub target: String<N>,
pub protocol: Protocol,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Method {
GET,
POST,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Protocol {
HTTP1, }
impl<const N: usize> RequestLine<N> {
pub fn query_params<'a, const KN: usize, const VN: usize>(&'a self) -> QueryParams<'a, KN, VN> {
let rest = self
.target
.as_str()
.split_once('?')
.map(|(_, q)| q)
.unwrap_or("");
QueryParams { rest }
}
}
impl<const N: usize> core::str::FromStr for RequestLine<N> {
type Err = ParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let line = s.lines().next().ok_or(ParseError::BadRequest)?;
let mut parts = line.split_ascii_whitespace();
let method_str = parts.next().ok_or(ParseError::BadRequest)?;
let target_str = parts.next().ok_or(ParseError::BadRequest)?;
let protocol_str = parts.next().ok_or(ParseError::BadRequest)?;
let method = match method_str {
"GET" => Method::GET,
"POST" => Method::POST,
_ => return Err(ParseError::UnsupportedMethod),
};
let target: String<N> = unescape(target_str)?;
let protocol = if protocol_str == "HTTP/1.1" {
Protocol::HTTP1
} else {
return Err(ParseError::UnsupportedProtocol);
};
Ok(RequestLine {
method,
target,
protocol,
})
}
}
pub fn unescape<const N: usize>(escaped: &str) -> Result<String<N>, ParseError> {
let mut out = String::<N>::new();
let bytes = escaped.as_bytes();
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'+' => {
out.push(' ')?;
i += 1;
}
b'%' if i + 2 < bytes.len() => {
if let (Some(hi), Some(lo)) =
(hex_char_to_dec(bytes[i + 1]), hex_char_to_dec(bytes[i + 2]))
{
let c = (hi << 4 | lo) as char;
if !c.is_ascii() {
return Err(ParseError::EncodedNonAscii);
}
out.push((hi << 4 | lo) as char)?;
i += 3;
} else {
out.push('%')?;
i += 1;
}
}
b => {
out.push(b as char)?;
i += 1;
}
}
}
Ok(out)
}
fn hex_char_to_dec(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(10 + b - b'a'),
b'A'..=b'F' => Some(10 + b - b'A'),
_ => None,
}
}
#[cfg(test)]
mod tests {
use heapless::String;
use crate::{ParseError, RequestLine, hex_char_to_dec, unescape};
#[test]
fn test_unescape() {
let escaped = "%21%40%23%24%25%5E%26%2A%28%29123asd";
let unescaped: String<32> = unescape(escaped).unwrap();
assert_eq!(unescaped, "!@#$%^&*()123asd");
let non_ascii_str = "%C3B3";
let non_ascii = unescape::<32>(non_ascii_str);
assert_eq!(non_ascii, Err(ParseError::EncodedNonAscii));
let percent_near_end_str = "123abc%20987%f";
let percent_near_end: String<32> = unescape(percent_near_end_str).unwrap();
assert_eq!(percent_near_end, "123abc 987%f");
}
#[test]
fn test_hex_to_dec() {
assert_eq!(hex_char_to_dec(b'F'), Some(15));
assert_eq!(hex_char_to_dec(b'0'), Some(0));
assert_eq!(hex_char_to_dec(b'A'), Some(10));
assert_eq!(hex_char_to_dec(b'H'), None);
assert_eq!(hex_char_to_dec(0x0), None);
}
#[test]
fn request_line_get() {
let line = "GET /submit HTTP/1.1";
let parsed = line.parse::<RequestLine<32>>().unwrap();
let mut target: String<32> = String::new();
target.push_str("/submit").unwrap();
let expected = RequestLine {
method: crate::Method::GET,
target,
protocol: crate::Protocol::HTTP1,
};
assert_eq!(parsed, expected);
}
#[test]
fn request_line_post() {
let line = "POST / HTTP/1.1";
let parsed = line.parse::<RequestLine<32>>().unwrap();
let mut target: String<32> = String::new();
target.push_str("/").unwrap();
let expected = RequestLine {
method: crate::Method::POST,
target,
protocol: crate::Protocol::HTTP1,
};
assert_eq!(parsed, expected);
}
#[test]
fn request_line_with_params() {
let line = "GET /submit?name=http%20lite HTTP/1.1";
let parsed = line.parse::<RequestLine<32>>().unwrap();
let mut target: String<32> = String::new();
target.push_str("/submit?name=http lite").unwrap();
let expected = RequestLine {
method: crate::Method::GET,
target,
protocol: crate::Protocol::HTTP1,
};
assert_eq!(parsed, expected);
}
#[test]
fn iterate_query_params() {
let line: RequestLine<64> = "GET /search?q=hi&lang=en HTTP/1.1".parse().unwrap();
let mut params = line.query_params::<16, 16>();
let first = params.next().unwrap().unwrap();
assert_eq!(first.k.as_str(), "q");
assert_eq!(first.v.as_str(), "hi");
let second = params.next().unwrap().unwrap();
assert_eq!(second.k.as_str(), "lang");
assert_eq!(second.v.as_str(), "en");
assert!(params.next().is_none());
}
}