pub mod error;
use std::collections::BTreeMap;
use reqwest::{StatusCode, Url};
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
use self::error::{Error, FailError};
use crate::cache;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientWellKnown {
#[serde(rename = "m.homeserver")]
pub homeserver: HomeserverInfo,
#[serde(rename = "m.identity_server")]
pub identity_server: Option<IdentityServerInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HomeserverInfo {
base_url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IdentityServerInfo {
base_url: String,
}
#[derive(Clone, Debug)]
pub struct Resolver {
http: ClientWithMiddleware,
}
#[allow(dead_code)]
#[derive(Deserialize)]
struct Versions {
pub versions: Vec<String>,
#[serde(default)]
pub unstable_features: BTreeMap<String, bool>,
}
impl Resolver {
#[must_use]
pub fn new() -> Self {
Self {
http: reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
.with(cache())
.build(),
}
}
#[must_use]
pub fn with(http: reqwest::Client) -> Self {
Self { http: reqwest_middleware::ClientBuilder::new(http).with(cache()).build() }
}
pub async fn resolve(&self, name: &str) -> Result<Url, Error> {
#[cfg(not(test))]
let url = Url::parse(&format!("https://{}", name))?;
#[cfg(test)]
let url = Url::parse(&format!("http://{}", name))?;
let response = self.http.get(url.join(".well-known/matrix/client")?).send().await?;
if response.status() == StatusCode::NOT_FOUND {
return Ok(url);
};
let well_known = response.json::<ClientWellKnown>().await?;
let url = Url::parse(&well_known.homeserver.base_url)?;
self.http
.get(url.join("_matrix/client/versions")?)
.send()
.await
.map_err(FailError::Http)?
.json::<Versions>()
.await
.map_err(|e| FailError::Http(e.into()))?;
if let Some(identity) = well_known.identity_server {
let url = Url::parse(&identity.base_url)?;
let result: Result<_, FailError> = async {
self.http
.get(url.join("_matrix/identity/api/v1")?)
.send()
.await?
.error_for_status()?;
Ok(())
}
.await;
result?;
}
Ok(url)
}
}
impl Default for Resolver {
fn default() -> Self {
Self { http: ClientWithMiddleware::from(reqwest::Client::new()) }
}
}
#[cfg(test)]
mod tests {
use wiremock::{
matchers::{method, path},
Mock, MockServer, ResponseTemplate,
};
use super::Resolver;
#[tokio::test]
async fn not_found() -> Result<(), Box<dyn std::error::Error>> {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/matrix/client"))
.respond_with(ResponseTemplate::new(404))
.expect(1)
.mount(&mock_server)
.await;
let http =
reqwest::Client::builder().resolve("example.test", *mock_server.address()).build()?;
let resolver = Resolver::with(http);
let url =
resolver.resolve(&format!("example.test:{}", mock_server.address().port())).await?;
assert_eq!(
format!("http://example.test:{}/", mock_server.address().port()),
url.to_string()
);
Ok(())
}
#[tokio::test]
async fn resolve() -> Result<(), Box<dyn std::error::Error>> {
let mock_server = MockServer::start().await;
let port = mock_server.address().port();
Mock::given(method("GET"))
.and(path("/.well-known/matrix/client"))
.respond_with(ResponseTemplate::new(200).set_body_raw(
format!(
r#"{{"m.homeserver": {{"base_url": "http://destination.test:{}"}} }}"#,
port
),
"application/json",
))
.expect(1)
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/_matrix/client/versions"))
.respond_with(
ResponseTemplate::new(200)
.set_body_raw(r#"{"versions":["r0.0.1"]}"#, "application/json"),
)
.expect(1)
.mount(&mock_server)
.await;
let http = reqwest::Client::builder()
.resolve("example.test", *mock_server.address())
.resolve("destination.test", *mock_server.address())
.build()?;
let resolver = Resolver::with(http);
let url =
resolver.resolve(&format!("example.test:{}", mock_server.address().port())).await?;
assert_eq!(url.to_string(), format!("http://destination.test:{}/", port));
Ok(())
}
}