use alloc::string::{String, ToString};
use alloc::vec::Vec;
use core::fmt;
use crate::validate::{is_valid_token, trim_ows};
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum UpgradeError {
Empty,
InvalidFormat,
InvalidProtocol,
InvalidVersion,
}
impl fmt::Display for UpgradeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UpgradeError::Empty => write!(f, "empty Upgrade header"),
UpgradeError::InvalidFormat => write!(f, "invalid Upgrade header format"),
UpgradeError::InvalidProtocol => write!(f, "invalid Upgrade protocol"),
UpgradeError::InvalidVersion => write!(f, "invalid Upgrade protocol version"),
}
}
}
impl core::error::Error for UpgradeError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Upgrade {
protocols: Vec<Protocol>,
}
impl Upgrade {
pub fn parse(input: &str) -> Result<Self, UpgradeError> {
let input = trim_ows(input);
let mut protocols = Vec::new();
for part in input.split(',') {
let part = trim_ows(part);
if part.is_empty() {
continue;
}
let (name, version) = if let Some((name, version)) = part.split_once('/') {
if version.contains('/') {
return Err(UpgradeError::InvalidFormat);
}
let name = trim_ows(name);
let version = trim_ows(version);
if name.is_empty() {
return Err(UpgradeError::InvalidProtocol);
}
if version.is_empty() {
return Err(UpgradeError::InvalidVersion);
}
if !is_valid_token(name) {
return Err(UpgradeError::InvalidProtocol);
}
if !is_valid_token(version) {
return Err(UpgradeError::InvalidVersion);
}
(
name.to_ascii_lowercase(),
Some(version.to_ascii_lowercase()),
)
} else {
if !is_valid_token(part) {
return Err(UpgradeError::InvalidProtocol);
}
(part.to_ascii_lowercase(), None)
};
protocols.push(Protocol { name, version });
}
Ok(Upgrade { protocols })
}
pub fn protocols(&self) -> &[Protocol] {
&self.protocols
}
pub fn has_protocol(&self, protocol: &str) -> bool {
self.protocols
.iter()
.any(|p| p.name.eq_ignore_ascii_case(protocol))
}
}
impl fmt::Display for Upgrade {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let values: Vec<String> = self.protocols.iter().map(|item| item.to_string()).collect();
write!(f, "{}", values.join(", "))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Protocol {
name: String,
version: Option<String>,
}
impl Protocol {
pub fn name(&self) -> &str {
&self.name
}
pub fn version(&self) -> Option<&str> {
self.version.as_deref()
}
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.version {
Some(version) => write!(f, "{}/{}", self.name, version),
None => write!(f, "{}", self.name),
}
}
}