Skip to main content

matrix_sdk/authentication/oauth/qrcode/rendezvous_channel/
msc_4108.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use http::{
18    HeaderMap, HeaderName, Method, StatusCode,
19    header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
20};
21use matrix_sdk_base::sleep;
22use ruma::api::{
23    EndpointError,
24    error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError},
25};
26use tracing::{debug, instrument, trace};
27use url::Url;
28
29use crate::{HttpError, RumaApiError, http_client::HttpClient};
30
31const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
32#[cfg(test)]
33const POLL_TIMEOUT: Duration = Duration::from_millis(10);
34#[cfg(not(test))]
35const POLL_TIMEOUT: Duration = Duration::from_secs(1);
36
37type Etag = String;
38
39/// Get a header from a [`HeaderMap`] and parse it as a UTF-8 string.
40fn get_header(
41    header_map: &HeaderMap,
42    header_name: &HeaderName,
43) -> Result<String, Box<FromHttpResponseError<RumaApiError>>> {
44    let header = header_map
45        .get(header_name)
46        .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))
47        .map_err(|error| Box::new(FromHttpResponseError::from(error)))?;
48
49    let header =
50        header.to_str().map_err(|error| Box::new(FromHttpResponseError::from(error)))?.to_owned();
51
52    Ok(header)
53}
54
55fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
56    match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
57        Ok(response) => {
58            let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::ClientApi(
59                ruma::api::error::Error::from_http_response(response),
60            ));
61
62            error.into()
63        }
64        Err(e) => HttpError::IntoHttp(e),
65    }
66}
67
68/// The result of the [`RendezvousChannel::create_inbound()`] method.
69pub(super) struct InboundChannelCreationResult {
70    /// The connected [`RendezvousChannel`].
71    pub channel: Channel,
72    /// The initial message we received when we connected to the
73    /// [`RendezvousChannel`].
74    ///
75    /// This is currently unused, but left in for completeness sake.
76    #[allow(dead_code)]
77    pub initial_message: Vec<u8>,
78}
79
80pub(super) struct RendezvousGetResponse {
81    pub status_code: StatusCode,
82    pub etag: String,
83    #[allow(dead_code)]
84    pub expires: String,
85    #[allow(dead_code)]
86    pub last_modified: String,
87    pub content_type: Option<String>,
88    pub body: Vec<u8>,
89}
90
91pub(super) struct RendezvousMessage {
92    pub status_code: StatusCode,
93    pub body: Vec<u8>,
94    pub content_type: String,
95}
96
97pub struct Channel {
98    client: HttpClient,
99    rendezvous_url: Url,
100    etag: Etag,
101}
102
103impl Channel {
104    /// Create a new outbound [`RendezvousChannel`].
105    ///
106    /// By outbound we mean that we're going to tell the Matrix server to create
107    /// a new rendezvous session. We're going to send an initial empty message
108    /// through the channel.
109    pub(super) async fn create_outbound(
110        client: HttpClient,
111        rendezvous_server: &Url,
112    ) -> Result<Self, HttpError> {
113        use std::borrow::Cow;
114
115        use ruma::api::{SupportedVersions, client::rendezvous::create_rendezvous_session};
116
117        let request = create_rendezvous_session::unstable_msc4108::Request::default();
118        let response = client
119            .send(
120                request,
121                None,
122                rendezvous_server.to_string(),
123                None,
124                Cow::Owned(SupportedVersions {
125                    versions: Default::default(),
126                    features: Default::default(),
127                }),
128                Default::default(),
129            )
130            .await?;
131
132        let rendezvous_url = response.url;
133        let etag = response.etag;
134
135        Ok(Self { client, rendezvous_url, etag })
136    }
137
138    /// Create a new inbound [`RendezvousChannel`].
139    ///
140    /// By inbound we mean that we're going to attempt to read an initial
141    /// message from the rendezvous session on the given [`rendezvous_url`].
142    pub(super) async fn create_inbound(
143        client: HttpClient,
144        rendezvous_url: &Url,
145    ) -> Result<InboundChannelCreationResult, HttpError> {
146        // Receive the initial message, which should be empty. But we need the ETAG to
147        // fully establish the rendezvous channel.
148        let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
149
150        let etag = response.etag.clone();
151
152        let initial_message = RendezvousMessage {
153            status_code: response.status_code,
154            body: response.body,
155            content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
156        };
157
158        let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
159
160        Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
161    }
162
163    /// Get the URL of the rendezvous session we're using to exchange messages
164    /// through the channel.
165    pub(super) fn rendezvous_url(&self) -> &Url {
166        &self.rendezvous_url
167    }
168
169    /// Send the given `message` through the [`RendezvousChannel`] to the other
170    /// device.
171    ///
172    /// The message must be of the `text/plain` content type.
173    #[instrument(skip_all)]
174    pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
175        let etag = self.etag.clone();
176
177        let request = self
178            .client
179            .inner
180            .request(Method::PUT, self.rendezvous_url().to_owned())
181            .body(message)
182            .header(IF_MATCH, etag)
183            .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
184
185        debug!("Sending a request to the rendezvous channel {request:?}");
186
187        let response = request.send().await?;
188        let status = response.status();
189
190        debug!("Response for the rendezvous sending request {response:?}");
191
192        if status.is_success() {
193            // We successfully send out a message, get the ETAG and update our internal copy
194            // of the ETAG.
195            let etag = get_header(response.headers(), &ETAG)?;
196            self.etag = etag;
197
198            Ok(())
199        } else {
200            let body = response.bytes().await?;
201            let error = response_to_error(status, body.to_vec());
202
203            return Err(error);
204        }
205    }
206
207    /// Attempt to receive a message from the [`RendezvousChannel`] from the
208    /// other device.
209    ///
210    /// The content should be of the `text/plain` content type but the parsing
211    /// and verification of this fact is left up to the caller.
212    ///
213    /// This method will wait in a loop for the channel to give us a new
214    /// message.
215    pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
216        loop {
217            let message = self.receive_single_message().await?;
218
219            trace!(
220                status_code = %message.status_code,
221                "Received data from the rendezvous channel"
222            );
223
224            if message.status_code == StatusCode::OK
225                && message.content_type == TEXT_PLAIN_CONTENT_TYPE
226                && !message.body.is_empty()
227            {
228                return Ok(message.body);
229            } else if message.status_code == StatusCode::NOT_MODIFIED {
230                sleep::sleep(POLL_TIMEOUT).await;
231                continue;
232            } else {
233                let error = response_to_error(message.status_code, message.body);
234
235                return Err(error);
236            }
237        }
238    }
239
240    #[instrument]
241    async fn receive_message_impl(
242        client: &reqwest::Client,
243        etag: Option<String>,
244        rendezvous_url: &Url,
245    ) -> Result<RendezvousGetResponse, HttpError> {
246        let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
247
248        if let Some(etag) = etag {
249            builder = builder.header(IF_NONE_MATCH, etag);
250        }
251
252        let response = builder.send().await?;
253
254        debug!("Received data from the rendezvous channel {response:?}");
255
256        let status_code = response.status();
257
258        if status_code.is_client_error() {
259            return Err(response_to_error(status_code, response.bytes().await?.to_vec()));
260        }
261
262        let headers = response.headers();
263        let etag = get_header(headers, &ETAG)?;
264        let expires = get_header(headers, &EXPIRES)?;
265        let last_modified = get_header(headers, &LAST_MODIFIED)?;
266        #[allow(clippy::result_large_err)]
267        let content_type = headers
268            .get(CONTENT_TYPE)
269            .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
270            .transpose()?
271            .map(ToOwned::to_owned);
272
273        let body = response.bytes().await?.to_vec();
274
275        let response =
276            RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
277
278        Ok(response)
279    }
280
281    async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
282        let etag = Some(self.etag.clone());
283
284        let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
285            Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
286
287        // We received a response with an ETAG, put it into the copy of our etag.
288        self.etag = etag;
289
290        let message = RendezvousMessage {
291            status_code,
292            body,
293            content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
294        };
295
296        Ok(message)
297    }
298}
299
300#[cfg(all(test, not(target_family = "wasm")))]
301mod test {
302    use matrix_sdk_test::async_test;
303    use serde_json::json;
304    use similar_asserts::assert_eq;
305    use wiremock::{
306        Mock, MockServer, ResponseTemplate,
307        matchers::{header, method, path},
308    };
309
310    use super::*;
311    use crate::config::RequestConfig;
312
313    async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
314        server
315            .register(
316                Mock::given(method("POST"))
317                    .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
318                    .respond_with(
319                        ResponseTemplate::new(200)
320                            .append_header("X-Max-Bytes", "10240")
321                            .append_header("ETag", "1")
322                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
323                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
324                            .set_body_json(json!({
325                                "url": rendezvous_url,
326                            })),
327                    ),
328            )
329            .await;
330    }
331
332    #[async_test]
333    async fn test_creation() {
334        let server = MockServer::start().await;
335        let url =
336            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
337        let rendezvous_url =
338            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
339
340        mock_rendzvous_create(&server, &rendezvous_url).await;
341
342        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
343
344        let mut alice = Channel::create_outbound(client, &url)
345            .await
346            .expect("We should be able to create an outbound rendezvous channel");
347
348        assert_eq!(
349            alice.rendezvous_url(),
350            &rendezvous_url,
351            "Alice should have configured the rendezvous URL correctly."
352        );
353
354        assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
355
356        let mut bob = {
357            let _scope = server
358                .register_as_scoped(
359                    Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
360                        ResponseTemplate::new(200)
361                            .append_header("Content-Type", "text/plain")
362                            .append_header("ETag", "2")
363                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
364                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
365                    ),
366                )
367                .await;
368
369            let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
370            let InboundChannelCreationResult { channel: bob, initial_message: _ } =
371                Channel::create_inbound(client, &rendezvous_url).await.expect(
372                    "We should be able to create a rendezvous channel from a received message",
373                );
374
375            assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
376
377            bob
378        };
379
380        assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
381
382        {
383            let _scope = server
384                .register_as_scoped(
385                    Mock::given(method("GET"))
386                        .and(path("/abcdEFG12345"))
387                        .and(header("if-none-match", "1"))
388                        .respond_with(
389                            ResponseTemplate::new(304)
390                                .append_header("ETag", "1")
391                                .append_header("Content-Type", "text/plain")
392                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
393                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
394                        ),
395                )
396                .await;
397
398            let response = alice
399                .receive_single_message()
400                .await
401                .expect("We should be able to wait for data on the rendezvous channel.");
402            assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
403        }
404
405        {
406            let _scope = server
407                .register_as_scoped(
408                    Mock::given(method("PUT"))
409                        .and(path("/abcdEFG12345"))
410                        .and(header("Content-Type", "text/plain"))
411                        .respond_with(
412                            ResponseTemplate::new(200)
413                                .append_header("ETag", "1")
414                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
415                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
416                        ),
417                )
418                .await;
419
420            bob.send(b"Hello world".to_vec())
421                .await
422                .expect("We should be able to send data to the rendezouvs server.");
423        }
424
425        {
426            let _scope = server
427                .register_as_scoped(
428                    Mock::given(method("GET"))
429                        .and(path("/abcdEFG12345"))
430                        .and(header("if-none-match", "1"))
431                        .respond_with(
432                            ResponseTemplate::new(200)
433                                .append_header("ETag", "3")
434                                .append_header("Content-Type", "text/plain")
435                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
436                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
437                                .set_body_string("Hello world"),
438                        ),
439                )
440                .await;
441
442            let response = alice
443                .receive_single_message()
444                .await
445                .expect("We should be able to wait and get data on the rendezvous channel.");
446
447            assert_eq!(response.status_code, StatusCode::OK);
448            assert_eq!(response.body, b"Hello world");
449            assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
450        }
451    }
452
453    #[async_test]
454    async fn test_retry_mechanism() {
455        let server = MockServer::start().await;
456        let url =
457            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
458        let rendezvous_url =
459            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
460        mock_rendzvous_create(&server, &rendezvous_url).await;
461
462        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
463
464        let mut alice = Channel::create_outbound(client, &url)
465            .await
466            .expect("We should be able to create an outbound rendezvous channel");
467
468        server
469            .register(
470                Mock::given(method("GET"))
471                    .and(path("/abcdEFG12345"))
472                    .and(header("if-none-match", "1"))
473                    .respond_with(
474                        ResponseTemplate::new(304)
475                            .append_header("ETag", "2")
476                            .append_header("Content-Type", "text/plain")
477                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
478                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
479                            .set_body_string(""),
480                    )
481                    .expect(1),
482            )
483            .await;
484
485        server
486            .register(
487                Mock::given(method("GET"))
488                    .and(path("/abcdEFG12345"))
489                    .and(header("if-none-match", "2"))
490                    .respond_with(
491                        ResponseTemplate::new(200)
492                            .append_header("ETag", "3")
493                            .append_header("Content-Type", "text/plain")
494                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
495                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
496                            .set_body_string("Hello world"),
497                    )
498                    .expect(1),
499            )
500            .await;
501
502        let response = alice
503            .receive()
504            .await
505            .expect("We should be able to wait and get data on the rendezvous channel.");
506
507        assert_eq!(response, b"Hello world");
508    }
509
510    #[async_test]
511    async fn test_receive_error() {
512        let server = MockServer::start().await;
513        let url =
514            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
515        let rendezvous_url =
516            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
517        mock_rendzvous_create(&server, &rendezvous_url).await;
518
519        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
520
521        let mut alice = Channel::create_outbound(client, &url)
522            .await
523            .expect("We should be able to create an outbound rendezvous channel");
524
525        {
526            let _scope = server
527                .register_as_scoped(
528                    Mock::given(method("GET"))
529                        .and(path("/abcdEFG12345"))
530                        .and(header("if-none-match", "1"))
531                        .respond_with(
532                            ResponseTemplate::new(404)
533                                .append_header("ETag", "1")
534                                .append_header("Content-Type", "text/plain")
535                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
536                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
537                                .set_body_string(""),
538                        )
539                        .expect(1),
540                )
541                .await;
542
543            alice.receive().await.expect_err("We should return an error if we receive a 404");
544        }
545
546        {
547            let _scope = server
548                .register_as_scoped(
549                    Mock::given(method("GET"))
550                        .and(path("/abcdEFG12345"))
551                        .and(header("if-none-match", "1"))
552                        .respond_with(
553                            ResponseTemplate::new(504)
554                                .append_header("ETag", "1")
555                                .append_header("Content-Type", "text/plain")
556                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
557                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
558                                .set_body_json(json!({
559                                  "errcode": "M_NOT_FOUND",
560                                  "error": "No resource was found for this request.",
561                                })),
562                        )
563                        .expect(1),
564                )
565                .await;
566
567            alice
568                .receive()
569                .await
570                .expect_err("We should return an error if we receive a gateway timeout");
571        }
572    }
573}