matrix_oracle/
client.rs
1pub mod error;
4
5use std::collections::BTreeMap;
6
7use reqwest::{StatusCode, Url};
8use reqwest_middleware::ClientWithMiddleware;
9use serde::{Deserialize, Serialize};
10
11use self::error::{Error, FailError};
12use crate::cache;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ClientWellKnown {
17 #[serde(rename = "m.homeserver")]
19 pub homeserver: HomeserverInfo,
20
21 #[serde(rename = "m.identity_server")]
23 pub identity_server: Option<IdentityServerInfo>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct HomeserverInfo {
29 base_url: String,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct IdentityServerInfo {
36 base_url: String,
38}
39
40#[derive(Clone, Debug)]
42pub struct Resolver {
43 http: ClientWithMiddleware,
46}
47
48#[allow(dead_code)]
51#[derive(Deserialize)]
52struct Versions {
53 pub versions: Vec<String>,
55 #[serde(default)]
57 pub unstable_features: BTreeMap<String, bool>,
58}
59
60impl Resolver {
61 #[must_use]
63 pub fn new() -> Self {
64 Self {
65 http: reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
66 .with(cache())
67 .build(),
68 }
69 }
70
71 #[must_use]
73 pub fn with(http: reqwest::Client) -> Self {
74 Self { http: reqwest_middleware::ClientBuilder::new(http).with(cache()).build() }
75 }
76
77 pub async fn resolve(&self, name: &str) -> Result<Url, Error> {
79 #[cfg(not(test))]
80 let url = Url::parse(&format!("https://{}", name))?;
81 #[cfg(test)]
82 let url = Url::parse(&format!("http://{}", name))?;
83
84 let response = self.http.get(url.join(".well-known/matrix/client")?).send().await?;
86 if response.status() == StatusCode::NOT_FOUND {
88 return Ok(url);
89 };
90 let well_known = response.json::<ClientWellKnown>().await?;
92 let url = Url::parse(&well_known.homeserver.base_url)?;
94 self.http
96 .get(url.join("_matrix/client/versions")?)
97 .send()
98 .await
99 .map_err(FailError::Http)?
100 .json::<Versions>()
101 .await
102 .map_err(|e| FailError::Http(e.into()))?;
103
104 if let Some(identity) = well_known.identity_server {
106 let url = Url::parse(&identity.base_url)?;
107 let result: Result<_, FailError> = async {
108 self.http
109 .get(url.join("_matrix/identity/api/v1")?)
110 .send()
111 .await?
112 .error_for_status()?;
113 Ok(())
114 }
115 .await;
116 result?;
117 }
118
119 Ok(url)
120 }
121}
122
123impl Default for Resolver {
124 fn default() -> Self {
125 Self { http: ClientWithMiddleware::from(reqwest::Client::new()) }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use wiremock::{
132 matchers::{method, path},
133 Mock, MockServer, ResponseTemplate,
134 };
135
136 use super::Resolver;
137
138 #[tokio::test]
140 async fn not_found() -> Result<(), Box<dyn std::error::Error>> {
141 let mock_server = MockServer::start().await;
142
143 Mock::given(method("GET"))
144 .and(path("/.well-known/matrix/client"))
145 .respond_with(ResponseTemplate::new(404))
146 .expect(1)
147 .mount(&mock_server)
148 .await;
149
150 let http =
151 reqwest::Client::builder().resolve("example.test", *mock_server.address()).build()?;
152 let resolver = Resolver::with(http);
153 let url =
154 resolver.resolve(&format!("example.test:{}", mock_server.address().port())).await?;
155
156 assert_eq!(
157 format!("http://example.test:{}/", mock_server.address().port()),
158 url.to_string()
159 );
160 Ok(())
161 }
162
163 #[tokio::test]
164 async fn resolve() -> Result<(), Box<dyn std::error::Error>> {
165 let mock_server = MockServer::start().await;
166
167 let port = mock_server.address().port();
168
169 Mock::given(method("GET"))
170 .and(path("/.well-known/matrix/client"))
171 .respond_with(ResponseTemplate::new(200).set_body_raw(
172 format!(
173 r#"{{"m.homeserver": {{"base_url": "http://destination.test:{}"}} }}"#,
174 port
175 ),
176 "application/json",
177 ))
178 .expect(1)
179 .mount(&mock_server)
180 .await;
181
182 Mock::given(method("GET"))
183 .and(path("/_matrix/client/versions"))
184 .respond_with(
185 ResponseTemplate::new(200)
186 .set_body_raw(r#"{"versions":["r0.0.1"]}"#, "application/json"),
187 )
188 .expect(1)
189 .mount(&mock_server)
190 .await;
191
192 let http = reqwest::Client::builder()
193 .resolve("example.test", *mock_server.address())
194 .resolve("destination.test", *mock_server.address())
195 .build()?;
196 let resolver = Resolver::with(http);
197
198 let url =
199 resolver.resolve(&format!("example.test:{}", mock_server.address().port())).await?;
200
201 assert_eq!(url.to_string(), format!("http://destination.test:{}/", port));
202 Ok(())
203 }
204}