use crate::exception::{Error, Result};
use async_trait::async_trait;
use bytes::Bytes;
use http::HeaderMap;
use serde_json::Value;
use std::collections::HashMap;
pub type ParseError = Error;
pub type ParseResult<T> = Result<T>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MediaType {
pub main_type: String,
pub sub_type: String,
pub parameters: HashMap<String, String>,
}
impl MediaType {
pub fn new(main_type: impl Into<String>, sub_type: impl Into<String>) -> Self {
Self {
main_type: main_type.into(),
sub_type: sub_type.into(),
parameters: HashMap::new(),
}
}
pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.parameters.insert(key.into(), value.into());
self
}
pub fn parse(content_type: &str) -> ParseResult<Self> {
let parts: Vec<&str> = content_type.split(';').collect();
if parts.is_empty() {
return Err(Error::Validation(content_type.to_string()));
}
let type_parts: Vec<&str> = parts[0].trim().split('/').collect();
if type_parts.len() != 2 {
return Err(Error::Validation(content_type.to_string()));
}
let mut media_type = MediaType::new(type_parts[0], type_parts[1]);
for part in parts.iter().skip(1) {
let param_parts: Vec<&str> = part.trim().splitn(2, '=').collect();
if param_parts.len() == 2 {
media_type.parameters.insert(
param_parts[0].trim().to_string(),
param_parts[1].trim().to_string(),
);
}
}
Ok(media_type)
}
pub fn matches(&self, pattern: &str) -> bool {
let parts: Vec<&str> = pattern.split('/').collect();
if parts.len() != 2 {
return false;
}
(parts[0] == "*" || parts[0] == self.main_type)
&& (parts[1] == "*" || parts[1] == self.sub_type)
}
}
impl std::fmt::Display for MediaType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}/{}", self.main_type, self.sub_type)?;
for (key, value) in &self.parameters {
write!(f, "; {}={}", key, value)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum ParsedData {
Json(Value),
Xml(Value),
Yaml(Value),
Form(HashMap<String, String>),
MultiPart {
fields: HashMap<String, String>,
files: Vec<UploadedFile>,
},
File(UploadedFile),
MessagePack(Value),
Protobuf(Value),
}
#[derive(Debug, Clone)]
pub struct UploadedFile {
pub name: String,
pub filename: Option<String>,
pub content_type: Option<String>,
pub size: usize,
pub data: Bytes,
}
impl UploadedFile {
pub fn new(name: String, data: Bytes) -> Self {
let size = data.len();
Self {
name,
filename: None,
content_type: None,
size,
data,
}
}
pub fn with_filename(mut self, filename: String) -> Self {
self.filename = Some(filename);
self
}
pub fn with_content_type(mut self, content_type: String) -> Self {
self.content_type = Some(content_type);
self
}
}
#[async_trait]
pub trait Parser: Send + Sync {
fn media_types(&self) -> Vec<String>;
async fn parse(
&self,
content_type: Option<&str>,
body: Bytes,
headers: &HeaderMap,
) -> ParseResult<ParsedData>;
fn can_parse(&self, content_type: Option<&str>) -> bool {
if let Some(ct) = content_type
&& let Ok(media_type) = MediaType::parse(ct)
{
return self.media_types().iter().any(|mt| media_type.matches(mt));
}
false
}
}
#[derive(Default)]
pub struct ParserRegistry {
parsers: Vec<Box<dyn Parser>>,
}
impl ParserRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<P: Parser + 'static>(mut self, parser: P) -> Self {
self.parsers.push(Box::new(parser));
self
}
pub async fn parse(
&self,
content_type: Option<&str>,
body: Bytes,
headers: &HeaderMap,
) -> ParseResult<ParsedData> {
for parser in &self.parsers {
if parser.can_parse(content_type) {
return parser.parse(content_type, body, headers).await;
}
}
Err(Error::Validation(
content_type.unwrap_or("none").to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_media_type_parse() {
let mt = MediaType::parse("application/json").unwrap();
assert_eq!(mt.main_type, "application");
assert_eq!(mt.sub_type, "json");
let mt = MediaType::parse("text/html; charset=utf-8").unwrap();
assert_eq!(mt.main_type, "text");
assert_eq!(mt.sub_type, "html");
assert_eq!(mt.parameters.get("charset"), Some(&"utf-8".to_string()));
}
#[test]
fn test_media_type_matches() {
let mt = MediaType::new("application", "json");
assert!(mt.matches("application/json"));
assert!(mt.matches("application/*"));
assert!(mt.matches("*/json"));
assert!(mt.matches("*/*"));
assert!(!mt.matches("text/html"));
}
}