matrix_sdk/authentication/oauth/qrcode/
login.rs1use std::future::IntoFuture;
16
17use eyeball::SharedObservable;
18use futures_core::Stream;
19use matrix_sdk_base::{
20 boxed_into_future,
21 crypto::types::qr_login::{QrCodeData, QrCodeMode},
22 store::RoomLoadSettings,
23 SessionMeta,
24};
25use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
26use ruma::{
27 api::client::discovery::get_authorization_server_metadata::msc2965::AuthorizationServerMetadata,
28 OwnedDeviceId,
29};
30use tracing::trace;
31use vodozemac::{ecies::CheckCode, Curve25519PublicKey};
32
33use super::{
34 messages::{LoginFailureReason, QrAuthMessage},
35 secure_channel::EstablishedSecureChannel,
36 DeviceAuthorizationOAuthError, QRCodeLoginError, SecureChannelError,
37};
38#[cfg(doc)]
39use crate::authentication::oauth::OAuth;
40use crate::{
41 authentication::oauth::{ClientRegistrationData, OAuthError},
42 Client,
43};
44
45async fn send_unexpected_message_error(
46 channel: &mut EstablishedSecureChannel,
47) -> Result<(), SecureChannelError> {
48 channel
49 .send_json(QrAuthMessage::LoginFailure {
50 reason: LoginFailureReason::UnexpectedMessageReceived,
51 homeserver: None,
52 })
53 .await
54}
55
56#[derive(Clone, Debug, Default)]
58pub enum LoginProgress {
59 #[default]
61 Starting,
62 EstablishingSecureChannel {
66 check_code: CheckCode,
68 },
69 WaitingForToken {
73 user_code: String,
77 },
78 Done,
80}
81
82#[derive(Debug)]
84pub struct LoginWithQrCode<'a> {
85 client: &'a Client,
86 registration_data: Option<&'a ClientRegistrationData>,
87 qr_code_data: &'a QrCodeData,
88 state: SharedObservable<LoginProgress>,
89}
90
91impl LoginWithQrCode<'_> {
92 pub fn subscribe_to_progress(&self) -> impl Stream<Item = LoginProgress> {
98 self.state.subscribe()
99 }
100}
101
102impl<'a> IntoFuture for LoginWithQrCode<'a> {
103 type Output = Result<(), QRCodeLoginError>;
104 boxed_into_future!(extra_bounds: 'a);
105
106 fn into_future(self) -> Self::IntoFuture {
107 Box::pin(async move {
108 let mut channel = self.establish_secure_channel().await?;
112
113 trace!("Established the secure channel.");
114
115 let check_code = channel.check_code().to_owned();
118 self.state.set(LoginProgress::EstablishingSecureChannel { check_code });
119
120 trace!("Registering the client with the OAuth 2.0 authorization server.");
122 let server_metadata = self.register_client().await?;
123
124 let account = vodozemac::olm::Account::new();
127 let public_key = account.identity_keys().curve25519;
128 let device_id = public_key;
129
130 trace!("Requesting device authorization.");
133 let auth_grant_response =
134 self.request_device_authorization(&server_metadata, device_id).await?;
135
136 trace!("Letting the existing device know about the device authorization grant.");
139 let message = QrAuthMessage::authorization_grant_login_protocol(
140 (&auth_grant_response).into(),
141 device_id,
142 );
143 channel.send_json(&message).await?;
144
145 match channel.receive_json().await? {
147 QrAuthMessage::LoginProtocolAccepted => (),
148 QrAuthMessage::LoginFailure { reason, homeserver } => {
149 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
150 }
151 message => {
152 send_unexpected_message_error(&mut channel).await?;
153
154 return Err(QRCodeLoginError::UnexpectedMessage {
155 expected: "m.login.protocol_accepted",
156 received: message,
157 });
158 }
159 }
160
161 let user_code = auth_grant_response.user_code();
165 self.state
166 .set(LoginProgress::WaitingForToken { user_code: user_code.secret().to_owned() });
167
168 trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
171 if let Err(e) = self.wait_for_tokens(&server_metadata, &auth_grant_response).await {
172 if let Some(e) = e.as_request_token_error() {
175 match e {
176 DeviceCodeErrorResponseType::AccessDenied => {
177 channel.send_json(QrAuthMessage::LoginDeclined).await?;
178 }
179 DeviceCodeErrorResponseType::ExpiredToken => {
180 channel
181 .send_json(QrAuthMessage::LoginFailure {
182 reason: LoginFailureReason::AuthorizationExpired,
183 homeserver: None,
184 })
185 .await?;
186 }
187 _ => (),
188 }
189 }
190
191 return Err(e.into());
192 };
193
194 trace!("Discovering our own user id.");
200 let whoami_response =
201 self.client.whoami().await.map_err(QRCodeLoginError::UserIdDiscovery)?;
202 self.client
203 .base_client()
204 .activate(
205 SessionMeta {
206 user_id: whoami_response.user_id,
207 device_id: OwnedDeviceId::from(device_id.to_base64()),
208 },
209 RoomLoadSettings::default(),
210 Some(account),
211 )
212 .await
213 .map_err(|error| QRCodeLoginError::SessionTokens(error.into()))?;
214
215 self.client.oauth().enable_cross_process_lock().await?;
216
217 trace!("Telling the existing device that we successfully logged in.");
219 let message = QrAuthMessage::LoginSuccess;
220 channel.send_json(&message).await?;
221
222 trace!("Waiting for the secrets bundle.");
225 let bundle = match channel.receive_json().await? {
226 QrAuthMessage::LoginSecrets(bundle) => bundle,
227 QrAuthMessage::LoginFailure { reason, homeserver } => {
228 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
229 }
230 message => {
231 send_unexpected_message_error(&mut channel).await?;
232
233 return Err(QRCodeLoginError::UnexpectedMessage {
234 expected: "m.login.secrets",
235 received: message,
236 });
237 }
238 };
239
240 self.client.encryption().import_secrets_bundle(&bundle).await?;
243
244 self.client
247 .encryption()
248 .ensure_device_keys_upload()
249 .await
250 .map_err(QRCodeLoginError::DeviceKeyUpload)?;
251
252 self.client.encryption().spawn_initialization_task(None);
257 self.client.encryption().wait_for_e2ee_initialization_tasks().await;
258
259 trace!("successfully logged in and enabled E2EE.");
260
261 self.state.set(LoginProgress::Done);
263
264 Ok(())
266 })
267 }
268}
269
270impl<'a> LoginWithQrCode<'a> {
271 pub(crate) fn new(
272 client: &'a Client,
273 qr_code_data: &'a QrCodeData,
274 registration_data: Option<&'a ClientRegistrationData>,
275 ) -> LoginWithQrCode<'a> {
276 LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() }
277 }
278
279 async fn establish_secure_channel(
280 &self,
281 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
282 let http_client = self.client.inner.http_client.inner.clone();
283
284 let channel = EstablishedSecureChannel::from_qr_code(
285 http_client,
286 self.qr_code_data,
287 QrCodeMode::Login,
288 )
289 .await?;
290
291 Ok(channel)
292 }
293
294 async fn register_client(
298 &self,
299 ) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
300 let oauth = self.client.oauth();
301 let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
302 oauth.use_registration_data(&server_metadata, self.registration_data).await?;
303
304 Ok(server_metadata)
305 }
306
307 async fn request_device_authorization(
308 &self,
309 server_metadata: &AuthorizationServerMetadata,
310 device_id: Curve25519PublicKey,
311 ) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOAuthError> {
312 let oauth = self.client.oauth();
313 let response = oauth
314 .request_device_authorization(server_metadata, Some(device_id.to_base64().into()))
315 .await?;
316 Ok(response)
317 }
318
319 async fn wait_for_tokens(
320 &self,
321 server_metadata: &AuthorizationServerMetadata,
322 auth_response: &StandardDeviceAuthorizationResponse,
323 ) -> Result<(), DeviceAuthorizationOAuthError> {
324 let oauth = self.client.oauth();
325 oauth.exchange_device_code(server_metadata, auth_response).await?;
326 Ok(())
327 }
328}
329
330#[cfg(all(test, not(target_family = "wasm")))]
331mod test {
332 use assert_matches2::{assert_let, assert_matches};
333 use futures_util::{join, StreamExt};
334 use matrix_sdk_base::crypto::types::{qr_login::QrCodeModeData, SecretsBundle};
335 use matrix_sdk_common::executor::spawn;
336 use matrix_sdk_test::async_test;
337 use serde_json::json;
338
339 use super::*;
340 use crate::{
341 authentication::oauth::qrcode::{
342 messages::LoginProtocolType,
343 secure_channel::{test::MockedRendezvousServer, SecureChannel},
344 },
345 config::RequestConfig,
346 http_client::HttpClient,
347 test_utils::{client::oauth::mock_client_metadata, mocks::MatrixMockServer},
348 };
349
350 enum AliceBehaviour {
351 HappyPath,
352 DeclinedProtocol,
353 UnexpectedMessage,
354 UnexpectedMessageInsteadOfSecrets,
355 RefuseSecrets,
356 }
357
358 enum TokenResponse {
360 Ok,
361 AccessDenied,
362 ExpiredToken,
363 }
364
365 fn secrets_bundle() -> SecretsBundle {
366 let json = json!({
367 "cross_signing": {
368 "master_key": "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw",
369 "self_signing_key": "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk",
370 "user_signing_key": "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM",
371 },
372 });
373
374 serde_json::from_value(json).expect("We should be able to deserialize a secrets bundle")
375 }
376
377 async fn grant_login(
382 alice: SecureChannel,
383 check_code_receiver: tokio::sync::oneshot::Receiver<CheckCode>,
384 behavior: AliceBehaviour,
385 ) {
386 let alice = alice.connect().await.expect("Alice should be able to connect the channel");
387
388 let check_code =
389 check_code_receiver.await.expect("We should receive the check code from bob");
390
391 let mut alice = alice
392 .confirm(check_code.to_digit())
393 .expect("Alice should be able to confirm the secure channel");
394
395 let message = alice
396 .receive_json()
397 .await
398 .expect("Alice should be able to receive the initial message from Bob");
399
400 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
401 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
402
403 let message = match behavior {
404 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
405 reason: LoginFailureReason::UnsupportedProtocol,
406 homeserver: None,
407 },
408 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
409 _ => QrAuthMessage::LoginProtocolAccepted,
410 };
411
412 alice.send_json(message).await.unwrap();
413
414 let message: QrAuthMessage = alice.receive_json().await.unwrap();
415 assert_let!(QrAuthMessage::LoginSuccess = message);
416
417 let message = match behavior {
418 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
419 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
420 reason: LoginFailureReason::DeviceNotFound,
421 homeserver: None,
422 },
423 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
424 };
425
426 alice.send_json(message).await.unwrap();
427 }
428
429 #[async_test]
430 async fn test_qr_login() {
431 let server = MatrixMockServer::new().await;
432 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
433 let (sender, receiver) = tokio::sync::oneshot::channel();
434
435 let oauth_server = server.oauth();
436 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
437 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
438 oauth_server
439 .mock_device_authorization()
440 .ok()
441 .expect(1)
442 .named("device_authorization")
443 .mount()
444 .await;
445 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
446
447 server.mock_versions().ok().expect(1..).named("versions").mount().await;
448 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
449 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
450 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
451
452 let client = HttpClient::new(reqwest::Client::new(), Default::default());
453 let alice = SecureChannel::new(client, &rendezvous_server.homeserver_url)
454 .await
455 .expect("Alice should be able to create a secure channel.");
456
457 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
458
459 let bob = Client::builder()
460 .server_name_or_homeserver_url(server_name)
461 .request_config(RequestConfig::new().disable_retry())
462 .build()
463 .await
464 .expect("We should be able to build the Client object from the URL in the QR code");
465
466 let qr_code = alice.qr_code_data().clone();
467
468 let oauth = bob.oauth();
469 let registration_data = mock_client_metadata().into();
470 let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data));
471 let mut updates = login_bob.subscribe_to_progress();
472
473 let updates_task = spawn(async move {
474 let mut sender = Some(sender);
475
476 while let Some(update) = updates.next().await {
477 match update {
478 LoginProgress::EstablishingSecureChannel { check_code } => {
479 sender
480 .take()
481 .expect("The establishing secure channel update should be received only once")
482 .send(check_code)
483 .expect("Bob should be able to send the check code to Alice");
484 }
485 LoginProgress::Done => break,
486 _ => (),
487 }
488 }
489 });
490 let alice_task =
491 spawn(async { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
492
493 join!(
494 async {
495 login_bob.await.expect("Bob should be able to login");
496 },
497 async {
498 alice_task.await.expect("Alice should have completed it's task successfully");
499 },
500 async { updates_task.await.unwrap() }
501 );
502
503 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
504 let own_identity =
505 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
506
507 assert!(own_identity.is_verified());
508 }
509
510 async fn test_failure(
511 token_response: TokenResponse,
512 alice_behavior: AliceBehaviour,
513 ) -> Result<(), QRCodeLoginError> {
514 let server = MatrixMockServer::new().await;
515 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
516 let (sender, receiver) = tokio::sync::oneshot::channel();
517
518 let oauth_server = server.oauth();
519 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
520 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
521 oauth_server
522 .mock_device_authorization()
523 .ok()
524 .expect(1)
525 .named("device_authorization")
526 .mount()
527 .await;
528
529 let token_mock = oauth_server.mock_token();
530 let token_mock = match token_response {
531 TokenResponse::Ok => token_mock.ok(),
532 TokenResponse::AccessDenied => token_mock.access_denied(),
533 TokenResponse::ExpiredToken => token_mock.expired_token(),
534 };
535 token_mock.named("token").mount().await;
536
537 server.mock_versions().ok().named("versions").mount().await;
538 server.mock_who_am_i().ok().named("whoami").mount().await;
539
540 let client = HttpClient::new(reqwest::Client::new(), Default::default());
541 let alice = SecureChannel::new(client, &rendezvous_server.homeserver_url)
542 .await
543 .expect("Alice should be able to create a secure channel.");
544
545 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
546
547 let bob = Client::builder()
548 .server_name_or_homeserver_url(server_name)
549 .request_config(RequestConfig::new().disable_retry())
550 .build()
551 .await
552 .expect("We should be able to build the Client object from the URL in the QR code");
553
554 let qr_code = alice.qr_code_data().clone();
555
556 let oauth = bob.oauth();
557 let registration_data = mock_client_metadata().into();
558 let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data));
559 let mut updates = login_bob.subscribe_to_progress();
560
561 let _updates_task = spawn(async move {
562 let mut sender = Some(sender);
563
564 while let Some(update) = updates.next().await {
565 match update {
566 LoginProgress::EstablishingSecureChannel { check_code } => {
567 sender
568 .take()
569 .expect("The establishing secure channel update should be received only once")
570 .send(check_code)
571 .expect("Bob should be able to send the check code to Alice");
572 }
573 LoginProgress::Done => break,
574 _ => (),
575 }
576 }
577 });
578 let _alice_task = spawn(async move { grant_login(alice, receiver, alice_behavior).await });
579 login_bob.await
580 }
581
582 #[async_test]
583 async fn test_qr_login_refused_access_token() {
584 let result = test_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
585
586 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
587 assert_eq!(
588 e.as_request_token_error(),
589 Some(&DeviceCodeErrorResponseType::AccessDenied),
590 "The server should have told us that access has been denied."
591 );
592 }
593
594 #[async_test]
595 async fn test_qr_login_expired_token() {
596 let result = test_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
597
598 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
599 assert_eq!(
600 e.as_request_token_error(),
601 Some(&DeviceCodeErrorResponseType::ExpiredToken),
602 "The server should have told us that access has been denied."
603 );
604 }
605
606 #[async_test]
607 async fn test_qr_login_declined_protocol() {
608 let result = test_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
609
610 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
611 assert_eq!(
612 reason,
613 LoginFailureReason::UnsupportedProtocol,
614 "Alice should have told us that the protocol is unsupported."
615 );
616 }
617
618 #[async_test]
619 async fn test_qr_login_unexpected_message() {
620 let result = test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
621
622 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
623 assert_eq!(expected, "m.login.protocol_accepted");
624 }
625
626 #[async_test]
627 async fn test_qr_login_unexpected_message_instead_of_secrets() {
628 let result =
629 test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessageInsteadOfSecrets)
630 .await;
631
632 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
633 assert_eq!(expected, "m.login.secrets");
634 }
635
636 #[async_test]
637 async fn test_qr_login_refuse_secrets() {
638 let result = test_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
639
640 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
641 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
642 }
643
644 #[async_test]
645 async fn test_device_authorization_endpoint_missing() {
646 let server = MatrixMockServer::new().await;
647 let rendezvous_server = MockedRendezvousServer::new(server.server(), "abcdEFG12345").await;
648 let (sender, receiver) = tokio::sync::oneshot::channel();
649
650 let oauth_server = server.oauth();
651 oauth_server
652 .mock_server_metadata()
653 .ok_without_device_authorization()
654 .expect(1)
655 .named("server_metadata")
656 .mount()
657 .await;
658 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
659
660 server.mock_versions().ok().named("versions").mount().await;
661 server.mock_who_am_i().ok().named("whoami").mount().await;
662
663 let client = HttpClient::new(reqwest::Client::new(), Default::default());
664 let alice = SecureChannel::new(client, &rendezvous_server.homeserver_url)
665 .await
666 .expect("Alice should be able to create a secure channel.");
667
668 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
669
670 let bob = Client::builder()
671 .server_name_or_homeserver_url(server_name)
672 .request_config(RequestConfig::new().disable_retry())
673 .build()
674 .await
675 .expect("We should be able to build the Client object from the URL in the QR code");
676
677 let qr_code = alice.qr_code_data().clone();
678
679 let oauth = bob.oauth();
680 let registration_data = mock_client_metadata().into();
681 let login_bob = oauth.login_with_qr_code(&qr_code, Some(®istration_data));
682 let mut updates = login_bob.subscribe_to_progress();
683
684 let _updates_task = spawn(async move {
685 let mut sender = Some(sender);
686
687 while let Some(update) = updates.next().await {
688 match update {
689 LoginProgress::EstablishingSecureChannel { check_code } => {
690 sender
691 .take()
692 .expect("The establishing secure channel update should be received only once")
693 .send(check_code)
694 .expect("Bob should be able to send the check code to Alice");
695 }
696 LoginProgress::Done => break,
697 _ => (),
698 }
699 }
700 });
701 let _alice_task =
702 spawn(async move { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
703 let error = login_bob.await.unwrap_err();
704
705 assert_matches!(
706 error,
707 QRCodeLoginError::OAuth(DeviceAuthorizationOAuthError::NoDeviceAuthorizationEndpoint)
708 );
709 }
710}