xocomil 0.3.0

A lightweight, zero-allocation HTTP/1.1 request parser and response writer
Documentation
//! Percent-decoding for HTTP request targets and form-encoded values.
//!
//! Two modes:
//!
//! - [`Mode::Path`] — RFC 3986 percent-decoding. `+` is left as-is.
//! - [`Mode::Form`] — `application/x-www-form-urlencoded` decoding.
//!   `+` decodes to space, in addition to `%XX` escapes.
//!
//! [`decode`] is zero-copy when the input contains no escapes (and, for
//! [`Mode::Form`], no `+`): the input slice is returned as-is and `out`
//! is left untouched. Otherwise, the decoded bytes are written into
//! `out` and a sub-slice of `out` is returned.
//!
//! # Example
//!
//! ```
//! use xocomil::pct::{decode, Mode};
//!
//! // Already clean — no allocation, `out` untouched.
//! let mut buf = [0u8; 32];
//! assert_eq!(decode(b"hello", Mode::Path, &mut buf).unwrap(), b"hello");
//!
//! // Needs decoding — written into `out`.
//! let mut buf = [0u8; 32];
//! assert_eq!(decode(b"foo%20bar", Mode::Path, &mut buf).unwrap(), b"foo bar");
//!
//! // Form mode also decodes `+`.
//! let mut buf = [0u8; 32];
//! assert_eq!(decode(b"a+b", Mode::Form, &mut buf).unwrap(), b"a b");
//! ```

use crate::error::PctErrorKind;

/// Selects which percent-decoding flavor to apply.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Mode {
    /// RFC 3986 percent-decoding. `+` passes through unchanged.
    /// Use for request paths and individual path segments.
    Path,
    /// `application/x-www-form-urlencoded` decoding. `+` decodes to
    /// space; `%XX` escapes are also decoded. Use for query string
    /// values and form bodies.
    Form,
}

/// Returns the decoded byte slice.
///
/// Zero-copy when `input` contains no `%` escapes (and, in [`Mode::Form`],
/// no `+`): the input slice is returned and `out` is not touched.
///
/// Otherwise, the decoded bytes are written into the prefix of `out` and
/// that sub-slice is returned.
///
/// # Errors
///
/// Returns [`PctErrorKind::InvalidEscape`] if a `%` is not followed by
/// two hex digits, or [`PctErrorKind::BufferTooSmall`] if `out` cannot
/// hold the decoded output.
pub fn decode<'a>(
    input: &'a [u8],
    mode: Mode,
    out: &'a mut [u8],
) -> Result<&'a [u8], PctErrorKind> {
    let needs_work = match mode {
        Mode::Path => input.contains(&b'%'),
        Mode::Form => input.iter().any(|&b| b == b'%' || b == b'+'),
    };
    if !needs_work {
        return Ok(input);
    }

    let written = decode_into(input, mode, out)?;
    Ok(&out[..written])
}

/// Returns the exact number of bytes [`decode`] would write for `input`.
///
/// Useful for sizing a caller-provided buffer when the input might
/// require decoding. For inputs that don't require decoding, returns
/// `input.len()` (the slice would be returned as-is).
///
/// # Errors
///
/// Returns [`PctErrorKind::InvalidEscape`] if a `%` is not followed by
/// two hex digits.
pub fn decoded_len(input: &[u8], mode: Mode) -> Result<usize, PctErrorKind> {
    let mut i = 0;
    let mut n = 0;
    while i < input.len() {
        match input[i] {
            b'%' => {
                if i + 2 >= input.len() {
                    return Err(PctErrorKind::InvalidEscape);
                }
                if hex_pair(input[i + 1], input[i + 2]).is_none() {
                    return Err(PctErrorKind::InvalidEscape);
                }
                i += 3;
                n += 1;
            }
            b'+' if matches!(mode, Mode::Form) => {
                i += 1;
                n += 1;
            }
            _ => {
                i += 1;
                n += 1;
            }
        }
    }
    Ok(n)
}

/// Write decoded bytes into `out`, returning the number of bytes written.
///
/// Unlike [`decode`], this always writes into `out` even when the input
/// contains no escapes. Prefer [`decode`] for the zero-copy fast path.
///
/// # Errors
///
/// Same as [`decode`].
pub fn decode_into(input: &[u8], mode: Mode, out: &mut [u8]) -> Result<usize, PctErrorKind> {
    let mut i = 0;
    let mut w = 0;
    while i < input.len() {
        let b = input[i];
        let decoded = match b {
            b'%' => {
                if i + 2 >= input.len() {
                    return Err(PctErrorKind::InvalidEscape);
                }
                let v = hex_pair(input[i + 1], input[i + 2]).ok_or(PctErrorKind::InvalidEscape)?;
                i += 3;
                v
            }
            b'+' if matches!(mode, Mode::Form) => {
                i += 1;
                b' '
            }
            other => {
                i += 1;
                other
            }
        };
        if w >= out.len() {
            return Err(PctErrorKind::BufferTooSmall);
        }
        out[w] = decoded;
        w += 1;
    }
    Ok(w)
}

