matrix_oracle/
server.rs

1//! Resolution for the server-server API
2
3use std::net::{IpAddr, SocketAddr};
4
5use reqwest_middleware::ClientWithMiddleware;
6use serde::{Deserialize, Serialize};
7use tracing::{debug, info, instrument};
8use trust_dns_resolver::{
9	error::{ResolveError, ResolveErrorKind},
10	TokioAsyncResolver,
11};
12
13use crate::cache;
14
15pub mod error;
16
17/// well-known information about the delegated server for server-server
18/// communication.
19///
20/// See [the specification] for more information.
21///
22/// [the specification]: https://matrix.org/docs/spec/server_server/latest#get-well-known-matrix-server
23#[derive(Debug, Clone, Deserialize, Serialize)]
24pub struct ServerWellKnown {
25	/// The server name to delegate server-server communications to, with
26	/// optional port
27	#[serde(rename = "m.server")]
28	pub server: String,
29}
30
31/// Client for server-server well-known lookups.
32#[derive(Debug, Clone)]
33pub struct Resolver {
34	/// HTTP client.
35	http: ClientWithMiddleware,
36	/// DNS resolver.
37	resolver: TokioAsyncResolver,
38}
39
40/// Resolved server name
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum Server {
43	/// IP address with implicit default port (8448)
44	Ip(IpAddr),
45	/// IP address and explicit port
46	Socket(SocketAddr),
47	/// Host string with implicit default port (8448)
48	Host(String),
49	/// Host string with explicit port.
50	HostPort(String),
51	/// Address from srv record, hostname from server name.
52	Srv(String, String),
53}
54
55impl Server {
56	/// The value to use for the `Host` HTTP header.
57	#[must_use]
58	pub fn host_header(&self) -> String {
59		match self {
60			Server::Ip(addr) => addr.to_string(),
61			Server::Socket(addr) => addr.to_string(),
62			Server::Host(host) => host.clone(),
63			Server::HostPort(host) => host.clone(),
64			Server::Srv(_, host) => host.to_string(),
65		}
66	}
67
68	/// The address to connect to.
69	#[must_use]
70	pub fn address(&self) -> String {
71		match self {
72			Server::Ip(addr) => format!("{}:8448", addr),
73			Server::Socket(addr) => addr.to_string(),
74			Server::Host(host) => format!("{}:8448", host),
75			Server::HostPort(host) => host.clone(),
76			Server::Srv(host, _) => host.clone(),
77		}
78	}
79}
80
81impl Resolver {
82	/// Constructs a new client.
83	pub fn new() -> Result<Self, ResolveError> {
84		Ok(Self {
85			http: reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
86				.with(cache())
87				.build(),
88			resolver: TokioAsyncResolver::tokio_from_system_conf()?,
89		})
90	}
91
92	/// Constructs a new client with the given HTTP client and DNS resolver
93	/// instances.
94	#[must_use]
95	pub fn with(http: reqwest::Client, resolver: TokioAsyncResolver) -> Self {
96		Self { http: reqwest_middleware::ClientBuilder::new(http).with(cache()).build(), resolver }
97	}
98
99	/// Resolve the given server name
100	#[instrument(skip(self, port), err)]
101	pub async fn resolve(
102		&self,
103		name: &str,
104		#[cfg(test)] port: Option<u16>,
105	) -> error::Result<Server> {
106		// 1. The host is an ip literal
107		debug!("Parsing socket literal");
108		if let Ok(addr) = name.parse::<SocketAddr>() {
109			info!("The server name is a socket literal");
110			return Ok(Server::Socket(addr));
111		}
112		debug!("Parsing IP literal");
113		if let Ok(addr) = name.parse::<IpAddr>() {
114			info!("The server name is an IP literal");
115			return Ok(Server::Ip(addr));
116		}
117		// 2. The host is not an ip literal, but includes a port
118		debug!("Parsing host with port");
119		if split_port(name).is_some() {
120			info!("The servername is a host with port");
121			return Ok(Server::HostPort(name.to_owned()));
122		}
123		// 3. Query the .well-known endpoint
124		debug!("Querying well known");
125		if let Some(well_known) = self
126			.well_known(
127				name,
128				#[cfg(test)]
129				port,
130			)
131			.await?
132		{
133			debug!("Well-known received: {:?}", &well_known);
134			// 3.1 delegated_hostname is an ip literal
135			debug!("Parsing delegated socket literal");
136			if let Ok(addr) = well_known.server.parse::<SocketAddr>() {
137				info!("The server name is a delegated IP literal");
138				return Ok(Server::Socket(addr));
139			}
140			debug!("Parsing delegated IP literal");
141			if let Ok(addr) = well_known.server.parse::<IpAddr>() {
142				info!("The server name is a delegated socket literal");
143				return Ok(Server::Ip(addr));
144			}
145			// 3.2 delegated_hostname includes a port
146			debug!("Parsing delegated hostname with port");
147			if split_port(&well_known.server).is_some() {
148				info!("The server name is a delegated hostname with port");
149				return Ok(Server::HostPort(well_known.server));
150			}
151			// 3.3 Look up SRV record
152			debug!("Looking up SRV record for delegated hostname");
153			if let Some(name) = self.srv_lookup(&well_known.server).await {
154				info!("The server name is a delegated SRV record");
155				return Ok(Server::Srv(name, well_known.server));
156			}
157			// 3.4 Use hostname in .well-known
158			debug!("Using delegated hostname directly");
159			return Ok(Server::Host(well_known.server));
160		}
161		// 4. The .well-known lookup failed, query SRV
162		debug!("Looking up SRV record for hostname");
163		if let Some(srv) = self.srv_lookup(name).await {
164			info!("The server name is an SRV record");
165			return Ok(Server::Srv(srv, name.to_owned()));
166		}
167		// 5. No SRV record found, use hostname
168		debug!("Using provided hostname directly");
169		Ok(Server::Host(name.to_owned()))
170	}
171
172	/// Query the .well-known information for a host.
173	#[cfg_attr(test, allow(unused_variables))]
174	#[instrument(skip(self, name, port), err)]
175	async fn well_known(
176		&self,
177		name: &str,
178		#[cfg(test)] port: Option<u16>,
179	) -> error::Result<Option<ServerWellKnown>> {
180		#[cfg(not(test))]
181		let response = self.http.get(format!("https://{}/.well-known/matrix/server", name)).send().await;
182
183		#[cfg(test)]
184		#[allow(clippy::expect_used)]
185		let response = self
186			.http
187			.get(format!(
188				"http://{name}:{port}/.well-known/matrix/server",
189				port = port.expect("port needed for test env")
190			))
191			.send()
192			.await;
193
194		// Only return Err on connection failure, skip to next step for other errors.
195		let response = match response {
196			Ok(response) => response,
197			Err(reqwest_middleware::Error::Reqwest(e)) if e.is_connect() => return Err(e.into()),
198			Err(_) => return Ok(None),
199		};
200		let well_known = response.json::<ServerWellKnown>().await.ok();
201		Ok(well_known)
202	}
203
204	/// Query the matrix SRV DNS record for a hostname
205	#[instrument(skip(self, name))]
206	async fn srv_lookup(&self, name: &str) -> Option<String> {
207		let srv = self.resolver.srv_lookup(format!("_matrix._tcp.{}", name)).await.ok()?;
208		// Get a record with the lowest priority value
209		match srv.iter().min_by_key(|srv| srv.priority()) {
210			Some(srv) => {
211				let target = srv.target().to_ascii();
212				let host = target.trim_end_matches('.');
213				Some(format!("{}:{}", host, srv.port()))
214			}
215			None => None,
216		}
217	}
218
219	/// Get the [`SocketAddr`] of an address
220	pub async fn socket(&self, server: &Server) -> Result<SocketAddr, ResolveError> {
221		let (host, port) = match *server {
222			Server::Ip(ip) => return Ok(SocketAddr::new(ip, 8448)),
223			Server::Socket(socket) => return Ok(socket),
224			Server::Host(ref host) => (host.as_str(), 8448),
225			#[allow(clippy::expect_used)]
226			Server::HostPort(ref host) => split_port(host).expect("HostPort was constructed with port"),
227			#[allow(clippy::expect_used)]
228			Server::Srv(ref addr, _) => split_port(addr).expect("The SRV record includes the port"),
229		};
230		let record = self.resolver.lookup_ip(host).await?;
231		// We naively get the first IP.
232		let socket = SocketAddr::new(
233			record.iter().next().ok_or(ResolveErrorKind::Message("No records"))?,
234			port,
235		);
236		Ok(socket)
237	}
238}
239
240/// Get the port at the end of a host string if there is one.
241fn split_port(host: &str) -> Option<(&str, u16)> {
242	match &host.split(':').collect::<Vec<_>>()[..] {
243		[host, port] => match port.parse() {
244			Ok(port) => Some((host, port)),
245			Err(_) => None,
246		},
247		_ => None,
248	}
249}
250
251#[cfg(test)]
252mod tests {
253	use std::net::{IpAddr, SocketAddr};
254
255	use trust_dns_resolver::TokioAsyncResolver;
256	use wiremock::{
257		matchers::{method, path},
258		Mock, MockServer, ResponseTemplate,
259	};
260
261	use super::{Resolver, Server};
262
263	/// Validates correct parsing of IP literals and server name with port
264	#[tokio::test]
265	async fn literals() -> Result<(), Box<dyn std::error::Error>> {
266		let resolver = Resolver::new()?;
267		assert_eq!(
268			resolver.resolve("127.0.0.1", None).await?,
269			Server::Ip(IpAddr::from([127, 0, 0, 1])),
270			"1. IP literal"
271		);
272		assert_eq!(
273			resolver.resolve("127.0.0.1:4884", None).await?,
274			Server::Socket(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), 4884)),
275			"1. Socket literal"
276		);
277		assert_eq!(
278			resolver.resolve("example.test:1234", None).await?,
279			Server::HostPort(String::from("example.test:1234")),
280			"2. Host with port"
281		);
282		Ok(())
283	}
284
285	/// Validates correct handing of the .well-known http endpoint.
286	#[tokio::test]
287	async fn http() -> Result<(), Box<dyn std::error::Error>> {
288		let mock_server = MockServer::start().await;
289
290		let client = reqwest::Client::builder()
291			.resolve("example.test", *mock_server.address())
292			.resolve("destination.test", *mock_server.address())
293			.build()?;
294		let resolver = Resolver::with(client, TokioAsyncResolver::tokio_from_system_conf()?);
295
296		let addr = mock_server.address();
297
298		Mock::given(method("GET"))
299			.and(path("/.well-known/matrix/server"))
300			.respond_with(
301				ResponseTemplate::new(200).set_body_raw(
302					format!(r#"{{"m.server": "{}"}}"#, addr.ip()),
303					"application/json",
304				),
305			)
306			.up_to_n_times(1)
307			.expect(1)
308			.mount(&mock_server)
309			.await;
310
311		assert_eq!(
312			resolver.resolve("example.test", Some(addr.port())).await?,
313			Server::Ip(addr.ip()),
314			"3.1 delegated_hostname is an IP literal"
315		);
316
317		Mock::given(method("GET"))
318			.and(path("/.well-known/matrix/server"))
319			.respond_with(
320				ResponseTemplate::new(200)
321					.set_body_raw(format!(r#"{{"m.server": "{}"}}"#, addr), "application/json"),
322			)
323			.up_to_n_times(1)
324			.expect(1)
325			.mount(&mock_server)
326			.await;
327
328		assert_eq!(
329			resolver.resolve("example.test", Some(addr.port())).await?,
330			Server::Socket(*mock_server.address()),
331			"3.1 delegated_hostname is a socket literal"
332		);
333
334		Mock::given(method("GET"))
335			.and(path("/.well-known/matrix/server"))
336			.respond_with(ResponseTemplate::new(200).set_body_raw(
337				format!(r#"{{"m.server": "destination.test:{}"}}"#, addr.port()),
338				"application/json",
339			))
340			.expect(1)
341			.up_to_n_times(1)
342			.mount(&mock_server)
343			.await;
344
345		assert_eq!(
346			resolver.resolve("example.test", Some(addr.port())).await?,
347			Server::HostPort(format!("destination.test:{}", addr.port())),
348			"3.2 delegated_hostname includes a port"
349		);
350		Ok(())
351	}
352}