matrix_sdk/authentication/oauth/qrcode/
rendezvous_channel.rs

1// Copyright 2024 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use http::{
18    header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
19    HeaderMap, HeaderName, Method, StatusCode,
20};
21use ruma::api::{
22    error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError, MatrixError},
23    EndpointError,
24};
25use tracing::{debug, instrument, trace};
26use url::Url;
27
28use crate::{http_client::HttpClient, HttpError, RumaApiError};
29
30const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
31#[cfg(test)]
32const POLL_TIMEOUT: Duration = Duration::from_millis(10);
33#[cfg(not(test))]
34const POLL_TIMEOUT: Duration = Duration::from_secs(1);
35
36type Etag = String;
37
38/// Get a header from a [`HeaderMap`] and parse it as a UTF-8 string.
39fn get_header(
40    header_map: &HeaderMap,
41    header_name: &HeaderName,
42) -> Result<String, Box<FromHttpResponseError<RumaApiError>>> {
43    let header = header_map
44        .get(header_name)
45        .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))
46        .map_err(|error| Box::new(FromHttpResponseError::from(error)))?;
47
48    let header =
49        header.to_str().map_err(|error| Box::new(FromHttpResponseError::from(error)))?.to_owned();
50
51    Ok(header)
52}
53
54/// The result of the [`RendezvousChannel::create_inbound()`] method.
55pub(super) struct InboundChannelCreationResult {
56    /// The connected [`RendezvousChannel`].
57    pub channel: RendezvousChannel,
58    /// The initial message we received when we connected to the
59    /// [`RendezvousChannel`].
60    ///
61    /// This is currently unused, but left in for completeness sake.
62    #[allow(dead_code)]
63    pub initial_message: Vec<u8>,
64}
65
66struct RendezvousGetResponse {
67    pub status_code: StatusCode,
68    pub etag: String,
69    // TODO: This is currently unused, but will be required once we implement the reciprocation of
70    // a login. Left here so we don't forget about it. We should put this into the
71    // [`RendezvousChannel`] struct, once we parse it into a [`SystemTime`].
72    #[allow(dead_code)]
73    pub expires: String,
74    #[allow(dead_code)]
75    pub last_modified: String,
76    pub content_type: Option<String>,
77    pub body: Vec<u8>,
78}
79
80struct RendezvousMessage {
81    pub status_code: StatusCode,
82    pub body: Vec<u8>,
83    pub content_type: String,
84}
85
86pub(super) struct RendezvousChannel {
87    client: HttpClient,
88    rendezvous_url: Url,
89    etag: Etag,
90}
91
92fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
93    match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
94        Ok(response) => {
95            let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::Other(
96                MatrixError::from_http_response(response),
97            ));
98
99            error.into()
100        }
101        Err(e) => HttpError::IntoHttp(e),
102    }
103}
104
105impl RendezvousChannel {
106    /// Create a new outbound [`RendezvousChannel`].
107    ///
108    /// By outbound we mean that we're going to tell the Matrix server to create
109    /// a new rendezvous session. We're going to send an initial empty message
110    /// through the channel.
111    #[cfg(test)]
112    pub(super) async fn create_outbound(
113        client: HttpClient,
114        rendezvous_server: &Url,
115    ) -> Result<Self, HttpError> {
116        use ruma::api::client::rendezvous::create_rendezvous_session;
117
118        let request = create_rendezvous_session::unstable::Request::default();
119        let response = client
120            .send(request, None, rendezvous_server.to_string(), None, &[], Default::default())
121            .await?;
122
123        let rendezvous_url = response.url;
124        let etag = response.etag;
125
126        Ok(Self { client, rendezvous_url, etag })
127    }
128
129    /// Create a new inbound [`RendezvousChannel`].
130    ///
131    /// By inbound we mean that we're going to attempt to read an initial
132    /// message from the rendezvous session on the given [`rendezvous_url`].
133    pub(super) async fn create_inbound(
134        client: HttpClient,
135        rendezvous_url: &Url,
136    ) -> Result<InboundChannelCreationResult, HttpError> {
137        // Receive the initial message, which should be empty. But we need the ETAG to
138        // fully establish the rendezvous channel.
139        let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
140
141        let etag = response.etag.clone();
142
143        let initial_message = RendezvousMessage {
144            status_code: response.status_code,
145            body: response.body,
146            content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
147        };
148
149        let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
150
151        Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
152    }
153
154    /// Get the URL of the rendezvous session we're using to exchange messages
155    /// through the channel.
156    pub(super) fn rendezvous_url(&self) -> &Url {
157        &self.rendezvous_url
158    }
159
160    /// Send the given `message` through the [`RendezvousChannel`] to the other
161    /// device.
162    ///
163    /// The message must be of the `text/plain` content type.
164    #[instrument(skip_all)]
165    pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
166        let etag = self.etag.clone();
167
168        let request = self
169            .client
170            .inner
171            .request(Method::PUT, self.rendezvous_url().to_owned())
172            .body(message)
173            .header(IF_MATCH, etag)
174            .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
175
176        debug!("Sending a request to the rendezvous channel {request:?}");
177
178        let response = request.send().await?;
179        let status = response.status();
180
181        debug!("Response for the rendezvous sending request {response:?}");
182
183        if status.is_success() {
184            // We successfully send out a message, get the ETAG and update our internal copy
185            // of the ETAG.
186            let etag = get_header(response.headers(), &ETAG)?;
187            self.etag = etag;
188
189            Ok(())
190        } else {
191            let body = response.bytes().await?;
192            let error = response_to_error(status, body.to_vec());
193
194            return Err(error);
195        }
196    }
197
198    /// Attempt to receive a message from the [`RendezvousChannel`] from the
199    /// other device.
200    ///
201    /// The content should be of the `text/plain` content type but the parsing
202    /// and verification of this fact is left up to the caller.
203    ///
204    /// This method will wait in a loop for the channel to give us a new
205    /// message.
206    pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
207        loop {
208            let message = self.receive_single_message().await?;
209
210            trace!(
211                status_code = %message.status_code,
212                "Received data from the rendezvous channel"
213            );
214
215            if message.status_code == StatusCode::OK
216                && message.content_type == TEXT_PLAIN_CONTENT_TYPE
217                && !message.body.is_empty()
218            {
219                return Ok(message.body);
220            } else if message.status_code == StatusCode::NOT_MODIFIED {
221                tokio::time::sleep(POLL_TIMEOUT).await;
222                continue;
223            } else {
224                let error = response_to_error(message.status_code, message.body);
225
226                return Err(error);
227            }
228        }
229    }
230
231    #[instrument]
232    async fn receive_message_impl(
233        client: &reqwest::Client,
234        etag: Option<String>,
235        rendezvous_url: &Url,
236    ) -> Result<RendezvousGetResponse, HttpError> {
237        let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
238
239        if let Some(etag) = etag {
240            builder = builder.header(IF_NONE_MATCH, etag);
241        }
242
243        let response = builder.send().await?;
244
245        debug!("Received data from the rendezvous channel {response:?}");
246
247        let status_code = response.status();
248        let headers = response.headers();
249
250        let etag = get_header(headers, &ETAG)?;
251        let expires = get_header(headers, &EXPIRES)?;
252        let last_modified = get_header(headers, &LAST_MODIFIED)?;
253        let content_type = response
254            .headers()
255            .get(CONTENT_TYPE)
256            .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
257            .transpose()?
258            .map(ToOwned::to_owned);
259
260        let body = response.bytes().await?.to_vec();
261
262        let response =
263            RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
264
265        Ok(response)
266    }
267
268    async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
269        let etag = Some(self.etag.clone());
270
271        let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
272            Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
273
274        // We received a response with an ETAG, put it into the copy of our etag.
275        self.etag = etag;
276
277        let message = RendezvousMessage {
278            status_code,
279            body,
280            content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
281        };
282
283        Ok(message)
284    }
285}
286
287#[cfg(all(test, not(target_family = "wasm")))]
288mod test {
289    use matrix_sdk_test::async_test;
290    use serde_json::json;
291    use similar_asserts::assert_eq;
292    use wiremock::{
293        matchers::{header, method, path},
294        Mock, MockServer, ResponseTemplate,
295    };
296
297    use super::*;
298    use crate::config::RequestConfig;
299
300    async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
301        server
302            .register(
303                Mock::given(method("POST"))
304                    .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
305                    .respond_with(
306                        ResponseTemplate::new(200)
307                            .append_header("X-Max-Bytes", "10240")
308                            .append_header("ETag", "1")
309                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
310                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
311                            .set_body_json(json!({
312                                "url": rendezvous_url,
313                            })),
314                    ),
315            )
316            .await;
317    }
318
319    #[async_test]
320    async fn test_creation() {
321        let server = MockServer::start().await;
322        let url =
323            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
324        let rendezvous_url =
325            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
326
327        mock_rendzvous_create(&server, &rendezvous_url).await;
328
329        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
330
331        let mut alice = RendezvousChannel::create_outbound(client, &url)
332            .await
333            .expect("We should be able to create an outbound rendezvous channel");
334
335        assert_eq!(
336            alice.rendezvous_url(),
337            &rendezvous_url,
338            "Alice should have configured the rendezvous URL correctly."
339        );
340
341        assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
342
343        let mut bob = {
344            let _scope = server
345                .register_as_scoped(
346                    Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
347                        ResponseTemplate::new(200)
348                            .append_header("Content-Type", "text/plain")
349                            .append_header("ETag", "2")
350                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
351                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
352                    ),
353                )
354                .await;
355
356            let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
357            let InboundChannelCreationResult { channel: bob, initial_message: _ } =
358                RendezvousChannel::create_inbound(client, &rendezvous_url).await.expect(
359                    "We should be able to create a rendezvous channel from a received message",
360                );
361
362            assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
363
364            bob
365        };
366
367        assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
368
369        {
370            let _scope = server
371                .register_as_scoped(
372                    Mock::given(method("GET"))
373                        .and(path("/abcdEFG12345"))
374                        .and(header("if-none-match", "1"))
375                        .respond_with(
376                            ResponseTemplate::new(304)
377                                .append_header("ETag", "1")
378                                .append_header("Content-Type", "text/plain")
379                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
380                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
381                        ),
382                )
383                .await;
384
385            let response = alice
386                .receive_single_message()
387                .await
388                .expect("We should be able to wait for data on the rendezvous channel.");
389            assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
390        }
391
392        {
393            let _scope = server
394                .register_as_scoped(
395                    Mock::given(method("PUT"))
396                        .and(path("/abcdEFG12345"))
397                        .and(header("Content-Type", "text/plain"))
398                        .respond_with(
399                            ResponseTemplate::new(200)
400                                .append_header("ETag", "1")
401                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
402                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
403                        ),
404                )
405                .await;
406
407            bob.send(b"Hello world".to_vec())
408                .await
409                .expect("We should be able to send data to the rendezouvs server.");
410        }
411
412        {
413            let _scope = server
414                .register_as_scoped(
415                    Mock::given(method("GET"))
416                        .and(path("/abcdEFG12345"))
417                        .and(header("if-none-match", "1"))
418                        .respond_with(
419                            ResponseTemplate::new(200)
420                                .append_header("ETag", "3")
421                                .append_header("Content-Type", "text/plain")
422                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
423                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
424                                .set_body_string("Hello world"),
425                        ),
426                )
427                .await;
428
429            let response = alice
430                .receive_single_message()
431                .await
432                .expect("We should be able to wait and get data on the rendezvous channel.");
433
434            assert_eq!(response.status_code, StatusCode::OK);
435            assert_eq!(response.body, b"Hello world");
436            assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
437        }
438    }
439
440    #[async_test]
441    async fn test_retry_mechanism() {
442        let server = MockServer::start().await;
443        let url =
444            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
445        let rendezvous_url =
446            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
447        mock_rendzvous_create(&server, &rendezvous_url).await;
448
449        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
450
451        let mut alice = RendezvousChannel::create_outbound(client, &url)
452            .await
453            .expect("We should be able to create an outbound rendezvous channel");
454
455        server
456            .register(
457                Mock::given(method("GET"))
458                    .and(path("/abcdEFG12345"))
459                    .and(header("if-none-match", "1"))
460                    .respond_with(
461                        ResponseTemplate::new(304)
462                            .append_header("ETag", "2")
463                            .append_header("Content-Type", "text/plain")
464                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
465                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
466                            .set_body_string(""),
467                    )
468                    .expect(1),
469            )
470            .await;
471
472        server
473            .register(
474                Mock::given(method("GET"))
475                    .and(path("/abcdEFG12345"))
476                    .and(header("if-none-match", "2"))
477                    .respond_with(
478                        ResponseTemplate::new(200)
479                            .append_header("ETag", "3")
480                            .append_header("Content-Type", "text/plain")
481                            .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
482                            .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
483                            .set_body_string("Hello world"),
484                    )
485                    .expect(1),
486            )
487            .await;
488
489        let response = alice
490            .receive()
491            .await
492            .expect("We should be able to wait and get data on the rendezvous channel.");
493
494        assert_eq!(response, b"Hello world");
495    }
496
497    #[async_test]
498    async fn test_receive_error() {
499        let server = MockServer::start().await;
500        let url =
501            Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
502        let rendezvous_url =
503            url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
504        mock_rendzvous_create(&server, &rendezvous_url).await;
505
506        let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
507
508        let mut alice = RendezvousChannel::create_outbound(client, &url)
509            .await
510            .expect("We should be able to create an outbound rendezvous channel");
511
512        {
513            let _scope = server
514                .register_as_scoped(
515                    Mock::given(method("GET"))
516                        .and(path("/abcdEFG12345"))
517                        .and(header("if-none-match", "1"))
518                        .respond_with(
519                            ResponseTemplate::new(404)
520                                .append_header("ETag", "1")
521                                .append_header("Content-Type", "text/plain")
522                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
523                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
524                                .set_body_string(""),
525                        )
526                        .expect(1),
527                )
528                .await;
529
530            alice.receive().await.expect_err("We should return an error if we receive a 404");
531        }
532
533        {
534            let _scope = server
535                .register_as_scoped(
536                    Mock::given(method("GET"))
537                        .and(path("/abcdEFG12345"))
538                        .and(header("if-none-match", "1"))
539                        .respond_with(
540                            ResponseTemplate::new(504)
541                                .append_header("ETag", "1")
542                                .append_header("Content-Type", "text/plain")
543                                .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
544                                .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
545                                .set_body_json(json!({
546                                  "errcode": "M_NOT_FOUND",
547                                  "error": "No resource was found for this request.",
548                                })),
549                        )
550                        .expect(1),
551                )
552                .await;
553
554            alice
555                .receive()
556                .await
557                .expect_err("We should return an error if we receive a gateway timeout");
558        }
559    }
560}