use std::cmp::Ordering;
use std::collections::HashMap;
use crate::metadata::{CodeCase, FieldCase, FormatConfig};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(super) enum ContentType {
Json,
Html,
GraphQL,
Text,
JsonRpc,
}
impl ContentType {
pub(super) fn mime_type(self) -> &'static str {
match self {
ContentType::Json => "application/json",
ContentType::Html => "text/html",
ContentType::GraphQL => "application/graphql-response+json",
ContentType::Text => "text/plain",
ContentType::JsonRpc => "application/json-rpc",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Renderer {
Json,
Html,
GraphQL,
Text,
JsonRpc,
}
impl Renderer {
pub(super) fn content_type(&self) -> ContentType {
match self {
Renderer::Json => ContentType::Json,
Renderer::Html => ContentType::Html,
Renderer::GraphQL => ContentType::GraphQL,
Renderer::Text => ContentType::Text,
Renderer::JsonRpc => ContentType::JsonRpc,
}
}
fn default_format_config(self) -> FormatConfig {
match self {
Renderer::GraphQL => FormatConfig {
field_case: FieldCase::CamelCase,
code_case: CodeCase::ScreamingSnakeCase,
},
Renderer::Json | Renderer::Html | Renderer::Text | Renderer::JsonRpc => {
FormatConfig::default()
}
}
}
}
#[derive(Debug, Clone)]
pub(super) struct MediaType {
pub(super) ty: String,
pub(super) subtype: String,
pub(super) quality: f32,
}
impl MediaType {
pub(super) fn parse(s: &str) -> Option<Self> {
let s = s.trim();
let (media, params) = s.split_once(';').unwrap_or((s, ""));
let (type_, subtype) = media.trim().split_once('/')?;
let mut quality = 1.0;
for param in params.split(';') {
let param = param.trim();
if let Some(q) = param.strip_prefix("q=")
&& let Ok(parsed) = q.parse::<f32>()
{
if parsed.is_finite() {
quality = parsed.clamp(0.0, 1.0);
}
}
}
Some(MediaType {
ty: type_.trim().to_lowercase(),
subtype: subtype.trim().to_lowercase(),
quality,
})
}
pub(super) fn matches(&self, content_type: ContentType) -> bool {
let mime = content_type.mime_type();
let (ct_type, ct_subtype) = mime.split_once('/').unwrap();
let type_matches = self.ty == "*" || self.ty == ct_type;
let subtype_matches = self.subtype == "*" || self.subtype == ct_subtype;
type_matches && subtype_matches
}
fn specificity(&self) -> u8 {
match (self.ty.as_str(), self.subtype.as_str()) {
("*", "*") => 0,
(_, "*") | ("*", _) => 1,
_ => 2,
}
}
}
#[derive(Debug, Clone)]
pub struct NegotiationConfig {
mappings: Vec<(String, Renderer)>,
fallback: Renderer,
format_overrides: HashMap<Renderer, FormatConfig>,
}
impl Default for NegotiationConfig {
fn default() -> Self {
Self {
mappings: vec![
("application/json".to_string(), Renderer::Json),
("text/html".to_string(), Renderer::Html),
(
"application/graphql-response+json".to_string(),
Renderer::GraphQL,
),
("application/json+graphql".to_string(), Renderer::GraphQL),
("text/plain".to_string(), Renderer::Text),
("application/json-rpc".to_string(), Renderer::JsonRpc),
],
fallback: Renderer::Json,
format_overrides: HashMap::new(),
}
}
}
impl NegotiationConfig {
pub fn new() -> Self {
Self::default()
}
#[must_use = "builder methods take self by value and return the modified value"]
pub fn with_mapping(mut self, media_type: impl Into<String>, renderer: Renderer) -> Self {
self.mappings.insert(0, (media_type.into(), renderer));
self
}
#[must_use = "builder methods take self by value and return the modified value"]
pub fn with_fallback(mut self, renderer: Renderer) -> Self {
self.fallback = renderer;
self
}
#[must_use = "builder methods take self by value and return the modified value"]
pub fn with_format_config(mut self, renderer: Renderer, format_config: FormatConfig) -> Self {
self.format_overrides.insert(renderer, format_config);
self
}
pub(super) fn format_for(&self, renderer: Renderer) -> FormatConfig {
self.format_overrides
.get(&renderer)
.copied()
.unwrap_or_else(|| renderer.default_format_config())
}
pub fn negotiate(&self, accept_header: Option<&str>) -> Renderer {
self.negotiate_with_content_type(accept_header).0
}
pub fn negotiate_with_content_type(&self, accept_header: Option<&str>) -> (Renderer, String) {
let Some(accept) = accept_header else {
return (
self.fallback,
self.fallback.content_type().mime_type().to_string(),
);
};
let mut media_types: Vec<MediaType> =
accept.split(',').filter_map(MediaType::parse).collect();
media_types.sort_by(|a, b| {
match b.quality.partial_cmp(&a.quality) {
Some(Ordering::Equal) => {}
Some(ord) => return ord,
None => {
unreachable!(
"Quality values should not be NaN or infinite: a={}, b={}",
a.quality, b.quality
);
}
}
b.specificity().cmp(&a.specificity())
});
for media_type in &media_types {
for (pattern, renderer) in &self.mappings {
if let Some(pattern_media) = MediaType::parse(pattern) {
if media_type.ty == pattern_media.ty
&& media_type.subtype == pattern_media.subtype
{
let content_type = format!("{}/{}", media_type.ty, media_type.subtype);
return (*renderer, content_type);
}
}
let content_type = renderer.content_type();
if media_type.matches(content_type) {
return (*renderer, content_type.mime_type().to_string());
}
}
}
(
self.fallback,
self.fallback.content_type().mime_type().to_string(),
)
}
pub fn build_lookup(&self) -> HashMap<String, Renderer> {
self.mappings
.iter()
.map(|(k, v)| (k.to_lowercase(), *v))
.collect()
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn test_media_type_parse() {
let mt = MediaType::parse("text/html").unwrap();
assert_eq!(mt.ty, "text");
assert_eq!(mt.subtype, "html");
assert_eq!(mt.quality, 1.0);
}
#[test]
fn test_media_type_parse_with_quality() {
let mt = MediaType::parse("text/html;q=0.9").unwrap();
assert_eq!(mt.ty, "text");
assert_eq!(mt.subtype, "html");
assert_eq!(mt.quality, 0.9);
}
#[test]
fn test_media_type_parse_wildcard() {
let mt = MediaType::parse("*/*").unwrap();
assert_eq!(mt.ty, "*");
assert_eq!(mt.subtype, "*");
}
#[test]
fn test_negotiate_json() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("application/json"));
assert_eq!(renderer, Renderer::Json);
}
#[test]
fn test_negotiate_html() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/html"));
assert_eq!(renderer, Renderer::Html);
}
#[test]
fn test_negotiate_graphql() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("application/graphql-response+json"));
assert_eq!(renderer, Renderer::GraphQL);
}
#[test]
fn test_negotiate_graphql_legacy() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("application/json+graphql"));
assert_eq!(renderer, Renderer::GraphQL);
}
#[test]
fn test_negotiate_text() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/plain"));
assert_eq!(renderer, Renderer::Text);
}
#[test]
fn test_negotiate_with_quality() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("application/json;q=0.8, text/html;q=0.9"));
assert_eq!(renderer, Renderer::Html);
}
#[test]
fn test_negotiate_wildcard_fallback() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("*/*"));
assert_eq!(renderer, Renderer::Json);
}
#[test]
fn test_negotiate_no_header_fallback() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(None);
assert_eq!(renderer, Renderer::Json);
}
#[test]
fn test_negotiate_custom_fallback() {
let config = NegotiationConfig::new().with_fallback(Renderer::Text);
let renderer = config.negotiate(None);
assert_eq!(renderer, Renderer::Text);
}
#[test]
fn test_negotiate_multiple_accept() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/html, application/json, */*"));
assert_eq!(renderer, Renderer::Html);
}
#[test]
fn test_negotiate_text_wildcard() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/*"));
assert_eq!(renderer, Renderer::Html);
}
#[test]
fn test_negotiate_application_wildcard() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("application/*"));
assert_eq!(renderer, Renderer::Json);
}
#[test]
fn test_negotiate_wildcard_with_quality() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/*;q=0.5, application/json"));
assert_eq!(renderer, Renderer::Json);
}
#[test]
fn test_negotiate_wildcard_fallback_order() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/*"));
assert_eq!(renderer, Renderer::Html);
}
#[test]
fn test_negotiate_specific_over_wildcard() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/*, text/plain"));
assert_eq!(renderer, Renderer::Text);
}
#[test]
fn test_negotiate_type_wildcard() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("*/json"));
assert_eq!(renderer, Renderer::Json);
}
#[test]
fn test_media_type_matches() {
let mt = MediaType::parse("text/*").unwrap();
assert!(mt.matches(ContentType::Html));
assert!(mt.matches(ContentType::Text));
assert!(!mt.matches(ContentType::Json));
assert!(!mt.matches(ContentType::GraphQL));
}
#[test]
fn test_media_type_matches_application_wildcard() {
let mt = MediaType::parse("application/*").unwrap();
assert!(mt.matches(ContentType::Json));
assert!(mt.matches(ContentType::GraphQL));
assert!(!mt.matches(ContentType::Html));
assert!(!mt.matches(ContentType::Text));
}
#[test]
fn test_media_type_matches_full_wildcard() {
let mt = MediaType::parse("*/*").unwrap();
assert!(mt.matches(ContentType::Json));
assert!(mt.matches(ContentType::Html));
assert!(mt.matches(ContentType::GraphQL));
assert!(mt.matches(ContentType::Text));
}
#[test]
fn test_media_type_parse_nan_quality() {
let mt = MediaType::parse("text/html;q=NaN").unwrap();
assert_eq!(mt.quality, 1.0);
assert!(!mt.quality.is_nan());
}
#[test]
fn test_media_type_parse_infinite_quality() {
let mt = MediaType::parse("text/html;q=inf").unwrap();
assert_eq!(mt.quality, 1.0);
assert!(mt.quality.is_finite());
}
#[test]
fn test_media_type_parse_negative_quality() {
let mt = MediaType::parse("text/html;q=-0.5").unwrap();
assert_eq!(mt.quality, 0.0);
}
#[test]
fn test_media_type_parse_quality_over_one() {
let mt = MediaType::parse("text/html;q=2.5").unwrap();
assert_eq!(mt.quality, 1.0);
}
#[test]
fn test_negotiate_with_nan_quality() {
let config = NegotiationConfig::new();
let renderer = config.negotiate(Some("text/html;q=NaN, application/json;q=0.5"));
assert_eq!(renderer, Renderer::Html);
}
}