1use crate::*;
4
5#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub struct Url(bytes::Bytes);
13
14impl serde::Serialize for Url {
15 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
16 where
17 S: serde::Serializer,
18 {
19 serializer.serialize_str(self.as_str())
20 }
21}
22
23impl<'de> serde::Deserialize<'de> for Url {
24 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25 where
26 D: serde::Deserializer<'de>,
27 {
28 struct V;
29
30 impl serde::de::Visitor<'_> for V {
31 type Value = bytes::Bytes;
32
33 fn expecting(
34 &self,
35 f: &mut std::fmt::Formatter,
36 ) -> std::fmt::Result {
37 f.write_str("a valid Kitsune2 Url")
38 }
39
40 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
41 where
42 E: serde::de::Error,
43 {
44 Ok(bytes::Bytes::copy_from_slice(v))
45 }
46
47 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
48 where
49 E: serde::de::Error,
50 {
51 Ok(bytes::Bytes::copy_from_slice(v.as_bytes()))
52 }
53 }
54
55 let b = deserializer.deserialize_bytes(V)?;
56
57 Url::new(b).map_err(serde::de::Error::custom)
58 }
59}
60
61impl From<Url> for bytes::Bytes {
62 fn from(u: Url) -> Self {
63 u.0
64 }
65}
66
67impl From<&Url> for bytes::Bytes {
68 fn from(u: &Url) -> Self {
69 u.0.clone()
70 }
71}
72
73impl AsRef<str> for Url {
74 fn as_ref(&self) -> &str {
75 self.as_str()
76 }
77}
78
79impl std::convert::TryFrom<bytes::Bytes> for Url {
80 type Error = K2Error;
81
82 fn try_from(b: bytes::Bytes) -> Result<Self, Self::Error> {
83 Self::new(b)
84 }
85}
86
87impl std::fmt::Display for Url {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.write_str(self.as_str())
90 }
91}
92
93impl std::fmt::Debug for Url {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "{}", self)
96 }
97}
98
99impl std::str::FromStr for Url {
100 type Err = K2Error;
101
102 fn from_str(src: &str) -> Result<Self, Self::Err> {
103 Self::from_str(src)
104 }
105}
106
107impl Url {
108 pub fn new(src: bytes::Bytes) -> K2Result<Self> {
110 let str_src = std::str::from_utf8(&src).map_err(|err| {
111 K2Error::other_src("Kitsne2 Url is not valid utf8", err)
112 })?;
113
114 let parsed = ::url::Url::parse(str_src).map_err(|err| {
115 K2Error::other_src("Could not parse as Kitsune2 Url", err)
116 })?;
117
118 let scheme = match parsed.scheme() {
119 scheme @ "ws" | scheme @ "wss" => scheme,
120 oth => {
121 return Err(K2Error::other(format!(
122 "Invalid Kitsune2 Url Scheme: {oth}",
123 )));
124 }
125 };
126
127 let host = match parsed.host_str() {
128 Some(host) => host,
129 None => {
130 return Err(K2Error::other(
131 "Invalid Kitsune2 Url, Missing Host",
132 ));
133 }
134 };
135
136 let port = match parsed.port_or_known_default() {
137 Some(port) => port,
138 None => {
139 return Err(K2Error::other(
140 "Invalid Kitsune2 Url, Explicit Port Required",
141 ));
142 }
143 };
144
145 let path = parsed.path();
146
147 if path.split('/').count() != 2 {
148 return Err(K2Error::other(
149 "Invalid Kitsune2 Url, path must contain exactly 1 slash",
150 ));
151 }
152
153 let canonical = if path == "/" {
154 format!("{scheme}://{host}:{port}")
155 } else {
156 format!("{scheme}://{host}:{port}{path}")
157 };
158
159 if str_src != canonical.as_str() {
160 return Err(K2Error::other(format!(
161 "Invalid Kitsune2 Url, Non-Canonical. Expected: {canonical}. Got: {str_src}",
162 )));
163 }
164
165 Ok(Self(src))
166 }
167
168 #[allow(clippy::should_implement_trait)]
172 pub fn from_str<S: AsRef<str>>(src: S) -> K2Result<Self> {
173 Self::new(bytes::Bytes::copy_from_slice(src.as_ref().as_bytes()))
174 }
175
176 pub fn as_str(&self) -> &str {
178 unsafe { std::str::from_utf8_unchecked(&self.0) }
180 }
181
182 pub fn uses_tls(&self) -> bool {
184 &self.0[..3] == b"wss"
185 }
186
187 pub fn is_peer(&self) -> bool {
189 self.peer_id().is_some()
190 }
191
192 pub fn peer_id(&self) -> Option<&str> {
194 match self.as_str().split_once("://") {
195 None => None,
196 Some((_, r)) => match r.rsplit_once('/') {
197 None => None,
198 Some((_, r)) => {
199 if r.is_empty() {
200 None
201 } else {
202 Some(r)
203 }
204 }
205 },
206 }
207 }
208
209 pub fn addr(&self) -> &str {
211 let addr = self.as_str().split_once("://").unwrap().1;
213 match addr.split_once('/') {
214 None => addr,
215 Some((addr, _)) => addr,
216 }
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use super::*;
223
224 #[test]
225 fn happy_serialize() {
226 const URL: &str = "wss://test.com:443";
227 let u = Url::from_str(URL).unwrap();
228 let e = serde_json::to_string(&u).unwrap();
229 assert_eq!(format!("\"{URL}\""), e);
230 let d: Url = serde_json::from_str(&e).unwrap();
231 assert_eq!(d, u);
232 }
233
234 #[test]
235 fn fixture_parse() {
236 const F: &[(&str, Option<&str>, bool, &str)] = &[
237 ("ws://a.b:80", None, false, "a.b:80"),
238 ("ws://1.1.1.1:80", None, false, "1.1.1.1:80"),
239 ("ws://[::1]:80", None, false, "[::1]:80"),
240 ("wss://a.b:443", None, true, "a.b:443"),
241 ("ws://a.b:999", None, false, "a.b:999"),
242 ("ws://a.b:80/foo", Some("foo"), false, "a.b:80"),
243 ("wss://a.b:443/foo", Some("foo"), true, "a.b:443"),
244 ("ws://a.b:999/foo", Some("foo"), false, "a.b:999"),
245 ];
246
247 for (s, id, tls, addr) in F.iter() {
248 let u = Url::from_str(s).unwrap();
249 assert_eq!(s, &u.as_str());
250 assert_eq!(id, &u.peer_id());
251 assert_eq!(tls, &u.uses_tls());
252 assert_eq!(addr, &u.addr());
253 }
254 }
255
256 #[test]
257 fn fixture_no_parse() {
258 const F: &[&str] = &[
259 "ws://a.b",
260 "wss://a.b",
261 "w://a.b:80",
262 "ws://a.b:80/",
263 "ws://a.b:80/foo/bar",
264 ];
265
266 for s in F.iter() {
267 assert!(Url::from_str(s).is_err());
268 }
269 }
270}