1use crate::core::MtopError;
2use crate::dns::{DnsClient, Message, MessageId, Name, RecordClass, RecordData, RecordType};
3use rustls_pki_types::ServerName;
4use std::cmp::Ordering;
5use std::collections::HashSet;
6use std::fmt;
7use std::net::{IpAddr, SocketAddr};
8use std::path::PathBuf;
9
10const DNS_A_PREFIX: &str = "dns+";
11const DNS_SRV_PREFIX: &str = "dnssrv+";
12const UNIX_SOCKET_PREFIX: &str = "/";
13
14#[derive(Debug, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
17pub enum ServerID {
18 Name(String),
19 Socket(SocketAddr),
20 Path(PathBuf),
21}
22
23impl ServerID {
24 fn from_host_port<S>(host: S, port: u16) -> Self
25 where
26 S: AsRef<str>,
27 {
28 let host = host.as_ref();
29 if let Ok(ip) = host.parse::<IpAddr>() {
30 Self::Socket(SocketAddr::new(ip, port))
31 } else {
32 Self::Name(format!("{}:{}", host, port))
33 }
34 }
35}
36
37impl From<SocketAddr> for ServerID {
38 fn from(value: SocketAddr) -> Self {
39 Self::Socket(value)
40 }
41}
42
43impl From<(&str, u16)> for ServerID {
44 fn from(value: (&str, u16)) -> Self {
45 Self::from_host_port(value.0, value.1)
46 }
47}
48
49impl From<(String, u16)> for ServerID {
50 fn from(value: (String, u16)) -> Self {
51 Self::from_host_port(value.0, value.1)
52 }
53}
54
55impl fmt::Display for ServerID {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 ServerID::Name(n) => n.fmt(f),
59 ServerID::Socket(s) => s.fmt(f),
60 ServerID::Path(p) => fmt::Debug::fmt(p, f),
61 }
62 }
63}
64
65#[derive(Debug, Clone, Eq, PartialEq, Hash)]
67pub struct Server {
68 id: ServerID,
69 name: Option<ServerName<'static>>,
70}
71
72impl Server {
73 pub fn new(id: ServerID, name: ServerName<'static>) -> Self {
74 Self { id, name: Some(name) }
75 }
76
77 pub fn without_name(id: ServerID) -> Self {
78 Self { id, name: None }
79 }
80
81 pub fn id(&self) -> &ServerID {
82 &self.id
83 }
84
85 pub fn server_name(&self) -> &Option<ServerName<'static>> {
86 &self.name
87 }
88}
89
90impl PartialOrd for Server {
91 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
92 Some(self.cmp(other))
93 }
94}
95
96impl Ord for Server {
97 fn cmp(&self, other: &Self) -> Ordering {
98 self.id.cmp(&other.id)
99 }
100}
101
102impl fmt::Display for Server {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 self.id.fmt(f)
105 }
106}
107
108pub struct Discovery {
113 client: Box<dyn DnsClient + Send + Sync>,
114}
115
116impl Discovery {
117 pub fn new<C>(client: C) -> Self
118 where
119 C: DnsClient + Send + Sync + 'static,
120 {
121 Self {
122 client: Box::new(client),
123 }
124 }
125
126 pub async fn resolve_by_proto(&self, name: &str) -> Result<Vec<Server>, MtopError> {
143 if name.starts_with(DNS_A_PREFIX) {
144 Ok(self.resolve_a_aaaa(name.trim_start_matches(DNS_A_PREFIX)).await?)
145 } else if name.starts_with(DNS_SRV_PREFIX) {
146 Ok(self.resolve_srv(name.trim_start_matches(DNS_SRV_PREFIX)).await?)
147 } else if name.starts_with(UNIX_SOCKET_PREFIX) {
148 Ok(Self::resolve_unix_addr(name))
149 } else if let Ok(addr) = name.parse::<SocketAddr>() {
150 Ok(Self::resolve_socket_addr(name, addr)?)
151 } else {
152 Ok(Self::resolve_bare_host(name)?)
153 }
154 }
155
156 async fn resolve_srv(&self, name: &str) -> Result<Vec<Server>, MtopError> {
157 let (host, port) = Self::host_and_port(name)?;
158 let server_name = Self::server_name(host)?;
159 let name = host.parse()?;
160 let id = MessageId::random();
161
162 let res = self.client.resolve(id, name, RecordType::SRV, RecordClass::INET).await?;
163 Ok(Self::servers_from_answers(port, &server_name, &res))
164 }
165
166 async fn resolve_a_aaaa(&self, name: &str) -> Result<Vec<Server>, MtopError> {
167 let (host, port) = Self::host_and_port(name)?;
168 let server_name = Self::server_name(host)?;
169 let name: Name = host.parse()?;
170 let id = MessageId::random();
171
172 let res = self.client.resolve(id, name.clone(), RecordType::A, RecordClass::INET).await?;
173 let mut out = Self::servers_from_answers(port, &server_name, &res);
174
175 let res = self.client.resolve(id, name, RecordType::AAAA, RecordClass::INET).await?;
176 out.extend(Self::servers_from_answers(port, &server_name, &res));
177
178 Ok(out)
179 }
180
181 fn resolve_unix_addr(name: &str) -> Vec<Server> {
182 let path = PathBuf::from(name);
183 vec![Server::without_name(ServerID::Path(path))]
184 }
185
186 fn resolve_socket_addr(name: &str, addr: SocketAddr) -> Result<Vec<Server>, MtopError> {
187 let (host, _port) = Self::host_and_port(name)?;
188 let server_name = Self::server_name(host)?;
189 Ok(vec![Server::new(ServerID::from(addr), server_name)])
190 }
191
192 fn resolve_bare_host(name: &str) -> Result<Vec<Server>, MtopError> {
193 let (host, port) = Self::host_and_port(name)?;
194 let server_name = Self::server_name(host)?;
195 Ok(vec![Server::new(ServerID::from((host, port)), server_name)])
196 }
197
198 fn servers_from_answers(port: u16, server_name: &ServerName<'static>, message: &Message) -> Vec<Server> {
199 let mut servers = HashSet::new();
200
201 for answer in message.answers() {
202 let id = match answer.rdata() {
203 RecordData::A(data) => {
204 let addr = SocketAddr::new(IpAddr::V4(data.addr()), port);
205 ServerID::from(addr)
206 }
207 RecordData::AAAA(data) => {
208 let addr = SocketAddr::new(IpAddr::V6(data.addr()), port);
209 ServerID::from(addr)
210 }
211 RecordData::SRV(data) => {
212 let target = data.target().to_string();
213
214 ServerID::from((&target as &str, port))
215 }
216 _ => {
217 tracing::warn!(message = "unexpected record data for answer", answer = ?answer);
218 continue;
219 }
220 };
221
222 servers.insert(Server::new(id, server_name.to_owned()));
227 }
228
229 servers.into_iter().collect()
230 }
231
232 fn host_and_port(name: &str) -> Result<(&str, u16), MtopError> {
233 name.rsplit_once(':')
234 .ok_or_else(|| {
235 MtopError::configuration(format!(
236 "invalid server name '{}', must be of the form 'host:port'",
237 name
238 ))
239 })
240 .map(|(host, port)| (host.trim_start_matches('[').trim_end_matches(']'), port))
244 .and_then(|(host, port)| {
245 port.parse().map(|p| (host, p)).map_err(|e| {
246 MtopError::configuration_cause(format!("unable to parse port number from '{}'", name), e)
247 })
248 })
249 }
250
251 fn server_name(host: &str) -> Result<ServerName<'static>, MtopError> {
252 ServerName::try_from(host)
253 .map(|s| s.to_owned())
254 .map_err(|e| MtopError::configuration_cause(format!("invalid server name '{}'", host), e))
255 }
256}
257
258impl fmt::Debug for Discovery {
259 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260 f.debug_struct("Discovery").field("client", &"...").finish()
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use super::{Discovery, ServerID};
267 use crate::core::MtopError;
268 use crate::dns::{
269 DnsClient, Flags, Message, MessageId, Name, Question, Record, RecordClass, RecordData, RecordDataA,
270 RecordDataAAAA, RecordDataSRV, RecordType,
271 };
272 use async_trait::async_trait;
273 use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
274 use std::str::FromStr;
275 use tokio::sync::Mutex;
276
277 #[test]
278 fn test_server_id_from_ipv4_addr() {
279 let addr = SocketAddr::from((Ipv4Addr::new(127, 1, 1, 1), 11211));
280 let id = ServerID::from(addr);
281 assert_eq!("127.1.1.1:11211", id.to_string());
282 }
283
284 #[test]
285 fn test_server_id_from_ipv6_addr() {
286 let addr = SocketAddr::from((Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 11211));
287 let id = ServerID::from(addr);
288 assert_eq!("[::1]:11211", id.to_string());
289 }
290
291 #[test]
292 fn test_server_id_from_ipv4_pair() {
293 let pair = ("10.1.1.22", 11212);
294 let id = ServerID::from(pair);
295 assert_eq!("10.1.1.22:11212", id.to_string());
296 }
297
298 #[test]
299 fn test_server_id_from_ipv6_pair() {
300 let pair = ("::1", 11212);
301 let id = ServerID::from(pair);
302 assert_eq!("[::1]:11212", id.to_string());
303 }
304
305 #[test]
306 fn test_server_id_from_host_pair() {
307 let pair = ("cache.example.com", 11211);
308 let id = ServerID::from(pair);
309 assert_eq!("cache.example.com:11211", id.to_string());
310 }
311
312 struct MockDnsClient {
313 responses: Mutex<Vec<Message>>,
314 }
315
316 impl MockDnsClient {
317 fn new(responses: Vec<Message>) -> Self {
318 Self {
319 responses: Mutex::new(responses),
320 }
321 }
322 }
323
324 #[async_trait]
325 impl DnsClient for MockDnsClient {
326 async fn resolve(
327 &self,
328 _id: MessageId,
329 _name: Name,
330 _rtype: RecordType,
331 _rclass: RecordClass,
332 ) -> Result<Message, MtopError> {
333 let mut responses = self.responses.lock().await;
334 let res = responses.pop().unwrap();
335 Ok(res)
336 }
337 }
338
339 fn response_with_answers(rtype: RecordType, records: Vec<Record>) -> Message {
340 let flags = Flags::default().set_recursion_desired().set_recursion_available();
341 let mut message = Message::new(MessageId::random(), flags)
342 .add_question(Question::new(Name::from_str("example.com.").unwrap(), rtype));
343
344 for r in records {
345 message = message.add_answer(r);
346 }
347
348 message
349 }
350
351 #[tokio::test]
352 async fn test_dns_client_resolve_a_aaaa() {
353 let response_a = response_with_answers(
354 RecordType::A,
355 vec![Record::new(
356 Name::from_str("example.com.").unwrap(),
357 RecordType::A,
358 RecordClass::INET,
359 300,
360 RecordData::A(RecordDataA::new(Ipv4Addr::new(10, 1, 1, 1))),
361 )],
362 );
363
364 let response_aaaa = response_with_answers(
365 RecordType::AAAA,
366 vec![Record::new(
367 Name::from_str("example.com.").unwrap(),
368 RecordType::AAAA,
369 RecordClass::INET,
370 300,
371 RecordData::AAAA(RecordDataAAAA::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))),
372 )],
373 );
374
375 let client = MockDnsClient::new(vec![response_a, response_aaaa]);
376 let discovery = Discovery::new(client);
377 let servers = discovery.resolve_by_proto("dns+example.com:11211").await.unwrap();
378
379 let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
380 let id_a = ServerID::from("10.1.1.1:11211".parse::<SocketAddr>().unwrap());
381 let id_aaaa = ServerID::from("[::1]:11211".parse::<SocketAddr>().unwrap());
382
383 assert!(ids.contains(&id_a), "expected {:?} to contain {:?}", ids, id_a);
384 assert!(ids.contains(&id_aaaa), "expected {:?} to contain {:?}", ids, id_aaaa);
385 }
386
387 #[tokio::test]
388 async fn test_dns_client_resolve_srv() {
389 let response = response_with_answers(
390 RecordType::SRV,
391 vec![
392 Record::new(
393 Name::from_str("_cache.example.com.").unwrap(),
394 RecordType::SRV,
395 RecordClass::INET,
396 300,
397 RecordData::SRV(RecordDataSRV::new(
398 100,
399 10,
400 11211,
401 Name::from_str("cache01.example.com.").unwrap(),
402 )),
403 ),
404 Record::new(
405 Name::from_str("_cache.example.com.").unwrap(),
406 RecordType::SRV,
407 RecordClass::INET,
408 300,
409 RecordData::SRV(RecordDataSRV::new(
410 100,
411 10,
412 11211,
413 Name::from_str("cache02.example.com.").unwrap(),
414 )),
415 ),
416 ],
417 );
418
419 let client = MockDnsClient::new(vec![response]);
420 let discovery = Discovery::new(client);
421 let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap();
422
423 let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
424 let id1 = ServerID::from(("cache01.example.com.", 11211));
425 let id2 = ServerID::from(("cache02.example.com.", 11211));
426
427 assert!(ids.contains(&id1), "expected {:?} to contain {:?}", ids, id1);
428 assert!(ids.contains(&id2), "expected {:?} to contain {:?}", ids, id2);
429 }
430
431 #[tokio::test]
432 async fn test_dns_client_resolve_srv_dupes() {
433 let response = response_with_answers(
434 RecordType::SRV,
435 vec![
436 Record::new(
437 Name::from_str("_cache.example.com.").unwrap(),
438 RecordType::SRV,
439 RecordClass::INET,
440 300,
441 RecordData::SRV(RecordDataSRV::new(
442 100,
443 10,
444 11211,
445 Name::from_str("cache01.example.com.").unwrap(),
446 )),
447 ),
448 Record::new(
449 Name::from_str("_cache.example.com.").unwrap(),
450 RecordType::SRV,
451 RecordClass::INET,
452 300,
453 RecordData::SRV(RecordDataSRV::new(
454 100,
455 10,
456 9105,
457 Name::from_str("cache01.example.com.").unwrap(),
458 )),
459 ),
460 ],
461 );
462
463 let client = MockDnsClient::new(vec![response]);
464 let discovery = Discovery::new(client);
465 let servers = discovery.resolve_by_proto("dnssrv+_cache.example.com:11211").await.unwrap();
466
467 let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
468 let id = ServerID::from(("cache01.example.com.", 11211));
469
470 assert_eq!(ids, vec![id]);
471 }
472
473 #[tokio::test]
474 async fn test_dns_client_resolve_socket_addr() {
475 let name = "127.0.0.2:11211";
476 let sock: SocketAddr = "127.0.0.2:11211".parse().unwrap();
477
478 let client = MockDnsClient::new(vec![]);
479 let discovery = Discovery::new(client);
480 let servers = discovery.resolve_by_proto(name).await.unwrap();
481
482 let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
483 let id = ServerID::from(sock);
484
485 assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
486 }
487
488 #[tokio::test]
489 async fn test_dns_client_resolve_bare_host() {
490 let name = "localhost:11211";
491
492 let client = MockDnsClient::new(vec![]);
493 let discovery = Discovery::new(client);
494 let servers = discovery.resolve_by_proto(name).await.unwrap();
495
496 let ids = servers.iter().map(|s| s.id().clone()).collect::<Vec<_>>();
497 let id = ServerID::from(("localhost", 11211));
498
499 assert!(ids.contains(&id), "expected {:?} to contain {:?}", ids, id);
500 }
501}