use core::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UpgradeError {
Empty,
InvalidFormat,
InvalidProtocol,
InvalidVersion,
}
impl fmt::Display for UpgradeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UpgradeError::Empty => write!(f, "empty Upgrade header"),
UpgradeError::InvalidFormat => write!(f, "invalid Upgrade header format"),
UpgradeError::InvalidProtocol => write!(f, "invalid Upgrade protocol"),
UpgradeError::InvalidVersion => write!(f, "invalid Upgrade protocol version"),
}
}
}
impl std::error::Error for UpgradeError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Upgrade {
protocols: Vec<Protocol>,
}
impl Upgrade {
pub fn parse(input: &str) -> Result<Self, UpgradeError> {
let input = input.trim();
let mut protocols = Vec::new();
for part in input.split(',') {
let part = part.trim();
if part.is_empty() {
continue;
}
let (name, version) = if let Some((name, version)) = part.split_once('/') {
if version.contains('/') {
return Err(UpgradeError::InvalidFormat);
}
let name = name.trim();
let version = version.trim();
if name.is_empty() {
return Err(UpgradeError::InvalidProtocol);
}
if version.is_empty() {
return Err(UpgradeError::InvalidVersion);
}
if !is_valid_token(name) {
return Err(UpgradeError::InvalidProtocol);
}
if !is_valid_token(version) {
return Err(UpgradeError::InvalidVersion);
}
(
name.to_ascii_lowercase(),
Some(version.to_ascii_lowercase()),
)
} else {
if !is_valid_token(part) {
return Err(UpgradeError::InvalidProtocol);
}
(part.to_ascii_lowercase(), None)
};
protocols.push(Protocol { name, version });
}
Ok(Upgrade { protocols })
}
pub fn protocols(&self) -> &[Protocol] {
&self.protocols
}
pub fn has_protocol(&self, protocol: &str) -> bool {
self.protocols
.iter()
.any(|p| p.name.eq_ignore_ascii_case(protocol))
}
}
impl fmt::Display for Upgrade {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let values: Vec<String> = self.protocols.iter().map(|item| item.to_string()).collect();
write!(f, "{}", values.join(", "))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Protocol {
name: String,
version: Option<String>,
}
impl Protocol {
pub fn name(&self) -> &str {
&self.name
}
pub fn version(&self) -> Option<&str> {
self.version.as_deref()
}
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.version {
Some(version) => write!(f, "{}/{}", self.name, version),
None => write!(f, "{}", self.name),
}
}
}
fn is_valid_token(s: &str) -> bool {
!s.is_empty() && s.bytes().all(is_token_char)
}
fn is_token_char(b: u8) -> bool {
matches!(
b,
b'!' | b'#' | b'$' | b'%' | b'&' | b'\'' | b'*' | b'+' | b'-' | b'.' |
b'0'..=b'9' | b'A'..=b'Z' | b'^' | b'_' | b'`' | b'a'..=b'z' | b'|' | b'~'
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_simple() {
let upgrade = Upgrade::parse("websocket").unwrap();
assert!(upgrade.has_protocol("websocket"));
assert_eq!(upgrade.protocols().len(), 1);
}
#[test]
fn parse_with_version() {
let upgrade = Upgrade::parse("h2c/1.0, websocket").unwrap();
assert_eq!(upgrade.protocols()[0].name(), "h2c");
assert_eq!(upgrade.protocols()[0].version(), Some("1.0"));
}
#[test]
fn parse_invalid() {
assert!(Upgrade::parse("bad value").is_err());
assert!(Upgrade::parse("websocket/").is_err());
assert!(Upgrade::parse("websocket/1/2").is_err());
}
#[test]
fn parse_empty_elements() {
let upgrade = Upgrade::parse("").unwrap();
assert!(upgrade.protocols().is_empty());
let upgrade = Upgrade::parse(",").unwrap();
assert!(upgrade.protocols().is_empty());
let upgrade = Upgrade::parse("websocket,,h2c").unwrap();
assert_eq!(upgrade.protocols().len(), 2);
}
#[test]
fn display() {
let upgrade = Upgrade::parse("websocket, h2c/1.0").unwrap();
assert_eq!(upgrade.to_string(), "websocket, h2c/1.0");
}
}