1pub(crate) mod address;
20pub mod metadata;
21
22use crate::net::cluster::proto;
23use eyre::ContextCompat;
24use serde::{Deserialize, Serialize};
25
26pub use self::{
27 address::{AddressKind, EndpointAddress},
28 metadata::DynamicMetadata,
29};
30
31pub use quilkin_xds::locality::Locality;
32
33pub type EndpointMetadata = metadata::MetadataView<Metadata>;
34pub use base64_set::Set;
35
36#[derive(Debug, Deserialize, Serialize, PartialEq, Clone, Eq, schemars::JsonSchema)]
38#[non_exhaustive]
39#[serde(deny_unknown_fields)]
40pub struct Endpoint {
41 #[schemars(with = "String")]
42 pub address: EndpointAddress,
43 #[serde(default)]
44 pub metadata: EndpointMetadata,
45}
46
47impl Endpoint {
48 pub fn new(address: EndpointAddress) -> Self {
50 Self {
51 address,
52 ..<_>::default()
53 }
54 }
55
56 pub fn with_metadata(address: EndpointAddress, metadata: impl Into<EndpointMetadata>) -> Self {
58 Self {
59 address,
60 metadata: metadata.into(),
61 ..<_>::default()
62 }
63 }
64
65 #[inline]
66 pub fn from_proto(proto: proto::Endpoint) -> eyre::Result<Self> {
67 let host: AddressKind = if let Some(host) = proto.host2 {
68 match host.inner.context("should be unreachable")? {
69 proto::host::Inner::Name(name) => AddressKind::Name(name),
70 proto::host::Inner::Ipv4(v4) => {
71 AddressKind::Ip(std::net::Ipv4Addr::from(v4).into())
72 }
73 proto::host::Inner::Ipv6(v6) => AddressKind::Ip(
74 std::net::Ipv6Addr::from(((v6.first as u128) << 64) | v6.second as u128).into(),
75 ),
76 }
77 } else {
78 proto.host.parse()?
79 };
80
81 Ok(Self {
82 address: (host, proto.port as u16).into(),
83 metadata: proto
84 .metadata
85 .map(TryFrom::try_from)
86 .transpose()?
87 .unwrap_or_default(),
88 })
89 }
90
91 #[inline]
92 pub fn into_proto(self) -> proto::Endpoint {
93 let host = match self.address.host {
94 AddressKind::Name(name) => proto::host::Inner::Name(name),
95 AddressKind::Ip(ip) => match ip {
96 std::net::IpAddr::V4(v4) => {
97 proto::host::Inner::Ipv4(u32::from_be_bytes(v4.octets()))
98 }
99 std::net::IpAddr::V6(v6) => {
100 let ip = u128::from_be_bytes(v6.octets());
101
102 let first = ((ip >> 64) & 0xffffffffffffffff) as u64;
103 let second = (ip & 0xffffffffffffffff) as u64;
104
105 proto::host::Inner::Ipv6(proto::Ipv6 { first, second })
106 }
107 },
108 };
109
110 proto::Endpoint {
111 host: String::new(),
112 port: self.address.port.into(),
113 metadata: Some(self.metadata.into()),
114 host2: Some(proto::Host { inner: Some(host) }),
115 }
116 }
117}
118
119impl Default for Endpoint {
120 fn default() -> Self {
121 Self {
122 address: EndpointAddress::UNSPECIFIED,
123 metadata: <_>::default(),
124 }
125 }
126}
127
128impl std::str::FromStr for Endpoint {
129 type Err = <EndpointAddress as std::str::FromStr>::Err;
130
131 fn from_str(s: &str) -> Result<Self, Self::Err> {
132 Ok(Self {
133 address: s.parse()?,
134 ..Self::default()
135 })
136 }
137}
138
139impl From<Endpoint> for proto::Endpoint {
140 fn from(endpoint: Endpoint) -> Self {
141 Self {
142 host: endpoint.address.host.to_string(),
143 port: endpoint.address.port.into(),
144 metadata: Some(endpoint.metadata.into()),
145 host2: None,
146 }
147 }
148}
149
150impl TryFrom<proto::Endpoint> for Endpoint {
151 type Error = eyre::Error;
152
153 fn try_from(endpoint: proto::Endpoint) -> Result<Self, Self::Error> {
154 let host: address::AddressKind = endpoint.host.parse()?;
155 if endpoint.port > u16::MAX as u32 {
156 return Err(eyre::eyre!("invalid endpoint port"));
157 }
158
159 Ok(Self {
160 address: (host, endpoint.port as u16).into(),
161 metadata: endpoint
162 .metadata
163 .map(TryFrom::try_from)
164 .transpose()?
165 .unwrap_or_default(),
166 })
167 }
168}
169
170impl std::cmp::PartialEq<EndpointAddress> for Endpoint {
171 fn eq(&self, rhs: &EndpointAddress) -> bool {
172 self.address == *rhs
173 }
174}
175
176impl<T: Into<EndpointAddress>> From<T> for Endpoint {
177 fn from(value: T) -> Self {
178 Self::new(value.into())
179 }
180}
181
182impl Ord for Endpoint {
183 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
184 self.address.cmp(&other.address)
185 }
186}
187
188impl PartialOrd for Endpoint {
189 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
190 Some(self.cmp(other))
191 }
192}
193
194impl std::hash::Hash for Endpoint {
195 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
196 self.address.hash(state);
197 self.metadata.known.tokens.hash(state);
198 }
199}
200
201#[derive(
203 Default, Debug, Deserialize, Serialize, PartialEq, Clone, PartialOrd, Eq, schemars::JsonSchema,
204)]
205pub struct Metadata {
206 #[serde(
207 serialize_with = "base64_set::serialize",
208 deserialize_with = "base64_set::deserialize"
209 )]
210 pub tokens: base64_set::Set,
211}
212
213impl From<Metadata> for crate::net::endpoint::metadata::MetadataView<Metadata> {
214 fn from(metadata: Metadata) -> Self {
215 Self {
216 known: metadata,
217 ..<_>::default()
218 }
219 }
220}
221
222impl From<Metadata> for prost_types::Struct {
223 fn from(metadata: Metadata) -> Self {
224 let tokens = prost_types::Value {
225 kind: Some(prost_types::value::Kind::ListValue(
226 prost_types::ListValue {
227 values: metadata
228 .tokens
229 .into_iter()
230 .map(crate::codec::base64::encode)
231 .map(prost_types::value::Kind::StringValue)
232 .map(|k| prost_types::Value { kind: Some(k) })
233 .collect(),
234 },
235 )),
236 };
237
238 Self {
239 fields: <_>::from([("tokens".into(), tokens)]),
240 }
241 }
242}
243
244impl std::convert::TryFrom<prost_types::Struct> for Metadata {
245 type Error = MetadataError;
246
247 fn try_from(mut value: prost_types::Struct) -> Result<Self, Self::Error> {
248 use prost_types::value::Kind;
249 const TOKENS: &str = "tokens";
250
251 let tokens =
252 if let Some(kind) = value.fields.remove(TOKENS).and_then(|v| v.kind) {
253 match kind {
254 Kind::ListValue(list) => list
255 .values
256 .into_iter()
257 .filter_map(|v| v.kind)
258 .map(|kind| {
259 if let Kind::StringValue(string) = kind {
260 crate::codec::base64::decode(string)
261 .map_err(MetadataError::InvalidBase64)
262 } else {
263 Err(MetadataError::InvalidType {
264 key: "quilkin.dev.tokens",
265 expected: "base64 string",
266 })
267 }
268 })
269 .collect::<Result<_, _>>()?,
270 Kind::StringValue(string) => <_>::from([crate::codec::base64::decode(string)
271 .map_err(MetadataError::InvalidBase64)?]),
272 _ => return Err(MetadataError::MissingKey(TOKENS)),
273 }
274 } else {
275 <_>::default()
276 };
277
278 Ok(Self { tokens })
279 }
280}
281
282#[derive(Debug, Clone, thiserror::Error)]
283pub enum MetadataError {
284 #[error("Invalid bas64 encoded token: `{0}`.")]
285 InvalidBase64(base64::DecodeError),
286 #[error("Missing required key `{0}`.")]
287 MissingKey(&'static str),
288 #[error("Invalid type ({expected}) given for `{key}`.")]
289 InvalidType {
290 key: &'static str,
291 expected: &'static str,
292 },
293}
294
295mod base64_set {
300 use serde::de::Error;
301
302 pub type Set = std::collections::BTreeSet<Vec<u8>>;
303
304 pub fn serialize<S>(set: &Set, ser: S) -> Result<S::Ok, S::Error>
305 where
306 S: serde::Serializer,
307 {
308 ser.collect_seq(set.iter().map(crate::codec::base64::encode))
309 }
310
311 pub fn deserialize<'de, D>(de: D) -> Result<Set, D::Error>
312 where
313 D: serde::Deserializer<'de>,
314 {
315 struct TokenVisitor;
316
317 impl<'de> serde::de::Visitor<'de> for TokenVisitor {
318 type Value = Set;
319
320 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
321 formatter.write_str("an array of base64 encoded tokens")
322 }
323
324 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
325 where
326 A: serde::de::SeqAccess<'de>,
327 {
328 let mut set = Set::new();
329
330 while let Some(token) = seq.next_element::<std::borrow::Cow<'_, str>>()? {
331 let decoded =
332 crate::codec::base64::decode(token.as_ref()).map_err(Error::custom)?;
333
334 if !set.insert(decoded) {
335 return Err(Error::custom(
336 "Found duplicate tokens in endpoint metadata.",
337 ));
338 }
339 }
340
341 Ok(set)
342 }
343 }
344
345 de.deserialize_seq(TokenVisitor)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn endpoint_metadata() {
355 let metadata = Metadata {
356 tokens: vec!["Man".into()].into_iter().collect(),
357 };
358
359 assert_eq!(
360 serde_json::to_value(EndpointMetadata::from(metadata)).unwrap(),
361 serde_json::json!({
362 crate::net::endpoint::metadata::KEY: {
363 "tokens": ["TWFu"],
364 }
365 })
366 );
367 }
368
369 #[test]
370 fn parse_dns_endpoints() {
371 let localhost = "address: localhost:80";
372 serde_yaml::from_str::<Endpoint>(localhost).unwrap();
373 }
374
375 #[test]
376 fn yaml_parse_invalid_endpoint_metadata() {
377 let not_a_list = "
378 quilkin.dev:
379 tokens: OGdqM3YyaQ==
380 ";
381 let not_a_string_value = "
382 quilkin.dev:
383 tokens:
384 - map:
385 a: b
386 ";
387 let not_a_base64_string = "
388 quilkin.dev:
389 tokens:
390 - OGdqM3YyaQ== #8gj3v2i
391 - iix
392 ";
393 for yaml in &[not_a_list, not_a_string_value, not_a_base64_string] {
394 serde_yaml::from_str::<EndpointMetadata>(yaml).unwrap_err();
395 }
396 }
397
398 #[test]
400 fn endpoint_proto_conversion() {
401 let first = Endpoint::new(EndpointAddress {
402 host: AddressKind::Ip(std::net::IpAddr::V6(std::net::Ipv6Addr::new(
403 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0xab, 0xcd,
404 ))),
405 port: 2001,
406 });
407
408 let expected = first.clone();
409 let proto = first.into_proto();
410 let actual = Endpoint::from_proto(proto).unwrap();
411
412 assert_eq!(expected, actual);
413 }
414}