use crate::config::Config;
use crate::core::auth::verify_signature;
use crate::core::notification::{Notifier, UserEvent};
use crate::core::user::{OneTimePreKey, PreKeyBundle, SignedPreKey};
use crate::error::{AppError, Result};
use crate::proto::obscura::v1::PreKeyStatus;
use crate::storage::key_repo::KeyRepository;
use crate::storage::message_repo::MessageRepository;
use sqlx::{PgConnection, PgPool};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Clone)]
pub struct KeyService {
pool: PgPool,
key_repo: KeyRepository,
message_repo: MessageRepository,
notifier: Arc<dyn Notifier>,
config: Config,
}
pub struct KeyUploadParams {
pub user_id: Uuid,
pub identity_key: Option<Vec<u8>>,
pub registration_id: Option<i32>,
pub signed_pre_key: SignedPreKey,
pub one_time_pre_keys: Vec<OneTimePreKey>,
}
impl KeyService {
pub fn new(
pool: PgPool,
key_repo: KeyRepository,
message_repo: MessageRepository,
notifier: Arc<dyn Notifier>,
config: Config,
) -> Self {
Self { pool, key_repo, message_repo, notifier, config }
}
pub async fn get_pre_key_bundle(&self, user_id: Uuid) -> Result<Option<PreKeyBundle>> {
self.key_repo.fetch_pre_key_bundle(user_id).await
}
pub async fn fetch_identity_key(&self, user_id: Uuid) -> Result<Option<Vec<u8>>> {
self.key_repo.fetch_identity_key(user_id).await
}
pub async fn check_pre_key_status(&self, user_id: Uuid) -> Result<Option<PreKeyStatus>> {
let count = self.key_repo.count_one_time_pre_keys(&self.pool, user_id).await?;
if count < self.config.messaging.pre_key_refill_threshold as i64 {
Ok(Some(PreKeyStatus {
one_time_pre_key_count: count as i32,
min_threshold: self.config.messaging.pre_key_refill_threshold,
}))
} else {
Ok(None)
}
}
pub async fn upload_keys(&self, params: KeyUploadParams) -> Result<()> {
let mut tx = self.pool.begin().await?;
let user_id = params.user_id;
let is_takeover = self.upload_keys_internal(&mut tx, params).await?;
tx.commit().await?;
if is_takeover {
self.notifier.notify(user_id, UserEvent::Disconnect);
}
Ok(())
}
pub async fn upload_keys_internal(&self, conn: &mut PgConnection, params: KeyUploadParams) -> Result<bool> {
let mut is_takeover = false;
let ik_bytes = if let Some(new_ik_bytes) = ¶ms.identity_key {
let existing_ik_opt = self.key_repo.fetch_identity_key_for_update(&mut *conn, params.user_id).await?;
if let Some(existing_ik) = existing_ik_opt {
if existing_ik != *new_ik_bytes {
is_takeover = true;
}
} else {
is_takeover = true;
}
new_ik_bytes.clone()
} else {
self.key_repo
.fetch_identity_key_for_update(&mut *conn, params.user_id)
.await?
.ok_or_else(|| AppError::BadRequest("Identity key missing".into()))?
};
verify_keys(&ik_bytes, ¶ms.signed_pre_key)?;
let current_count =
if is_takeover { 0 } else { self.key_repo.count_one_time_pre_keys(&mut *conn, params.user_id).await? };
let new_keys_count = params.one_time_pre_keys.len() as i64;
if current_count + new_keys_count > self.config.messaging.max_pre_keys {
return Err(AppError::BadRequest(format!(
"Too many pre-keys. Limit is {}",
self.config.messaging.max_pre_keys
)));
}
if is_takeover {
let reg_id = params
.registration_id
.ok_or_else(|| AppError::BadRequest("registrationId required for takeover".into()))?;
self.key_repo.delete_all_signed_pre_keys(&mut *conn, params.user_id).await?;
self.key_repo.delete_all_one_time_pre_keys(&mut *conn, params.user_id).await?;
self.message_repo.delete_all_for_user(&mut *conn, params.user_id).await?;
if let Some(ik) = ¶ms.identity_key {
self.key_repo.upsert_identity_key(&mut *conn, params.user_id, ik, reg_id).await?;
}
}
self.key_repo
.upsert_signed_pre_key(
&mut *conn,
params.user_id,
params.signed_pre_key.key_id,
¶ms.signed_pre_key.public_key,
¶ms.signed_pre_key.signature,
)
.await?;
let otpk_vec: Vec<(i32, Vec<u8>)> =
params.one_time_pre_keys.into_iter().map(|k| (k.key_id, k.public_key)).collect();
self.key_repo.insert_one_time_pre_keys(&mut *conn, params.user_id, &otpk_vec).await?;
Ok(is_takeover)
}
}
fn verify_keys(ik_bytes: &[u8], signed_pre_key: &SignedPreKey) -> Result<()> {
let ik_raw = if ik_bytes.len() == 33 { &ik_bytes[1..] } else { ik_bytes };
if verify_signature(ik_raw, &signed_pre_key.public_key, &signed_pre_key.signature).is_ok() {
return Ok(());
}
if signed_pre_key.public_key.len() == 33 {
let spk_pub_raw = &signed_pre_key.public_key[1..];
if verify_signature(ik_raw, spk_pub_raw, &signed_pre_key.signature).is_ok() {
return Ok(());
}
}
Err(AppError::BadRequest("Invalid signature".into()))
}
#[cfg(test)]
mod tests {
use super::*;
use ed25519_dalek::{Signer, SigningKey};
use rand::RngCore;
use rand::rngs::OsRng;
fn generate_keys() -> (SigningKey, Vec<u8>, SigningKey, Vec<u8>, Vec<u8>) {
let mut ik_bytes = [0u8; 32];
OsRng.fill_bytes(&mut ik_bytes);
let ik = SigningKey::from_bytes(&ik_bytes);
let ik_pub = ik.verifying_key().to_bytes().to_vec();
let mut spk_bytes = [0u8; 32];
OsRng.fill_bytes(&mut spk_bytes);
let spk = SigningKey::from_bytes(&spk_bytes);
let spk_pub = spk.verifying_key().to_bytes().to_vec();
let signature = ik.sign(&spk_pub).to_bytes().to_vec();
(ik, ik_pub, spk, spk_pub, signature)
}
#[test]
fn test_verify_keys_standard() {
let (_, ik_pub, _, spk_pub, signature) = generate_keys();
let spk = SignedPreKey { key_id: 1, public_key: spk_pub, signature };
assert!(verify_keys(&ik_pub, &spk).is_ok());
}
#[test]
fn test_verify_keys_strict_33() {
let mut ik_bytes = [0u8; 32];
OsRng.fill_bytes(&mut ik_bytes);
let ik = SigningKey::from_bytes(&ik_bytes);
let mut ik_pub_33 = ik.verifying_key().to_bytes().to_vec();
ik_pub_33.insert(0, 0x05);
let mut spk_bytes = [0u8; 32];
OsRng.fill_bytes(&mut spk_bytes);
let spk = SigningKey::from_bytes(&spk_bytes);
let mut spk_pub_33 = spk.verifying_key().to_bytes().to_vec();
spk_pub_33.insert(0, 0x05);
let signature = ik.sign(&spk_pub_33).to_bytes().to_vec();
let spk = SignedPreKey { key_id: 1, public_key: spk_pub_33, signature };
assert!(verify_keys(&ik_pub_33, &spk).is_ok());
}
#[test]
fn test_verify_keys_libsignal_behavior() {
let mut ik_bytes = [0u8; 32];
OsRng.fill_bytes(&mut ik_bytes);
let ik = SigningKey::from_bytes(&ik_bytes);
let mut ik_pub_33 = ik.verifying_key().to_bytes().to_vec();
ik_pub_33.insert(0, 0x05);
let mut spk_bytes = [0u8; 32];
OsRng.fill_bytes(&mut spk_bytes);
let spk = SigningKey::from_bytes(&spk_bytes);
let spk_pub_32 = spk.verifying_key().to_bytes().to_vec();
let mut spk_pub_33 = spk_pub_32.clone();
spk_pub_33.insert(0, 0x05);
let signature = ik.sign(&spk_pub_32).to_bytes().to_vec();
let spk = SignedPreKey { key_id: 1, public_key: spk_pub_33, signature };
assert!(verify_keys(&ik_pub_33, &spk).is_ok());
}
#[test]
fn test_verify_keys_mixed() {
let (_, ik_pub, _, spk_pub, signature) = generate_keys();
let mut ik_pub_33 = ik_pub.clone();
ik_pub_33.insert(0, 0x05);
let spk = SignedPreKey { key_id: 1, public_key: spk_pub, signature };
assert!(verify_keys(&ik_pub_33, &spk).is_ok());
}
}