#[inline]
const fn hex_digit(b: u8) -> Option<u8> {
    match b {
        b'0'..=b'9' => Some(b - b'0'),
        b'a'..=b'f' => Some(b - b'a' + 10),
        b'A'..=b'F' => Some(b - b'A' + 10),
        _ => None,
    }
}

#[inline]
const fn hex_pair(hi: u8, lo: u8) -> Option<u8> {
    match (hex_digit(hi), hex_digit(lo)) {
        (Some(h), Some(l)) => Some((h << 4) | l),
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn path_no_escapes_is_zero_copy() {
        let input = b"/users/foo";
        let mut out = [0xAAu8; 16];
        let got = decode(input, Mode::Path, &mut out).unwrap();
        assert_eq!(got, input);
        assert!(std::ptr::eq(got.as_ptr(), input.as_ptr()));
        assert!(out.iter().all(|&b| b == 0xAA));
    }

    #[test]
    fn path_decodes_space() {
        let mut out = [0u8; 16];
        let got = decode(b"foo%20bar", Mode::Path, &mut out).unwrap();
        assert_eq!(got, b"foo bar");
    }

    #[test]
    fn path_leaves_plus_alone() {
        let mut out = [0u8; 16];
        let got = decode(b"a+b%20c", Mode::Path, &mut out).unwrap();
        assert_eq!(got, b"a+b c");
    }

    #[test]
    fn form_decodes_plus_to_space() {
        let mut out = [0u8; 16];
        let got = decode(b"a+b", Mode::Form, &mut out).unwrap();
        assert_eq!(got, b"a b");
    }

    #[test]
    fn form_decodes_plus_and_percent() {
        let mut out = [0u8; 16];
        let got = decode(b"hello+world%21", Mode::Form, &mut out).unwrap();
        assert_eq!(got, b"hello world!");
    }

    #[test]
    fn form_no_plus_no_percent_is_zero_copy() {
        let input = b"plain";
        let mut out = [0xAAu8; 16];
        let got = decode(input, Mode::Form, &mut out).unwrap();
        assert!(std::ptr::eq(got.as_ptr(), input.as_ptr()));
    }

    #[test]
    fn upper_and_lower_hex() {
        let mut out = [0u8; 16];
        assert_eq!(decode(b"%2f", Mode::Path, &mut out).unwrap(), b"/");
        let mut out = [0u8; 16];
        assert_eq!(decode(b"%2F", Mode::Path, &mut out).unwrap(), b"/");
    }

    #[test]
    fn truncated_escape_at_end() {
        let mut out = [0u8; 16];
        assert_eq!(
            decode(b"foo%2", Mode::Path, &mut out),
            Err(PctErrorKind::InvalidEscape)
        );
        let mut out = [0u8; 16];
        assert_eq!(
            decode(b"foo%", Mode::Path, &mut out),
            Err(PctErrorKind::InvalidEscape)
        );
    }

    #[test]
    fn non_hex_escape() {
        let mut out = [0u8; 16];
        assert_eq!(
            decode(b"foo%zz", Mode::Path, &mut out),
            Err(PctErrorKind::InvalidEscape)
        );
    }

    #[test]
    fn buffer_too_small() {
        let mut out = [0u8; 3];
        assert_eq!(
            decode(b"foo%20bar", Mode::Path, &mut out),
            Err(PctErrorKind::BufferTooSmall)
        );
    }

    #[test]
    fn empty_input() {
        let mut out = [0u8; 4];
        assert_eq!(decode(b"", Mode::Path, &mut out).unwrap(), b"");
        assert_eq!(decode(b"", Mode::Form, &mut out).unwrap(), b"");
    }

    #[test]
    fn null_byte_via_escape() {
        let mut out = [0u8; 4];
        assert_eq!(decode(b"%00", Mode::Path, &mut out).unwrap(), b"\0");
    }

    #[test]
    fn decoded_len_matches_decode() {
        for (input, mode) in [
            (&b"foo%20bar"[..], Mode::Path),
            (b"a+b", Mode::Form),
            (b"plain", Mode::Path),
            (b"%2F%2Fetc", Mode::Path),
            (b"", Mode::Path),
        ] {
            let n = decoded_len(input, mode).unwrap();
            let mut out = [0u8; 32];
            let got = decode(input, mode, &mut out).unwrap();
            assert_eq!(got.len(), n, "input={input:?} mode={mode:?}");
        }
    }

    #[test]
    fn decoded_len_invalid_escape() {
        assert_eq!(
            decoded_len(b"foo%2", Mode::Path),
            Err(PctErrorKind::InvalidEscape)
        );
    }

    #[test]
    fn decode_into_writes_even_when_clean() {
        let mut out = [0u8; 8];
        let n = decode_into(b"abc", Mode::Path, &mut out).unwrap();
        assert_eq!(n, 3);
        assert_eq!(&out[..n], b"abc");
    }
}