kitsune2_api/
url.rs

1//! Url-related types.
2
3use crate::*;
4
5// We're using bytes::Bytes as the storage type for urls instead of String,
6// even though it adds a little complexity overhead to the accessor functions
7// here, because Bytes are more cheaply clone-able, and we need it to be bytes
8// for the protobuf wire message types.
9
10/// A validated Kitsune2 Url.
11#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
12pub struct Url(bytes::Bytes);
13
14impl serde::Serialize for Url {
15    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
16    where
17        S: serde::Serializer,
18    {
19        serializer.serialize_str(self.as_str())
20    }
21}
22
23impl<'de> serde::Deserialize<'de> for Url {
24    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
25    where
26        D: serde::Deserializer<'de>,
27    {
28        struct V;
29
30        impl serde::de::Visitor<'_> for V {
31            type Value = bytes::Bytes;
32
33            fn expecting(
34                &self,
35                f: &mut std::fmt::Formatter,
36            ) -> std::fmt::Result {
37                f.write_str("a valid Kitsune2 Url")
38            }
39
40            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
41            where
42                E: serde::de::Error,
43            {
44                Ok(bytes::Bytes::copy_from_slice(v))
45            }
46
47            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
48            where
49                E: serde::de::Error,
50            {
51                Ok(bytes::Bytes::copy_from_slice(v.as_bytes()))
52            }
53        }
54
55        let b = deserializer.deserialize_bytes(V)?;
56
57        Url::new(b).map_err(serde::de::Error::custom)
58    }
59}
60
61impl From<Url> for bytes::Bytes {
62    fn from(u: Url) -> Self {
63        u.0
64    }
65}
66
67impl From<&Url> for bytes::Bytes {
68    fn from(u: &Url) -> Self {
69        u.0.clone()
70    }
71}
72
73impl AsRef<str> for Url {
74    fn as_ref(&self) -> &str {
75        self.as_str()
76    }
77}
78
79impl std::convert::TryFrom<bytes::Bytes> for Url {
80    type Error = K2Error;
81
82    fn try_from(b: bytes::Bytes) -> Result<Self, Self::Error> {
83        Self::new(b)
84    }
85}
86
87impl std::fmt::Display for Url {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.write_str(self.as_str())
90    }
91}
92
93impl std::fmt::Debug for Url {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "{}", self)
96    }
97}
98
99impl std::str::FromStr for Url {
100    type Err = K2Error;
101
102    fn from_str(src: &str) -> Result<Self, Self::Err> {
103        Self::from_str(src)
104    }
105}
106
107impl Url {
108    /// Construct a new validated Kitsune2 Url.
109    pub fn new(src: bytes::Bytes) -> K2Result<Self> {
110        let str_src = std::str::from_utf8(&src).map_err(|err| {
111            K2Error::other_src("Kitsne2 Url is not valid utf8", err)
112        })?;
113
114        let parsed = ::url::Url::parse(str_src).map_err(|err| {
115            K2Error::other_src("Could not parse as Kitsune2 Url", err)
116        })?;
117
118        let scheme = match parsed.scheme() {
119            scheme @ "ws" | scheme @ "wss" => scheme,
120            oth => {
121                return Err(K2Error::other(format!(
122                    "Invalid Kitsune2 Url Scheme: {oth}",
123                )));
124            }
125        };
126
127        let host = match parsed.host_str() {
128            Some(host) => host,
129            None => {
130                return Err(K2Error::other(
131                    "Invalid Kitsune2 Url, Missing Host",
132                ));
133            }
134        };
135
136        let port = match parsed.port_or_known_default() {
137            Some(port) => port,
138            None => {
139                return Err(K2Error::other(
140                    "Invalid Kitsune2 Url, Explicit Port Required",
141                ));
142            }
143        };
144
145        let path = parsed.path();
146
147        if path.split('/').count() != 2 {
148            return Err(K2Error::other(
149                "Invalid Kitsune2 Url, path must contain exactly 1 slash",
150            ));
151        }
152
153        let canonical = if path == "/" {
154            format!("{scheme}://{host}:{port}")
155        } else {
156            format!("{scheme}://{host}:{port}{path}")
157        };
158
159        if str_src != canonical.as_str() {
160            return Err(K2Error::other(format!(
161                "Invalid Kitsune2 Url, Non-Canonical. Expected: {canonical}. Got: {str_src}",
162            )));
163        }
164
165        Ok(Self(src))
166    }
167
168    /// Construct a new validated Kitsune2 Url from a str.
169    // We *do* also implement the trait. But it's not as usable,
170    // so implement a better local version as well.
171    #[allow(clippy::should_implement_trait)]
172    pub fn from_str<S: AsRef<str>>(src: S) -> K2Result<Self> {
173        Self::new(bytes::Bytes::copy_from_slice(src.as_ref().as_bytes()))
174    }
175
176    /// Get this url as a str.
177    pub fn as_str(&self) -> &str {
178        // we've already checked it is valid utf8 in the constructor.
179        unsafe { std::str::from_utf8_unchecked(&self.0) }
180    }
181
182    /// Returns true if the protocol scheme is `wss`.
183    pub fn uses_tls(&self) -> bool {
184        &self.0[..3] == b"wss"
185    }
186
187    /// Returns true if this is a peer url. Otherwise, this is a server url.
188    pub fn is_peer(&self) -> bool {
189        self.peer_id().is_some()
190    }
191
192    /// Returns the peer id if this is a peer url.
193    pub fn peer_id(&self) -> Option<&str> {
194        match self.as_str().split_once("://") {
195            None => None,
196            Some((_, r)) => match r.rsplit_once('/') {
197                None => None,
198                Some((_, r)) => {
199                    if r.is_empty() {
200                        None
201                    } else {
202                        Some(r)
203                    }
204                }
205            },
206        }
207    }
208
209    /// Returns the host:port to use for connecting to this url.
210    pub fn addr(&self) -> &str {
211        // unwraps in here because this has all been validated by constructor
212        let addr = self.as_str().split_once("://").unwrap().1;
213        match addr.split_once('/') {
214            None => addr,
215            Some((addr, _)) => addr,
216        }
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use super::*;
223
224    #[test]
225    fn happy_serialize() {
226        const URL: &str = "wss://test.com:443";
227        let u = Url::from_str(URL).unwrap();
228        let e = serde_json::to_string(&u).unwrap();
229        assert_eq!(format!("\"{URL}\""), e);
230        let d: Url = serde_json::from_str(&e).unwrap();
231        assert_eq!(d, u);
232    }
233
234    #[test]
235    fn fixture_parse() {
236        const F: &[(&str, Option<&str>, bool, &str)] = &[
237            ("ws://a.b:80", None, false, "a.b:80"),
238            ("ws://1.1.1.1:80", None, false, "1.1.1.1:80"),
239            ("ws://[::1]:80", None, false, "[::1]:80"),
240            ("wss://a.b:443", None, true, "a.b:443"),
241            ("ws://a.b:999", None, false, "a.b:999"),
242            ("ws://a.b:80/foo", Some("foo"), false, "a.b:80"),
243            ("wss://a.b:443/foo", Some("foo"), true, "a.b:443"),
244            ("ws://a.b:999/foo", Some("foo"), false, "a.b:999"),
245        ];
246
247        for (s, id, tls, addr) in F.iter() {
248            let u = Url::from_str(s).unwrap();
249            assert_eq!(s, &u.as_str());
250            assert_eq!(id, &u.peer_id());
251            assert_eq!(tls, &u.uses_tls());
252            assert_eq!(addr, &u.addr());
253        }
254    }
255
256    #[test]
257    fn fixture_no_parse() {
258        const F: &[&str] = &[
259            "ws://a.b",
260            "wss://a.b",
261            "w://a.b:80",
262            "ws://a.b:80/",
263            "ws://a.b:80/foo/bar",
264        ];
265
266        for s in F.iter() {
267            assert!(Url::from_str(s).is_err());
268        }
269    }
270}