use std::fmt::Display;
use http::uri::{Authority, Scheme};
use crate::ValidationError;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub struct Origin {
scheme: Scheme,
authority: Authority,
}
impl Origin {
pub fn new(
scheme: impl TryInto<Scheme, Error: Into<http::Error>>,
authority: impl TryInto<Authority, Error: Into<http::Error>>,
) -> Result<Self, ValidationError> {
let scheme: Scheme = scheme.try_into().map_err(Into::into)?;
if scheme != Scheme::HTTP && scheme != Scheme::HTTPS {
return Err(ValidationError::invalid_uri(format!(
"unsupported scheme: {scheme}, only HTTP and HTTPS schemes are supported",
)));
}
Ok(Self {
scheme,
authority: authority.try_into().map_err(Into::into)?,
})
}
#[must_use]
pub const fn scheme(&self) -> &Scheme {
&self.scheme
}
#[must_use]
pub const fn authority(&self) -> &Authority {
&self.authority
}
#[must_use]
pub fn into_parts(self) -> (Scheme, Authority) {
(self.scheme, self.authority)
}
pub fn port(&self) -> u16 {
if let Some(port) = self.authority.port_u16() {
return port;
}
if self.scheme == Scheme::HTTP {
return 80;
}
if self.scheme == Scheme::HTTPS {
return 443;
}
unreachable!("the scheme is always either http or https")
}
#[expect(
clippy::expect_used,
reason = "the host is always valid, and we are even stricter about valid port than http crate, so this should never fail"
)]
#[must_use]
#[expect(clippy::missing_panics_doc, reason = "impossible panic")]
pub fn with_port(self, port: u16) -> Self {
let host = self.authority.host();
Self::new(self.scheme, format!("{host}:{port}")).expect("Scheme and host are already valid and port is a valid u16")
}
pub fn is_https(&self) -> bool {
self.scheme == Scheme::HTTPS
}
}
impl std::str::FromStr for Origin {
type Err = ValidationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let uri: http::Uri = s.parse().map_err(ValidationError::from)?;
let scheme = uri.scheme().ok_or_else(|| ValidationError::invalid_uri("missing scheme"))?.clone();
let authority = uri
.authority()
.ok_or_else(|| ValidationError::invalid_uri("missing authority"))?
.clone();
Self::new(scheme, authority)
}
}
#[cfg(feature = "serde")]
impl serde::Serialize for Origin {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_str(self)
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for Origin {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
impl Display for Origin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}://", self.scheme)?;
match (self.scheme.as_str(), self.port()) {
("http", 80) | ("https", 443) => write!(f, "{}", self.authority.host()),
_ => write!(f, "{}", self.authority),
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
#[test]
fn test_port() {
let origin_implicit_http = Origin::new("http", "example.com").unwrap();
assert_eq!(origin_implicit_http.port(), 80);
let origin_implicit_https = Origin::new("https", "example.com").unwrap();
assert_eq!(origin_implicit_https.port(), 443);
let origin_explicit = Origin::new("http", "example.com:8080").unwrap();
assert_eq!(origin_explicit.port(), 8080);
let origin_explicit = Origin::new("https", "example.com:8443").unwrap();
assert_eq!(origin_explicit.port(), 8443);
}
#[test]
fn test_origin_display() {
let origin_http = Origin::new("http", "example.com").unwrap();
assert_eq!(format!("{origin_http}"), "http://example.com");
let origin_https = Origin::new("https", "example.com:443").unwrap();
assert_eq!(format!("{origin_https}"), "https://example.com");
let origin_custom = Origin::new("https", "example.com:8443").unwrap();
assert_eq!(format!("{origin_custom}"), "https://example.com:8443");
let origin_ipv6 = Origin::new("https", "[::1]:8443").unwrap();
assert_eq!(format!("{origin_ipv6}"), "https://[::1]:8443");
}
#[test]
fn test_scheme_accessor() {
let origin_http = Origin::new("http", "example.com").unwrap();
assert_eq!(origin_http.scheme().as_str(), "http");
let origin_https = Origin::new("https", "example.com:8443").unwrap();
assert_eq!(origin_https.scheme().as_str(), "https");
}
#[test]
fn test_authority_accessor() {
let origin = Origin::new("https", "example.com:8443").unwrap();
assert_eq!(origin.authority().as_str(), "example.com:8443");
let origin_no_port = Origin::new("http", "example.com").unwrap();
assert_eq!(origin_no_port.authority().as_str(), "example.com");
let origin_ipv6 = Origin::new("https", "[::1]:8080").unwrap();
assert_eq!(origin_ipv6.authority().as_str(), "[::1]:8080");
}
#[test]
fn test_into_parts() {
let origin = Origin::new("https", "example.com:8443").unwrap();
let (scheme, authority) = origin.into_parts();
assert_eq!(scheme.as_str(), "https");
assert_eq!(authority.as_str(), "example.com:8443");
}
#[test]
fn test_with_port() {
let origin = Origin::new("https", "example.com").unwrap();
let with_port = origin.with_port(8443);
assert_eq!(with_port.port(), 8443);
assert_eq!(format!("{with_port}"), "https://example.com:8443");
}
#[test]
#[should_panic(expected = "entered unreachable code: the scheme is always either http or https")]
fn test_with_impossible_scheme() {
let origin = Origin {
scheme: Scheme::from_str("ftp").unwrap(),
authority: Authority::from_static("example.com"),
};
origin.port();
}
#[test]
fn from_str_valid() {
let origin: Origin = "https://example.com:8443".parse().unwrap();
assert_eq!(origin.scheme().as_str(), "https");
assert_eq!(origin.authority().as_str(), "example.com:8443");
}
#[test]
fn from_str_missing_scheme() {
"example.com".parse::<Origin>().unwrap_err();
}
#[test]
fn from_str_unsupported_scheme() {
"ftp://example.com".parse::<Origin>().unwrap_err();
}
#[cfg(feature = "serde")]
mod serde_tests {
use super::*;
#[test]
fn origin_roundtrip() {
let original = Origin::new(Scheme::HTTPS, "example.com:8443").unwrap();
let json = serde_json::to_string(&original).unwrap();
assert_eq!(json, r#""https://example.com:8443""#);
let deserialized: Origin = serde_json::from_str(&json).unwrap();
assert_eq!(original, deserialized);
}
}
}