rama_http_headers/common/
origin.rs1use std::convert::TryFrom;
2use std::fmt;
3
4use bytes::Bytes;
5use rama_http_types::HeaderValue;
6use rama_http_types::dep::http::uri::{self, Authority, Scheme, Uri};
7
8use crate::Error;
9use crate::util::{IterExt, TryFromValues};
10
11#[derive(Clone, Debug, PartialEq, Eq, Hash)]
29pub struct Origin(OriginOrNull);
30
31derive_header! {
32 Origin(_),
33 name: ORIGIN
34}
35
36#[derive(Clone, Debug, PartialEq, Eq, Hash)]
37enum OriginOrNull {
38 Origin(Scheme, Authority),
39 Null,
40}
41
42impl Origin {
43 pub const NULL: Origin = Origin(OriginOrNull::Null);
45
46 #[inline]
48 pub fn is_null(&self) -> bool {
49 matches!(self.0, OriginOrNull::Null)
50 }
51
52 #[inline]
54 pub fn scheme(&self) -> &str {
55 match self.0 {
56 OriginOrNull::Origin(ref scheme, _) => scheme.as_str(),
57 OriginOrNull::Null => "",
58 }
59 }
60
61 #[inline]
63 pub fn hostname(&self) -> &str {
64 match self.0 {
65 OriginOrNull::Origin(_, ref auth) => auth.host(),
66 OriginOrNull::Null => "",
67 }
68 }
69
70 #[inline]
72 pub fn port(&self) -> Option<u16> {
73 match self.0 {
74 OriginOrNull::Origin(_, ref auth) => auth.port_u16(),
75 OriginOrNull::Null => None,
76 }
77 }
78
79 pub fn try_from_parts(
81 scheme: &str,
82 host: &str,
83 port: impl Into<Option<u16>>,
84 ) -> Result<Self, InvalidOrigin> {
85 struct MaybePort(Option<u16>);
86
87 impl fmt::Display for MaybePort {
88 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
89 if let Some(port) = self.0 {
90 write!(f, ":{}", port)
91 } else {
92 Ok(())
93 }
94 }
95 }
96
97 let bytes = Bytes::from(format!("{}://{}{}", scheme, host, MaybePort(port.into())));
98 HeaderValue::from_maybe_shared(bytes)
99 .ok()
100 .and_then(|val| Self::try_from_value(&val))
101 .ok_or(InvalidOrigin)
102 }
103
104 pub(super) fn try_from_value(value: &HeaderValue) -> Option<Self> {
106 OriginOrNull::try_from_value(value).map(Origin)
107 }
108
109 pub(super) fn to_value(&self) -> HeaderValue {
110 (&self.0).into()
111 }
112}
113
114impl fmt::Display for Origin {
115 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116 match self.0 {
117 OriginOrNull::Origin(ref scheme, ref auth) => write!(f, "{}://{}", scheme, auth),
118 OriginOrNull::Null => f.write_str("null"),
119 }
120 }
121}
122
123rama_utils::macros::error::static_str_error! {
124 #[doc = "origin is not valid"]
125 pub struct InvalidOrigin;
126}
127
128impl OriginOrNull {
129 fn try_from_value(value: &HeaderValue) -> Option<Self> {
130 if value == "null" {
131 return Some(OriginOrNull::Null);
132 }
133
134 let uri = Uri::try_from(value.as_bytes()).ok()?;
135
136 let (scheme, auth) = match uri.into_parts() {
137 uri::Parts {
138 scheme: Some(scheme),
139 authority: Some(auth),
140 path_and_query: None,
141 ..
142 } => (scheme, auth),
143 uri::Parts {
144 scheme: Some(ref scheme),
145 authority: Some(ref auth),
146 path_and_query: Some(ref p),
147 ..
148 } if p == "/" => (scheme.clone(), auth.clone()),
149 _ => {
150 return None;
151 }
152 };
153
154 Some(OriginOrNull::Origin(scheme, auth))
155 }
156}
157
158impl TryFromValues for OriginOrNull {
159 fn try_from_values<'i, I>(values: &mut I) -> Result<Self, Error>
160 where
161 I: Iterator<Item = &'i HeaderValue>,
162 {
163 values
164 .just_one()
165 .and_then(OriginOrNull::try_from_value)
166 .ok_or_else(Error::invalid)
167 }
168}
169
170impl<'a> From<&'a OriginOrNull> for HeaderValue {
171 fn from(origin: &'a OriginOrNull) -> HeaderValue {
172 match origin {
173 OriginOrNull::Origin(scheme, auth) => {
174 let s = format!("{}://{}", scheme, auth);
175 let bytes = Bytes::from(s);
176 HeaderValue::from_maybe_shared(bytes)
177 .expect("Scheme and Authority are valid header values")
178 }
179 OriginOrNull::Null => HeaderValue::from_static("null"),
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::super::{test_decode, test_encode};
189 use super::*;
190
191 #[test]
192 fn origin() {
193 let s = "http://web-platform.test:8000";
194 let origin = test_decode::<Origin>(&[s]).unwrap();
195 assert_eq!(origin.scheme(), "http");
196 assert_eq!(origin.hostname(), "web-platform.test");
197 assert_eq!(origin.port(), Some(8000));
198
199 let headers = test_encode(origin);
200 assert_eq!(headers["origin"], s);
201 }
202
203 #[test]
204 fn null() {
205 assert_eq!(test_decode::<Origin>(&["null"]), Some(Origin::NULL),);
206
207 let headers = test_encode(Origin::NULL);
208 assert_eq!(headers["origin"], "null");
209 }
210}