use std::net::{IpAddr, SocketAddr};
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use tracing::{debug, info, instrument};
use trust_dns_resolver::{
error::{ResolveError, ResolveErrorKind},
TokioAsyncResolver,
};
use crate::cache;
pub mod error;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ServerWellKnown {
#[serde(rename = "m.server")]
pub server: String,
}
#[derive(Debug, Clone)]
pub struct Resolver {
http: ClientWithMiddleware,
resolver: TokioAsyncResolver,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Server {
Ip(IpAddr),
Socket(SocketAddr),
Host(String),
HostPort(String),
Srv(String, String),
}
impl Server {
#[must_use]
pub fn host_header(&self) -> String {
match self {
Server::Ip(addr) => addr.to_string(),
Server::Socket(addr) => addr.to_string(),
Server::Host(host) => host.clone(),
Server::HostPort(host) => host.clone(),
Server::Srv(_, host) => host.to_string(),
}
}
#[must_use]
pub fn address(&self) -> String {
match self {
Server::Ip(addr) => format!("{}:8448", addr),
Server::Socket(addr) => addr.to_string(),
Server::Host(host) => format!("{}:8448", host),
Server::HostPort(host) => host.clone(),
Server::Srv(host, _) => host.clone(),
}
}
}
impl Resolver {
pub fn new() -> Result<Self, ResolveError> {
Ok(Self {
http: reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with(cache())
.build(),
resolver: TokioAsyncResolver::tokio_from_system_conf()?,
})
}
#[must_use]
pub fn with(http: reqwest::Client, resolver: TokioAsyncResolver) -> Self {
Self { http: reqwest_middleware::ClientBuilder::new(http).with(cache()).build(), resolver }
}
#[instrument(skip(self, port), err)]
pub async fn resolve(
&self,
name: &str,
#[cfg(test)] port: Option<u16>,
) -> error::Result<Server> {
debug!("Parsing socket literal");
if let Ok(addr) = name.parse::<SocketAddr>() {
info!("The server name is a socket literal");
return Ok(Server::Socket(addr));
}
debug!("Parsing IP literal");
if let Ok(addr) = name.parse::<IpAddr>() {
info!("The server name is an IP literal");
return Ok(Server::Ip(addr));
}
debug!("Parsing host with port");
if split_port(name).is_some() {
info!("The servername is a host with port");
return Ok(Server::HostPort(name.to_owned()));
}
debug!("Querying well known");
if let Some(well_known) = self
.well_known(
name,
#[cfg(test)]
port,
)
.await?
{
debug!("Well-known received: {:?}", &well_known);
debug!("Parsing delegated socket literal");
if let Ok(addr) = well_known.server.parse::<SocketAddr>() {
info!("The server name is a delegated IP literal");
return Ok(Server::Socket(addr));
}
debug!("Parsing delegated IP literal");
if let Ok(addr) = well_known.server.parse::<IpAddr>() {
info!("The server name is a delegated socket literal");
return Ok(Server::Ip(addr));
}
debug!("Parsing delegated hostname with port");
if split_port(&well_known.server).is_some() {
info!("The server name is a delegated hostname with port");
return Ok(Server::HostPort(well_known.server));
}
debug!("Looking up SRV record for delegated hostname");
if let Some(name) = self.srv_lookup(&well_known.server).await {
info!("The server name is a delegated SRV record");
return Ok(Server::Srv(name, well_known.server));
}
debug!("Using delegated hostname directly");
return Ok(Server::Host(well_known.server));
}
debug!("Looking up SRV record for hostname");
if let Some(srv) = self.srv_lookup(name).await {
info!("The server name is an SRV record");
return Ok(Server::Srv(srv, name.to_owned()));
}
debug!("Using provided hostname directly");
Ok(Server::Host(name.to_owned()))
}
#[cfg_attr(test, allow(unused_variables))]
#[instrument(skip(self, name, port), err)]
async fn well_known(
&self,
name: &str,
#[cfg(test)] port: Option<u16>,
) -> error::Result<Option<ServerWellKnown>> {
#[cfg(not(test))]
let response = self.http.get(format!("https://{}/.well-known/matrix/server", name)).send().await;
#[cfg(test)]
#[allow(clippy::expect_used)]
let response = self
.http
.get(format!(
"http://{name}:{port}/.well-known/matrix/server",
port = port.expect("port needed for test env")
))
.send()
.await;
let response = match response {
Ok(response) => response,
Err(reqwest_middleware::Error::Reqwest(e)) if e.is_connect() => return Err(e.into()),
Err(_) => return Ok(None),
};
let well_known = response.json::<ServerWellKnown>().await.ok();
Ok(well_known)
}
#[instrument(skip(self, name))]
async fn srv_lookup(&self, name: &str) -> Option<String> {
let srv = self.resolver.srv_lookup(format!("_matrix._tcp.{}", name)).await.ok()?;
match srv.iter().min_by_key(|srv| srv.priority()) {
Some(srv) => {
let target = srv.target().to_ascii();
let host = target.trim_end_matches('.');
Some(format!("{}:{}", host, srv.port()))
}
None => None,
}
}
pub async fn socket(&self, server: &Server) -> Result<SocketAddr, ResolveError> {
let (host, port) = match *server {
Server::Ip(ip) => return Ok(SocketAddr::new(ip, 8448)),
Server::Socket(socket) => return Ok(socket),
Server::Host(ref host) => (host.as_str(), 8448),
#[allow(clippy::expect_used)]
Server::HostPort(ref host) => split_port(host).expect("HostPort was constructed with port"),
#[allow(clippy::expect_used)]
Server::Srv(ref addr, _) => split_port(addr).expect("The SRV record includes the port"),
};
let record = self.resolver.lookup_ip(host).await?;
let socket = SocketAddr::new(
record.iter().next().ok_or(ResolveErrorKind::Message("No records"))?,
port,
);
Ok(socket)
}
}
fn split_port(host: &str) -> Option<(&str, u16)> {
match &host.split(':').collect::<Vec<_>>()[..] {
[host, port] => match port.parse() {
Ok(port) => Some((host, port)),
Err(_) => None,
},
_ => None,
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, SocketAddr};
use trust_dns_resolver::TokioAsyncResolver;
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use super::{Resolver, Server};
#[tokio::test]
async fn literals() -> Result<(), Box<dyn std::error::Error>> {
let resolver = Resolver::new()?;
assert_eq!(
resolver.resolve("127.0.0.1", None).await?,
Server::Ip(IpAddr::from([127, 0, 0, 1])),
"1. IP literal"
);
assert_eq!(
resolver.resolve("127.0.0.1:4884", None).await?,
Server::Socket(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), 4884)),
"1. Socket literal"
);
assert_eq!(
resolver.resolve("example.test:1234", None).await?,
Server::HostPort(String::from("example.test:1234")),
"2. Host with port"
);
Ok(())
}
#[tokio::test]
async fn http() -> Result<(), Box<dyn std::error::Error>> {
let mock_server = MockServer::start().await;
let client = reqwest::Client::builder()
.resolve("example.test", *mock_server.address())
.resolve("destination.test", *mock_server.address())
.build()?;
let resolver = Resolver::with(client, TokioAsyncResolver::tokio_from_system_conf()?);
let addr = mock_server.address();
Mock::given(method("GET"))
.and(path("/.well-known/matrix/server"))
.respond_with(
ResponseTemplate::new(200).set_body_raw(
format!(r#"{{"m.server": "{}"}}"#, addr.ip()),
"application/json",
),
)
.up_to_n_times(1)
.expect(1)
.mount(&mock_server)
.await;
assert_eq!(
resolver.resolve("example.test", Some(addr.port())).await?,
Server::Ip(addr.ip()),
"3.1 delegated_hostname is an IP literal"
);
Mock::given(method("GET"))
.and(path("/.well-known/matrix/server"))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(format!(r#"{{"m.server": "{}"}}"#, addr), "application/json"),
)
.up_to_n_times(1)
.expect(1)
.mount(&mock_server)
.await;
assert_eq!(
resolver.resolve("example.test", Some(addr.port())).await?,
Server::Socket(*mock_server.address()),
"3.1 delegated_hostname is a socket literal"
);
Mock::given(method("GET"))
.and(path("/.well-known/matrix/server"))
.respond_with(ResponseTemplate::new(200).set_body_raw(
format!(r#"{{"m.server": "destination.test:{}"}}"#, addr.port()),
"application/json",
))
.expect(1)
.up_to_n_times(1)
.mount(&mock_server)
.await;
assert_eq!(
resolver.resolve("example.test", Some(addr.port())).await?,
Server::HostPort(format!("destination.test:{}", addr.port())),
"3.2 delegated_hostname includes a port"
);
Ok(())
}
}