use std::path::Path;
use eyre::{OptionExt, Result, bail};
use matrix_sdk::authentication::matrix::MatrixSession;
use matrix_sdk::encryption::{
BackupDownloadStrategy, CrossSigningResetAuthType, EncryptionSettings,
};
use matrix_sdk::ruma::api::client::uiaa;
use matrix_sdk::{AuthSession, Client};
use rand::Rng;
use rusqlite::OptionalExtension;
use tracing::{info, instrument};
use crate::SyncHelper;
use crate::db::SQLiteHelper;
#[derive(Clone)]
pub struct SetupConfig<
'a,
AskRecoveryKeyCallback,
BeforeCreateBackupCallback,
PrintRecoveryKeyCallback,
> {
pub data_dir: &'a Path,
pub homeserver: &'a str,
pub username: &'a str,
pub password: &'a str,
pub device_name: &'a str,
pub ask_recovery_key: AskRecoveryKeyCallback,
pub before_create_backup: BeforeCreateBackupCallback,
pub print_recovery_key: PrintRecoveryKeyCallback,
}
macro_rules! delete_data_file {
($data_dir:expr, $($file:expr),* $(,)?) => {
_ = tokio::join!($(tokio::fs::remove_file($data_dir.join($file))),*);
};
}
#[instrument(skip_all)]
pub async fn setup<
AskRecoveryKeyCallback,
BeforeCreateBackupCallback,
PrintRecoveryKeyCallback,
PrintRecoveryKeyReturn,
>(
config: SetupConfig<
'_,
AskRecoveryKeyCallback,
BeforeCreateBackupCallback,
PrintRecoveryKeyCallback,
>,
) -> Result<Client>
where
AskRecoveryKeyCallback: Future<Output = Result<String>>,
BeforeCreateBackupCallback: Future<Output = Result<()>>,
PrintRecoveryKeyCallback: FnOnce(String, bool) -> PrintRecoveryKeyReturn,
PrintRecoveryKeyReturn: Future<Output = Result<()>>,
{
tokio::fs::create_dir_all(&config.data_dir).await?;
let session_db = SQLiteHelper::open(&config.data_dir.join("matrixbot-ezlogin.sqlite3"), true)?;
session_db.execute_batch(
"BEGIN TRANSACTION;
DROP TABLE IF EXISTS matrix_session;
DROP TABLE IF EXISTS sync_token;
CREATE TABLE matrix_session (id INTEGER PRIMARY KEY CHECK (id = 0), homeserver TEXT NOT NULL, passphrase TEXT NOT NULL, session BLOB NOT NULL);
CREATE TABLE sync_token (id INTEGER PRIMARY KEY CHECK (id = 0), token TEXT NOT NULL);
COMMIT;",
)?;
delete_data_file!(
&config.data_dir,
"matrix-sdk-crypto.sqlite3",
"matrix-sdk-crypto.sqlite3-journal",
"matrix-sdk-crypto.sqlite3-shm",
"matrix-sdk-crypto.sqlite3-wal",
"matrix-sdk-event-cache.sqlite3",
"matrix-sdk-event-cache.sqlite3-journal",
"matrix-sdk-event-cache.sqlite3-shm",
"matrix-sdk-event-cache.sqlite3-wal",
"matrix-sdk-state.sqlite3",
"matrix-sdk-state.sqlite3-journal",
"matrix-sdk-state.sqlite3-shm",
"matrix-sdk-state.sqlite3-wal",
);
info!("Logging into Matrix.");
let rng = rand::rng();
let db_passphrase = rng
.sample_iter(rand::distr::Alphanumeric)
.take(32)
.map(char::from)
.collect::<String>();
let client: Client = build_client(config.data_dir, config.homeserver, &db_passphrase).await?;
client
.matrix_auth()
.login_username(config.username, config.password)
.initial_device_display_name(config.device_name)
.await?;
match save_session(config, &session_db, db_passphrase, &client).await {
Ok(_) => {
info!("Setup finished.");
Ok(client)
}
Err(err) => {
info!("Logging out of Matrix.");
client.logout().await?;
Err(err)?
}
}
}
#[instrument(skip_all)]
pub async fn login(data_dir: &Path) -> Result<(Client, SyncHelper)> {
let session_db = SQLiteHelper::open(&data_dir.join("matrixbot-ezlogin.sqlite3"), false)?;
let client = restore_session(data_dir, &session_db).await?;
let sync_helper = SyncHelper::from_opened_db(session_db)?;
info!("Login finished.");
Ok((client, sync_helper))
}
#[instrument(skip_all)]
pub async fn logout(data_dir: &Path) -> Result<()> {
let session_db = SQLiteHelper::open(&data_dir.join("matrixbot-ezlogin.sqlite3"), false)?;
let client = restore_session(data_dir, &session_db).await?;
info!("Logging out.");
client.logout().await?;
drop(client);
info!("Deleting the data files");
delete_data_file!(
data_dir,
"matrix-sdk-crypto.sqlite3",
"matrix-sdk-crypto.sqlite3-journal",
"matrix-sdk-crypto.sqlite3-shm",
"matrix-sdk-crypto.sqlite3-wal",
"matrix-sdk-event-cache.sqlite3",
"matrix-sdk-event-cache.sqlite3-journal",
"matrix-sdk-event-cache.sqlite3-shm",
"matrix-sdk-event-cache.sqlite3-wal",
"matrix-sdk-state.sqlite3",
"matrix-sdk-state.sqlite3-journal",
"matrix-sdk-state.sqlite3-shm",
"matrix-sdk-state.sqlite3-wal",
"matrixbot-ezlogin.sqlite3",
"matrixbot-ezlogin.sqlite3-journal",
"matrixbot-ezlogin.sqlite3-shm",
"matrixbot-ezlogin.sqlite3-wal",
);
info!("Logout finished.");
Ok(())
}
async fn build_client(data_dir: &Path, homeserver: &str, passphrase: &str) -> Result<Client> {
let mut client_builder = Client::builder()
.server_name_or_homeserver_url(homeserver)
.sqlite_store(data_dir, Some(passphrase))
.with_enable_share_history_on_invite(true)
.with_encryption_settings(EncryptionSettings {
auto_enable_cross_signing: true,
backup_download_strategy: BackupDownloadStrategy::AfterDecryptionFailure,
auto_enable_backups: true,
});
if let Some((_, proxy)) =
std::env::vars_os().find(|(k, _)| k.eq_ignore_ascii_case("https_proxy"))
{
client_builder = client_builder.proxy(proxy.to_string_lossy());
}
Ok(client_builder.build().await?)
}
async fn save_session<
AskRecoveryKeyCallback,
BeforeCreateBackupCallback,
PrintRecoveryKeyCallback,
PrintRecoveryKeyReturn,
>(
config: SetupConfig<
'_,
AskRecoveryKeyCallback,
BeforeCreateBackupCallback,
PrintRecoveryKeyCallback,
>,
session_db: &rusqlite::Connection,
db_passphrase: String,
client: &Client,
) -> Result<()>
where
AskRecoveryKeyCallback: Future<Output = Result<String>>,
BeforeCreateBackupCallback: Future<Output = Result<()>>,
PrintRecoveryKeyCallback: FnOnce(String, bool) -> PrintRecoveryKeyReturn,
PrintRecoveryKeyReturn: Future<Output = Result<()>>,
{
info!("Saving the Matrix session.");
let session = client
.session()
.ok_or_eyre("Matrix SDK did not return a session")?;
let AuthSession::Matrix(matrix_session) = session else {
bail!("Matrix SDK returned an unsupported session type");
};
let session_json = serde_json::to_string(&matrix_session)?;
session_db.execute(
"INSERT INTO matrix_session (id, homeserver, passphrase, session) VALUES (0, ?, ?, jsonb(?));",
(client.homeserver().as_str(), db_passphrase, &session_json),
)?;
info!("Setting up encryption.");
let encryption = client.encryption();
let has_backup = encryption.backups().fetch_exists_on_server().await?;
let recovery = encryption.recovery();
encryption.wait_for_e2ee_initialization_tasks().await;
let recovery_key = if has_backup {
info!("A backup exists on the server, recovering from it.");
let recovery_key = config.ask_recovery_key.await?;
recovery.recover(&recovery_key).await?;
encryption.wait_for_e2ee_initialization_tasks().await;
info!("Recovered from the server backup.");
recovery_key
} else {
info!("No backup exists on the server, creating a new one.");
config.before_create_backup.await?;
info!("Resetting cryptography identity.");
if let Some(reset_handle) = recovery.reset_identity().await? {
match reset_handle.auth_type() {
CrossSigningResetAuthType::Uiaa(uiaa) => {
info!("Resetting cryptography identity. (Stage 2: UIAA)");
let mut auth_data = uiaa::Password::new(
client
.user_id()
.ok_or_eyre("failed to get user ID")?
.to_owned()
.into(),
config.password.to_owned(),
);
auth_data.session = uiaa.session.clone();
reset_handle
.reset(Some(uiaa::AuthData::Password(auth_data)))
.await?;
}
CrossSigningResetAuthType::OAuth(oauth) => {
eprintln!(
"To reset your end-to-end encryption cross-signing identity, you first need to approve it at: {}",
oauth.approval_url
);
reset_handle.reset(None).await?;
}
}
}
encryption.wait_for_e2ee_initialization_tasks().await;
info!("Creating a server backup.");
let recovery_key = recovery.enable().wait_for_backups_to_upload().await?;
info!("Finished initial backup.");
recovery_key
};
info!("Saving the recovery key.");
(config.print_recovery_key)(recovery_key, !has_backup).await?;
Ok(())
}
async fn restore_session(data_dir: &Path, session_db: &rusqlite::Connection) -> Result<Client> {
let (homeserver, passphrase, session): (String, String, String) = session_db
.query_row(
"SELECT homeserver, passphrase, json(session) FROM matrix_session WHERE id = 0;",
(),
|row| row.try_into(),
)
.optional()?
.ok_or_eyre("no session found, run setup first")?;
let matrix_session = serde_json::from_str::<MatrixSession>(&session)?;
info!("Logging into Matrix.");
let client = build_client(data_dir, &homeserver, &passphrase).await?;
client
.restore_session(AuthSession::Matrix(matrix_session))
.await?;
Ok(client)
}