mtop_client/
discovery.rs

1use crate::core::MtopError;
2use crate::dns::{DnsClient, Message, MessageId, Name, RecordClass, RecordData, RecordType};
3use rustls_pki_types::ServerName;
4use std::cmp::Ordering;
5use std::collections::HashSet;
6use std::fmt;
7use std::net::{IpAddr, SocketAddr};
8use std::path::PathBuf;
9
10const DNS_A_PREFIX: &str = "dns+";
11const DNS_SRV_PREFIX: &str = "dnssrv+";
12const UNIX_SOCKET_PREFIX: &str = "/";
13
14/// Unique ID and address for a server in a Memcached cluster for indexing responses
15/// or errors and establishing connections.
16#[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
17pub enum ServerID {
18    Name(String),
19    Socket(SocketAddr),
20    Path(PathBuf),
21}
22
23impl ServerID {
24    fn from_host_port<S>(host: S, port: u16) -> Self
25    where
26        S: AsRef<str>,
27    {
28        let host = host.as_ref();
29        if let Ok(ip) = host.parse::<IpAddr>() {
30            Self::Socket(SocketAddr::new(ip, port))
31        } else {
32            Self::Name(format!("{}:{}", host, port))
33        }
34    }
35}
36
37impl From<SocketAddr> for ServerID {
38    fn from(value: SocketAddr) -> Self {
39        Self::Socket(value)
40    }
41}
42
43impl From<(&str, u16)> for ServerID {
44    fn from(value: (&str, u16)) -> Self {
45        Self::from_host_port(value.0, value.1)
46    }
47}
48
49impl From<(String, u16)> for ServerID {
50    fn from(value: (String, u16)) -> Self {
51        Self::from_host_port(value.0, value.1)
52    }
53}
54
55impl fmt::Display for ServerID {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        match self {
58            ServerID::Name(n) => n.fmt(f),
59            ServerID::Socket(s) => s.fmt(f),
60            ServerID::Path(p) => fmt::Debug::fmt(p, f),
61        }
62    }
63}
64
65/// An individual server that is part of a Memcached cluster.
66#[derive(Debug, Clone, Eq, PartialEq, Hash)]
67pub struct Server {
68    id: ServerID,
69    name: Option<ServerName<'static>>,
70}
71
72impl Server {
73    pub fn new(id: ServerID, name: ServerName<'static>) -> Self {
74        Self { id, name: Some(name) }
75    }
76
77    pub fn without_name(id: ServerID) -> Self {
78        Self { id, name: None }
79    }
80
81    pub fn id(&self) -> &ServerID {
82        &self.id
83    }
84
85    pub fn server_name(&self) -> &Option<ServerName<'static>> {
86        &self.name
87    }
88}
89
90impl PartialOrd for Server {
91    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
92        Some(self.cmp(other))
93    }
94}
95
96impl Ord for Server {
97    fn cmp(&self, other: &Self) -> Ordering {
98        self.id.cmp(&other.id)
99    }
100}
101
102impl fmt::Display for Server {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        self.id.fmt(f)
105    }
106}
107
108/// Service discovery implementation for finding Memcached servers using DNS.
109///
110/// Different types of DNS records and different behaviors are used based on the
111/// presence of specific prefixes for hostnames. See `resolve_by_proto` for details.
112pub struct Discovery {
113    client: Box<dyn DnsClient + Send + Sync>,
114}
115
116impl Discovery {
117    pub fn new<C>(client: C) -> Self
118    where
119        C: DnsClient + Send + Sync + 'static,
120    {
121        Self {
122            client: Box::new(client),
123        }
124    }
125
126    /// Resolve a hostname to one or multiple Memcached servers based on DNS records
127    /// and/or the presence of certain prefixes on the hostnames.
128    ///
129    /// * `dns+` will resolve a hostname into multiple A and AAAA records and use the
130    ///   IP addresses from the records as Memcached servers.
131    /// * `dnssrv+` will resolve a hostname into multiple SRV records and use the
132    ///   unresolved targets from the SRV records as Memcached servers. Resolution of
133    ///   the targets to IP addresses will happen at connection time using the system
134    ///   resolver.
135    /// * `/` will resolve a hostname into a UNIX socket path and will use this path
136    ///   as a local Memcached server on a UNIX socket.
137    /// * No prefix with an IPv4 or IPv6 address will use the IP address as a Memcached
138    ///   server.
139    /// * No prefix with a non-IP address will use the host as a Memcached server.
140    ///   Resolution of the host to an IP address will happen at connection time using the
141    ///   system resolver.
142    pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
143        if name.starts_with(DNS_A_PREFIX) {
144            Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?)
145        } else if name.starts_with(DNS_SRV_PREFIX) {
146            Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?)
147        } else if name.starts_with(UNIX_SOCKET_PREFIX) {
148            Ok(Self::resolve_unix_addr(name))
149        } else if let Ok(addr) = name.parse::<SocketAddr>() {
150            Ok(Self::resolve_socket_addr(name, addr)?)
151        } else {
152            Ok(Self::resolve_bare_host(name)?)
153        }
154    }
155
156    async fn resolve_srv(&self, name: &str) -> Result<Vec<Server>, MtopError> {
157        let (host, port) = Self::host_and_port(name)?;
158        let server_name = Self::server_name(host)?;
159        let name = host.parse()?;
160        let id = MessageId::random();
161
162        let res = self.client.resolve(id, name, RecordType::SRV, RecordClass::INET).await?;
163        Ok(Self::servers_from_answers(port, &server_name, &res))
164    }
165
166    async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
167        let (host, port) = Self::host_and_port(name)?;
168        let server_name = Self::server_name(host)?;
169        let name: Name = host.parse()?;
170        let id = MessageId::random();
171
172        let res = self.client.resolve(id, name.clone(), RecordType::A, RecordClass::INET).await?;
173        let mut out = Self::servers_from_answers(port, &server_name, &res);
174
175        let res = self.client.resolve(id, name, RecordType::AAAA, RecordClass::INET).await?;
176        out.extend(Self::servers_from_answers(port, &server_name, &res));
177
178        Ok(out)
179    }
180
181    fn resolve_unix_addr(name: &str) -> Vec<Server> {
182        let path = PathBuf::from(name);
183        vec![Server::without_name(ServerID::Path(path))]
184    }
185
186    fn resolve_socket_addr(name: &str, addr: SocketAddr) -> Result<Vec<Server>, MtopError> {
187        let (host, _port) = Self::host_and_port(name)?;
188        let server_name = Self::server_name(host)?;
189        Ok(vec![Server::new(ServerID::from(addr), server_name)])
190    }
191
192    fn resolve_bare_host(name: &str) -> Result<Vec<Server>, MtopError> {
193        let (host, port) = Self::host_and_port(name)?;
194        let server_name = Self::server_name(host)?;
195        Ok(vec![Server::new(ServerID::from((host, port)), server_name)])
196    }
197
198    fn servers_from_answers(port: u16, server_name: &ServerName<'static>, message: &Message) -> Vec<Server> {
199        let mut servers = HashSet::new();
200
201        for answer in message.answers() {
202            let id = match answer.rdata() {
203                RecordData::A(data) => {
204                    let addr = SocketAddr::new(IpAddr::V4(data.addr()), port);
205                    ServerID::from(addr)
206                }
207                RecordData::AAAA(data) => {
208                    let addr = SocketAddr::new(IpAddr::V6(data.addr()), port);
209                    ServerID::from(addr)
210                }
211                RecordData::SRV(data) => {
212                    let target = data.target().to_string();
213
214                    ServerID::from((&target as &str, port))
215                }
216                _ => {
217                    tracing::warn!(message = "unexpected record data for answer", answer = ?answer);
218                    continue;
219                }
220            };
221
222            // Insert server into a HashSet to deduplicate them. We can potentially end up with
223            // duplicates when a SRV query returns multiple answers per hostname (such as when
224            // each host has more than a single port). Because we ignore the port number from the
225            // SRV answer we need to deduplicate here.
226            servers.insert(Server::new(id, server_name.to_owned()));
227        }
228
229        servers.into_iter().collect()
230    }
231
232    fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> {
233        name.rsplit_once(':')
234            .ok_or_else(|| {
235                MtopError::configuration(format!(
236                    "invalid server name '{}', must be of the form 'host:port'",
237                    name
238                ))
239            })
240            // IPv6 addresses use brackets around them to disambiguate them from a port number.
241            // Since we're parsing the host and port, strip the brackets because they aren't
242            // needed or valid to include in a TLS ServerName.
243            .map(|(host, port)| (host.trim_start_matches('[').trim_end_matches(']'), port))
244            .and_then(|(host, port)| {
245                port.parse().map(|p| (host, p)).map_err(|e| {
246                    MtopError::configuration_cause(format!("unable to parse port number from '{}'", name), e)
247                })
248            })
249    }
250
251    fn server_name(host: &str) -> Result<ServerName<'static>, MtopError> {
252        ServerName::try_from(host)
253            .map(|s| s.to_owned())
254            .map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
255    }
256}
257
258impl fmt::Debug for Discovery {
259    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260        f.debug_struct("Discovery").field("client", &"...").finish()
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use super::{Discovery, ServerID};
267    use crate::core::MtopError;
268    use crate::dns::{
269        DnsClient, Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordData, RecordDataA,
270        RecordDataAAAA, RecordDataSRV, RecordType,
271    };
272    use async_trait::async_trait;
273    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
274    use std::str::FromStr;
275    use tokio::sync::Mutex;
276
277    #[test]
278    fn test_server_id_from_ipv4_addr() {
279        let addr = SocketAddr::from((Ipv4Addr::new(127, 1, 1, 1), 11211));
280        let id = ServerID::from(addr);
281        assert_eq!("127.1.1.1:11211", id.to_string());
282    }
283
284    #[test]
285    fn test_server_id_from_ipv6_addr() {
286        let addr = SocketAddr::from((Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 11211));
287        let id = ServerID::from(addr);
288        assert_eq!("[::1]:11211", id.to_string());
289    }
290
291    #[test]
292    fn test_server_id_from_ipv4_pair() {
293        let pair = ("10.1.1.22", 11212);
294        let id = ServerID::from(pair);
295        assert_eq!("10.1.1.22:11212", id.to_string());
296    }
297
298    #[test]
299    fn test_server_id_from_ipv6_pair() {
300        let pair = ("::1", 11212);
301        let id = ServerID::from(pair);
302        assert_eq!("[::1]:11212", id.to_string());
303    }
304
305    #[test]
306    fn test_server_id_from_host_pair() {
307        let pair = ("cache.example.com", 11211);
308        let id = ServerID::from(pair);
309        assert_eq!("cache.example.com:11211", id.to_string());
310    }
311
312    struct MockDnsClient {
313        responses: Mutex<Vec<Message>>,
314    }
315
316    impl MockDnsClient {
317        fn new(responses: Vec<Message>) -> Self {
318            Self {
319                responses: Mutex::new(responses),
320            }
321        }
322    }
323
324    #[async_trait]
325    impl DnsClient for MockDnsClient {
326        async fn resolve(
327            &self,
328            _id: MessageId,
329            _name: Name,
330            _rtype: RecordType,
331            _rclass: RecordClass,
332        ) -> Result<Message, MtopError> {
333            let mut responses = self.responses.lock().await;
334            let res = responses.pop().unwrap();
335            Ok(res)
336        }
337    }
338
339    fn response_with_answers(rtype: RecordType, records: Vec<Record>) -> Message {
340        let flags = Flags::default().set_recursion_desired().set_recursion_available();
341        let mut message = Message::new(MessageId::random(), flags)
342            .add_question(Question::new(Name::from_str("example.com.").unwrap(), rtype));
343
344        for r in records {
345            message = message.add_answer(r);
346        }
347
348        message
349    }
350
351    #[tokio::test]
352    async fn test_dns_client_resolve_a_aaaa() {
353        let response_a = response_with_answers(
354            RecordType::A,
355            vec![Record::new(
356                Name::from_str("example.com.").unwrap(),
357                RecordType::A,
358                RecordClass::INET,
359                300,
360                RecordData::A(RecordDataA::new(Ipv4Addr::new(10, 1, 1, 1))),
361            )],
362        );
363
364        let response_aaaa = response_with_answers(
365            RecordType::AAAA,
366            vec![Record::new(
367                Name::from_str("example.com.").unwrap(),
368                RecordType::AAAA,
369                RecordClass::INET,
370                300,
371                RecordData::AAAA(RecordDataAAAA::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
372            )],
373        );
374
375        let client = MockDnsClient::new(vec![response_a, response_aaaa]);
376        let discovery = Discovery::new(client);
377        let servers = discovery.resolve_by_proto("dns+example.com:11211").await.unwrap();
378
379        let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
380        let id_a = ServerID::from("10.1.1.1:11211".parse::<SocketAddr>().unwrap());
381        let id_aaaa = ServerID::from("[::1]:11211".parse::<SocketAddr>().unwrap());
382
383        assert!(ids.contains(&id_a), "expected {:?} to contain {:?}", ids, id_a);
384        assert!(ids.contains(&id_aaaa), "expected {:?} to contain {:?}", ids, id_aaaa);
385    }
386
387    #[tokio::test]
388    async fn test_dns_client_resolve_srv() {
389        let response = response_with_answers(
390            RecordType::SRV,
391            vec![
392                Record::new(
393                    Name::from_str("_cache.example.com.").unwrap(),
394                    RecordType::SRV,
395                    RecordClass::INET,
396                    300,
397                    RecordData::SRV(RecordDataSRV::new(
398                        100,
399                        10,
400                        11211,
401                        Name::from_str("cache01.example.com.").unwrap(),
402                    )),
403                ),
404                Record::new(
405                    Name::from_str("_cache.example.com.").unwrap(),
406                    RecordType::SRV,
407                    RecordClass::INET,
408                    300,
409                    RecordData::SRV(RecordDataSRV::new(
410                        100,
411                        10,
412                        11211,
413                        Name::from_str("cache02.example.com.").unwrap(),
414                    )),
415                ),
416            ],
417        );
418
419        let client = MockDnsClient::new(vec![response]);
420        let discovery = Discovery::new(client);
421        let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap();
422
423        let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
424        let id1 = ServerID::from(("cache01.example.com.", 11211));
425        let id2 = ServerID::from(("cache02.example.com.", 11211));
426
427        assert!(ids.contains(&id1), "expected {:?} to contain {:?}", ids, id1);
428        assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2);
429    }
430
431    #[tokio::test]
432    async fn test_dns_client_resolve_srv_dupes() {
433        let response = response_with_answers(
434            RecordType::SRV,
435            vec![
436                Record::new(
437                    Name::from_str("_cache.example.com.").unwrap(),
438                    RecordType::SRV,
439                    RecordClass::INET,
440                    300,
441                    RecordData::SRV(RecordDataSRV::new(
442                        100,
443                        10,
444                        11211,
445                        Name::from_str("cache01.example.com.").unwrap(),
446                    )),
447                ),
448                Record::new(
449                    Name::from_str("_cache.example.com.").unwrap(),
450                    RecordType::SRV,
451                    RecordClass::INET,
452                    300,
453                    RecordData::SRV(RecordDataSRV::new(
454                        100,
455                        10,
456                        9105,
457                        Name::from_str("cache01.example.com.").unwrap(),
458                    )),
459                ),
460            ],
461        );
462
463        let client = MockDnsClient::new(vec![response]);
464        let discovery = Discovery::new(client);
465        let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap();
466
467        let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
468        let id = ServerID::from(("cache01.example.com.", 11211));
469
470        assert_eq!(ids, vec![id]);
471    }
472
473    #[tokio::test]
474    async fn test_dns_client_resolve_socket_addr() {
475        let name = "127.0.0.2:11211";
476        let sock: SocketAddr = "127.0.0.2:11211".parse().unwrap();
477
478        let client = MockDnsClient::new(vec![]);
479        let discovery = Discovery::new(client);
480        let servers = discovery.resolve_by_proto(name).await.unwrap();
481
482        let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
483        let id = ServerID::from(sock);
484
485        assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
486    }
487
488    #[tokio::test]
489    async fn test_dns_client_resolve_bare_host() {
490        let name = "localhost:11211";
491
492        let client = MockDnsClient::new(vec![]);
493        let discovery = Discovery::new(client);
494        let servers = discovery.resolve_by_proto(name).await.unwrap();
495
496        let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
497        let id = ServerID::from(("localhost", 11211));
498
499        assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
500    }
501}