use std::{fmt::Display, io::Write};
use thiserror::Error;
use crate::{stringesc::StringLosslessExt, stringutil::CharClassExt};
use super::{FieldName, FieldPair, FieldValue, HeaderMap};
#[derive(Error, Debug)]
pub enum FormatError {
#[error(transparent)]
Data(#[from] FormatDataError),
#[error(transparent)]
Io(#[from] std::io::Error),
}
#[derive(Debug)]
pub struct FormatDataError {
line: u64,
name: Option<FieldName>,
value: Option<FieldValue>,
}
impl FormatDataError {
pub fn line(&self) -> u64 {
self.line
}
pub fn name(&self) -> Option<&FieldName> {
self.name.as_ref()
}
pub fn value(&self) -> Option<&FieldValue> {
self.value.as_ref()
}
}
impl std::error::Error for FormatDataError {}
impl Display for FormatDataError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("invalid name-value field")?;
f.write_fmt(format_args!(" line {}", self.line))?;
if let Some(name) = &self.name {
f.write_fmt(format_args!(" name '{}'", name.text))?;
}
if let Some(value) = &self.value {
f.write_fmt(format_args!(" value '{}'", value.text))?;
}
Ok(())
}
}
impl From<std::fmt::Error> for FormatDataError {
fn from(_: std::fmt::Error) -> Self {
FormatDataError {
line: 0,
name: None,
value: None,
}
}
}
pub struct HeaderFormatter {
lossless_scheme: bool,
use_raw: bool,
disable_validation: bool,
}
impl HeaderFormatter {
pub fn new() -> Self {
Self {
lossless_scheme: false,
use_raw: false,
disable_validation: false,
}
}
pub fn set_lossless_scheme(&mut self, value: bool) {
self.lossless_scheme = value;
}
pub fn set_use_raw(&mut self, value: bool) {
self.use_raw = value;
}
pub fn set_disable_validation(&mut self, value: bool) {
self.disable_validation = value;
}
pub fn format_header<W: Write>(
&self,
header: &HeaderMap,
mut dest: W,
) -> Result<usize, FormatError> {
let mut num_bytes = 0;
let mut temp = Vec::new();
for (line, pair) in header.iter().enumerate() {
let name_bytes = self.get_name_bytes(pair, &mut temp);
self.validate_name(pair, name_bytes, line as u64)?;
dest.write_all(name_bytes)?;
if self.use_raw && pair.value.raw.is_some() {
dest.write_all(b":")?;
num_bytes += name_bytes.len() + 1;
} else {
dest.write_all(b": ")?;
num_bytes += name_bytes.len() + 2;
}
let value_bytes = self.get_value_bytes(pair, &mut temp);
self.validate_value(pair, value_bytes, line as u64)?;
dest.write_all(value_bytes)?;
dest.write_all(b"\r\n")?;
num_bytes += value_bytes.len() + 2;
}
Ok(num_bytes)
}
fn get_name_bytes<'a>(&self, pair: &'a FieldPair, temp: &'a mut Vec<u8>) -> &'a [u8] {
match pair.name.raw.as_ref() {
Some(raw) if self.use_raw => raw.as_slice(),
_ => {
if self.lossless_scheme {
*temp = pair.name.text.to_utf8_lossless();
temp
} else {
pair.name.text.as_bytes()
}
}
}
}
fn get_value_bytes<'a>(&self, pair: &'a FieldPair, temp: &'a mut Vec<u8>) -> &'a [u8] {
match pair.value.raw.as_ref() {
Some(raw) if self.use_raw => raw.as_slice(),
_ => {
if self.lossless_scheme {
*temp = pair.value.text.to_utf8_lossless();
temp
} else {
pair.value.text.as_bytes()
}
}
}
}
fn validate_name(
&self,
pair: &FieldPair,
name_bytes: &[u8],
line: u64,
) -> Result<(), FormatError> {
if !self.disable_validation && !name_bytes.iter().all(|c| c.is_token()) {
return Err(FormatDataError {
line,
name: Some(pair.name.clone()),
value: None,
}
.into());
}
Ok(())
}
fn validate_value(
&self,
pair: &FieldPair,
value_bytes: &[u8],
line: u64,
) -> Result<(), FormatError> {
if !self.disable_validation && !value_bytes.iter().all(|&c| c != b'\r' && c != b'\n') {
return Err(FormatDataError {
line: line as u64,
name: None,
value: Some(pair.value.clone()),
}
.into());
}
Ok(())
}
}
impl Default for HeaderFormatter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format() {
let mut map = HeaderMap::new();
map.insert("k1", "v1");
let mut buf = Vec::new();
let formatter = HeaderFormatter::new();
formatter.format_header(&map, &mut buf).unwrap();
assert_eq!(buf, b"k1: v1\r\n");
}
#[test]
fn test_format_lossless() {
let mut map = HeaderMap::new();
map.insert("k1", "v1\u{FFFD}\u{1055FF}");
let mut buf = Vec::new();
let mut formatter = HeaderFormatter::new();
formatter.set_lossless_scheme(true);
formatter.format_header(&map, &mut buf).unwrap();
assert_eq!(buf, b"k1: v1\xff\r\n");
}
#[test]
fn test_format_raw() {
let mut map = HeaderMap::new();
map.insert("k1", "v1");
map.insert(
FieldName::new("k2".to_string(), Some(b"K2".to_vec())),
FieldValue::new("v2".to_string(), Some(b"\tv2".to_vec())),
);
let mut buf = Vec::new();
let mut formatter = HeaderFormatter::new();
formatter.set_use_raw(true);
formatter.format_header(&map, &mut buf).unwrap();
assert_eq!(buf, b"k1: v1\r\nK2:\tv2\r\n");
}
#[test]
fn test_format_invalid_key() {
let mut map = HeaderMap::new();
map.insert("k1:", "v1");
let mut buf = Vec::new();
let mut formatter = HeaderFormatter::new();
let result = formatter.format_header(&map, &mut buf);
assert!(result.is_err());
buf.clear();
formatter.set_disable_validation(true);
formatter.format_header(&map, &mut buf).unwrap();
assert_eq!(buf, b"k1:: v1\r\n");
}
#[test]
fn test_format_invalid_value() {
let mut map = HeaderMap::new();
map.insert("k1", "v1\n");
let mut buf = Vec::new();
let mut formatter = HeaderFormatter::new();
let result = formatter.format_header(&map, &mut buf);
assert!(result.is_err());
buf.clear();
formatter.set_disable_validation(true);
formatter.format_header(&map, &mut buf).unwrap();
assert_eq!(buf, b"k1: v1\n\r\n");
}
}