matrix_oracle/
client.rs

1//! Resolution for the client-server API
2
3pub mod error;
4
5use std::collections::BTreeMap;
6
7use reqwest::{StatusCode, Url};
8use reqwest_middleware::ClientWithMiddleware;
9use serde::{Deserialize, Serialize};
10
11use self::error::{Error, FailError};
12use crate::cache;
13
14/// well-known information for the client-server API.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ClientWellKnown {
17	/// Information about the homeserver to connect to.
18	#[serde(rename = "m.homeserver")]
19	pub homeserver: HomeserverInfo,
20
21	/// Information about the identity server to connect to.
22	#[serde(rename = "m.identity_server")]
23	pub identity_server: Option<IdentityServerInfo>,
24}
25
26/// Information about the homeserver to connect to.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct HomeserverInfo {
29	/// The base url to use for client-server API endpoints.
30	base_url: String,
31}
32
33/// Information about the identity server to connect to.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct IdentityServerInfo {
36	/// The base url to use for identity server API endpoints.
37	base_url: String,
38}
39
40/// Resolver for well-known lookups for the client-server API.
41#[derive(Clone, Debug)]
42pub struct Resolver {
43	/// The HTTP client used to send and receive requests. Should transparently
44	/// handle HTTP caching.
45	http: ClientWithMiddleware,
46}
47
48/// Represents the set of matrix versions a server support. Used exclusively for
49/// validating the contents of a response
50#[allow(dead_code)]
51#[derive(Deserialize)]
52struct Versions {
53	/// List of matrix spec versions the server supports.
54	pub versions: Vec<String>,
55	/// Set of unstable matrix extensions which the server supports
56	#[serde(default)]
57	pub unstable_features: BTreeMap<String, bool>,
58}
59
60impl Resolver {
61	/// Construct a new resolver.
62	#[must_use]
63	pub fn new() -> Self {
64		Self {
65			http: reqwest_middleware::ClientBuilder::new(reqwest::Client::new())
66				.with(cache())
67				.build(),
68		}
69	}
70
71	/// Construct a new resolver with the given reqwest client.
72	#[must_use]
73	pub fn with(http: reqwest::Client) -> Self {
74		Self { http: reqwest_middleware::ClientBuilder::new(http).with(cache()).build() }
75	}
76
77	/// Get the base URL for the client-server API with the given name.
78	pub async fn resolve(&self, name: &str) -> Result<Url, Error> {
79		#[cfg(not(test))]
80		let url = Url::parse(&format!("https://{}", name))?;
81		#[cfg(test)]
82		let url = Url::parse(&format!("http://{}", name))?;
83
84		// 3. make a GET request to the well-known endpoint
85		let response = self.http.get(url.join(".well-known/matrix/client")?).send().await?;
86		// a. if the returned status code is 404, then IGNORE
87		if response.status() == StatusCode::NOT_FOUND {
88			return Ok(url);
89		};
90		// c. parse the response as json
91		let well_known = response.json::<ClientWellKnown>().await?;
92		// d+e.i Extract base_url and parse it as a URL
93		let url = Url::parse(&well_known.homeserver.base_url)?;
94		// e.ii Validate versions endpoint
95		self.http
96			.get(url.join("_matrix/client/versions")?)
97			.send()
98			.await
99			.map_err(FailError::Http)?
100			.json::<Versions>()
101			.await
102			.map_err(|e| FailError::Http(e.into()))?;
103
104		// f. if present, validate identity server endpoint
105		if let Some(identity) = well_known.identity_server {
106			let url = Url::parse(&identity.base_url)?;
107			let result: Result<_, FailError> = async {
108				self.http
109					.get(url.join("_matrix/identity/api/v1")?)
110					.send()
111					.await?
112					.error_for_status()?;
113				Ok(())
114			}
115			.await;
116			result?;
117		}
118
119		Ok(url)
120	}
121}
122
123impl Default for Resolver {
124	fn default() -> Self {
125		Self { http: ClientWithMiddleware::from(reqwest::Client::new()) }
126	}
127}
128
129#[cfg(test)]
130mod tests {
131	use wiremock::{
132		matchers::{method, path},
133		Mock, MockServer, ResponseTemplate,
134	};
135
136	use super::Resolver;
137
138	/// Tests that a 404 response is correctly handled
139	#[tokio::test]
140	async fn not_found() -> Result<(), Box<dyn std::error::Error>> {
141		let mock_server = MockServer::start().await;
142
143		Mock::given(method("GET"))
144			.and(path("/.well-known/matrix/client"))
145			.respond_with(ResponseTemplate::new(404))
146			.expect(1)
147			.mount(&mock_server)
148			.await;
149
150		let http =
151			reqwest::Client::builder().resolve("example.test", *mock_server.address()).build()?;
152		let resolver = Resolver::with(http);
153		let url =
154			resolver.resolve(&format!("example.test:{}", mock_server.address().port())).await?;
155
156		assert_eq!(
157			format!("http://example.test:{}/", mock_server.address().port()),
158			url.to_string()
159		);
160		Ok(())
161	}
162
163	#[tokio::test]
164	async fn resolve() -> Result<(), Box<dyn std::error::Error>> {
165		let mock_server = MockServer::start().await;
166
167		let port = mock_server.address().port();
168
169		Mock::given(method("GET"))
170			.and(path("/.well-known/matrix/client"))
171			.respond_with(ResponseTemplate::new(200).set_body_raw(
172				format!(
173					r#"{{"m.homeserver": {{"base_url": "http://destination.test:{}"}} }}"#,
174					port
175				),
176				"application/json",
177			))
178			.expect(1)
179			.mount(&mock_server)
180			.await;
181
182		Mock::given(method("GET"))
183			.and(path("/_matrix/client/versions"))
184			.respond_with(
185				ResponseTemplate::new(200)
186					.set_body_raw(r#"{"versions":["r0.0.1"]}"#, "application/json"),
187			)
188			.expect(1)
189			.mount(&mock_server)
190			.await;
191
192		let http = reqwest::Client::builder()
193			.resolve("example.test", *mock_server.address())
194			.resolve("destination.test", *mock_server.address())
195			.build()?;
196		let resolver = Resolver::with(http);
197
198		let url =
199			resolver.resolve(&format!("example.test:{}", mock_server.address().port())).await?;
200
201		assert_eq!(url.to_string(), format!("http://destination.test:{}/", port));
202		Ok(())
203	}
204}