1use std::time::Duration;
16
17use http::{
18 header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
19 HeaderMap, HeaderName, Method, StatusCode,
20};
21use ruma::api::{
22 error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError, MatrixError},
23 EndpointError,
24};
25use tracing::{debug, instrument, trace};
26use url::Url;
27
28use crate::{http_client::HttpClient, HttpError, RumaApiError};
29
30const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
31#[cfg(test)]
32const POLL_TIMEOUT: Duration = Duration::from_millis(10);
33#[cfg(not(test))]
34const POLL_TIMEOUT: Duration = Duration::from_secs(1);
35
36type Etag = String;
37
38fn get_header(
40 header_map: &HeaderMap,
41 header_name: &HeaderName,
42) -> Result<String, Box<FromHttpResponseError<RumaApiError>>> {
43 let header = header_map
44 .get(header_name)
45 .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))
46 .map_err(|error| Box::new(FromHttpResponseError::from(error)))?;
47
48 let header =
49 header.to_str().map_err(|error| Box::new(FromHttpResponseError::from(error)))?.to_owned();
50
51 Ok(header)
52}
53
54pub(super) struct InboundChannelCreationResult {
56 pub channel: RendezvousChannel,
58 #[allow(dead_code)]
63 pub initial_message: Vec<u8>,
64}
65
66struct RendezvousGetResponse {
67 pub status_code: StatusCode,
68 pub etag: String,
69 #[allow(dead_code)]
73 pub expires: String,
74 #[allow(dead_code)]
75 pub last_modified: String,
76 pub content_type: Option<String>,
77 pub body: Vec<u8>,
78}
79
80struct RendezvousMessage {
81 pub status_code: StatusCode,
82 pub body: Vec<u8>,
83 pub content_type: String,
84}
85
86pub(super) struct RendezvousChannel {
87 client: HttpClient,
88 rendezvous_url: Url,
89 etag: Etag,
90}
91
92fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
93 match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
94 Ok(response) => {
95 let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::Other(
96 MatrixError::from_http_response(response),
97 ));
98
99 error.into()
100 }
101 Err(e) => HttpError::IntoHttp(e),
102 }
103}
104
105impl RendezvousChannel {
106 #[cfg(test)]
112 pub(super) async fn create_outbound(
113 client: HttpClient,
114 rendezvous_server: &Url,
115 ) -> Result<Self, HttpError> {
116 use ruma::api::client::rendezvous::create_rendezvous_session;
117
118 let request = create_rendezvous_session::unstable::Request::default();
119 let response = client
120 .send(request, None, rendezvous_server.to_string(), None, &[], Default::default())
121 .await?;
122
123 let rendezvous_url = response.url;
124 let etag = response.etag;
125
126 Ok(Self { client, rendezvous_url, etag })
127 }
128
129 pub(super) async fn create_inbound(
134 client: HttpClient,
135 rendezvous_url: &Url,
136 ) -> Result<InboundChannelCreationResult, HttpError> {
137 let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
140
141 let etag = response.etag.clone();
142
143 let initial_message = RendezvousMessage {
144 status_code: response.status_code,
145 body: response.body,
146 content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
147 };
148
149 let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
150
151 Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
152 }
153
154 pub(super) fn rendezvous_url(&self) -> &Url {
157 &self.rendezvous_url
158 }
159
160 #[instrument(skip_all)]
165 pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
166 let etag = self.etag.clone();
167
168 let request = self
169 .client
170 .inner
171 .request(Method::PUT, self.rendezvous_url().to_owned())
172 .body(message)
173 .header(IF_MATCH, etag)
174 .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
175
176 debug!("Sending a request to the rendezvous channel {request:?}");
177
178 let response = request.send().await?;
179 let status = response.status();
180
181 debug!("Response for the rendezvous sending request {response:?}");
182
183 if status.is_success() {
184 let etag = get_header(response.headers(), &ETAG)?;
187 self.etag = etag;
188
189 Ok(())
190 } else {
191 let body = response.bytes().await?;
192 let error = response_to_error(status, body.to_vec());
193
194 return Err(error);
195 }
196 }
197
198 pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
207 loop {
208 let message = self.receive_single_message().await?;
209
210 trace!(
211 status_code = %message.status_code,
212 "Received data from the rendezvous channel"
213 );
214
215 if message.status_code == StatusCode::OK
216 && message.content_type == TEXT_PLAIN_CONTENT_TYPE
217 && !message.body.is_empty()
218 {
219 return Ok(message.body);
220 } else if message.status_code == StatusCode::NOT_MODIFIED {
221 tokio::time::sleep(POLL_TIMEOUT).await;
222 continue;
223 } else {
224 let error = response_to_error(message.status_code, message.body);
225
226 return Err(error);
227 }
228 }
229 }
230
231 #[instrument]
232 async fn receive_message_impl(
233 client: &reqwest::Client,
234 etag: Option<String>,
235 rendezvous_url: &Url,
236 ) -> Result<RendezvousGetResponse, HttpError> {
237 let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
238
239 if let Some(etag) = etag {
240 builder = builder.header(IF_NONE_MATCH, etag);
241 }
242
243 let response = builder.send().await?;
244
245 debug!("Received data from the rendezvous channel {response:?}");
246
247 let status_code = response.status();
248 let headers = response.headers();
249
250 let etag = get_header(headers, &ETAG)?;
251 let expires = get_header(headers, &EXPIRES)?;
252 let last_modified = get_header(headers, &LAST_MODIFIED)?;
253 let content_type = response
254 .headers()
255 .get(CONTENT_TYPE)
256 .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
257 .transpose()?
258 .map(ToOwned::to_owned);
259
260 let body = response.bytes().await?.to_vec();
261
262 let response =
263 RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
264
265 Ok(response)
266 }
267
268 async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
269 let etag = Some(self.etag.clone());
270
271 let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
272 Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
273
274 self.etag = etag;
276
277 let message = RendezvousMessage {
278 status_code,
279 body,
280 content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
281 };
282
283 Ok(message)
284 }
285}
286
287#[cfg(all(test, not(target_family = "wasm")))]
288mod test {
289 use matrix_sdk_test::async_test;
290 use serde_json::json;
291 use similar_asserts::assert_eq;
292 use wiremock::{
293 matchers::{header, method, path},
294 Mock, MockServer, ResponseTemplate,
295 };
296
297 use super::*;
298 use crate::config::RequestConfig;
299
300 async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
301 server
302 .register(
303 Mock::given(method("POST"))
304 .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
305 .respond_with(
306 ResponseTemplate::new(200)
307 .append_header("X-Max-Bytes", "10240")
308 .append_header("ETag", "1")
309 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
310 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
311 .set_body_json(json!({
312 "url": rendezvous_url,
313 })),
314 ),
315 )
316 .await;
317 }
318
319 #[async_test]
320 async fn test_creation() {
321 let server = MockServer::start().await;
322 let url =
323 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
324 let rendezvous_url =
325 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
326
327 mock_rendzvous_create(&server, &rendezvous_url).await;
328
329 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
330
331 let mut alice = RendezvousChannel::create_outbound(client, &url)
332 .await
333 .expect("We should be able to create an outbound rendezvous channel");
334
335 assert_eq!(
336 alice.rendezvous_url(),
337 &rendezvous_url,
338 "Alice should have configured the rendezvous URL correctly."
339 );
340
341 assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
342
343 let mut bob = {
344 let _scope = server
345 .register_as_scoped(
346 Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
347 ResponseTemplate::new(200)
348 .append_header("Content-Type", "text/plain")
349 .append_header("ETag", "2")
350 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
351 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
352 ),
353 )
354 .await;
355
356 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
357 let InboundChannelCreationResult { channel: bob, initial_message: _ } =
358 RendezvousChannel::create_inbound(client, &rendezvous_url).await.expect(
359 "We should be able to create a rendezvous channel from a received message",
360 );
361
362 assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
363
364 bob
365 };
366
367 assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
368
369 {
370 let _scope = server
371 .register_as_scoped(
372 Mock::given(method("GET"))
373 .and(path("/abcdEFG12345"))
374 .and(header("if-none-match", "1"))
375 .respond_with(
376 ResponseTemplate::new(304)
377 .append_header("ETag", "1")
378 .append_header("Content-Type", "text/plain")
379 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
380 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
381 ),
382 )
383 .await;
384
385 let response = alice
386 .receive_single_message()
387 .await
388 .expect("We should be able to wait for data on the rendezvous channel.");
389 assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
390 }
391
392 {
393 let _scope = server
394 .register_as_scoped(
395 Mock::given(method("PUT"))
396 .and(path("/abcdEFG12345"))
397 .and(header("Content-Type", "text/plain"))
398 .respond_with(
399 ResponseTemplate::new(200)
400 .append_header("ETag", "1")
401 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
402 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
403 ),
404 )
405 .await;
406
407 bob.send(b"Hello world".to_vec())
408 .await
409 .expect("We should be able to send data to the rendezouvs server.");
410 }
411
412 {
413 let _scope = server
414 .register_as_scoped(
415 Mock::given(method("GET"))
416 .and(path("/abcdEFG12345"))
417 .and(header("if-none-match", "1"))
418 .respond_with(
419 ResponseTemplate::new(200)
420 .append_header("ETag", "3")
421 .append_header("Content-Type", "text/plain")
422 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
423 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
424 .set_body_string("Hello world"),
425 ),
426 )
427 .await;
428
429 let response = alice
430 .receive_single_message()
431 .await
432 .expect("We should be able to wait and get data on the rendezvous channel.");
433
434 assert_eq!(response.status_code, StatusCode::OK);
435 assert_eq!(response.body, b"Hello world");
436 assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
437 }
438 }
439
440 #[async_test]
441 async fn test_retry_mechanism() {
442 let server = MockServer::start().await;
443 let url =
444 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
445 let rendezvous_url =
446 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
447 mock_rendzvous_create(&server, &rendezvous_url).await;
448
449 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
450
451 let mut alice = RendezvousChannel::create_outbound(client, &url)
452 .await
453 .expect("We should be able to create an outbound rendezvous channel");
454
455 server
456 .register(
457 Mock::given(method("GET"))
458 .and(path("/abcdEFG12345"))
459 .and(header("if-none-match", "1"))
460 .respond_with(
461 ResponseTemplate::new(304)
462 .append_header("ETag", "2")
463 .append_header("Content-Type", "text/plain")
464 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
465 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
466 .set_body_string(""),
467 )
468 .expect(1),
469 )
470 .await;
471
472 server
473 .register(
474 Mock::given(method("GET"))
475 .and(path("/abcdEFG12345"))
476 .and(header("if-none-match", "2"))
477 .respond_with(
478 ResponseTemplate::new(200)
479 .append_header("ETag", "3")
480 .append_header("Content-Type", "text/plain")
481 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
482 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
483 .set_body_string("Hello world"),
484 )
485 .expect(1),
486 )
487 .await;
488
489 let response = alice
490 .receive()
491 .await
492 .expect("We should be able to wait and get data on the rendezvous channel.");
493
494 assert_eq!(response, b"Hello world");
495 }
496
497 #[async_test]
498 async fn test_receive_error() {
499 let server = MockServer::start().await;
500 let url =
501 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
502 let rendezvous_url =
503 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
504 mock_rendzvous_create(&server, &rendezvous_url).await;
505
506 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
507
508 let mut alice = RendezvousChannel::create_outbound(client, &url)
509 .await
510 .expect("We should be able to create an outbound rendezvous channel");
511
512 {
513 let _scope = server
514 .register_as_scoped(
515 Mock::given(method("GET"))
516 .and(path("/abcdEFG12345"))
517 .and(header("if-none-match", "1"))
518 .respond_with(
519 ResponseTemplate::new(404)
520 .append_header("ETag", "1")
521 .append_header("Content-Type", "text/plain")
522 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
523 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
524 .set_body_string(""),
525 )
526 .expect(1),
527 )
528 .await;
529
530 alice.receive().await.expect_err("We should return an error if we receive a 404");
531 }
532
533 {
534 let _scope = server
535 .register_as_scoped(
536 Mock::given(method("GET"))
537 .and(path("/abcdEFG12345"))
538 .and(header("if-none-match", "1"))
539 .respond_with(
540 ResponseTemplate::new(504)
541 .append_header("ETag", "1")
542 .append_header("Content-Type", "text/plain")
543 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
544 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
545 .set_body_json(json!({
546 "errcode": "M_NOT_FOUND",
547 "error": "No resource was found for this request.",
548 })),
549 )
550 .expect(1),
551 )
552 .await;
553
554 alice
555 .receive()
556 .await
557 .expect_err("We should return an error if we receive a gateway timeout");
558 }
559 }
560}