1use std::time::Duration;
16
17use http::{
18 HeaderMap, HeaderName, Method, StatusCode,
19 header::{CONTENT_TYPE, ETAG, EXPIRES, IF_MATCH, IF_NONE_MATCH, LAST_MODIFIED},
20};
21use matrix_sdk_base::sleep;
22use ruma::api::{
23 EndpointError,
24 error::{FromHttpResponseError, HeaderDeserializationError, IntoHttpError},
25};
26use tracing::{debug, instrument, trace};
27use url::Url;
28
29use crate::{HttpError, RumaApiError, http_client::HttpClient};
30
31const TEXT_PLAIN_CONTENT_TYPE: &str = "text/plain";
32#[cfg(test)]
33const POLL_TIMEOUT: Duration = Duration::from_millis(10);
34#[cfg(not(test))]
35const POLL_TIMEOUT: Duration = Duration::from_secs(1);
36
37type Etag = String;
38
39fn get_header(
41 header_map: &HeaderMap,
42 header_name: &HeaderName,
43) -> Result<String, Box<FromHttpResponseError<RumaApiError>>> {
44 let header = header_map
45 .get(header_name)
46 .ok_or(HeaderDeserializationError::MissingHeader(ETAG.to_string()))
47 .map_err(|error| Box::new(FromHttpResponseError::from(error)))?;
48
49 let header =
50 header.to_str().map_err(|error| Box::new(FromHttpResponseError::from(error)))?.to_owned();
51
52 Ok(header)
53}
54
55fn response_to_error(status: StatusCode, body: Vec<u8>) -> HttpError {
56 match http::Response::builder().status(status).body(body).map_err(IntoHttpError::from) {
57 Ok(response) => {
58 let error = FromHttpResponseError::<RumaApiError>::Server(RumaApiError::ClientApi(
59 ruma::api::error::Error::from_http_response(response),
60 ));
61
62 error.into()
63 }
64 Err(e) => HttpError::IntoHttp(e),
65 }
66}
67
68pub(super) struct InboundChannelCreationResult {
70 pub channel: Channel,
72 #[allow(dead_code)]
77 pub initial_message: Vec<u8>,
78}
79
80pub(super) struct RendezvousGetResponse {
81 pub status_code: StatusCode,
82 pub etag: String,
83 #[allow(dead_code)]
84 pub expires: String,
85 #[allow(dead_code)]
86 pub last_modified: String,
87 pub content_type: Option<String>,
88 pub body: Vec<u8>,
89}
90
91pub(super) struct RendezvousMessage {
92 pub status_code: StatusCode,
93 pub body: Vec<u8>,
94 pub content_type: String,
95}
96
97pub struct Channel {
98 client: HttpClient,
99 rendezvous_url: Url,
100 etag: Etag,
101}
102
103impl Channel {
104 pub(super) async fn create_outbound(
110 client: HttpClient,
111 rendezvous_server: &Url,
112 ) -> Result<Self, HttpError> {
113 use std::borrow::Cow;
114
115 use ruma::api::{SupportedVersions, client::rendezvous::create_rendezvous_session};
116
117 let request = create_rendezvous_session::unstable_msc4108::Request::default();
118 let response = client
119 .send(
120 request,
121 None,
122 rendezvous_server.to_string(),
123 None,
124 Cow::Owned(SupportedVersions {
125 versions: Default::default(),
126 features: Default::default(),
127 }),
128 Default::default(),
129 )
130 .await?;
131
132 let rendezvous_url = response.url;
133 let etag = response.etag;
134
135 Ok(Self { client, rendezvous_url, etag })
136 }
137
138 pub(super) async fn create_inbound(
143 client: HttpClient,
144 rendezvous_url: &Url,
145 ) -> Result<InboundChannelCreationResult, HttpError> {
146 let response = Self::receive_message_impl(&client.inner, None, rendezvous_url).await?;
149
150 let etag = response.etag.clone();
151
152 let initial_message = RendezvousMessage {
153 status_code: response.status_code,
154 body: response.body,
155 content_type: response.content_type.unwrap_or_else(|| "text/plain".to_owned()),
156 };
157
158 let channel = Self { client, rendezvous_url: rendezvous_url.clone(), etag };
159
160 Ok(InboundChannelCreationResult { channel, initial_message: initial_message.body })
161 }
162
163 pub(super) fn rendezvous_url(&self) -> &Url {
166 &self.rendezvous_url
167 }
168
169 #[instrument(skip_all)]
174 pub(super) async fn send(&mut self, message: Vec<u8>) -> Result<(), HttpError> {
175 let etag = self.etag.clone();
176
177 let request = self
178 .client
179 .inner
180 .request(Method::PUT, self.rendezvous_url().to_owned())
181 .body(message)
182 .header(IF_MATCH, etag)
183 .header(CONTENT_TYPE, TEXT_PLAIN_CONTENT_TYPE);
184
185 debug!("Sending a request to the rendezvous channel {request:?}");
186
187 let response = request.send().await?;
188 let status = response.status();
189
190 debug!("Response for the rendezvous sending request {response:?}");
191
192 if status.is_success() {
193 let etag = get_header(response.headers(), &ETAG)?;
196 self.etag = etag;
197
198 Ok(())
199 } else {
200 let body = response.bytes().await?;
201 let error = response_to_error(status, body.to_vec());
202
203 return Err(error);
204 }
205 }
206
207 pub(super) async fn receive(&mut self) -> Result<Vec<u8>, HttpError> {
216 loop {
217 let message = self.receive_single_message().await?;
218
219 trace!(
220 status_code = %message.status_code,
221 "Received data from the rendezvous channel"
222 );
223
224 if message.status_code == StatusCode::OK
225 && message.content_type == TEXT_PLAIN_CONTENT_TYPE
226 && !message.body.is_empty()
227 {
228 return Ok(message.body);
229 } else if message.status_code == StatusCode::NOT_MODIFIED {
230 sleep::sleep(POLL_TIMEOUT).await;
231 continue;
232 } else {
233 let error = response_to_error(message.status_code, message.body);
234
235 return Err(error);
236 }
237 }
238 }
239
240 #[instrument]
241 async fn receive_message_impl(
242 client: &reqwest::Client,
243 etag: Option<String>,
244 rendezvous_url: &Url,
245 ) -> Result<RendezvousGetResponse, HttpError> {
246 let mut builder = client.request(Method::GET, rendezvous_url.to_owned());
247
248 if let Some(etag) = etag {
249 builder = builder.header(IF_NONE_MATCH, etag);
250 }
251
252 let response = builder.send().await?;
253
254 debug!("Received data from the rendezvous channel {response:?}");
255
256 let status_code = response.status();
257
258 if status_code.is_client_error() {
259 return Err(response_to_error(status_code, response.bytes().await?.to_vec()));
260 }
261
262 let headers = response.headers();
263 let etag = get_header(headers, &ETAG)?;
264 let expires = get_header(headers, &EXPIRES)?;
265 let last_modified = get_header(headers, &LAST_MODIFIED)?;
266 #[allow(clippy::result_large_err)]
267 let content_type = headers
268 .get(CONTENT_TYPE)
269 .map(|c| c.to_str().map_err(FromHttpResponseError::<RumaApiError>::from))
270 .transpose()?
271 .map(ToOwned::to_owned);
272
273 let body = response.bytes().await?.to_vec();
274
275 let response =
276 RendezvousGetResponse { status_code, etag, expires, last_modified, content_type, body };
277
278 Ok(response)
279 }
280
281 async fn receive_single_message(&mut self) -> Result<RendezvousMessage, HttpError> {
282 let etag = Some(self.etag.clone());
283
284 let RendezvousGetResponse { status_code, etag, content_type, body, .. } =
285 Self::receive_message_impl(&self.client.inner, etag, &self.rendezvous_url).await?;
286
287 self.etag = etag;
289
290 let message = RendezvousMessage {
291 status_code,
292 body,
293 content_type: content_type.unwrap_or_else(|| "text/plain".to_owned()),
294 };
295
296 Ok(message)
297 }
298}
299
300#[cfg(all(test, not(target_family = "wasm")))]
301mod test {
302 use matrix_sdk_test::async_test;
303 use serde_json::json;
304 use similar_asserts::assert_eq;
305 use wiremock::{
306 Mock, MockServer, ResponseTemplate,
307 matchers::{header, method, path},
308 };
309
310 use super::*;
311 use crate::config::RequestConfig;
312
313 async fn mock_rendzvous_create(server: &MockServer, rendezvous_url: &Url) {
314 server
315 .register(
316 Mock::given(method("POST"))
317 .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
318 .respond_with(
319 ResponseTemplate::new(200)
320 .append_header("X-Max-Bytes", "10240")
321 .append_header("ETag", "1")
322 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
323 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
324 .set_body_json(json!({
325 "url": rendezvous_url,
326 })),
327 ),
328 )
329 .await;
330 }
331
332 #[async_test]
333 async fn test_creation() {
334 let server = MockServer::start().await;
335 let url =
336 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
337 let rendezvous_url =
338 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
339
340 mock_rendzvous_create(&server, &rendezvous_url).await;
341
342 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
343
344 let mut alice = Channel::create_outbound(client, &url)
345 .await
346 .expect("We should be able to create an outbound rendezvous channel");
347
348 assert_eq!(
349 alice.rendezvous_url(),
350 &rendezvous_url,
351 "Alice should have configured the rendezvous URL correctly."
352 );
353
354 assert_eq!(alice.etag, "1", "Alice should have remembered the ETAG the server gave us.");
355
356 let mut bob = {
357 let _scope = server
358 .register_as_scoped(
359 Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with(
360 ResponseTemplate::new(200)
361 .append_header("Content-Type", "text/plain")
362 .append_header("ETag", "2")
363 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
364 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
365 ),
366 )
367 .await;
368
369 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::short_retry());
370 let InboundChannelCreationResult { channel: bob, initial_message: _ } =
371 Channel::create_inbound(client, &rendezvous_url).await.expect(
372 "We should be able to create a rendezvous channel from a received message",
373 );
374
375 assert_eq!(alice.rendezvous_url(), bob.rendezvous_url());
376
377 bob
378 };
379
380 assert_eq!(bob.etag, "2", "Bob should have remembered the ETAG the server gave us.");
381
382 {
383 let _scope = server
384 .register_as_scoped(
385 Mock::given(method("GET"))
386 .and(path("/abcdEFG12345"))
387 .and(header("if-none-match", "1"))
388 .respond_with(
389 ResponseTemplate::new(304)
390 .append_header("ETag", "1")
391 .append_header("Content-Type", "text/plain")
392 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
393 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
394 ),
395 )
396 .await;
397
398 let response = alice
399 .receive_single_message()
400 .await
401 .expect("We should be able to wait for data on the rendezvous channel.");
402 assert_eq!(response.status_code, StatusCode::NOT_MODIFIED);
403 }
404
405 {
406 let _scope = server
407 .register_as_scoped(
408 Mock::given(method("PUT"))
409 .and(path("/abcdEFG12345"))
410 .and(header("Content-Type", "text/plain"))
411 .respond_with(
412 ResponseTemplate::new(200)
413 .append_header("ETag", "1")
414 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
415 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT"),
416 ),
417 )
418 .await;
419
420 bob.send(b"Hello world".to_vec())
421 .await
422 .expect("We should be able to send data to the rendezouvs server.");
423 }
424
425 {
426 let _scope = server
427 .register_as_scoped(
428 Mock::given(method("GET"))
429 .and(path("/abcdEFG12345"))
430 .and(header("if-none-match", "1"))
431 .respond_with(
432 ResponseTemplate::new(200)
433 .append_header("ETag", "3")
434 .append_header("Content-Type", "text/plain")
435 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
436 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
437 .set_body_string("Hello world"),
438 ),
439 )
440 .await;
441
442 let response = alice
443 .receive_single_message()
444 .await
445 .expect("We should be able to wait and get data on the rendezvous channel.");
446
447 assert_eq!(response.status_code, StatusCode::OK);
448 assert_eq!(response.body, b"Hello world");
449 assert_eq!(response.content_type, TEXT_PLAIN_CONTENT_TYPE);
450 }
451 }
452
453 #[async_test]
454 async fn test_retry_mechanism() {
455 let server = MockServer::start().await;
456 let url =
457 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
458 let rendezvous_url =
459 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
460 mock_rendzvous_create(&server, &rendezvous_url).await;
461
462 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
463
464 let mut alice = Channel::create_outbound(client, &url)
465 .await
466 .expect("We should be able to create an outbound rendezvous channel");
467
468 server
469 .register(
470 Mock::given(method("GET"))
471 .and(path("/abcdEFG12345"))
472 .and(header("if-none-match", "1"))
473 .respond_with(
474 ResponseTemplate::new(304)
475 .append_header("ETag", "2")
476 .append_header("Content-Type", "text/plain")
477 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
478 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
479 .set_body_string(""),
480 )
481 .expect(1),
482 )
483 .await;
484
485 server
486 .register(
487 Mock::given(method("GET"))
488 .and(path("/abcdEFG12345"))
489 .and(header("if-none-match", "2"))
490 .respond_with(
491 ResponseTemplate::new(200)
492 .append_header("ETag", "3")
493 .append_header("Content-Type", "text/plain")
494 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
495 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
496 .set_body_string("Hello world"),
497 )
498 .expect(1),
499 )
500 .await;
501
502 let response = alice
503 .receive()
504 .await
505 .expect("We should be able to wait and get data on the rendezvous channel.");
506
507 assert_eq!(response, b"Hello world");
508 }
509
510 #[async_test]
511 async fn test_receive_error() {
512 let server = MockServer::start().await;
513 let url =
514 Url::parse(&server.uri()).expect("We should be able to parse the example homeserver");
515 let rendezvous_url =
516 url.join("abcdEFG12345").expect("We should be able to create a rendezvous URL");
517 mock_rendzvous_create(&server, &rendezvous_url).await;
518
519 let client = HttpClient::new(reqwest::Client::new(), RequestConfig::new().disable_retry());
520
521 let mut alice = Channel::create_outbound(client, &url)
522 .await
523 .expect("We should be able to create an outbound rendezvous channel");
524
525 {
526 let _scope = server
527 .register_as_scoped(
528 Mock::given(method("GET"))
529 .and(path("/abcdEFG12345"))
530 .and(header("if-none-match", "1"))
531 .respond_with(
532 ResponseTemplate::new(404)
533 .append_header("ETag", "1")
534 .append_header("Content-Type", "text/plain")
535 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
536 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
537 .set_body_string(""),
538 )
539 .expect(1),
540 )
541 .await;
542
543 alice.receive().await.expect_err("We should return an error if we receive a 404");
544 }
545
546 {
547 let _scope = server
548 .register_as_scoped(
549 Mock::given(method("GET"))
550 .and(path("/abcdEFG12345"))
551 .and(header("if-none-match", "1"))
552 .respond_with(
553 ResponseTemplate::new(504)
554 .append_header("ETag", "1")
555 .append_header("Content-Type", "text/plain")
556 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
557 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
558 .set_body_json(json!({
559 "errcode": "M_NOT_FOUND",
560 "error": "No resource was found for this request.",
561 })),
562 )
563 .expect(1),
564 )
565 .await;
566
567 alice
568 .receive()
569 .await
570 .expect_err("We should return an error if we receive a gateway timeout");
571 }
572 }
573}