#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::{
string::{String, ToString},
vec::Vec,
};
use core::{fmt, str::FromStr};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub enum Vary {
Wildcard,
Headers(Vec<String>),
}
impl Vary {
#[must_use]
pub fn new() -> Self {
Self::Headers(Vec::new())
}
#[must_use]
pub fn wildcard() -> Self {
Self::Wildcard
}
#[must_use]
pub fn is_wildcard(&self) -> bool {
matches!(self, Self::Wildcard)
}
#[must_use]
pub fn headers(&self) -> Option<&[String]> {
match self {
Self::Wildcard => None,
Self::Headers(h) => Some(h),
}
}
pub fn add(&mut self, name: impl Into<String>) {
if let Self::Headers(headers) = self {
let name = name.into();
let lower = name.to_lowercase();
if !headers.iter().any(|h| h.to_lowercase() == lower) {
headers.push(name);
}
}
}
pub fn remove(&mut self, name: &str) -> bool {
if let Self::Headers(headers) = self {
let lower = name.to_lowercase();
let before = headers.len();
headers.retain(|h| h.to_lowercase() != lower);
return headers.len() < before;
}
false
}
#[must_use]
pub fn contains(&self, name: &str) -> bool {
match self {
Self::Wildcard => false,
Self::Headers(headers) => {
let lower = name.to_lowercase();
headers.iter().any(|h| h.to_lowercase() == lower)
}
}
}
#[must_use]
pub fn len(&self) -> Option<usize> {
match self {
Self::Wildcard => None,
Self::Headers(h) => Some(h.len()),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
match self {
Self::Wildcard => false,
Self::Headers(h) => h.is_empty(),
}
}
}
impl Default for Vary {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for Vary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Wildcard => f.write_str("*"),
Self::Headers(headers) => {
for (i, h) in headers.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
f.write_str(h)?;
}
Ok(())
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseVaryError;
impl fmt::Display for ParseVaryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("invalid Vary header")
}
}
#[cfg(feature = "std")]
impl std::error::Error for ParseVaryError {}
impl FromStr for Vary {
type Err = ParseVaryError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.trim();
if s == "*" {
return Ok(Self::Wildcard);
}
let mut vary = Self::new();
for part in s.split(',') {
let part = part.trim();
if !part.is_empty() {
vary.add(part.to_string());
}
}
Ok(vary)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_is_empty() {
let v = Vary::new();
assert!(v.is_empty());
assert!(!v.is_wildcard());
assert_eq!(v.to_string(), "");
}
#[test]
fn wildcard() {
let v = Vary::wildcard();
assert!(v.is_wildcard());
assert!(!v.is_empty());
assert_eq!(v.to_string(), "*");
}
#[test]
fn add_and_contains() {
let mut v = Vary::new();
v.add("Accept");
v.add("Accept-Encoding");
assert!(v.contains("Accept"));
assert!(v.contains("accept"));
assert!(v.contains("ACCEPT-ENCODING"));
assert!(!v.contains("Content-Type"));
assert_eq!(v.len(), Some(2));
}
#[test]
fn add_deduplicates() {
let mut v = Vary::new();
v.add("Accept");
v.add("accept");
assert_eq!(v.len(), Some(1));
assert_eq!(v.to_string(), "Accept");
}
#[test]
fn remove_present() {
let mut v = Vary::new();
v.add("Accept");
v.add("Content-Type");
assert!(v.remove("accept"));
assert!(!v.contains("Accept"));
assert_eq!(v.len(), Some(1));
}
#[test]
fn remove_absent_returns_false() {
let mut v = Vary::new();
v.add("Accept");
assert!(!v.remove("Content-Type"));
}
#[test]
fn add_on_wildcard_is_noop() {
let mut v = Vary::wildcard();
v.add("Accept");
assert!(v.is_wildcard());
}
#[test]
fn remove_on_wildcard_returns_false() {
let mut v = Vary::wildcard();
assert!(!v.remove("Accept"));
}
#[test]
fn display_multiple() {
let mut v = Vary::new();
v.add("Accept");
v.add("Accept-Encoding");
assert_eq!(v.to_string(), "Accept, Accept-Encoding");
}
#[test]
fn parse_wildcard() {
let v: Vary = "*".parse().unwrap();
assert!(v.is_wildcard());
}
#[test]
fn parse_header_list() {
let v: Vary = "Accept, Accept-Encoding".parse().unwrap();
assert!(v.contains("Accept"));
assert!(v.contains("Accept-Encoding"));
assert_eq!(v.len(), Some(2));
}
#[test]
fn roundtrip() {
let mut v = Vary::new();
v.add("Accept");
v.add("Origin");
let s = v.to_string();
let parsed: Vary = s.parse().unwrap();
assert_eq!(parsed, v);
}
#[test]
fn default_is_empty_headers() {
let v = Vary::default();
assert!(v.is_empty());
assert!(!v.is_wildcard());
}
#[test]
fn headers_returns_slice_for_headers_variant() {
let mut v = Vary::new();
v.add("Accept");
v.add("Content-Type");
let h = v.headers().unwrap();
assert_eq!(h.len(), 2);
assert_eq!(h[0], "Accept");
assert_eq!(h[1], "Content-Type");
}
#[test]
fn headers_returns_none_for_wildcard() {
let v = Vary::wildcard();
assert!(v.headers().is_none());
}
#[test]
fn parse_vary_error_display() {
let e = ParseVaryError;
assert_eq!(e.to_string(), "invalid Vary header");
}
#[test]
fn len_returns_none_for_wildcard() {
assert_eq!(Vary::wildcard().len(), None);
}
#[test]
fn contains_returns_false_for_wildcard() {
let v = Vary::wildcard();
assert!(!v.contains("Accept"));
}
}