1use std::time::Duration;
16
17use http::{
18 HeaderMap, HeaderName, Method, StatusCode,
19 header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
20};
21use matrix_sdk_base::sleep;
22use ruma::api::{
23 EndpointError,
24 error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError},
25};
26use tracing::{debug, instrument, trace};
27use url::Url;
28
29use crate::{HttpError, RumaApiError, http_client::HttpClient};
30
31const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
32#[cfg(test)]
33const POLL_TIMEOUT: Duration = Duration::from_millis(10);
34#[cfg(not(test))]
35const POLL_TIMEOUT: Duration = Duration::from_secs(1);
36
37type Etag = String;
38
39fn get_header(
41 header_map: &HeaderMap,
42 header_name: &HeaderName,
43) -> Result<String, Box<FromHttpResponseError<RumaApiError>>> {
44 let header = header_map
45 .get(header_name)
46 .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))
47 .map_err(|error| Box::new(FromHttpResponseError::from(error)))?;
48
49 let header =
50 header.to_str().map_err(|error| Box::new(FromHttpResponseError::from(error)))?.to_owned();
51
52 Ok(header)
53}
54
55pub(super) struct InboundChannelCreationResult {
57 pub channel: RendezvousChannel,
59 #[allow(dead_code)]
64 pub initial_message: Vec<u8>,
65}
66
67struct RendezvousGetResponse {
68 pub status_code: StatusCode,
69 pub etag: String,
70 #[allow(dead_code)]
74 pub expires: String,
75 #[allow(dead_code)]
76 pub last_modified: String,
77 pub content_type: Option<String>,
78 pub body: Vec<u8>,
79}
80
81struct RendezvousMessage {
82 pub status_code: StatusCode,
83 pub body: Vec<u8>,
84 pub content_type: String,
85}
86
87pub(super) struct RendezvousChannel {
88 client: HttpClient,
89 rendezvous_url: Url,
90 etag: Etag,
91}
92
93fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
94 match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
95 Ok(response) => {
96 let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::ClientApi(
97 ruma::api::client::Error::from_http_response(response),
98 ));
99
100 error.into()
101 }
102 Err(e) => HttpError::IntoHttp(e),
103 }
104}
105
106impl RendezvousChannel {
107 pub(super) async fn create_outbound(
113 client: HttpClient,
114 rendezvous_server: &Url,
115 ) -> Result<Self, HttpError> {
116 use std::borrow::Cow;
117
118 use ruma::api::{SupportedVersions, client::rendezvous::create_rendezvous_session};
119
120 let request = create_rendezvous_session::unstable::Request::default();
121 let response = client
122 .send(
123 request,
124 None,
125 rendezvous_server.to_string(),
126 None,
127 Cow::Owned(SupportedVersions {
128 versions: Default::default(),
129 features: Default::default(),
130 }),
131 Default::default(),
132 )
133 .await?;
134
135 let rendezvous_url = response.url;
136 let etag = response.etag;
137
138 Ok(Self { client, rendezvous_url, etag })
139 }
140
141 pub(super) async fn create_inbound(
146 client: HttpClient,
147 rendezvous_url: &Url,
148 ) -> Result<InboundChannelCreationResult, HttpError> {
149 let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
152
153 let etag = response.etag.clone();
154
155 let initial_message = RendezvousMessage {
156 status_code: response.status_code,
157 body: response.body,
158 content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
159 };
160
161 let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
162
163 Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
164 }
165
166 pub(super) fn rendezvous_url(&self) -> &Url {
169 &self.rendezvous_url
170 }
171
172 #[instrument(skip_all)]
177 pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
178 let etag = self.etag.clone();
179
180 let request = self
181 .client
182 .inner
183 .request(Method::PUT, self.rendezvous_url().to_owned())
184 .body(message)
185 .header(IF_MATCH, etag)
186 .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
187
188 debug!("Sending a request to the rendezvous channel {request:?}");
189
190 let response = request.send().await?;
191 let status = response.status();
192
193 debug!("Response for the rendezvous sending request {response:?}");
194
195 if status.is_success() {
196 let etag = get_header(response.headers(), &ETAG)?;
199 self.etag = etag;
200
201 Ok(())
202 } else {
203 let body = response.bytes().await?;
204 let error = response_to_error(status, body.to_vec());
205
206 return Err(error);
207 }
208 }
209
210 pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
219 loop {
220 let message = self.receive_single_message().await?;
221
222 trace!(
223 status_code = %message.status_code,
224 "Received data from the rendezvous channel"
225 );
226
227 if message.status_code == StatusCode::OK
228 && message.content_type == TEXT_PLAIN_CONTENT_TYPE
229 && !message.body.is_empty()
230 {
231 return Ok(message.body);
232 } else if message.status_code == StatusCode::NOT_MODIFIED {
233 sleep::sleep(POLL_TIMEOUT).await;
234 continue;
235 } else {
236 let error = response_to_error(message.status_code, message.body);
237
238 return Err(error);
239 }
240 }
241 }
242
243 #[instrument]
244 async fn receive_message_impl(
245 client: &reqwest::Client,
246 etag: Option<String>,
247 rendezvous_url: &Url,
248 ) -> Result<RendezvousGetResponse, HttpError> {
249 let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
250
251 if let Some(etag) = etag {
252 builder = builder.header(IF_NONE_MATCH, etag);
253 }
254
255 let response = builder.send().await?;
256
257 debug!("Received data from the rendezvous channel {response:?}");
258
259 let status_code = response.status();
260
261 if status_code.is_client_error() {
262 return Err(response_to_error(status_code, response.bytes().await?.to_vec()));
263 }
264
265 let headers = response.headers();
266 let etag = get_header(headers, &ETAG)?;
267 let expires = get_header(headers, &EXPIRES)?;
268 let last_modified = get_header(headers, &LAST_MODIFIED)?;
269 let content_type = headers
270 .get(CONTENT_TYPE)
271 .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
272 .transpose()?
273 .map(ToOwned::to_owned);
274
275 let body = response.bytes().await?.to_vec();
276
277 let response =
278 RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
279
280 Ok(response)
281 }
282
283 async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
284 let etag = Some(self.etag.clone());
285
286 let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
287 Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
288
289 self.etag = etag;
291
292 let message = RendezvousMessage {
293 status_code,
294 body,
295 content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
296 };
297
298 Ok(message)
299 }
300}
301
302#[cfg(all(test, not(target_family = "wasm")))]
303mod test {
304 use matrix_sdk_test::async_test;
305 use serde_json::json;
306 use similar_asserts::assert_eq;
307 use wiremock::{
308 Mock, MockServer, ResponseTemplate,
309 matchers::{header, method, path},
310 };
311
312 use super::*;
313 use crate::config::RequestConfig;
314
315 async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
316 server
317 .register(
318 Mock::given(method("POST"))
319 .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
320 .respond_with(
321 ResponseTemplate::new(200)
322 .append_header("X-Max-Bytes", "10240")
323 .append_header("ETag", "1")
324 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
325 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
326 .set_body_json(json!({
327 "url": rendezvous_url,
328 })),
329 ),
330 )
331 .await;
332 }
333
334 #[async_test]
335 async fn test_creation() {
336 let server = MockServer::start().await;
337 let url =
338 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
339 let rendezvous_url =
340 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
341
342 mock_rendzvous_create(&server, &rendezvous_url).await;
343
344 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
345
346 let mut alice = RendezvousChannel::create_outbound(client, &url)
347 .await
348 .expect("We should be able to create an outbound rendezvous channel");
349
350 assert_eq!(
351 alice.rendezvous_url(),
352 &rendezvous_url,
353 "Alice should have configured the rendezvous URL correctly."
354 );
355
356 assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
357
358 let mut bob = {
359 let _scope = server
360 .register_as_scoped(
361 Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
362 ResponseTemplate::new(200)
363 .append_header("Content-Type", "text/plain")
364 .append_header("ETag", "2")
365 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
366 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
367 ),
368 )
369 .await;
370
371 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
372 let InboundChannelCreationResult { channel: bob, initial_message: _ } =
373 RendezvousChannel::create_inbound(client, &rendezvous_url).await.expect(
374 "We should be able to create a rendezvous channel from a received message",
375 );
376
377 assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
378
379 bob
380 };
381
382 assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
383
384 {
385 let _scope = server
386 .register_as_scoped(
387 Mock::given(method("GET"))
388 .and(path("/abcdEFG12345"))
389 .and(header("if-none-match", "1"))
390 .respond_with(
391 ResponseTemplate::new(304)
392 .append_header("ETag", "1")
393 .append_header("Content-Type", "text/plain")
394 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
395 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
396 ),
397 )
398 .await;
399
400 let response = alice
401 .receive_single_message()
402 .await
403 .expect("We should be able to wait for data on the rendezvous channel.");
404 assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
405 }
406
407 {
408 let _scope = server
409 .register_as_scoped(
410 Mock::given(method("PUT"))
411 .and(path("/abcdEFG12345"))
412 .and(header("Content-Type", "text/plain"))
413 .respond_with(
414 ResponseTemplate::new(200)
415 .append_header("ETag", "1")
416 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
417 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
418 ),
419 )
420 .await;
421
422 bob.send(b"Hello world".to_vec())
423 .await
424 .expect("We should be able to send data to the rendezouvs server.");
425 }
426
427 {
428 let _scope = server
429 .register_as_scoped(
430 Mock::given(method("GET"))
431 .and(path("/abcdEFG12345"))
432 .and(header("if-none-match", "1"))
433 .respond_with(
434 ResponseTemplate::new(200)
435 .append_header("ETag", "3")
436 .append_header("Content-Type", "text/plain")
437 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
438 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
439 .set_body_string("Hello world"),
440 ),
441 )
442 .await;
443
444 let response = alice
445 .receive_single_message()
446 .await
447 .expect("We should be able to wait and get data on the rendezvous channel.");
448
449 assert_eq!(response.status_code, StatusCode::OK);
450 assert_eq!(response.body, b"Hello world");
451 assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
452 }
453 }
454
455 #[async_test]
456 async fn test_retry_mechanism() {
457 let server = MockServer::start().await;
458 let url =
459 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
460 let rendezvous_url =
461 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
462 mock_rendzvous_create(&server, &rendezvous_url).await;
463
464 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
465
466 let mut alice = RendezvousChannel::create_outbound(client, &url)
467 .await
468 .expect("We should be able to create an outbound rendezvous channel");
469
470 server
471 .register(
472 Mock::given(method("GET"))
473 .and(path("/abcdEFG12345"))
474 .and(header("if-none-match", "1"))
475 .respond_with(
476 ResponseTemplate::new(304)
477 .append_header("ETag", "2")
478 .append_header("Content-Type", "text/plain")
479 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
480 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
481 .set_body_string(""),
482 )
483 .expect(1),
484 )
485 .await;
486
487 server
488 .register(
489 Mock::given(method("GET"))
490 .and(path("/abcdEFG12345"))
491 .and(header("if-none-match", "2"))
492 .respond_with(
493 ResponseTemplate::new(200)
494 .append_header("ETag", "3")
495 .append_header("Content-Type", "text/plain")
496 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
497 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
498 .set_body_string("Hello world"),
499 )
500 .expect(1),
501 )
502 .await;
503
504 let response = alice
505 .receive()
506 .await
507 .expect("We should be able to wait and get data on the rendezvous channel.");
508
509 assert_eq!(response, b"Hello world");
510 }
511
512 #[async_test]
513 async fn test_receive_error() {
514 let server = MockServer::start().await;
515 let url =
516 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
517 let rendezvous_url =
518 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
519 mock_rendzvous_create(&server, &rendezvous_url).await;
520
521 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
522
523 let mut alice = RendezvousChannel::create_outbound(client, &url)
524 .await
525 .expect("We should be able to create an outbound rendezvous channel");
526
527 {
528 let _scope = server
529 .register_as_scoped(
530 Mock::given(method("GET"))
531 .and(path("/abcdEFG12345"))
532 .and(header("if-none-match", "1"))
533 .respond_with(
534 ResponseTemplate::new(404)
535 .append_header("ETag", "1")
536 .append_header("Content-Type", "text/plain")
537 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
538 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
539 .set_body_string(""),
540 )
541 .expect(1),
542 )
543 .await;
544
545 alice.receive().await.expect_err("We should return an error if we receive a 404");
546 }
547
548 {
549 let _scope = server
550 .register_as_scoped(
551 Mock::given(method("GET"))
552 .and(path("/abcdEFG12345"))
553 .and(header("if-none-match", "1"))
554 .respond_with(
555 ResponseTemplate::new(504)
556 .append_header("ETag", "1")
557 .append_header("Content-Type", "text/plain")
558 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
559 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
560 .set_body_json(json!({
561 "errcode": "M_NOT_FOUND",
562 "error": "No resource was found for this request.",
563 })),
564 )
565 .expect(1),
566 )
567 .await;
568
569 alice
570 .receive()
571 .await
572 .expect_err("We should return an error if we receive a gateway timeout");
573 }
574 }
575}