matrix_sdk/authentication/oauth/qrcode/
secure_channel.rs1use matrix_sdk_base::crypto::types::qr_login::{QrCodeData, QrCodeMode, QrCodeModeData};
16use serde::{Serialize, de::DeserializeOwned};
17use tracing::{instrument, trace};
18use url::Url;
19use vodozemac::ecies::{
20 CheckCode, Ecies, EstablishedEcies, InboundCreationResult, InitialMessage, Message,
21 OutboundCreationResult,
22};
23
24use super::{
25 SecureChannelError as Error,
26 rendezvous_channel::{InboundChannelCreationResult, RendezvousChannel},
27};
28use crate::{config::RequestConfig, http_client::HttpClient};
29
30const LOGIN_INITIATE_MESSAGE: &str = "MATRIX_QR_CODE_LOGIN_INITIATE";
31const LOGIN_OK_MESSAGE: &str = "MATRIX_QR_CODE_LOGIN_OK";
32
33pub(super) struct SecureChannel {
34 channel: RendezvousChannel,
35 qr_code_data: QrCodeData,
36 ecies: Ecies,
37}
38
39impl SecureChannel {
40 pub(super) async fn login(
42 http_client: HttpClient,
43 homeserver_url: &Url,
44 ) -> Result<Self, Error> {
45 let channel = RendezvousChannel::create_outbound(http_client, homeserver_url).await?;
46 let rendezvous_url = channel.rendezvous_url().to_owned();
47 let mode_data = QrCodeModeData::Login;
48
49 let ecies = Ecies::new();
50 let public_key = ecies.public_key();
51
52 let qr_code_data = QrCodeData { public_key, rendezvous_url, mode_data };
53
54 Ok(Self { channel, qr_code_data, ecies })
55 }
56
57 pub(super) async fn reciprocate(
59 http_client: HttpClient,
60 homeserver_url: &Url,
61 ) -> Result<Self, Error> {
62 let mut channel = SecureChannel::login(http_client, homeserver_url).await?;
63 channel.qr_code_data.mode_data =
64 QrCodeModeData::Reciprocate { server_name: homeserver_url.to_string() };
65 Ok(channel)
66 }
67
68 pub(super) fn qr_code_data(&self) -> &QrCodeData {
69 &self.qr_code_data
70 }
71
72 #[instrument(skip(self))]
73 pub(super) async fn connect(mut self) -> Result<AlmostEstablishedSecureChannel, Error> {
74 trace!("Trying to connect the secure channel.");
75
76 let message = self.channel.receive().await?;
77 let message = std::str::from_utf8(&message)?;
78 let message = InitialMessage::decode(message)?;
79
80 let InboundCreationResult { ecies, message } =
81 self.ecies.establish_inbound_channel(&message)?;
82 let message = std::str::from_utf8(&message)?;
83
84 trace!("Received the initial secure channel message");
85
86 if message == LOGIN_INITIATE_MESSAGE {
87 let mut secure_channel = EstablishedSecureChannel { channel: self.channel, ecies };
88
89 trace!("Sending the LOGIN OK message");
90
91 secure_channel.send(LOGIN_OK_MESSAGE).await?;
92
93 Ok(AlmostEstablishedSecureChannel { secure_channel })
94 } else {
95 Err(Error::SecureChannelMessage {
96 expected: LOGIN_INITIATE_MESSAGE,
97 received: message.to_owned(),
98 })
99 }
100 }
101}
102
103pub(super) struct AlmostEstablishedSecureChannel {
106 secure_channel: EstablishedSecureChannel,
107}
108
109impl AlmostEstablishedSecureChannel {
110 pub(super) fn confirm(self, check_code: u8) -> Result<EstablishedSecureChannel, Error> {
115 if check_code == self.secure_channel.check_code().to_digit() {
116 Ok(self.secure_channel)
117 } else {
118 Err(Error::InvalidCheckCode)
119 }
120 }
121}
122
123pub(super) struct EstablishedSecureChannel {
124 channel: RendezvousChannel,
125 ecies: EstablishedEcies,
126}
127
128impl EstablishedSecureChannel {
129 #[instrument(skip(client))]
131 pub(super) async fn from_qr_code(
132 client: reqwest::Client,
133 qr_code_data: &QrCodeData,
134 expected_mode: QrCodeMode,
135 ) -> Result<Self, Error> {
136 if qr_code_data.mode() == expected_mode {
137 Err(Error::InvalidIntent)
138 } else {
139 trace!("Attempting to create a new inbound secure channel from a QR code.");
140
141 let client = HttpClient::new(client, RequestConfig::short_retry());
142 let ecies = Ecies::new();
143
144 let OutboundCreationResult { ecies, message } = ecies.establish_outbound_channel(
149 qr_code_data.public_key,
150 LOGIN_INITIATE_MESSAGE.as_bytes(),
151 )?;
152
153 let InboundChannelCreationResult { mut channel, .. } =
158 RendezvousChannel::create_inbound(client, &qr_code_data.rendezvous_url).await?;
159
160 trace!(
161 "Received the initial message from the rendezvous channel, sending the LOGIN \
162 INITIATE message"
163 );
164
165 let encoded_message = message.encode().as_bytes().to_vec();
168 channel.send(encoded_message).await?;
169
170 trace!("Waiting for the LOGIN OK message");
171
172 let mut ret = Self { channel, ecies };
175 let response = ret.receive().await?;
176
177 trace!("Received the LOGIN OK message, maybe.");
178
179 if response == LOGIN_OK_MESSAGE {
180 Ok(ret)
181 } else {
182 Err(Error::SecureChannelMessage { expected: LOGIN_OK_MESSAGE, received: response })
183 }
184 }
185 }
186
187 pub(super) fn check_code(&self) -> &CheckCode {
191 self.ecies.check_code()
192 }
193
194 pub(super) async fn send_json(&mut self, message: impl Serialize) -> Result<(), Error> {
199 let message = serde_json::to_string(&message)?;
200 self.send(&message).await
201 }
202
203 pub(super) async fn receive_json<D: DeserializeOwned>(&mut self) -> Result<D, Error> {
208 let message = self.receive().await?;
209 Ok(serde_json::from_str(&message)?)
210 }
211
212 async fn send(&mut self, message: &str) -> Result<(), Error> {
213 let message = self.ecies.encrypt(message.as_bytes());
214 let message = message.encode();
215
216 Ok(self.channel.send(message.as_bytes().to_vec()).await?)
217 }
218
219 async fn receive(&mut self) -> Result<String, Error> {
220 let message = self.channel.receive().await?;
221 let ciphertext = std::str::from_utf8(&message)?;
222 let message = Message::decode(ciphertext)?;
223
224 let decrypted = self.ecies.decrypt(&message)?;
225
226 Ok(String::from_utf8(decrypted).map_err(|e| e.utf8_error())?)
227 }
228}
229
230#[cfg(all(test, not(target_family = "wasm")))]
231pub(super) mod test {
232 use std::{
233 sync::{
234 Arc, Mutex,
235 atomic::{AtomicU8, Ordering},
236 },
237 time::Duration,
238 };
239
240 use matrix_sdk_base::crypto::types::qr_login::QrCodeMode;
241 use matrix_sdk_common::executor::spawn;
242 use matrix_sdk_test::async_test;
243 use ruma::time::Instant;
244 use serde_json::json;
245 use similar_asserts::assert_eq;
246 use url::Url;
247 use wiremock::{
248 Mock, MockGuard, MockServer, ResponseTemplate,
249 matchers::{method, path},
250 };
251
252 use super::{EstablishedSecureChannel, SecureChannel};
253 use crate::http_client::HttpClient;
254
255 #[allow(dead_code)]
256 pub struct MockedRendezvousServer {
257 pub homeserver_url: Url,
258 pub rendezvous_url: Url,
259 expiration: Duration,
260 content: Arc<Mutex<Option<String>>>,
261 created: Arc<Mutex<Option<Instant>>>,
262 etag: Arc<AtomicU8>,
263 post_guard: MockGuard,
264 put_guard: MockGuard,
265 get_guard: MockGuard,
266 }
267
268 impl MockedRendezvousServer {
269 pub async fn new(server: &MockServer, location: &str, expiration: Duration) -> Self {
270 let content: Arc<Mutex<Option<String>>> = Mutex::default().into();
271 let created: Arc<Mutex<Option<Instant>>> = Mutex::default().into();
272 let etag = Arc::new(AtomicU8::new(0));
273
274 let homeserver_url = Url::parse(&server.uri())
275 .expect("We should be able to parse the example homeserver");
276
277 let rendezvous_url = homeserver_url
278 .join(location)
279 .expect("We should be able to create a rendezvous URL");
280
281 let post_guard = server
282 .register_as_scoped(
283 Mock::given(method("POST"))
284 .and(path("/_matrix/client/unstable/org.matrix.msc4108/rendezvous"))
285 .respond_with({
286 *created.lock().unwrap() = Some(Instant::now());
287
288 ResponseTemplate::new(200)
289 .append_header("X-Max-Bytes", "10240")
290 .append_header("ETag", "1")
291 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
292 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
293 .set_body_json(json!({
294 "url": rendezvous_url,
295 }))
296 }),
297 )
298 .await;
299
300 let put_guard = server
301 .register_as_scoped(
302 Mock::given(method("PUT")).and(path("/abcdEFG12345")).respond_with({
303 let content = content.clone();
304 let created = created.clone();
305 let etag = etag.clone();
306
307 move |request: &wiremock::Request| {
308 if created.lock().unwrap().unwrap().elapsed() > expiration {
310 return ResponseTemplate::new(404).set_body_json(json!({
311 "errcode": "M_NOT_FOUND",
312 "error": "This rendezvous session does not exist.",
313 }));
314 }
315
316 *content.lock().unwrap() =
317 Some(String::from_utf8(request.body.clone()).unwrap());
318 let current_etag = etag.fetch_add(1, Ordering::SeqCst);
319
320 ResponseTemplate::new(200)
321 .append_header("ETag", (current_etag + 2).to_string())
322 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
323 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
324 }
325 }),
326 )
327 .await;
328
329 let get_guard = server
330 .register_as_scoped(
331 Mock::given(method("GET")).and(path("/abcdEFG12345")).respond_with({
332 let content = content.clone();
333 let created = created.clone();
334 let etag = etag.clone();
335
336 move |request: &wiremock::Request| {
337 if created.lock().unwrap().unwrap().elapsed() > expiration {
339 return ResponseTemplate::new(404).set_body_json(json!({
340 "errcode": "M_NOT_FOUND",
341 "error": "This rendezvous session does not exist.",
342 }));
343 }
344
345 let requested_etag = request.headers.get("if-none-match").map(|etag| {
346 str::parse::<u8>(std::str::from_utf8(etag.as_bytes()).unwrap())
347 .unwrap()
348 });
349
350 let mut content = content.lock().unwrap();
351 let current_etag = etag.load(Ordering::SeqCst);
352
353 if requested_etag == Some(current_etag) || requested_etag.is_none() {
354 let content = content.take();
355
356 ResponseTemplate::new(200)
357 .append_header("ETag", (current_etag).to_string())
358 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
359 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
360 .set_body_string(content.unwrap_or_default())
361 } else {
362 let etag = requested_etag.unwrap_or_default();
363
364 ResponseTemplate::new(304)
365 .append_header("ETag", etag.to_string())
366 .append_header("Expires", "Wed, 07 Sep 2022 14:28:51 GMT")
367 .append_header("Last-Modified", "Wed, 07 Sep 2022 14:27:51 GMT")
368 }
369 }
370 }),
371 )
372 .await;
373
374 Self {
375 expiration,
376 content,
377 created,
378 etag,
379 post_guard,
380 put_guard,
381 get_guard,
382 homeserver_url,
383 rendezvous_url,
384 }
385 }
386 }
387
388 #[async_test]
389 async fn test_creation() {
390 let server = MockServer::start().await;
391 let rendezvous_server =
392 MockedRendezvousServer::new(&server, "abcdEFG12345", Duration::MAX).await;
393
394 let client = HttpClient::new(reqwest::Client::new(), Default::default());
395 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
396 .await
397 .expect("Alice should be able to create a secure channel.");
398
399 let qr_code_data = alice.qr_code_data().clone();
400
401 let bob_task = spawn(async move {
402 EstablishedSecureChannel::from_qr_code(
403 reqwest::Client::new(),
404 &qr_code_data,
405 QrCodeMode::Login,
406 )
407 .await
408 .expect("Bob should be able to fully establish the secure channel.")
409 });
410
411 let alice_task = spawn(async move {
412 alice
413 .connect()
414 .await
415 .expect("Alice should be able to connect the established secure channel")
416 });
417
418 let bob = bob_task.await.unwrap();
419 let alice = alice_task.await.unwrap();
420
421 assert_eq!(alice.secure_channel.check_code(), bob.check_code());
422
423 let alice = alice
424 .confirm(bob.check_code().to_digit())
425 .expect("Alice should be able to confirm the established secure channel.");
426
427 assert_eq!(bob.channel.rendezvous_url(), alice.channel.rendezvous_url());
428 }
429}