use crate::{PreKeyBundle, Signature, X3DHError, X3dhVault, CSUITE};
use alloc::vec;
use arrayref::array_ref;
use ockam_core::vault::{
KeyId, PublicKey, SecretAttributes, SecretPersistence, SecretType, AES256_SECRET_LENGTH_U32,
CURVE25519_SECRET_LENGTH_U32,
};
use ockam_core::Result;
use ockam_core::{async_trait, compat::boxed::Box};
use ockam_core::{
compat::{
string::{String, ToString},
vec::Vec,
},
vault::{Secret, SecretKey},
};
use ockam_core::{CompletedKeyExchange, KeyExchanger};
#[derive(Debug)]
enum ResponderState {
HandleInitiatorKeys,
SendBundle,
Done,
}
pub struct Responder<V: X3dhVault> {
identity_key: Option<KeyId>,
signed_prekey: Option<KeyId>,
one_time_prekey: Option<KeyId>,
state: ResponderState,
vault: V,
completed_key_exchange: Option<CompletedKeyExchange>,
}
impl<V: X3dhVault> Responder<V> {
pub(crate) fn new(vault: V, identity_key: Option<KeyId>) -> Self {
Self {
identity_key,
signed_prekey: None,
one_time_prekey: None,
completed_key_exchange: None,
state: ResponderState::HandleInitiatorKeys,
vault,
}
}
async fn prologue(&mut self) -> Result<()> {
let p_atts = SecretAttributes::new(
SecretType::X25519,
SecretPersistence::Persistent,
CURVE25519_SECRET_LENGTH_U32,
);
let e_atts = SecretAttributes::new(
SecretType::X25519,
SecretPersistence::Ephemeral,
CURVE25519_SECRET_LENGTH_U32,
);
if self.identity_key.is_none() {
self.identity_key = Some(self.vault.secret_generate(p_atts).await?);
}
self.signed_prekey = Some(self.vault.secret_generate(p_atts).await?);
self.one_time_prekey = Some(self.vault.secret_generate(e_atts).await?);
Ok(())
}
}
impl<V: X3dhVault> core::fmt::Debug for Responder<V> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(
f,
r#"X3dhResponder {{ identity_key: {:?},
signed_prekey: {:?},
one_time_prekey: {:?},
state: {:?},
vault,
completed_key_exchange: {:?} }}"#,
self.identity_key,
self.signed_prekey,
self.one_time_prekey,
self.state,
self.completed_key_exchange
)
}
}
#[async_trait]
impl<V: X3dhVault> KeyExchanger for Responder<V> {
async fn name(&self) -> Result<String> {
Ok("X3DH".to_string())
}
async fn generate_request(&mut self, _payload: &[u8]) -> Result<Vec<u8>> {
match self.state {
ResponderState::SendBundle => {
let identity_secret_key =
self.identity_key.as_ref().ok_or(X3DHError::InvalidState)?;
let signed_prekey = self.signed_prekey.as_ref().ok_or(X3DHError::InvalidState)?;
let one_time_prekey = self
.one_time_prekey
.as_ref()
.ok_or(X3DHError::InvalidState)?;
let signed_prekey_pub = self.vault.secret_public_key_get(signed_prekey).await?;
let signature = self
.vault
.sign(identity_secret_key, signed_prekey_pub.data())
.await?;
let identity_key = self
.vault
.secret_public_key_get(identity_secret_key)
.await?;
let one_time_prekey_pub = self.vault.secret_public_key_get(one_time_prekey).await?;
if signature.as_ref().len() != 64 {
return Err(X3DHError::SignatureLenMismatch.into());
}
let signature_array = array_ref![signature.as_ref(), 0, 64]; let bundle = PreKeyBundle {
identity_key,
signed_prekey: signed_prekey_pub,
signature_prekey: Signature(*signature_array),
one_time_prekey: one_time_prekey_pub,
};
self.state = ResponderState::Done;
Ok(bundle.to_bytes())
}
ResponderState::HandleInitiatorKeys | ResponderState::Done => {
Err(X3DHError::InvalidState.into())
}
}
}
async fn handle_response(&mut self, response: &[u8]) -> Result<Vec<u8>> {
match self.state {
ResponderState::HandleInitiatorKeys => {
if response.len() != 64 {
return Err(X3DHError::MessageLenMismatch.into());
}
self.prologue().await?;
let other_identity_pubkey =
PublicKey::new(array_ref![response, 0, 32].to_vec(), SecretType::X25519);
let other_ephemeral_pubkey =
PublicKey::new(array_ref![response, 32, 32].to_vec(), SecretType::X25519);
let signed_prekey = self.signed_prekey.as_ref().ok_or(X3DHError::InvalidState)?;
let one_time_prekey = self
.one_time_prekey
.as_ref()
.ok_or(X3DHError::InvalidState)?;
let local_static_secret =
self.identity_key.as_ref().ok_or(X3DHError::InvalidState)?;
let dh1 = self
.vault
.ec_diffie_hellman(signed_prekey, &other_identity_pubkey)
.await?;
let dh2 = self
.vault
.ec_diffie_hellman(local_static_secret, &other_ephemeral_pubkey)
.await?;
let dh3 = self
.vault
.ec_diffie_hellman(signed_prekey, &other_ephemeral_pubkey)
.await?;
let dh4 = self
.vault
.ec_diffie_hellman(one_time_prekey, &other_ephemeral_pubkey)
.await?;
let mut ikm_bytes = vec![0xFFu8; 32]; ikm_bytes.extend_from_slice(
self.vault.secret_export(&dh1).await?.try_as_key()?.as_ref(),
);
ikm_bytes.extend_from_slice(
self.vault.secret_export(&dh2).await?.try_as_key()?.as_ref(),
);
ikm_bytes.extend_from_slice(
self.vault.secret_export(&dh3).await?.try_as_key()?.as_ref(),
);
ikm_bytes.extend_from_slice(
self.vault.secret_export(&dh4).await?.try_as_key()?.as_ref(),
);
let ikm = self
.vault
.secret_import(
Secret::Key(SecretKey::new(ikm_bytes.clone())),
SecretAttributes::new(
SecretType::Buffer,
SecretPersistence::Ephemeral,
ikm_bytes.len() as u32,
),
)
.await?;
let salt = self
.vault
.secret_import(
Secret::Key(SecretKey::new(vec![0u8; 32])),
SecretAttributes::new(
SecretType::Buffer,
SecretPersistence::Ephemeral,
32u32,
),
)
.await?;
let atts = SecretAttributes::new(
SecretType::Aes,
SecretPersistence::Persistent,
AES256_SECRET_LENGTH_U32,
);
let mut keyrefs = self
.vault
.hkdf_sha256(&salt, CSUITE, Some(&ikm), vec![atts, atts])
.await?;
let decrypt_key = keyrefs.pop().ok_or(X3DHError::InvalidState)?;
let encrypt_key = keyrefs.pop().ok_or(X3DHError::InvalidState)?;
let mut state_hash = self.vault.sha256(CSUITE).await?.to_vec();
state_hash.append(&mut ikm_bytes);
let state_hash = self.vault.sha256(state_hash.as_slice()).await?;
self.completed_key_exchange = Some(CompletedKeyExchange::new(
state_hash,
encrypt_key,
decrypt_key,
));
self.state = ResponderState::SendBundle;
Ok(vec![])
}
ResponderState::SendBundle | ResponderState::Done => {
Err(X3DHError::InvalidState.into())
}
}
}
async fn is_complete(&self) -> Result<bool> {
Ok(matches!(self.state, ResponderState::Done))
}
async fn finalize(self) -> Result<CompletedKeyExchange> {
self.completed_key_exchange
.ok_or_else(|| X3DHError::InvalidState.into())
}
}