matrix_sdk/authentication/oauth/qrcode/
rendezvous_channel.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use http::{
18    HeaderMap, HeaderName, Method, StatusCode,
19    header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
20};
21use matrix_sdk_base::sleep;
22use ruma::api::{
23    EndpointError,
24    error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError},
25};
26use tracing::{debug, instrument, trace};
27use url::Url;
28
29use crate::{HttpError, RumaApiError, http_client::HttpClient};
30
31const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
32#[cfg(test)]
33const POLL_TIMEOUT: Duration = Duration::from_millis(10);
34#[cfg(not(test))]
35const POLL_TIMEOUT: Duration = Duration::from_secs(1);
36
37type Etag = String;
38
39/// Get a header from a [`HeaderMap`] and parse it as a UTF-8 string.
40fn get_header(
41    header_map: &HeaderMap,
42    header_name: &HeaderName,
43) -> Result<String, Box<FromHttpResponseError<RumaApiError>>> {
44    let header = header_map
45        .get(header_name)
46        .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))
47        .map_err(|error| Box::new(FromHttpResponseError::from(error)))?;
48
49    let header =
50        header.to_str().map_err(|error| Box::new(FromHttpResponseError::from(error)))?.to_owned();
51
52    Ok(header)
53}
54
55/// The result of the [`RendezvousChannel::create_inbound()`] method.
56pub(super) struct InboundChannelCreationResult {
57    /// The connected [`RendezvousChannel`].
58    pub channel: RendezvousChannel,
59    /// The initial message we received when we connected to the
60    /// [`RendezvousChannel`].
61    ///
62    /// This is currently unused, but left in for completeness sake.
63    #[allow(dead_code)]
64    pub initial_message: Vec<u8>,
65}
66
67struct RendezvousGetResponse {
68    pub status_code: StatusCode,
69    pub etag: String,
70    // TODO: This is currently unused, but will be required once we implement the reciprocation of
71    // a login. Left here so we don't forget about it. We should put this into the
72    // [`RendezvousChannel`] struct, once we parse it into a [`SystemTime`].
73    #[allow(dead_code)]
74    pub expires: String,
75    #[allow(dead_code)]
76    pub last_modified: String,
77    pub content_type: Option<String>,
78    pub body: Vec<u8>,
79}
80
81struct RendezvousMessage {
82    pub status_code: StatusCode,
83    pub body: Vec<u8>,
84    pub content_type: String,
85}
86
87pub(super) struct RendezvousChannel {
88    client: HttpClient,
89    rendezvous_url: Url,
90    etag: Etag,
91}
92
93fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
94    match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
95        Ok(response) => {
96            let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::ClientApi(
97                ruma::api::client::Error::from_http_response(response),
98            ));
99
100            error.into()
101        }
102        Err(e) => HttpError::IntoHttp(e),
103    }
104}
105
106impl RendezvousChannel {
107    /// Create a new outbound [`RendezvousChannel`].
108    ///
109    /// By outbound we mean that we're going to tell the Matrix server to create
110    /// a new rendezvous session. We're going to send an initial empty message
111    /// through the channel.
112    pub(super) async fn create_outbound(
113        client: HttpClient,
114        rendezvous_server: &Url,
115    ) -> Result<Self, HttpError> {
116        use std::borrow::Cow;
117
118        use ruma::api::{SupportedVersions, client::rendezvous::create_rendezvous_session};
119
120        let request = create_rendezvous_session::unstable::Request::default();
121        let response = client
122            .send(
123                request,
124                None,
125                rendezvous_server.to_string(),
126                None,
127                Cow::Owned(SupportedVersions {
128                    versions: Default::default(),
129                    features: Default::default(),
130                }),
131                Default::default(),
132            )
133            .await?;
134
135        let rendezvous_url = response.url;
136        let etag = response.etag;
137
138        Ok(Self { client, rendezvous_url, etag })
139    }
140
141    /// Create a new inbound [`RendezvousChannel`].
142    ///
143    /// By inbound we mean that we're going to attempt to read an initial
144    /// message from the rendezvous session on the given [`rendezvous_url`].
145    pub(super) async fn create_inbound(
146        client: HttpClient,
147        rendezvous_url: &Url,
148    ) -> Result<InboundChannelCreationResult, HttpError> {
149        // Receive the initial message, which should be empty. But we need the ETAG to
150        // fully establish the rendezvous channel.
151        let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
152
153        let etag = response.etag.clone();
154
155        let initial_message = RendezvousMessage {
156            status_code: response.status_code,
157            body: response.body,
158            content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
159        };
160
161        let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
162
163        Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
164    }
165
166    /// Get the URL of the rendezvous session we're using to exchange messages
167    /// through the channel.
168    pub(super) fn rendezvous_url(&self) -> &Url {
169        &self.rendezvous_url
170    }
171
172    /// Send the given `message` through the [`RendezvousChannel`] to the other
173    /// device.
174    ///
175    /// The message must be of the `text/plain` content type.
176    #[instrument(skip_all)]
177    pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
178        let etag = self.etag.clone();
179
180        let request = self
181            .client
182            .inner
183            .request(Method::PUT, self.rendezvous_url().to_owned())
184            .body(message)
185            .header(IF_MATCH, etag)
186            .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
187
188        debug!("Sending a request to the rendezvous channel {request:?}");
189
190        let response = request.send().await?;
191        let status = response.status();
192
193        debug!("Response for the rendezvous sending request {response:?}");
194
195        if status.is_success() {
196            // We successfully send out a message, get the ETAG and update our internal copy
197            // of the ETAG.
198            let etag = get_header(response.headers(), &ETAG)?;
199            self.etag = etag;
200
201            Ok(())
202        } else {
203            let body = response.bytes().await?;
204            let error = response_to_error(status, body.to_vec());
205
206            return Err(error);
207        }
208    }
209
210    /// Attempt to receive a message from the [`RendezvousChannel`] from the
211    /// other device.
212    ///
213    /// The content should be of the `text/plain` content type but the parsing
214    /// and verification of this fact is left up to the caller.
215    ///
216    /// This method will wait in a loop for the channel to give us a new
217    /// message.
218    pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
219        loop {
220            let message = self.receive_single_message().await?;
221
222            trace!(
223                status_code = %message.status_code,
224                "Received data from the rendezvous channel"
225            );
226
227            if message.status_code == StatusCode::OK
228                && message.content_type == TEXT_PLAIN_CONTENT_TYPE
229                && !message.body.is_empty()
230            {
231                return Ok(message.body);
232            } else if message.status_code == StatusCode::NOT_MODIFIED {
233                sleep::sleep(POLL_TIMEOUT).await;
234                continue;
235            } else {
236                let error = response_to_error(message.status_code, message.body);
237
238                return Err(error);
239            }
240        }
241    }
242
243    #[instrument]
244    async fn receive_message_impl(
245        client: &reqwest::Client,
246        etag: Option<String>,
247        rendezvous_url: &Url,
248    ) -> Result<RendezvousGetResponse, HttpError> {
249        let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
250
251        if let Some(etag) = etag {
252            builder = builder.header(IF_NONE_MATCH, etag);
253        }
254
255        let response = builder.send().await?;
256
257        debug!("Received data from the rendezvous channel {response:?}");
258
259        let status_code = response.status();
260
261        if status_code.is_client_error() {
262            return Err(response_to_error(status_code, response.bytes().await?.to_vec()));
263        }
264
265        let headers = response.headers();
266        let etag = get_header(headers, &ETAG)?;
267        let expires = get_header(headers, &EXPIRES)?;
268        let last_modified = get_header(headers, &LAST_MODIFIED)?;
269        let content_type = headers
270            .get(CONTENT_TYPE)
271            .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
272            .transpose()?
273            .map(ToOwned::to_owned);
274
275        let body = response.bytes().await?.to_vec();
276
277        let response =
278            RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
279
280        Ok(response)
281    }
282
283    async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
284        let etag = Some(self.etag.clone());
285
286        let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
287            Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
288
289        // We received a response with an ETAG, put it into the copy of our etag.
290        self.etag = etag;
291
292        let message = RendezvousMessage {
293            status_code,
294            body,
295            content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
296        };
297
298        Ok(message)
299    }
300}
301
302#[cfg(all(test, not(target_family = "wasm")))]
303mod test {
304    use matrix_sdk_test::async_test;
305    use serde_json::json;
306    use similar_asserts::assert_eq;
307    use wiremock::{
308        Mock, MockServer, ResponseTemplate,
309        matchers::{header, method, path},
310    };
311
312    use super::*;
313    use crate::config::RequestConfig;
314
315    async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
316        server
317            .register(
318                Mock::given(method("POST"))
319                    .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
320                    .respond_with(
321                        ResponseTemplate::new(200)
322                            .append_header("X-Max-Bytes", "10240")
323                            .append_header("ETag", "1")
324                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
325                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
326                            .set_body_json(json!({
327                                "url": rendezvous_url,
328                            })),
329                    ),
330            )
331            .await;
332    }
333
334    #[async_test]
335    async fn test_creation() {
336        let server = MockServer::start().await;
337        let url =
338            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
339        let rendezvous_url =
340            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
341
342        mock_rendzvous_create(&server, &rendezvous_url).await;
343
344        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
345
346        let mut alice = RendezvousChannel::create_outbound(client, &url)
347            .await
348            .expect("We should be able to create an outbound rendezvous channel");
349
350        assert_eq!(
351            alice.rendezvous_url(),
352            &rendezvous_url,
353            "Alice should have configured the rendezvous URL correctly."
354        );
355
356        assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
357
358        let mut bob = {
359            let _scope = server
360                .register_as_scoped(
361                    Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
362                        ResponseTemplate::new(200)
363                            .append_header("Content-Type", "text/plain")
364                            .append_header("ETag", "2")
365                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
366                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
367                    ),
368                )
369                .await;
370
371            let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
372            let InboundChannelCreationResult { channel: bob, initial_message: _ } =
373                RendezvousChannel::create_inbound(client, &rendezvous_url).await.expect(
374                    "We should be able to create a rendezvous channel from a received message",
375                );
376
377            assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
378
379            bob
380        };
381
382        assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
383
384        {
385            let _scope = server
386                .register_as_scoped(
387                    Mock::given(method("GET"))
388                        .and(path("/abcdEFG12345"))
389                        .and(header("if-none-match", "1"))
390                        .respond_with(
391                            ResponseTemplate::new(304)
392                                .append_header("ETag", "1")
393                                .append_header("Content-Type", "text/plain")
394                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
395                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
396                        ),
397                )
398                .await;
399
400            let response = alice
401                .receive_single_message()
402                .await
403                .expect("We should be able to wait for data on the rendezvous channel.");
404            assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
405        }
406
407        {
408            let _scope = server
409                .register_as_scoped(
410                    Mock::given(method("PUT"))
411                        .and(path("/abcdEFG12345"))
412                        .and(header("Content-Type", "text/plain"))
413                        .respond_with(
414                            ResponseTemplate::new(200)
415                                .append_header("ETag", "1")
416                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
417                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
418                        ),
419                )
420                .await;
421
422            bob.send(b"Hello world".to_vec())
423                .await
424                .expect("We should be able to send data to the rendezouvs server.");
425        }
426
427        {
428            let _scope = server
429                .register_as_scoped(
430                    Mock::given(method("GET"))
431                        .and(path("/abcdEFG12345"))
432                        .and(header("if-none-match", "1"))
433                        .respond_with(
434                            ResponseTemplate::new(200)
435                                .append_header("ETag", "3")
436                                .append_header("Content-Type", "text/plain")
437                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
438                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
439                                .set_body_string("Hello world"),
440                        ),
441                )
442                .await;
443
444            let response = alice
445                .receive_single_message()
446                .await
447                .expect("We should be able to wait and get data on the rendezvous channel.");
448
449            assert_eq!(response.status_code, StatusCode::OK);
450            assert_eq!(response.body, b"Hello world");
451            assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
452        }
453    }
454
455    #[async_test]
456    async fn test_retry_mechanism() {
457        let server = MockServer::start().await;
458        let url =
459            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
460        let rendezvous_url =
461            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
462        mock_rendzvous_create(&server, &rendezvous_url).await;
463
464        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
465
466        let mut alice = RendezvousChannel::create_outbound(client, &url)
467            .await
468            .expect("We should be able to create an outbound rendezvous channel");
469
470        server
471            .register(
472                Mock::given(method("GET"))
473                    .and(path("/abcdEFG12345"))
474                    .and(header("if-none-match", "1"))
475                    .respond_with(
476                        ResponseTemplate::new(304)
477                            .append_header("ETag", "2")
478                            .append_header("Content-Type", "text/plain")
479                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
480                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
481                            .set_body_string(""),
482                    )
483                    .expect(1),
484            )
485            .await;
486
487        server
488            .register(
489                Mock::given(method("GET"))
490                    .and(path("/abcdEFG12345"))
491                    .and(header("if-none-match", "2"))
492                    .respond_with(
493                        ResponseTemplate::new(200)
494                            .append_header("ETag", "3")
495                            .append_header("Content-Type", "text/plain")
496                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
497                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
498                            .set_body_string("Hello world"),
499                    )
500                    .expect(1),
501            )
502            .await;
503
504        let response = alice
505            .receive()
506            .await
507            .expect("We should be able to wait and get data on the rendezvous channel.");
508
509        assert_eq!(response, b"Hello world");
510    }
511
512    #[async_test]
513    async fn test_receive_error() {
514        let server = MockServer::start().await;
515        let url =
516            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
517        let rendezvous_url =
518            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
519        mock_rendzvous_create(&server, &rendezvous_url).await;
520
521        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
522
523        let mut alice = RendezvousChannel::create_outbound(client, &url)
524            .await
525            .expect("We should be able to create an outbound rendezvous channel");
526
527        {
528            let _scope = server
529                .register_as_scoped(
530                    Mock::given(method("GET"))
531                        .and(path("/abcdEFG12345"))
532                        .and(header("if-none-match", "1"))
533                        .respond_with(
534                            ResponseTemplate::new(404)
535                                .append_header("ETag", "1")
536                                .append_header("Content-Type", "text/plain")
537                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
538                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
539                                .set_body_string(""),
540                        )
541                        .expect(1),
542                )
543                .await;
544
545            alice.receive().await.expect_err("We should return an error if we receive a 404");
546        }
547
548        {
549            let _scope = server
550                .register_as_scoped(
551                    Mock::given(method("GET"))
552                        .and(path("/abcdEFG12345"))
553                        .and(header("if-none-match", "1"))
554                        .respond_with(
555                            ResponseTemplate::new(504)
556                                .append_header("ETag", "1")
557                                .append_header("Content-Type", "text/plain")
558                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
559                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
560                                .set_body_json(json!({
561                                  "errcode": "M_NOT_FOUND",
562                                  "error": "No resource was found for this request.",
563                                })),
564                        )
565                        .expect(1),
566                )
567                .await;
568
569            alice
570                .receive()
571                .await
572                .expect_err("We should return an error if we receive a gateway timeout");
573        }
574    }
575}