matrix_sdk/client/builder/
homeserver_config.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ruma::{
16    OwnedServerName, ServerName,
17    api::client::discovery::{discover_homeserver, get_supported_versions},
18};
19use tracing::debug;
20use url::Url;
21
22use crate::{
23    ClientBuildError, HttpError, config::RequestConfig, http_client::HttpClient,
24    sanitize_server_name,
25};
26
27/// Configuration for the homeserver.
28#[derive(Clone, Debug)]
29pub(super) enum HomeserverConfig {
30    /// A homeserver name URL, including the protocol.
31    HomeserverUrl(String),
32
33    /// A server name, with the protocol put apart.
34    ServerName { server: OwnedServerName, protocol: UrlScheme },
35
36    /// A server name with or without the protocol (it will fallback to `https`
37    /// if absent), or a homeserver URL.
38    ServerNameOrHomeserverUrl(String),
39}
40
41/// A simple helper to represent `http` or `https` in a URL.
42#[derive(Clone, Copy, Debug)]
43pub(super) enum UrlScheme {
44    Http,
45    Https,
46}
47
48/// The `Ok` result for `HomeserverConfig::discover`.
49pub(super) struct HomeserverDiscoveryResult {
50    pub server: Option<Url>,
51    pub homeserver: Url,
52    pub supported_versions: Option<get_supported_versions::Response>,
53    pub well_known: Option<discover_homeserver::Response>,
54}
55
56impl HomeserverConfig {
57    pub async fn discover(
58        &self,
59        http_client: &HttpClient,
60    ) -> Result<HomeserverDiscoveryResult, ClientBuildError> {
61        Ok(match self {
62            Self::HomeserverUrl(url) => {
63                let homeserver = Url::parse(url)?;
64
65                HomeserverDiscoveryResult {
66                    server: None, // We can't know the `server` if we only have a `homeserver`.
67                    homeserver,
68                    supported_versions: None,
69                    well_known: None,
70                }
71            }
72
73            Self::ServerName { server, protocol } => {
74                let (server, well_known) =
75                    discover_homeserver(server, protocol, http_client).await?;
76
77                HomeserverDiscoveryResult {
78                    server: Some(server),
79                    homeserver: Url::parse(&well_known.homeserver.base_url)?,
80                    supported_versions: None,
81                    well_known: Some(well_known),
82                }
83            }
84
85            Self::ServerNameOrHomeserverUrl(server_name_or_url) => {
86                let (server, homeserver, supported_versions, well_known) =
87                    discover_homeserver_from_server_name_or_url(
88                        server_name_or_url.to_owned(),
89                        http_client,
90                    )
91                    .await?;
92
93                HomeserverDiscoveryResult { server, homeserver, supported_versions, well_known }
94            }
95        })
96    }
97}
98
99/// Discovers a homeserver from a server name or a URL.
100///
101/// Tries well-known discovery and checking if the URL points to a homeserver.
102async fn discover_homeserver_from_server_name_or_url(
103    mut server_name_or_url: String,
104    http_client: &HttpClient,
105) -> Result<
106    (
107        Option<Url>,
108        Url,
109        Option<get_supported_versions::Response>,
110        Option<discover_homeserver::Response>,
111    ),
112    ClientBuildError,
113> {
114    let mut discovery_error: Option<ClientBuildError> = None;
115
116    // Attempt discovery as a server name first.
117    let sanitize_result = sanitize_server_name(&server_name_or_url);
118
119    if let Ok(server_name) = sanitize_result.as_ref() {
120        let protocol = if server_name_or_url.starts_with("http://") {
121            UrlScheme::Http
122        } else {
123            UrlScheme::Https
124        };
125
126        match discover_homeserver(server_name, &protocol, http_client).await {
127            Ok((server, well_known)) => {
128                return Ok((
129                    Some(server),
130                    Url::parse(&well_known.homeserver.base_url)?,
131                    None,
132                    Some(well_known),
133                ));
134            }
135            Err(e) => {
136                debug!(error = %e, "Well-known discovery failed.");
137                discovery_error = Some(e);
138
139                // Check if the server name points to a homeserver.
140                server_name_or_url = match protocol {
141                    UrlScheme::Http => format!("http://{server_name}"),
142                    UrlScheme::Https => format!("https://{server_name}"),
143                }
144            }
145        }
146    }
147
148    // When discovery fails, or the input isn't a valid server name, fallback to
149    // trying a homeserver URL.
150    if let Ok(homeserver_url) = Url::parse(&server_name_or_url) {
151        // Make sure the URL is definitely for a homeserver.
152        match get_supported_versions(&homeserver_url, http_client).await {
153            Ok(response) => {
154                return Ok((None, homeserver_url, Some(response), None));
155            }
156            Err(e) => {
157                debug!(error = %e, "Checking supported versions failed.");
158            }
159        }
160    }
161
162    Err(discovery_error.unwrap_or(ClientBuildError::InvalidServerName))
163}
164
165/// Discovers a homeserver by looking up the well-known at the supplied server
166/// name.
167async fn discover_homeserver(
168    server_name: &ServerName,
169    protocol: &UrlScheme,
170    http_client: &HttpClient,
171) -> Result<(Url, discover_homeserver::Response), ClientBuildError> {
172    debug!("Trying to discover the homeserver");
173
174    let server = Url::parse(&match protocol {
175        UrlScheme::Http => format!("http://{server_name}"),
176        UrlScheme::Https => format!("https://{server_name}"),
177    })?;
178
179    let well_known = http_client
180        .send(
181            discover_homeserver::Request::new(),
182            Some(RequestConfig::short_retry()),
183            server.to_string(),
184            None,
185            (),
186            Default::default(),
187        )
188        .await
189        .map_err(|e| match e {
190            HttpError::Api(err) => ClientBuildError::AutoDiscovery(*err),
191            err => ClientBuildError::Http(err),
192        })?;
193
194    debug!(homeserver_url = well_known.homeserver.base_url, "Discovered the homeserver");
195
196    Ok((server, well_known))
197}
198
199pub(super) async fn get_supported_versions(
200    homeserver_url: &Url,
201    http_client: &HttpClient,
202) -> Result<get_supported_versions::Response, HttpError> {
203    http_client
204        .send(
205            get_supported_versions::Request::new(),
206            Some(RequestConfig::short_retry()),
207            homeserver_url.to_string(),
208            None,
209            (),
210            Default::default(),
211        )
212        .await
213}
214
215#[cfg(all(test, not(target_family = "wasm")))]
216mod tests {
217    use matrix_sdk_test::async_test;
218    use ruma::OwnedServerName;
219    use serde_json::json;
220    use wiremock::{
221        Mock, MockServer, ResponseTemplate,
222        matchers::{method, path},
223    };
224
225    use super::*;
226    use crate::http_client::HttpSettings;
227
228    #[async_test]
229    async fn test_url() {
230        let http_client =
231            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
232
233        let result = HomeserverConfig::HomeserverUrl("https://matrix-client.matrix.org".to_owned())
234            .discover(&http_client)
235            .await
236            .unwrap();
237
238        assert_eq!(result.server, None);
239        assert_eq!(result.homeserver, Url::parse("https://matrix-client.matrix.org").unwrap());
240        assert!(result.supported_versions.is_none());
241    }
242
243    #[async_test]
244    async fn test_server_name() {
245        let http_client =
246            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
247
248        let server = MockServer::start().await;
249        let homeserver = MockServer::start().await;
250
251        Mock::given(method("GET"))
252            .and(path("/.well-known/matrix/client"))
253            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
254                "m.homeserver": {
255                    "base_url": homeserver.uri(),
256                },
257            })))
258            .mount(&server)
259            .await;
260
261        let result = HomeserverConfig::ServerName {
262            server: OwnedServerName::try_from(server.address().to_string()).unwrap(),
263            protocol: UrlScheme::Http,
264        }
265        .discover(&http_client)
266        .await
267        .unwrap();
268
269        assert_eq!(result.server, Some(Url::parse(&server.uri()).unwrap()));
270        assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
271        assert!(result.supported_versions.is_none());
272    }
273
274    #[async_test]
275    async fn test_server_name_or_url_with_name() {
276        let http_client =
277            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
278
279        let server = MockServer::start().await;
280        let homeserver = MockServer::start().await;
281
282        Mock::given(method("GET"))
283            .and(path("/.well-known/matrix/client"))
284            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
285                "m.homeserver": {
286                    "base_url": homeserver.uri(),
287                },
288            })))
289            .mount(&server)
290            .await;
291
292        let result = HomeserverConfig::ServerNameOrHomeserverUrl(server.uri().to_string())
293            .discover(&http_client)
294            .await
295            .unwrap();
296
297        assert_eq!(result.server, Some(Url::parse(&server.uri()).unwrap()));
298        assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
299        assert!(result.supported_versions.is_none());
300    }
301
302    #[async_test]
303    async fn test_server_name_or_url_with_url() {
304        let http_client =
305            HttpClient::new(HttpSettings::default().make_client().unwrap(), Default::default());
306
307        let homeserver = MockServer::start().await;
308
309        Mock::given(method("GET"))
310            .and(path("/_matrix/client/versions"))
311            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
312                "versions": [],
313            })))
314            .mount(&homeserver)
315            .await;
316
317        let result = HomeserverConfig::ServerNameOrHomeserverUrl(homeserver.uri().to_string())
318            .discover(&http_client)
319            .await
320            .unwrap();
321
322        assert!(result.server.is_none());
323        assert_eq!(result.homeserver, Url::parse(&homeserver.uri()).unwrap());
324        assert!(result.supported_versions.is_some());
325    }
326}