use futures_core::{Future, Stream};
use futures_util::StreamExt as _;
use ruma::{
api::client::keys::get_keys,
events::{
GlobalAccountDataEventType,
secret::{request::SecretName, send::ToDeviceSecretSendEvent},
secret_storage::{default_key::SecretStorageDefaultKeyEvent, secret::SecretEventContent},
},
serde::Raw,
};
use serde_json::{json, value::to_raw_value};
use tracing::{error, info, instrument, warn};
#[cfg(doc)]
use crate::encryption::{
backups::Backups,
secret_storage::{SecretStorage, SecretStore},
};
use crate::{Client, client::WeakClient, encryption::backups::BackupState};
pub mod futures;
mod types;
pub use self::types::{EnableProgress, RecoveryError, RecoveryState, Result};
use self::{
futures::{Enable, RecoverAndReset, Reset},
types::{BackupDisabledContent, SecretStorageDisabledContent},
};
use crate::encryption::{AuthData, CrossSigningResetAuthType, CrossSigningResetHandle};
#[derive(Debug)]
pub struct Recovery {
pub(super) client: Client,
}
impl Recovery {
pub const KNOWN_SECRETS: &[SecretName] = &[
SecretName::CrossSigningMasterKey,
SecretName::CrossSigningUserSigningKey,
SecretName::CrossSigningSelfSigningKey,
SecretName::RecoveryKey,
];
pub fn state(&self) -> RecoveryState {
self.client.inner.e2ee.recovery_state.get()
}
pub fn state_stream(&self) -> impl Stream<Item = RecoveryState> + use<> {
self.client.inner.e2ee.recovery_state.subscribe_reset()
}
#[instrument(skip_all)]
pub fn enable(&self) -> Enable<'_> {
Enable::new(self)
}
#[instrument(skip_all)]
pub async fn enable_backup(&self) -> Result<()> {
if !self.client.encryption().backups().fetch_exists_on_server().await? {
self.mark_backup_as_enabled().await?;
self.client.encryption().backups().create().await?;
self.client.encryption().backups().maybe_trigger_backup();
Ok(())
} else {
Err(RecoveryError::BackupExistsOnServer)
}
}
#[instrument(skip_all)]
pub async fn disable(&self) -> Result<()> {
self.client.encryption().backups().disable().await?;
if let Ok(Some(default_event)) =
self.client.encryption().secret_storage().fetch_default_key_id().await
&& let Ok(default_event) = default_event.deserialize()
{
let key_id = default_event.key_id;
let event_type = GlobalAccountDataEventType::SecretStorageKey(key_id);
self.client
.account()
.set_account_data_raw(event_type, Raw::new(&json!({})).expect("").cast_unchecked())
.await?;
}
self.client.account().set_account_data(SecretStorageDisabledContent {}).await?;
self.client.account().set_account_data(BackupDisabledContent { disabled: true }).await?;
self.delete_all_known_secrets().await?;
self.update_recovery_state().await?;
Ok(())
}
#[instrument(skip_all)]
pub fn reset_key(&self) -> Reset<'_> {
Reset::new(self)
}
#[instrument(skip_all)]
pub fn recover_and_reset<'a>(&'a self, old_key: &'a str) -> RecoverAndReset<'a> {
RecoverAndReset::new(self, old_key)
}
pub async fn reset_identity(&self) -> Result<Option<IdentityResetHandle>> {
self.client.encryption().backups().disable_and_delete().await?;
self.client.account().set_account_data(SecretStorageDisabledContent {}).await?;
self.client.encryption().recovery().update_recovery_state().await?;
let cross_signing_reset_handle = self.client.encryption().reset_cross_signing().await?;
if let Some(handle) = cross_signing_reset_handle {
Ok(Some(IdentityResetHandle {
client: self.client.clone(),
cross_signing_reset_handle: handle,
}))
} else {
if self.client.encryption().recovery().should_auto_enable_backups().await? {
self.client.encryption().recovery().enable_backup().await?; }
Ok(None)
}
}
#[instrument(skip_all)]
pub async fn recover(&self, recovery_key: &str) -> Result<()> {
let store =
self.client.encryption().secret_storage().open_secret_store(recovery_key).await?;
store.import_secrets().await?;
self.update_recovery_state().await?;
Ok(())
}
pub async fn is_last_device(&self) -> Result<bool> {
let olm_machine = self.client.olm_machine().await;
let olm_machine = olm_machine.as_ref().ok_or(crate::Error::NoOlmMachine)?;
let user_id = olm_machine.user_id();
self.client.encryption().ensure_initial_key_query().await?;
let devices = self.client.encryption().get_user_devices(user_id).await?;
Ok(devices.devices().count() == 1)
}
async fn all_known_secrets_available(&self) -> Result<bool> {
let cross_signing_complete = self
.client
.encryption()
.cross_signing_status()
.await
.map(|status| status.is_complete());
if !cross_signing_complete.unwrap_or_default() {
return Ok(false);
}
if self.client.encryption().backups().are_enabled().await {
Ok(true)
} else {
self.are_backups_marked_as_disabled().await
}
}
async fn should_auto_enable_backups(&self) -> Result<bool> {
Ok(self.client.inner.e2ee.encryption_settings.auto_enable_backups
&& !self.client.encryption().backups().are_enabled().await
&& !self.client.encryption().backups().fetch_exists_on_server().await?
&& !self.are_backups_marked_as_disabled().await?)
}
pub(crate) async fn setup(&self) -> Result<()> {
info!("Setting up account data listeners and trying to setup recovery");
self.client.add_event_handler(Self::default_key_event_handler);
self.client.add_event_handler(Self::secret_send_event_handler);
self.client.inner.e2ee.initialize_recovery_state_update_task(&self.client);
self.update_recovery_state().await?;
if self.should_auto_enable_backups().await? {
info!("Trying to automatically enable backups");
if let Err(e) = self.enable_backup().await {
warn!("Could not automatically enable backups: {e:?}");
}
}
Ok(())
}
async fn delete_all_known_secrets(&self) -> Result<()> {
for secret_name in Self::KNOWN_SECRETS {
let event_type = GlobalAccountDataEventType::from(secret_name.to_owned());
let content = SecretEventContent::new(Default::default());
let secret_content = Raw::from_json(
to_raw_value(&content)
.expect("We should be able to serialize a raw empty secret event content"),
);
self.client.account().set_account_data_raw(event_type, secret_content).await?;
}
Ok(())
}
async fn are_backups_marked_as_disabled(&self) -> Result<bool> {
Ok(self
.client
.account()
.fetch_account_data_static::<BackupDisabledContent>()
.await?
.map(|event| event.deserialize().map(|event| event.disabled).unwrap_or(false))
.unwrap_or(false))
}
async fn mark_backup_as_enabled(&self) -> Result<()> {
self.client.account().set_account_data(BackupDisabledContent { disabled: false }).await?;
Ok(())
}
async fn check_recovery_state(&self) -> Result<RecoveryState> {
Ok(if self.client.encryption().secret_storage().is_enabled().await? {
if self.all_known_secrets_available().await? {
RecoveryState::Enabled
} else {
RecoveryState::Incomplete
}
} else {
RecoveryState::Disabled
})
}
async fn update_recovery_state(&self) -> Result<()> {
let new_state = self.check_recovery_state().await?;
let old_state = self.client.inner.e2ee.recovery_state.set(new_state);
if new_state != old_state {
info!("Recovery state changed from {old_state:?} to {new_state:?}");
}
Ok(())
}
async fn update_recovery_state_no_fail(&self) {
if let Err(e) = self.update_recovery_state().await {
error!("Couldn't update the recovery state: {e:?}");
}
}
#[instrument]
async fn secret_send_event_handler(_: ToDeviceSecretSendEvent, client: Client) {
client.encryption().recovery().update_recovery_state_no_fail().await;
}
#[instrument]
async fn default_key_event_handler(_: SecretStorageDefaultKeyEvent, client: Client) {
client.encryption().recovery().update_recovery_state_no_fail().await;
}
pub(crate) fn update_state_after_backup_state_change(
client: &Client,
) -> impl Future<Output = ()> + use<> {
let mut stream = client.encryption().backups().state_stream();
let weak = WeakClient::from_client(client);
async move {
while let Some(update) = stream.next().await {
if let Some(client) = weak.get() {
match update {
Ok(update) => {
if matches!(update, BackupState::Unknown | BackupState::Enabled) {
client
.encryption()
.recovery()
.update_recovery_state_no_fail()
.await;
}
}
Err(_) => {
client.encryption().recovery().update_recovery_state_no_fail().await;
}
}
} else {
break;
}
}
}
}
#[instrument(skip_all)]
pub(crate) async fn update_state_after_keys_query(&self, response: &get_keys::v3::Response) {
if let Some(user_id) = self.client.user_id()
&& response.master_keys.contains_key(user_id)
{
self.update_recovery_state_no_fail().await;
}
}
}
#[derive(Debug)]
pub struct IdentityResetHandle {
client: Client,
cross_signing_reset_handle: CrossSigningResetHandle,
}
impl IdentityResetHandle {
pub fn auth_type(&self) -> &CrossSigningResetAuthType {
&self.cross_signing_reset_handle.auth_type
}
pub async fn reset(&self, auth: Option<AuthData>) -> Result<()> {
self.cross_signing_reset_handle.auth(auth).await?;
if self.client.encryption().recovery().should_auto_enable_backups().await? {
self.client.encryption().recovery().enable_backup().await?;
}
Ok(())
}
pub async fn cancel(&self) {
self.cross_signing_reset_handle.cancel().await;
}
}
#[cfg(all(test, not(target_family = "wasm")))]
pub(crate) mod tests {
use assert_matches::assert_matches;
use matrix_sdk_test::async_test;
use ruma::{
events::{secret::request::SecretName, secret_storage::key},
serde::Base64,
};
use serde_json::json;
use super::Recovery;
use crate::{
encryption::{recovery::types::RecoveryError, secret_storage::SecretStorageError},
test_utils::mocks::MatrixMockServer,
};
#[async_test]
async fn test_recover_with_no_cross_signing_key() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
server
.mock_get_secret_storage_key()
.ok(
client.user_id().unwrap(),
&key::SecretStorageKeyEventContent::new(
"abc".into(),
key::SecretStorageEncryptionAlgorithm::V1AesHmacSha2(
key::SecretStorageV1AesHmacSha2Properties::new(
Some(Base64::parse("xv5b6/p3ExEw++wTyfSHEg==").unwrap()),
Some(
Base64::parse("ujBBbXahnTAMkmPUX2/0+VTfUh63pGyVRuBcDMgmJC8=")
.unwrap(),
),
),
),
),
)
.mount()
.await;
server
.mock_get_default_secret_storage_key()
.ok(client.user_id().unwrap(), "abc")
.mount()
.await;
let recovery = Recovery { client };
let ret =
recovery.recover("EsTj 3yST y93F SLpB jJsz eAXc 2XzA ygD3 w69H fGaN TKBj jXEd").await;
assert_matches!(
ret,
Err(RecoveryError::SecretStorage(SecretStorageError::ImportError {
name: SecretName::CrossSigningMasterKey,
error: _
}))
);
}
#[async_test]
async fn test_recover_with_invalid_cross_signing_key() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
server
.mock_get_secret_storage_key()
.ok(
client.user_id().unwrap(),
&key::SecretStorageKeyEventContent::new(
"abc".into(),
key::SecretStorageEncryptionAlgorithm::V1AesHmacSha2(
key::SecretStorageV1AesHmacSha2Properties::new(
Some(Base64::parse("xv5b6/p3ExEw++wTyfSHEg==").unwrap()),
Some(
Base64::parse("ujBBbXahnTAMkmPUX2/0+VTfUh63pGyVRuBcDMgmJC8=")
.unwrap(),
),
),
),
),
)
.mount()
.await;
server
.mock_get_default_secret_storage_key()
.ok(client.user_id().unwrap(), "abc")
.mount()
.await;
server.mock_get_master_signing_key().ok(client.user_id().unwrap(), json!({})).mount().await;
let recovery = Recovery { client };
let ret =
recovery.recover("EsTj 3yST y93F SLpB jJsz eAXc 2XzA ygD3 w69H fGaN TKBj jXEd").await;
assert_matches!(
ret,
Err(RecoveryError::SecretStorage(SecretStorageError::ImportError {
name: SecretName::CrossSigningMasterKey,
error: _
}))
);
}
#[async_test]
async fn test_recover_with_undecryptable_cross_signing_key() {
let server = MatrixMockServer::new().await;
let client = server.client_builder().build().await;
server
.mock_get_secret_storage_key()
.ok(
client.user_id().unwrap(),
&key::SecretStorageKeyEventContent::new(
"abc".into(),
key::SecretStorageEncryptionAlgorithm::V1AesHmacSha2(
key::SecretStorageV1AesHmacSha2Properties::new(
Some(Base64::parse("xv5b6/p3ExEw++wTyfSHEg==").unwrap()),
Some(
Base64::parse("ujBBbXahnTAMkmPUX2/0+VTfUh63pGyVRuBcDMgmJC8=")
.unwrap(),
),
),
),
),
)
.mount()
.await;
server
.mock_get_default_secret_storage_key()
.ok(client.user_id().unwrap(), "abc")
.mount()
.await;
server
.mock_get_master_signing_key()
.ok(
client.user_id().unwrap(),
json!({
"encrypted": {
"abc": {
"iv": "xv5b6/p3ExEw++wTyfSHEg==",
"mac": "ujBBbXahnTAMkmPUX2/0+VTfUh63pGyVRuBcDMgmJC8=",
"ciphertext": "abcd"
}
}
}),
)
.mount()
.await;
let recovery = Recovery { client };
let ret =
recovery.recover("EsTj 3yST y93F SLpB jJsz eAXc 2XzA ygD3 w69H fGaN TKBj jXEd").await;
assert_matches!(
ret,
Err(RecoveryError::SecretStorage(SecretStorageError::ImportError {
name: SecretName::CrossSigningMasterKey,
error: _
}))
);
}
}