1use std::future::IntoFuture;
16
17use eyeball::SharedObservable;
18use futures_core::Stream;
19use matrix_sdk_base::{
20 SessionMeta, boxed_into_future,
21 crypto::types::qr_login::{QrCodeData, QrCodeMode},
22 store::RoomLoadSettings,
23};
24use oauth2::{DeviceCodeErrorResponseType, StandardDeviceAuthorizationResponse};
25use ruma::{
26 OwnedDeviceId,
27 api::client::discovery::get_authorization_server_metadata::v1::AuthorizationServerMetadata,
28};
29use tracing::trace;
30use vodozemac::Curve25519PublicKey;
31#[cfg(doc)]
32use vodozemac::ecies::CheckCode;
33
34use super::{
35 DeviceAuthorizationOAuthError, QRCodeLoginError, SecureChannelError,
36 messages::{LoginFailureReason, QrAuthMessage},
37 secure_channel::{EstablishedSecureChannel, SecureChannel},
38};
39use crate::{
40 Client,
41 authentication::oauth::{
42 ClientRegistrationData, OAuth, OAuthError,
43 qrcode::{CheckCodeSender, GeneratedQrProgress, LoginProtocolType, QrProgress},
44 },
45};
46
47async fn send_unexpected_message_error(
48 channel: &mut EstablishedSecureChannel,
49) -> Result<(), SecureChannelError> {
50 channel
51 .send_json(QrAuthMessage::LoginFailure {
52 reason: LoginFailureReason::UnexpectedMessageReceived,
53 homeserver: None,
54 })
55 .await
56}
57
58async fn finish_login<Q>(
59 client: &Client,
60 mut channel: EstablishedSecureChannel,
61 registration_data: Option<&ClientRegistrationData>,
62 state: SharedObservable<LoginProgress<Q>>,
63) -> Result<(), QRCodeLoginError> {
64 let oauth = client.oauth();
65
66 trace!("Registering the client with the OAuth 2.0 authorization server.");
68 let server_metadata = register_client(&oauth, registration_data).await?;
69
70 let account = vodozemac::olm::Account::new();
73 let public_key = account.identity_keys().curve25519;
74 let device_id = public_key;
75
76 trace!("Requesting device authorization.");
79 let auth_grant_response =
80 request_device_authorization(&oauth, &server_metadata, device_id).await?;
81
82 trace!("Letting the existing device know about the device authorization grant.");
85 let message =
86 QrAuthMessage::authorization_grant_login_protocol((&auth_grant_response).into(), device_id);
87 channel.send_json(&message).await?;
88
89 match channel.receive_json().await? {
91 QrAuthMessage::LoginProtocolAccepted => (),
92 QrAuthMessage::LoginFailure { reason, homeserver } => {
93 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
94 }
95 message => {
96 send_unexpected_message_error(&mut channel).await?;
97
98 return Err(QRCodeLoginError::UnexpectedMessage {
99 expected: "m.login.protocol_accepted",
100 received: message,
101 });
102 }
103 }
104
105 let user_code = auth_grant_response.user_code();
109 state.set(LoginProgress::WaitingForToken { user_code: user_code.secret().to_owned() });
110
111 trace!("Waiting for the OAuth 2.0 authorization server to give us the access token.");
114 if let Err(e) = wait_for_tokens(&oauth, &server_metadata, &auth_grant_response).await {
115 if let Some(e) = e.as_request_token_error() {
118 match e {
119 DeviceCodeErrorResponseType::AccessDenied => {
120 channel.send_json(QrAuthMessage::LoginDeclined).await?;
121 }
122 DeviceCodeErrorResponseType::ExpiredToken => {
123 channel
124 .send_json(QrAuthMessage::LoginFailure {
125 reason: LoginFailureReason::AuthorizationExpired,
126 homeserver: None,
127 })
128 .await?;
129 }
130 _ => (),
131 }
132 }
133
134 return Err(e.into());
135 }
136
137 trace!("Discovering our own user id.");
143 let whoami_response = client.whoami().await.map_err(QRCodeLoginError::UserIdDiscovery)?;
144 client
145 .base_client()
146 .activate(
147 SessionMeta {
148 user_id: whoami_response.user_id,
149 device_id: OwnedDeviceId::from(device_id.to_base64()),
150 },
151 RoomLoadSettings::default(),
152 Some(account),
153 )
154 .await
155 .map_err(|error| QRCodeLoginError::SessionTokens(error.into()))?;
156
157 client.oauth().enable_cross_process_lock().await?;
158
159 state.set(LoginProgress::SyncingSecrets);
160
161 trace!("Telling the existing device that we successfully logged in.");
163 let message = QrAuthMessage::LoginSuccess;
164 channel.send_json(&message).await?;
165
166 trace!("Waiting for the secrets bundle.");
169 let bundle = match channel.receive_json().await? {
170 QrAuthMessage::LoginSecrets(bundle) => bundle,
171 QrAuthMessage::LoginFailure { reason, homeserver } => {
172 return Err(QRCodeLoginError::LoginFailure { reason, homeserver });
173 }
174 message => {
175 send_unexpected_message_error(&mut channel).await?;
176
177 return Err(QRCodeLoginError::UnexpectedMessage {
178 expected: "m.login.secrets",
179 received: message,
180 });
181 }
182 };
183
184 client.encryption().import_secrets_bundle(&bundle).await?;
187
188 client
191 .encryption()
192 .ensure_device_keys_upload()
193 .await
194 .map_err(QRCodeLoginError::DeviceKeyUpload)?;
195
196 client.encryption().spawn_initialization_task(None).await;
201 client.encryption().wait_for_e2ee_initialization_tasks().await;
202
203 trace!("successfully logged in and enabled E2EE.");
204
205 state.set(LoginProgress::Done);
207
208 Ok(())
210}
211
212async fn register_client(
216 oauth: &OAuth,
217 registration_data: Option<&ClientRegistrationData>,
218) -> Result<AuthorizationServerMetadata, DeviceAuthorizationOAuthError> {
219 let server_metadata = oauth.server_metadata().await.map_err(OAuthError::from)?;
220 oauth.use_registration_data(&server_metadata, registration_data).await?;
221
222 Ok(server_metadata)
223}
224
225async fn request_device_authorization(
226 oauth: &OAuth,
227 server_metadata: &AuthorizationServerMetadata,
228 device_id: Curve25519PublicKey,
229) -> Result<StandardDeviceAuthorizationResponse, DeviceAuthorizationOAuthError> {
230 let response = oauth
231 .request_device_authorization(server_metadata, Some(device_id.to_base64().into()))
232 .await?;
233 Ok(response)
234}
235
236async fn wait_for_tokens(
237 oauth: &OAuth,
238 server_metadata: &AuthorizationServerMetadata,
239 auth_response: &StandardDeviceAuthorizationResponse,
240) -> Result<(), DeviceAuthorizationOAuthError> {
241 oauth.exchange_device_code(server_metadata, auth_response).await?;
242 Ok(())
243}
244
245#[derive(Clone, Debug, Default)]
247pub enum LoginProgress<Q> {
248 #[default]
250 Starting,
251 EstablishingSecureChannel(Q),
254 WaitingForToken {
258 user_code: String,
262 },
263 SyncingSecrets,
265 Done,
267}
268
269#[derive(Debug)]
272pub struct LoginWithQrCode<'a> {
273 client: &'a Client,
274 registration_data: Option<&'a ClientRegistrationData>,
275 qr_code_data: &'a QrCodeData,
276 state: SharedObservable<LoginProgress<QrProgress>>,
277}
278
279impl LoginWithQrCode<'_> {
280 pub fn subscribe_to_progress(&self) -> impl Stream<Item = LoginProgress<QrProgress>> + use<> {
286 self.state.subscribe()
287 }
288}
289
290impl<'a> IntoFuture for LoginWithQrCode<'a> {
291 type Output = Result<(), QRCodeLoginError>;
292 boxed_into_future!(extra_bounds: 'a);
293
294 fn into_future(self) -> Self::IntoFuture {
295 Box::pin(async move {
296 let channel = self.establish_secure_channel().await?;
305
306 trace!("Established the secure channel.");
307
308 let check_code = channel.check_code().to_owned();
312 self.state.set(LoginProgress::EstablishingSecureChannel(QrProgress { check_code }));
313
314 finish_login(self.client, channel, self.registration_data, self.state).await
321 })
322 }
323}
324
325impl<'a> LoginWithQrCode<'a> {
326 pub(crate) fn new(
327 client: &'a Client,
328 qr_code_data: &'a QrCodeData,
329 registration_data: Option<&'a ClientRegistrationData>,
330 ) -> LoginWithQrCode<'a> {
331 LoginWithQrCode { client, registration_data, qr_code_data, state: Default::default() }
332 }
333
334 async fn establish_secure_channel(
335 &self,
336 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
337 let http_client = self.client.inner.http_client.inner.clone();
338
339 let channel = EstablishedSecureChannel::from_qr_code(
340 http_client,
341 self.qr_code_data,
342 QrCodeMode::Login,
343 )
344 .await?;
345
346 Ok(channel)
347 }
348}
349
350#[derive(Debug)]
353pub struct LoginWithGeneratedQrCode<'a> {
354 client: &'a Client,
355 registration_data: Option<&'a ClientRegistrationData>,
356 state: SharedObservable<LoginProgress<GeneratedQrProgress>>,
357}
358
359impl LoginWithGeneratedQrCode<'_> {
360 pub fn subscribe_to_progress(
365 &self,
366 ) -> impl Stream<Item = LoginProgress<GeneratedQrProgress>> + use<> {
367 self.state.subscribe()
368 }
369}
370
371impl<'a> IntoFuture for LoginWithGeneratedQrCode<'a> {
372 type Output = Result<(), QRCodeLoginError>;
373 boxed_into_future!(extra_bounds: 'a);
374
375 fn into_future(self) -> Self::IntoFuture {
376 Box::pin(async move {
377 let mut channel = self.establish_secure_channel().await?;
380
381 trace!("Established the secure channel.");
382
383 let message = channel.receive_json().await?;
387
388 let homeserver = match message {
391 QrAuthMessage::LoginProtocols { protocols, homeserver } => {
392 if !protocols.contains(&LoginProtocolType::DeviceAuthorizationGrant) {
393 channel
394 .send_json(QrAuthMessage::LoginFailure {
395 reason: LoginFailureReason::UnsupportedProtocol,
396 homeserver: None,
397 })
398 .await?;
399
400 return Err(QRCodeLoginError::LoginFailure {
401 reason: LoginFailureReason::UnsupportedProtocol,
402 homeserver: None,
403 });
404 }
405
406 homeserver
407 }
408 _ => {
409 send_unexpected_message_error(&mut channel).await?;
410
411 return Err(QRCodeLoginError::UnexpectedMessage {
412 expected: "m.login.protocols",
413 received: message,
414 });
415 }
416 };
417
418 if self.client.homeserver() != homeserver {
421 self.client
422 .switch_homeserver_and_re_resolve_well_known(homeserver)
423 .await
424 .map_err(QRCodeLoginError::ServerReset)?;
425 }
426
427 finish_login(self.client, channel, self.registration_data, self.state).await
430 })
431 }
432}
433
434impl<'a> LoginWithGeneratedQrCode<'a> {
435 pub(crate) fn new(
436 client: &'a Client,
437 registration_data: Option<&'a ClientRegistrationData>,
438 ) -> Self {
439 Self { client, registration_data, state: Default::default() }
440 }
441
442 async fn establish_secure_channel(
443 &self,
444 ) -> Result<EstablishedSecureChannel, SecureChannelError> {
445 let http_client = self.client.inner.http_client.clone();
446
447 let secure_channel = SecureChannel::login(http_client, &self.client.homeserver()).await?;
451
452 let qr_code_data = secure_channel.qr_code_data().clone();
456 trace!("Generated QR code.");
457 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(
458 qr_code_data,
459 )));
460
461 let channel = secure_channel.connect().await?;
465
466 trace!("Waiting for checkcode.");
471 let (tx, rx) = tokio::sync::oneshot::channel();
472 self.state.set(LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
473 CheckCodeSender::new(tx),
474 )));
475
476 let check_code = rx.await.map_err(|_| SecureChannelError::CannotReceiveCheckCode)?;
480 trace!("Received check code.");
481 channel.confirm(check_code)
482 }
483}
484
485#[cfg(all(test, not(target_family = "wasm")))]
486mod test {
487 use std::time::Duration;
488
489 use assert_matches2::{assert_let, assert_matches};
490 use futures_util::StreamExt;
491 use matrix_sdk_base::crypto::types::{SecretsBundle, qr_login::QrCodeModeData};
492 use matrix_sdk_common::executor::spawn;
493 use matrix_sdk_test::async_test;
494 use serde_json::json;
495 use vodozemac::ecies::CheckCode;
496
497 use super::*;
498 use crate::{
499 authentication::oauth::qrcode::{
500 messages::LoginProtocolType,
501 secure_channel::{SecureChannel, test::MockedRendezvousServer},
502 },
503 config::RequestConfig,
504 http_client::HttpClient,
505 test_utils::{client::oauth::mock_client_metadata, mocks::MatrixMockServer},
506 };
507
508 enum AliceBehaviour {
509 HappyPath,
510 DeclinedProtocol,
511 UnexpectedMessage,
512 UnexpectedMessageInsteadOfSecrets,
513 RefuseSecrets,
514 LetSessionExpire,
515 }
516
517 enum TokenResponse {
519 Ok,
520 AccessDenied,
521 ExpiredToken,
522 }
523
524 fn secrets_bundle() -> SecretsBundle {
525 let json = json!({
526 "cross_signing": {
527 "master_key": "rTtSv67XGS6k/rg6/yTG/m573cyFTPFRqluFhQY+hSw",
528 "self_signing_key": "4jbPt7jh5D2iyM4U+3IDa+WthgJB87IQN1ATdkau+xk",
529 "user_signing_key": "YkFKtkjcsTxF6UAzIIG/l6Nog/G2RigCRfWj3cjNWeM",
530 },
531 });
532
533 serde_json::from_value(json).expect("We should be able to deserialize a secrets bundle")
534 }
535
536 async fn grant_login(
539 alice: SecureChannel,
540 check_code_receiver: tokio::sync::oneshot::Receiver<CheckCode>,
541 behavior: AliceBehaviour,
542 ) {
543 let alice = alice.connect().await.expect("Alice should be able to connect the channel");
544
545 let check_code =
546 check_code_receiver.await.expect("We should receive the check code from bob");
547
548 let mut alice = alice
549 .confirm(check_code.to_digit())
550 .expect("Alice should be able to confirm the secure channel");
551
552 let message = alice
553 .receive_json()
554 .await
555 .expect("Alice should be able to receive the initial message from Bob");
556
557 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
558 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
559
560 let message = match behavior {
561 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
562 reason: LoginFailureReason::UnsupportedProtocol,
563 homeserver: None,
564 },
565 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
566 _ => QrAuthMessage::LoginProtocolAccepted,
567 };
568
569 alice.send_json(message).await.unwrap();
570
571 let message: QrAuthMessage = alice.receive_json().await.unwrap();
572 assert_let!(QrAuthMessage::LoginSuccess = message);
573
574 let message = match behavior {
575 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
576 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
577 reason: LoginFailureReason::DeviceNotFound,
578 homeserver: None,
579 },
580 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
581 };
582
583 alice.send_json(message).await.unwrap();
584 }
585
586 #[async_test]
587 async fn test_qr_login() {
588 let server = MatrixMockServer::new().await;
589 let rendezvous_server =
590 MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await;
591 let (sender, receiver) = tokio::sync::oneshot::channel();
592
593 let oauth_server = server.oauth();
594 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
595 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
596 oauth_server
597 .mock_device_authorization()
598 .ok()
599 .expect(1)
600 .named("device_authorization")
601 .mount()
602 .await;
603 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
604
605 server.mock_versions().ok().expect(1..).named("versions").mount().await;
606 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
607 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
608 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
609
610 let client = HttpClient::new(reqwest::Client::new(), Default::default());
611 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
612 .await
613 .expect("Alice should be able to create a secure channel.");
614
615 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
616
617 let bob = Client::builder()
618 .server_name_or_homeserver_url(server_name)
619 .request_config(RequestConfig::new().disable_retry())
620 .build()
621 .await
622 .expect("We should be able to build the Client object from the URL in the QR code");
623
624 let qr_code = alice.qr_code_data().clone();
625
626 let oauth = bob.oauth();
627 let registration_data = mock_client_metadata().into();
628 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
629 let mut updates = login_bob.subscribe_to_progress();
630
631 let updates_task = spawn(async move {
632 let mut sender = Some(sender);
633
634 while let Some(update) = updates.next().await {
635 match update {
636 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
637 sender
638 .take()
639 .expect("The establishing secure channel update should be received only once")
640 .send(check_code)
641 .expect("Bob should be able to send the check code to Alice");
642 }
643 LoginProgress::Done => break,
644 _ => (),
645 }
646 }
647 });
648 let alice_task =
649 spawn(async { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
650
651 login_bob.await.expect("Bob should be able to login");
653 alice_task.await.expect("Alice should have completed it's task successfully");
654 updates_task.await.unwrap();
655
656 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
657 let own_identity =
658 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
659
660 assert!(own_identity.is_verified());
661 }
662
663 async fn grant_login_with_generated_qr(
664 alice: &Client,
665 qr_receiver: tokio::sync::oneshot::Receiver<QrCodeData>,
666 cctx_receiver: tokio::sync::oneshot::Receiver<CheckCodeSender>,
667 behavior: AliceBehaviour,
668 ) {
669 let qr_code_data = qr_receiver.await.expect("Alice should receive the QR code");
670
671 let mut channel = EstablishedSecureChannel::from_qr_code(
672 alice.inner.http_client.inner.clone(),
673 &qr_code_data,
674 QrCodeMode::Reciprocate,
675 )
676 .await
677 .expect("Alice should be able to establish the secure channel");
678
679 trace!("Established the secure channel.");
680
681 let check_code = channel.check_code().to_digit();
684
685 let check_code_sender =
686 cctx_receiver.await.expect("Alice should receive the CheckCodeSender");
687
688 check_code_sender
689 .send(check_code)
690 .await
691 .expect("Alice should be able to send the check code to Bob");
692
693 let message = QrAuthMessage::LoginProtocols {
695 protocols: vec![LoginProtocolType::DeviceAuthorizationGrant],
696 homeserver: alice.homeserver(),
697 };
698 channel
699 .send_json(message)
700 .await
701 .expect("Alice should be able to send the `m.login.protocols` message to Bob");
702
703 let message: QrAuthMessage = channel
705 .receive_json()
706 .await
707 .expect("Alice should be able to receive the `m.login.protocol` message from Bob");
708 assert_let!(QrAuthMessage::LoginProtocol { protocol, .. } = message);
709 assert_eq!(protocol, LoginProtocolType::DeviceAuthorizationGrant);
710
711 let message = match behavior {
713 AliceBehaviour::DeclinedProtocol => QrAuthMessage::LoginFailure {
714 reason: LoginFailureReason::UnsupportedProtocol,
715 homeserver: None,
716 },
717 AliceBehaviour::UnexpectedMessage => QrAuthMessage::LoginDeclined,
718 _ => QrAuthMessage::LoginProtocolAccepted,
719 };
720 channel
721 .send_json(message)
722 .await
723 .expect("Alice should be able to send the `m.login.protocol_accepted` message to Bob");
724
725 let message: QrAuthMessage = channel
726 .receive_json()
727 .await
728 .expect("Alice should be able to receive the `m.login.success` message from Bob");
729 assert_let!(QrAuthMessage::LoginSuccess = message);
730
731 let message = match behavior {
733 AliceBehaviour::UnexpectedMessageInsteadOfSecrets => QrAuthMessage::LoginDeclined,
734 AliceBehaviour::RefuseSecrets => QrAuthMessage::LoginFailure {
735 reason: LoginFailureReason::DeviceNotFound,
736 homeserver: None,
737 },
738 _ => QrAuthMessage::LoginSecrets(secrets_bundle()),
739 };
740 channel
741 .send_json(message)
742 .await
743 .expect("Alice should be able to send the `m.login.secrets` message to Bob");
744 }
745
746 #[async_test]
747 async fn test_generated_qr_login() {
748 let server = MatrixMockServer::new().await;
749 let rendezvous_server =
750 MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await;
751 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
752 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
753
754 let oauth_server = server.oauth();
755 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
756 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
757 oauth_server
758 .mock_device_authorization()
759 .ok()
760 .expect(1)
761 .named("device_authorization")
762 .mount()
763 .await;
764 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
765
766 server.mock_versions().ok().expect(1..).named("versions").mount().await;
767 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
768 server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
769 server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
770
771 let homeserver_url = rendezvous_server.homeserver_url.clone();
772
773 let alice = server.client_builder().logged_in_with_oauth().build().await;
776 assert!(alice.session_meta().is_some(), "Alice should be logged in");
777
778 let bob = Client::builder()
780 .server_name_or_homeserver_url(&homeserver_url)
781 .request_config(RequestConfig::new().disable_retry())
782 .build()
783 .await
784 .expect("Should be able to create a client for Bob");
785
786 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
787 .await
788 .expect("Bob should be able to create a secure channel");
789
790 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
791
792 let registration_data = mock_client_metadata().into();
793 let bob_oauth = bob.oauth();
794 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
795 let mut bob_updates = bob_login.subscribe_to_progress();
796
797 let updates_task = spawn(async move {
798 let mut qr_sender = Some(qr_sender);
799 let mut cctx_sender = Some(cctx_sender);
800
801 while let Some(update) = bob_updates.next().await {
802 match update {
803 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
804 qr_sender
805 .take()
806 .expect("The establishing secure channel update with a qr code should be received only once")
807 .send(qr)
808 .expect("Bob should be able to send the qr code code to Alice");
809 }
810 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
811 cctx,
812 )) => {
813 cctx_sender
814 .take()
815 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
816 .send(cctx)
817 .expect("Bob should be able to send the qr code code to Alice");
818 }
819 LoginProgress::Done => break,
820 _ => (),
821 }
822 }
823 });
824
825 let alice_task = spawn(async move {
826 grant_login_with_generated_qr(
827 &alice,
828 qr_receiver,
829 cctx_receiver,
830 AliceBehaviour::HappyPath,
831 )
832 .await
833 });
834
835 bob_login.await.expect("Bob should be able to login");
837 alice_task.await.expect("Alice should have completed it's task successfully");
838 updates_task.await.unwrap();
839
840 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
841 let own_identity =
842 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
843
844 assert!(own_identity.is_verified());
845 }
846
847 #[async_test]
848 async fn test_generated_qr_login_with_homeserver_swap() {
849 let server = MatrixMockServer::new().await;
850 let rendezvous_server =
851 MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await;
852 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
853 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
854
855 let login_server = MatrixMockServer::new().await;
856 let oauth_server = login_server.oauth();
857 oauth_server.mock_server_metadata().ok().expect(1).named("server_metadata").mount().await;
858 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
859 oauth_server
860 .mock_device_authorization()
861 .ok()
862 .expect(1)
863 .named("device_authorization")
864 .mount()
865 .await;
866 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
867
868 server.mock_versions().ok().expect(1..).named("versions").mount().await;
869
870 login_server.mock_well_known().ok().expect(1).named("well_known").mount().await;
871 login_server.mock_versions().ok().expect(1..).named("versions").mount().await;
872 login_server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
873 login_server.mock_upload_keys().ok().expect(1).named("upload_keys").mount().await;
874 login_server.mock_query_keys().ok().expect(1).named("query_keys").mount().await;
875
876 let homeserver_url = rendezvous_server.homeserver_url.clone();
877
878 let alice = login_server.client_builder().logged_in_with_oauth().build().await;
881 assert!(alice.session_meta().is_some(), "Alice should be logged in");
882
883 let bob = Client::builder()
885 .server_name_or_homeserver_url(&homeserver_url)
886 .request_config(RequestConfig::new().disable_retry())
887 .build()
888 .await
889 .expect("Should be able to create a client for Bob");
890
891 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
892 .await
893 .expect("Bob should be able to create a secure channel");
894
895 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
896
897 let registration_data = mock_client_metadata().into();
898 let bob_oauth = bob.oauth();
899 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
900 let mut bob_updates = bob_login.subscribe_to_progress();
901
902 let updates_task = spawn(async move {
903 let mut qr_sender = Some(qr_sender);
904 let mut cctx_sender = Some(cctx_sender);
905
906 while let Some(update) = bob_updates.next().await {
907 match update {
908 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
909 qr_sender
910 .take()
911 .expect("The establishing secure channel update with a qr code should be received only once")
912 .send(qr)
913 .expect("Bob should be able to send the qr code code to Alice");
914 }
915 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
916 cctx,
917 )) => {
918 cctx_sender
919 .take()
920 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
921 .send(cctx)
922 .expect("Bob should be able to send the qr code code to Alice");
923 }
924 LoginProgress::Done => break,
925 _ => (),
926 }
927 }
928 });
929
930 let alice_task = spawn(async move {
931 grant_login_with_generated_qr(
932 &alice,
933 qr_receiver,
934 cctx_receiver,
935 AliceBehaviour::HappyPath,
936 )
937 .await
938 });
939
940 bob_login.await.expect("Bob should be able to login");
942 alice_task.await.expect("Alice should have completed it's task successfully");
943 updates_task.await.unwrap();
944
945 assert!(bob.encryption().cross_signing_status().await.unwrap().is_complete());
946 let own_identity =
947 bob.encryption().get_user_identity(bob.user_id().unwrap()).await.unwrap().unwrap();
948
949 assert!(own_identity.is_verified());
950 }
951
952 async fn test_failure(
953 token_response: TokenResponse,
954 alice_behavior: AliceBehaviour,
955 ) -> Result<(), QRCodeLoginError> {
956 let server = MatrixMockServer::new().await;
957 let expiration = match alice_behavior {
958 AliceBehaviour::LetSessionExpire => Duration::from_secs(2),
959 _ => Duration::MAX,
960 };
961 let rendezvous_server =
962 MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration).await;
963 let (sender, receiver) = tokio::sync::oneshot::channel();
964
965 let oauth_server = server.oauth();
966 let expected_calls = match alice_behavior {
967 AliceBehaviour::LetSessionExpire => 0,
968 _ => 1,
969 };
970 oauth_server
971 .mock_server_metadata()
972 .ok()
973 .expect(expected_calls)
974 .named("server_metadata")
975 .mount()
976 .await;
977 oauth_server
978 .mock_registration()
979 .ok()
980 .expect(expected_calls)
981 .named("registration")
982 .mount()
983 .await;
984 oauth_server
985 .mock_device_authorization()
986 .ok()
987 .expect(expected_calls)
988 .named("device_authorization")
989 .mount()
990 .await;
991
992 let token_mock = oauth_server.mock_token();
993 let token_mock = match token_response {
994 TokenResponse::Ok => token_mock.ok(),
995 TokenResponse::AccessDenied => token_mock.access_denied(),
996 TokenResponse::ExpiredToken => token_mock.expired_token(),
997 };
998 token_mock.named("token").mount().await;
999
1000 server.mock_versions().ok().named("versions").mount().await;
1001 server.mock_who_am_i().ok().named("whoami").mount().await;
1002
1003 let client = HttpClient::new(reqwest::Client::new(), Default::default());
1004 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
1005 .await
1006 .expect("Alice should be able to create a secure channel.");
1007
1008 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
1009
1010 let bob = Client::builder()
1011 .server_name_or_homeserver_url(server_name)
1012 .request_config(RequestConfig::new().disable_retry())
1013 .build()
1014 .await
1015 .expect("We should be able to build the Client object from the URL in the QR code");
1016
1017 let qr_code = alice.qr_code_data().clone();
1018
1019 let oauth = bob.oauth();
1020 let registration_data = mock_client_metadata().into();
1021 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1022 let mut updates = login_bob.subscribe_to_progress();
1023
1024 let _updates_task = spawn(async move {
1025 let mut sender = Some(sender);
1026
1027 while let Some(update) = updates.next().await {
1028 match update {
1029 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1030 sender
1031 .take()
1032 .expect("The establishing secure channel update should be received only once")
1033 .send(check_code)
1034 .expect("Bob should be able to send the check code to Alice");
1035 }
1036 LoginProgress::Done => break,
1037 _ => (),
1038 }
1039 }
1040 });
1041
1042 if !matches!(alice_behavior, AliceBehaviour::LetSessionExpire) {
1043 let _alice_task =
1044 spawn(async move { grant_login(alice, receiver, alice_behavior).await });
1045 }
1046
1047 login_bob.await
1048 }
1049
1050 async fn test_generated_failure(
1051 token_response: TokenResponse,
1052 alice_behavior: AliceBehaviour,
1053 ) -> Result<(), QRCodeLoginError> {
1054 let server = MatrixMockServer::new().await;
1055 let expiration = match alice_behavior {
1056 AliceBehaviour::LetSessionExpire => Duration::from_secs(2),
1057 _ => Duration::MAX,
1058 };
1059 let rendezvous_server =
1060 MockedRendezvousServer::new(server.server(), "abcdEFG12345", expiration).await;
1061
1062 let (qr_sender, qr_receiver) = tokio::sync::oneshot::channel();
1063 let (cctx_sender, cctx_receiver) = tokio::sync::oneshot::channel();
1064
1065 let oauth_server = server.oauth();
1066 let expected_calls = match alice_behavior {
1067 AliceBehaviour::LetSessionExpire => 0,
1068 _ => 1,
1069 };
1070 oauth_server
1071 .mock_server_metadata()
1072 .ok()
1073 .expect(expected_calls)
1074 .named("server_metadata")
1075 .mount()
1076 .await;
1077 oauth_server
1078 .mock_registration()
1079 .ok()
1080 .expect(expected_calls)
1081 .named("registration")
1082 .mount()
1083 .await;
1084 oauth_server
1085 .mock_device_authorization()
1086 .ok()
1087 .expect(expected_calls)
1088 .named("device_authorization")
1089 .mount()
1090 .await;
1091
1092 let token_mock = oauth_server.mock_token();
1093 let token_mock = match token_response {
1094 TokenResponse::Ok => token_mock.ok(),
1095 TokenResponse::AccessDenied => token_mock.access_denied(),
1096 TokenResponse::ExpiredToken => token_mock.expired_token(),
1097 };
1098 token_mock.named("token").mount().await;
1099
1100 server.mock_versions().ok().named("versions").mount().await;
1101 server.mock_who_am_i().ok().named("whoami").mount().await;
1102
1103 let homeserver_url = rendezvous_server.homeserver_url.clone();
1104
1105 let alice = server.client_builder().logged_in_with_oauth().build().await;
1108 assert!(alice.session_meta().is_some(), "Alice should be logged in");
1109
1110 let bob = Client::builder()
1112 .server_name_or_homeserver_url(&homeserver_url)
1113 .request_config(RequestConfig::new().disable_retry())
1114 .build()
1115 .await
1116 .expect("Should be able to create a client for Bob");
1117
1118 let secure_channel = SecureChannel::login(bob.inner.http_client.clone(), &homeserver_url)
1119 .await
1120 .expect("Bob should be able to create a secure channel");
1121
1122 assert_eq!(QrCodeModeData::Login, secure_channel.qr_code_data().mode_data);
1123
1124 let registration_data = mock_client_metadata().into();
1125 let bob_oauth = bob.oauth();
1126 let bob_login = bob_oauth.login_with_qr_code(Some(®istration_data)).generate();
1127 let mut bob_updates = bob_login.subscribe_to_progress();
1128
1129 let _updates_task = spawn(async move {
1130 let mut qr_sender = Some(qr_sender);
1131 let mut cctx_sender = Some(cctx_sender);
1132
1133 while let Some(update) = bob_updates.next().await {
1134 match update {
1135 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrReady(qr)) => {
1136 qr_sender
1137 .take()
1138 .expect("The establishing secure channel update with a qr code should be received only once")
1139 .send(qr)
1140 .expect("Bob should be able to send the qr code code to Alice");
1141 }
1142 LoginProgress::EstablishingSecureChannel(GeneratedQrProgress::QrScanned(
1143 cctx,
1144 )) => {
1145 cctx_sender
1146 .take()
1147 .expect("The establishing secure channel update with a CheckCodeSender should be received only once")
1148 .send(cctx)
1149 .expect("Bob should be able to send the qr code code to Alice");
1150 }
1151 LoginProgress::Done => break,
1152 _ => (),
1153 }
1154 }
1155 });
1156
1157 if !matches!(alice_behavior, AliceBehaviour::LetSessionExpire) {
1158 let _alice_task = spawn(async move {
1159 grant_login_with_generated_qr(&alice, qr_receiver, cctx_receiver, alice_behavior)
1160 .await
1161 });
1162 }
1163
1164 bob_login.await
1165 }
1166
1167 #[async_test]
1168 async fn test_qr_login_refused_access_token() {
1169 let result = test_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1170
1171 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1172 assert_eq!(
1173 e.as_request_token_error(),
1174 Some(&DeviceCodeErrorResponseType::AccessDenied),
1175 "The server should have told us that access has been denied."
1176 );
1177 }
1178
1179 #[async_test]
1180 async fn test_generated_qr_login_refused_access_token() {
1181 let result =
1182 test_generated_failure(TokenResponse::AccessDenied, AliceBehaviour::HappyPath).await;
1183
1184 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1185 assert_eq!(
1186 e.as_request_token_error(),
1187 Some(&DeviceCodeErrorResponseType::AccessDenied),
1188 "The server should have told us that access has been denied."
1189 );
1190 }
1191
1192 #[async_test]
1193 async fn test_qr_login_expired_token() {
1194 let result = test_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1195
1196 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1197 assert_eq!(
1198 e.as_request_token_error(),
1199 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1200 "The server should have told us that access has been denied."
1201 );
1202 }
1203
1204 #[async_test]
1205 async fn test_generated_qr_login_expired_token() {
1206 let result =
1207 test_generated_failure(TokenResponse::ExpiredToken, AliceBehaviour::HappyPath).await;
1208
1209 assert_let!(Err(QRCodeLoginError::OAuth(e)) = result);
1210 assert_eq!(
1211 e.as_request_token_error(),
1212 Some(&DeviceCodeErrorResponseType::ExpiredToken),
1213 "The server should have told us that access has been denied."
1214 );
1215 }
1216
1217 #[async_test]
1218 async fn test_qr_login_declined_protocol() {
1219 let result = test_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1220
1221 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1222 assert_eq!(
1223 reason,
1224 LoginFailureReason::UnsupportedProtocol,
1225 "Alice should have told us that the protocol is unsupported."
1226 );
1227 }
1228
1229 #[async_test]
1230 async fn test_generated_qr_login_declined_protocol() {
1231 let result =
1232 test_generated_failure(TokenResponse::Ok, AliceBehaviour::DeclinedProtocol).await;
1233
1234 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1235 assert_eq!(
1236 reason,
1237 LoginFailureReason::UnsupportedProtocol,
1238 "Alice should have told us that the protocol is unsupported."
1239 );
1240 }
1241
1242 #[async_test]
1243 async fn test_qr_login_unexpected_message() {
1244 let result = test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1245
1246 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1247 assert_eq!(expected, "m.login.protocol_accepted");
1248 }
1249
1250 #[async_test]
1251 async fn test_generated_qr_login_unexpected_message() {
1252 let result =
1253 test_generated_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessage).await;
1254
1255 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1256 assert_eq!(expected, "m.login.protocol_accepted");
1257 }
1258
1259 #[async_test]
1260 async fn test_qr_login_unexpected_message_instead_of_secrets() {
1261 let result =
1262 test_failure(TokenResponse::Ok, AliceBehaviour::UnexpectedMessageInsteadOfSecrets)
1263 .await;
1264
1265 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1266 assert_eq!(expected, "m.login.secrets");
1267 }
1268
1269 #[async_test]
1270 async fn test_generated_qr_login_unexpected_message_instead_of_secrets() {
1271 let result = test_generated_failure(
1272 TokenResponse::Ok,
1273 AliceBehaviour::UnexpectedMessageInsteadOfSecrets,
1274 )
1275 .await;
1276
1277 assert_let!(Err(QRCodeLoginError::UnexpectedMessage { expected, .. }) = result);
1278 assert_eq!(expected, "m.login.secrets");
1279 }
1280
1281 #[async_test]
1282 async fn test_qr_login_refuse_secrets() {
1283 let result = test_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1284
1285 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1286 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1287 }
1288
1289 #[async_test]
1290 async fn test_generated_qr_login_refuse_secrets() {
1291 let result = test_generated_failure(TokenResponse::Ok, AliceBehaviour::RefuseSecrets).await;
1292
1293 assert_let!(Err(QRCodeLoginError::LoginFailure { reason, .. }) = result);
1294 assert_eq!(reason, LoginFailureReason::DeviceNotFound);
1295 }
1296
1297 #[async_test]
1298 async fn test_qr_login_session_expired() {
1299 let result = test_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire).await;
1300
1301 assert_matches!(result, Err(QRCodeLoginError::NotFound));
1302 }
1303
1304 #[async_test]
1305 async fn test_generated_qr_login_session_expired() {
1306 let result =
1307 test_generated_failure(TokenResponse::Ok, AliceBehaviour::LetSessionExpire).await;
1308
1309 assert_matches!(result, Err(QRCodeLoginError::NotFound));
1310 }
1311
1312 #[async_test]
1313 async fn test_device_authorization_endpoint_missing() {
1314 let server = MatrixMockServer::new().await;
1315 let rendezvous_server =
1316 MockedRendezvousServer::new(server.server(), "abcdEFG12345", Duration::MAX).await;
1317 let (sender, receiver) = tokio::sync::oneshot::channel();
1318
1319 let oauth_server = server.oauth();
1320 oauth_server
1321 .mock_server_metadata()
1322 .ok_without_device_authorization()
1323 .expect(1)
1324 .named("server_metadata")
1325 .mount()
1326 .await;
1327 oauth_server.mock_registration().ok().expect(1).named("registration").mount().await;
1328
1329 server.mock_versions().ok().named("versions").mount().await;
1330 server.mock_who_am_i().ok().named("whoami").mount().await;
1331
1332 let client = HttpClient::new(reqwest::Client::new(), Default::default());
1333 let alice = SecureChannel::reciprocate(client, &rendezvous_server.homeserver_url)
1334 .await
1335 .expect("Alice should be able to create a secure channel.");
1336
1337 assert_let!(QrCodeModeData::Reciprocate { server_name } = &alice.qr_code_data().mode_data);
1338
1339 let bob = Client::builder()
1340 .server_name_or_homeserver_url(server_name)
1341 .request_config(RequestConfig::new().disable_retry())
1342 .build()
1343 .await
1344 .expect("We should be able to build the Client object from the URL in the QR code");
1345
1346 let qr_code = alice.qr_code_data().clone();
1347
1348 let oauth = bob.oauth();
1349 let registration_data = mock_client_metadata().into();
1350 let login_bob = oauth.login_with_qr_code(Some(®istration_data)).scan(&qr_code);
1351 let mut updates = login_bob.subscribe_to_progress();
1352
1353 let _updates_task = spawn(async move {
1354 let mut sender = Some(sender);
1355
1356 while let Some(update) = updates.next().await {
1357 match update {
1358 LoginProgress::EstablishingSecureChannel(QrProgress { check_code }) => {
1359 sender
1360 .take()
1361 .expect("The establishing secure channel update should be received only once")
1362 .send(check_code)
1363 .expect("Bob should be able to send the check code to Alice");
1364 }
1365 LoginProgress::Done => break,
1366 _ => (),
1367 }
1368 }
1369 });
1370 let _alice_task =
1371 spawn(async move { grant_login(alice, receiver, AliceBehaviour::HappyPath).await });
1372 let error = login_bob.await.unwrap_err();
1373
1374 assert_matches!(
1375 error,
1376 QRCodeLoginError::OAuth(DeviceAuthorizationOAuthError::NoDeviceAuthorizationEndpoint)
1377 );
1378 }
1379}