quilkin/net/
endpoint.rs

1/*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *       http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *  Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17//! Types representing where the data is the sent.
18
19pub(crate) mod address;
20pub mod metadata;
21
22use crate::net::cluster::proto;
23use eyre::ContextCompat;
24use serde::{Deserialize, Serialize};
25
26pub use self::{
27    address::{AddressKind, EndpointAddress},
28    metadata::DynamicMetadata,
29};
30
31pub use quilkin_xds::locality::Locality;
32
33pub type EndpointMetadata = metadata::MetadataView<Metadata>;
34pub use base64_set::Set;
35
36/// A destination endpoint with any associated metadata.
37#[derive(Debug, Deserialize, Serialize, PartialEq, Clone, Eq, schemars::JsonSchema)]
38#[non_exhaustive]
39#[serde(deny_unknown_fields)]
40pub struct Endpoint {
41    #[schemars(with = "String")]
42    pub address: EndpointAddress,
43    #[serde(default)]
44    pub metadata: EndpointMetadata,
45}
46
47impl Endpoint {
48    /// Creates a new [`Endpoint`] with no metadata.
49    pub fn new(address: EndpointAddress) -> Self {
50        Self {
51            address,
52            ..<_>::default()
53        }
54    }
55
56    /// Creates a new [`Endpoint`] with the specified `metadata`.
57    pub fn with_metadata(address: EndpointAddress, metadata: impl Into<EndpointMetadata>) -> Self {
58        Self {
59            address,
60            metadata: metadata.into(),
61            ..<_>::default()
62        }
63    }
64
65    #[inline]
66    pub fn from_proto(proto: proto::Endpoint) -> eyre::Result<Self> {
67        let host: AddressKind = if let Some(host) = proto.host2 {
68            match host.inner.context("should be unreachable")? {
69                proto::host::Inner::Name(name) => AddressKind::Name(name),
70                proto::host::Inner::Ipv4(v4) => {
71                    AddressKind::Ip(std::net::Ipv4Addr::from(v4).into())
72                }
73                proto::host::Inner::Ipv6(v6) => AddressKind::Ip(
74                    std::net::Ipv6Addr::from(((v6.first as u128) << 64) | v6.second as u128).into(),
75                ),
76            }
77        } else {
78            proto.host.parse()?
79        };
80
81        Ok(Self {
82            address: (host, proto.port as u16).into(),
83            metadata: proto
84                .metadata
85                .map(TryFrom::try_from)
86                .transpose()?
87                .unwrap_or_default(),
88        })
89    }
90
91    #[inline]
92    pub fn into_proto(self) -> proto::Endpoint {
93        let host = match self.address.host {
94            AddressKind::Name(name) => proto::host::Inner::Name(name),
95            AddressKind::Ip(ip) => match ip {
96                std::net::IpAddr::V4(v4) => {
97                    proto::host::Inner::Ipv4(u32::from_be_bytes(v4.octets()))
98                }
99                std::net::IpAddr::V6(v6) => {
100                    let ip = u128::from_be_bytes(v6.octets());
101
102                    let first = ((ip >> 64) & 0xffffffffffffffff) as u64;
103                    let second = (ip & 0xffffffffffffffff) as u64;
104
105                    proto::host::Inner::Ipv6(proto::Ipv6 { first, second })
106                }
107            },
108        };
109
110        proto::Endpoint {
111            host: String::new(),
112            port: self.address.port.into(),
113            metadata: Some(self.metadata.into()),
114            host2: Some(proto::Host { inner: Some(host) }),
115        }
116    }
117}
118
119impl Default for Endpoint {
120    fn default() -> Self {
121        Self {
122            address: EndpointAddress::UNSPECIFIED,
123            metadata: <_>::default(),
124        }
125    }
126}
127
128impl std::str::FromStr for Endpoint {
129    type Err = <EndpointAddress as std::str::FromStr>::Err;
130
131    fn from_str(s: &str) -> Result<Self, Self::Err> {
132        Ok(Self {
133            address: s.parse()?,
134            ..Self::default()
135        })
136    }
137}
138
139impl From<Endpoint> for proto::Endpoint {
140    fn from(endpoint: Endpoint) -> Self {
141        Self {
142            host: endpoint.address.host.to_string(),
143            port: endpoint.address.port.into(),
144            metadata: Some(endpoint.metadata.into()),
145            host2: None,
146        }
147    }
148}
149
150impl TryFrom<proto::Endpoint> for Endpoint {
151    type Error = eyre::Error;
152
153    fn try_from(endpoint: proto::Endpoint) -> Result<Self, Self::Error> {
154        let host: address::AddressKind = endpoint.host.parse()?;
155        if endpoint.port > u16::MAX as u32 {
156            return Err(eyre::eyre!("invalid endpoint port"));
157        }
158
159        Ok(Self {
160            address: (host, endpoint.port as u16).into(),
161            metadata: endpoint
162                .metadata
163                .map(TryFrom::try_from)
164                .transpose()?
165                .unwrap_or_default(),
166        })
167    }
168}
169
170impl std::cmp::PartialEq<EndpointAddress> for Endpoint {
171    fn eq(&self, rhs: &EndpointAddress) -> bool {
172        self.address == *rhs
173    }
174}
175
176impl<T: Into<EndpointAddress>> From<T> for Endpoint {
177    fn from(value: T) -> Self {
178        Self::new(value.into())
179    }
180}
181
182impl Ord for Endpoint {
183    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
184        self.address.cmp(&other.address)
185    }
186}
187
188impl PartialOrd for Endpoint {
189    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
190        Some(self.cmp(other))
191    }
192}
193
194impl std::hash::Hash for Endpoint {
195    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
196        self.address.hash(state);
197        self.metadata.known.tokens.hash(state);
198    }
199}
200
201/// Metadata specific to endpoints.
202#[derive(
203    Default, Debug, Deserialize, Serialize, PartialEq, Clone, PartialOrd, Eq, schemars::JsonSchema,
204)]
205pub struct Metadata {
206    #[serde(
207        serialize_with = "base64_set::serialize",
208        deserialize_with = "base64_set::deserialize"
209    )]
210    pub tokens: base64_set::Set,
211}
212
213impl From<Metadata> for crate::net::endpoint::metadata::MetadataView<Metadata> {
214    fn from(metadata: Metadata) -> Self {
215        Self {
216            known: metadata,
217            ..<_>::default()
218        }
219    }
220}
221
222impl From<Metadata> for prost_types::Struct {
223    fn from(metadata: Metadata) -> Self {
224        let tokens = prost_types::Value {
225            kind: Some(prost_types::value::Kind::ListValue(
226                prost_types::ListValue {
227                    values: metadata
228                        .tokens
229                        .into_iter()
230                        .map(crate::codec::base64::encode)
231                        .map(prost_types::value::Kind::StringValue)
232                        .map(|k| prost_types::Value { kind: Some(k) })
233                        .collect(),
234                },
235            )),
236        };
237
238        Self {
239            fields: <_>::from([("tokens".into(), tokens)]),
240        }
241    }
242}
243
244impl std::convert::TryFrom<prost_types::Struct> for Metadata {
245    type Error = MetadataError;
246
247    fn try_from(mut value: prost_types::Struct) -> Result<Self, Self::Error> {
248        use prost_types::value::Kind;
249        const TOKENS: &str = "tokens";
250
251        let tokens =
252            if let Some(kind) = value.fields.remove(TOKENS).and_then(|v| v.kind) {
253                match kind {
254                    Kind::ListValue(list) => list
255                        .values
256                        .into_iter()
257                        .filter_map(|v| v.kind)
258                        .map(|kind| {
259                            if let Kind::StringValue(string) = kind {
260                                crate::codec::base64::decode(string)
261                                    .map_err(MetadataError::InvalidBase64)
262                            } else {
263                                Err(MetadataError::InvalidType {
264                                    key: "quilkin.dev.tokens",
265                                    expected: "base64 string",
266                                })
267                            }
268                        })
269                        .collect::<Result<_, _>>()?,
270                    Kind::StringValue(string) => <_>::from([crate::codec::base64::decode(string)
271                        .map_err(MetadataError::InvalidBase64)?]),
272                    _ => return Err(MetadataError::MissingKey(TOKENS)),
273                }
274            } else {
275                <_>::default()
276            };
277
278        Ok(Self { tokens })
279    }
280}
281
282#[derive(Debug, Clone, thiserror::Error)]
283pub enum MetadataError {
284    #[error("Invalid bas64 encoded token: `{0}`.")]
285    InvalidBase64(base64::DecodeError),
286    #[error("Missing required key `{0}`.")]
287    MissingKey(&'static str),
288    #[error("Invalid type ({expected}) given for `{key}`.")]
289    InvalidType {
290        key: &'static str,
291        expected: &'static str,
292    },
293}
294
295/// A module for providing base64 encoding for a `BTreeSet` at the `serde`
296/// boundary. Accepts a list of strings representing Base64 encoded data,
297/// this list is then converted into its binary representation while in memory,
298/// and then encoded back as a list of base64 strings.
299mod base64_set {
300    use serde::de::Error;
301
302    pub type Set = std::collections::BTreeSet<Vec<u8>>;
303
304    pub fn serialize<S>(set: &Set, ser: S) -> Result<S::Ok, S::Error>
305    where
306        S: serde::Serializer,
307    {
308        ser.collect_seq(set.iter().map(crate::codec::base64::encode))
309    }
310
311    pub fn deserialize<'de, D>(de: D) -> Result<Set, D::Error>
312    where
313        D: serde::Deserializer<'de>,
314    {
315        struct TokenVisitor;
316
317        impl<'de> serde::de::Visitor<'de> for TokenVisitor {
318            type Value = Set;
319
320            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
321                formatter.write_str("an array of base64 encoded tokens")
322            }
323
324            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
325            where
326                A: serde::de::SeqAccess<'de>,
327            {
328                let mut set = Set::new();
329
330                while let Some(token) = seq.next_element::<std::borrow::Cow<'_, str>>()? {
331                    let decoded =
332                        crate::codec::base64::decode(token.as_ref()).map_err(Error::custom)?;
333
334                    if !set.insert(decoded) {
335                        return Err(Error::custom(
336                            "Found duplicate tokens in endpoint metadata.",
337                        ));
338                    }
339                }
340
341                Ok(set)
342            }
343        }
344
345        de.deserialize_seq(TokenVisitor)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn endpoint_metadata() {
355        let metadata = Metadata {
356            tokens: vec!["Man".into()].into_iter().collect(),
357        };
358
359        assert_eq!(
360            serde_json::to_value(EndpointMetadata::from(metadata)).unwrap(),
361            serde_json::json!({
362                crate::net::endpoint::metadata::KEY: {
363                    "tokens": ["TWFu"],
364                }
365            })
366        );
367    }
368
369    #[test]
370    fn parse_dns_endpoints() {
371        let localhost = "address: localhost:80";
372        serde_yaml::from_str::<Endpoint>(localhost).unwrap();
373    }
374
375    #[test]
376    fn yaml_parse_invalid_endpoint_metadata() {
377        let not_a_list = "
378 quilkin.dev:
379     tokens: OGdqM3YyaQ==
380 ";
381        let not_a_string_value = "
382 quilkin.dev:
383     tokens:
384         - map:
385           a: b
386 ";
387        let not_a_base64_string = "
388 quilkin.dev:
389     tokens:
390         - OGdqM3YyaQ== #8gj3v2i
391         - iix
392 ";
393        for yaml in &[not_a_list, not_a_string_value, not_a_base64_string] {
394            serde_yaml::from_str::<EndpointMetadata>(yaml).unwrap_err();
395        }
396    }
397
398    // Sanity check conversion between endpoint <-> proto works
399    #[test]
400    fn endpoint_proto_conversion() {
401        let first = Endpoint::new(EndpointAddress {
402            host: AddressKind::Ip(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
403                0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0xab, 0xcd,
404            ))),
405            port: 2001,
406        });
407
408        let expected = first.clone();
409        let proto = first.into_proto();
410        let actual = Endpoint::from_proto(proto).unwrap();
411
412        assert_eq!(expected, actual);
413    }
414}