1use std::net::{IpAddr, SocketAddr};
4
5use reqwest_middleware::ClientWithMiddleware;
6use serde::{Deserialize, Serialize};
7use tracing::{debug, info, instrument};
8use trust_dns_resolver::{
9 error::{ResolveError, ResolveErrorKind},
10 TokioAsyncResolver,
11};
12
13use crate::cache;
14
15pub mod error;
16
17#[derive(Debug, Clone, Deserialize, Serialize)]
24pub struct ServerWellKnown {
25 #[serde(rename = "m.server")]
28 pub server: String,
29}
30
31#[derive(Debug, Clone)]
33pub struct Resolver {
34 http: ClientWithMiddleware,
36 resolver: TokioAsyncResolver,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum Server {
43 Ip(IpAddr),
45 Socket(SocketAddr),
47 Host(String),
49 HostPort(String),
51 Srv(String, String),
53}
54
55impl Server {
56 #[must_use]
58 pub fn host_header(&self) -> String {
59 match self {
60 Server::Ip(addr) => addr.to_string(),
61 Server::Socket(addr) => addr.to_string(),
62 Server::Host(host) => host.clone(),
63 Server::HostPort(host) => host.clone(),
64 Server::Srv(_, host) => host.to_string(),
65 }
66 }
67
68 #[must_use]
70 pub fn address(&self) -> String {
71 match self {
72 Server::Ip(addr) => format!("{}:8448", addr),
73 Server::Socket(addr) => addr.to_string(),
74 Server::Host(host) => format!("{}:8448", host),
75 Server::HostPort(host) => host.clone(),
76 Server::Srv(host, _) => host.clone(),
77 }
78 }
79}
80
81impl Resolver {
82 pub fn new() -> Result<Self, ResolveError> {
84 Ok(Self {
85 http: reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
86 .with(cache())
87 .build(),
88 resolver: TokioAsyncResolver::tokio_from_system_conf()?,
89 })
90 }
91
92 #[must_use]
95 pub fn with(http: reqwest::Client, resolver: TokioAsyncResolver) -> Self {
96 Self { http: reqwest_middleware::ClientBuilder::new(http).with(cache()).build(), resolver }
97 }
98
99 #[instrument(skip(self, port), err)]
101 pub async fn resolve(
102 &self,
103 name: &str,
104 #[cfg(test)] port: Option<u16>,
105 ) -> error::Result<Server> {
106 debug!("Parsing socket literal");
108 if let Ok(addr) = name.parse::<SocketAddr>() {
109 info!("The server name is a socket literal");
110 return Ok(Server::Socket(addr));
111 }
112 debug!("Parsing IP literal");
113 if let Ok(addr) = name.parse::<IpAddr>() {
114 info!("The server name is an IP literal");
115 return Ok(Server::Ip(addr));
116 }
117 debug!("Parsing host with port");
119 if split_port(name).is_some() {
120 info!("The servername is a host with port");
121 return Ok(Server::HostPort(name.to_owned()));
122 }
123 debug!("Querying well known");
125 if let Some(well_known) = self
126 .well_known(
127 name,
128 #[cfg(test)]
129 port,
130 )
131 .await?
132 {
133 debug!("Well-known received: {:?}", &well_known);
134 debug!("Parsing delegated socket literal");
136 if let Ok(addr) = well_known.server.parse::<SocketAddr>() {
137 info!("The server name is a delegated IP literal");
138 return Ok(Server::Socket(addr));
139 }
140 debug!("Parsing delegated IP literal");
141 if let Ok(addr) = well_known.server.parse::<IpAddr>() {
142 info!("The server name is a delegated socket literal");
143 return Ok(Server::Ip(addr));
144 }
145 debug!("Parsing delegated hostname with port");
147 if split_port(&well_known.server).is_some() {
148 info!("The server name is a delegated hostname with port");
149 return Ok(Server::HostPort(well_known.server));
150 }
151 debug!("Looking up SRV record for delegated hostname");
153 if let Some(name) = self.srv_lookup(&well_known.server).await {
154 info!("The server name is a delegated SRV record");
155 return Ok(Server::Srv(name, well_known.server));
156 }
157 debug!("Using delegated hostname directly");
159 return Ok(Server::Host(well_known.server));
160 }
161 debug!("Looking up SRV record for hostname");
163 if let Some(srv) = self.srv_lookup(name).await {
164 info!("The server name is an SRV record");
165 return Ok(Server::Srv(srv, name.to_owned()));
166 }
167 debug!("Using provided hostname directly");
169 Ok(Server::Host(name.to_owned()))
170 }
171
172 #[cfg_attr(test, allow(unused_variables))]
174 #[instrument(skip(self, name, port), err)]
175 async fn well_known(
176 &self,
177 name: &str,
178 #[cfg(test)] port: Option<u16>,
179 ) -> error::Result<Option<ServerWellKnown>> {
180 #[cfg(not(test))]
181 let response = self.http.get(format!("https://{}/.well-known/matrix/server", name)).send().await;
182
183 #[cfg(test)]
184 #[allow(clippy::expect_used)]
185 let response = self
186 .http
187 .get(format!(
188 "http://{name}:{port}/.well-known/matrix/server",
189 port = port.expect("port needed for test env")
190 ))
191 .send()
192 .await;
193
194 let response = match response {
196 Ok(response) => response,
197 Err(reqwest_middleware::Error::Reqwest(e)) if e.is_connect() => return Err(e.into()),
198 Err(_) => return Ok(None),
199 };
200 let well_known = response.json::<ServerWellKnown>().await.ok();
201 Ok(well_known)
202 }
203
204 #[instrument(skip(self, name))]
206 async fn srv_lookup(&self, name: &str) -> Option<String> {
207 let srv = self.resolver.srv_lookup(format!("_matrix._tcp.{}", name)).await.ok()?;
208 match srv.iter().min_by_key(|srv| srv.priority()) {
210 Some(srv) => {
211 let target = srv.target().to_ascii();
212 let host = target.trim_end_matches('.');
213 Some(format!("{}:{}", host, srv.port()))
214 }
215 None => None,
216 }
217 }
218
219 pub async fn socket(&self, server: &Server) -> Result<SocketAddr, ResolveError> {
221 let (host, port) = match *server {
222 Server::Ip(ip) => return Ok(SocketAddr::new(ip, 8448)),
223 Server::Socket(socket) => return Ok(socket),
224 Server::Host(ref host) => (host.as_str(), 8448),
225 #[allow(clippy::expect_used)]
226 Server::HostPort(ref host) => split_port(host).expect("HostPort was constructed with port"),
227 #[allow(clippy::expect_used)]
228 Server::Srv(ref addr, _) => split_port(addr).expect("The SRV record includes the port"),
229 };
230 let record = self.resolver.lookup_ip(host).await?;
231 let socket = SocketAddr::new(
233 record.iter().next().ok_or(ResolveErrorKind::Message("No records"))?,
234 port,
235 );
236 Ok(socket)
237 }
238}
239
240fn split_port(host: &str) -> Option<(&str, u16)> {
242 match &host.split(':').collect::<Vec<_>>()[..] {
243 [host, port] => match port.parse() {
244 Ok(port) => Some((host, port)),
245 Err(_) => None,
246 },
247 _ => None,
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use std::net::{IpAddr, SocketAddr};
254
255 use trust_dns_resolver::TokioAsyncResolver;
256 use wiremock::{
257 matchers::{method, path},
258 Mock, MockServer, ResponseTemplate,
259 };
260
261 use super::{Resolver, Server};
262
263 #[tokio::test]
265 async fn literals() -> Result<(), Box<dyn std::error::Error>> {
266 let resolver = Resolver::new()?;
267 assert_eq!(
268 resolver.resolve("127.0.0.1", None).await?,
269 Server::Ip(IpAddr::from([127, 0, 0, 1])),
270 "1. IP literal"
271 );
272 assert_eq!(
273 resolver.resolve("127.0.0.1:4884", None).await?,
274 Server::Socket(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), 4884)),
275 "1. Socket literal"
276 );
277 assert_eq!(
278 resolver.resolve("example.test:1234", None).await?,
279 Server::HostPort(String::from("example.test:1234")),
280 "2. Host with port"
281 );
282 Ok(())
283 }
284
285 #[tokio::test]
287 async fn http() -> Result<(), Box<dyn std::error::Error>> {
288 let mock_server = MockServer::start().await;
289
290 let client = reqwest::Client::builder()
291 .resolve("example.test", *mock_server.address())
292 .resolve("destination.test", *mock_server.address())
293 .build()?;
294 let resolver = Resolver::with(client, TokioAsyncResolver::tokio_from_system_conf()?);
295
296 let addr = mock_server.address();
297
298 Mock::given(method("GET"))
299 .and(path("/.well-known/matrix/server"))
300 .respond_with(
301 ResponseTemplate::new(200).set_body_raw(
302 format!(r#"{{"m.server": "{}"}}"#, addr.ip()),
303 "application/json",
304 ),
305 )
306 .up_to_n_times(1)
307 .expect(1)
308 .mount(&mock_server)
309 .await;
310
311 assert_eq!(
312 resolver.resolve("example.test", Some(addr.port())).await?,
313 Server::Ip(addr.ip()),
314 "3.1 delegated_hostname is an IP literal"
315 );
316
317 Mock::given(method("GET"))
318 .and(path("/.well-known/matrix/server"))
319 .respond_with(
320 ResponseTemplate::new(200)
321 .set_body_raw(format!(r#"{{"m.server": "{}"}}"#, addr), "application/json"),
322 )
323 .up_to_n_times(1)
324 .expect(1)
325 .mount(&mock_server)
326 .await;
327
328 assert_eq!(
329 resolver.resolve("example.test", Some(addr.port())).await?,
330 Server::Socket(*mock_server.address()),
331 "3.1 delegated_hostname is a socket literal"
332 );
333
334 Mock::given(method("GET"))
335 .and(path("/.well-known/matrix/server"))
336 .respond_with(ResponseTemplate::new(200).set_body_raw(
337 format!(r#"{{"m.server": "destination.test:{}"}}"#, addr.port()),
338 "application/json",
339 ))
340 .expect(1)
341 .up_to_n_times(1)
342 .mount(&mock_server)
343 .await;
344
345 assert_eq!(
346 resolver.resolve("example.test", Some(addr.port())).await?,
347 Server::HostPort(format!("destination.test:{}", addr.port())),
348 "3.2 delegated_hostname includes a port"
349 );
350 Ok(())
351 }
352}