use std::collections::HashMap;
use curve25519_dalek::edwards::CompressedEdwardsY;
use newton_aggregator::PartialDecryptionData;
use newton_core::{
crypto::{compute_aad, SecureEnvelope},
dkg::{self, KeyShare, PartialDecryption},
};
use crate::{
decrypt::DecryptedEnvelope,
error::EnclaveError,
protocol::{
EnclaveEnvelope, EnclavePartialDhRequest, EnclavePartialDhResponse, EncryptedPartialDH, ThresholdEvalInput,
},
};
pub(crate) fn partial_dh(
key_share: Option<&KeyShare>,
req: EnclavePartialDhRequest,
) -> Result<EnclavePartialDhResponse, EnclaveError> {
let key_share = key_share.ok_or_else(|| EnclaveError::MissingInput("threshold key share".to_string()))?;
if req.peer_enclave_pubkeys.is_empty() {
return Err(EnclaveError::MissingInput(
"peer_enclave_pubkeys is required for threshold partial DH".to_string(),
));
}
let partials = req
.enc_points
.iter()
.map(|enc| {
let enc: [u8; 32] = enc
.as_slice()
.try_into()
.map_err(|_| EnclaveError::InvalidRequest("enc point must be 32 bytes".to_string()))?;
let enc = dkg::montgomery_to_edwards(&enc).map_err(|e| EnclaveError::ThresholdFailed(e.to_string()))?;
Ok(PartialDecryptionData::from(dkg::compute_partial_decryption(
key_share.index,
&key_share.secret_share,
&enc,
)))
})
.collect::<Result<Vec<_>, EnclaveError>>()?;
let encrypted = encrypt_partials_per_peer(&partials, &req.peer_enclave_pubkeys, &req.task_id, key_share.index)?;
Ok(EnclavePartialDhResponse {
encrypted_partials: encrypted,
})
}
fn encrypt_partials_per_peer(
partials: &[PartialDecryptionData],
peer_pubkeys: &[(crate::protocol::EnclaveOperatorId, [u8; 32])],
task_id: &newton_core::TaskId,
sender_index: u32,
) -> Result<Vec<EncryptedPartialDH>, EnclaveError> {
use newton_core::crypto::hpke::{encrypt, HpkePublicKey};
let plaintext =
serde_json::to_vec(partials).map_err(|e| EnclaveError::ThresholdFailed(format!("serialize partials: {e}")))?;
peer_pubkeys
.iter()
.map(|(recipient_id, pubkey_bytes)| {
let pk = HpkePublicKey::from_bytes(pubkey_bytes)
.map_err(|e| EnclaveError::ThresholdFailed(format!("invalid peer pubkey: {e}")))?;
let aad = partial_dh_aad(sender_index, task_id, recipient_id);
let (encapped_key, ciphertext) = encrypt(&pk, &plaintext, &aad)
.map_err(|e| EnclaveError::ThresholdFailed(format!("HPKE encrypt to peer failed: {e}")))?;
Ok(EncryptedPartialDH {
recipient: *recipient_id,
sender_index,
encapped_key,
ciphertext,
})
})
.collect()
}
fn partial_dh_aad(
sender_index: u32,
task_id: &newton_core::TaskId,
recipient_id: &crate::protocol::EnclaveOperatorId,
) -> Vec<u8> {
let mut aad = Vec::with_capacity(4 + 32 + 32);
aad.extend_from_slice(&sender_index.to_le_bytes());
aad.extend_from_slice(task_id.as_slice());
aad.extend_from_slice(recipient_id.as_slice());
aad
}
pub(crate) fn decrypt_all_peer_partials(
encrypted: &[EncryptedPartialDH],
hpke_sk: &newton_core::crypto::HpkePrivateKey,
own_operator_id: &crate::protocol::EnclaveOperatorId,
task_id: &newton_core::TaskId,
) -> Result<Vec<Vec<PartialDecryptionData>>, EnclaveError> {
encrypted
.iter()
.filter(|e| e.recipient == *own_operator_id)
.map(|blob| {
let aad = partial_dh_aad(blob.sender_index, task_id, own_operator_id);
let plaintext = newton_core::crypto::hpke::decrypt(hpke_sk, &blob.encapped_key, &blob.ciphertext, &aad)
.map_err(|e| {
EnclaveError::DecryptFailed(format!(
"peer partial decrypt failed (likely enclave key rotation between phases): {e}"
))
})?;
serde_json::from_slice::<Vec<PartialDecryptionData>>(&plaintext)
.map_err(|e| EnclaveError::DecryptFailed(format!("peer partial deserialize failed: {e}")))
})
.collect()
}
pub(crate) fn decrypt_envelopes(
envelopes: &[SecureEnvelope],
partials: Option<&[Vec<PartialDecryptionData>]>,
threshold: &ThresholdEvalInput,
) -> Result<Option<serde_json::Value>, EnclaveError> {
if envelopes.is_empty() {
return Ok(None);
}
let values = decrypt_values(envelopes, partials, threshold)?;
Ok(Some(serde_json::json!({ "inline": values })))
}
pub(crate) fn decrypt_domain_envelopes(
envelopes: &[EnclaveEnvelope],
partials: Option<&[Vec<PartialDecryptionData>]>,
threshold: &ThresholdEvalInput,
) -> Result<Vec<DecryptedEnvelope>, EnclaveError> {
let secure_envelopes = envelopes.iter().map(|e| e.envelope.clone()).collect::<Vec<_>>();
let values = decrypt_values(&secure_envelopes, partials, threshold)?;
Ok(envelopes
.iter()
.zip(values)
.map(|(item, value)| DecryptedEnvelope {
domain: item.domain,
value,
})
.collect())
}
fn decrypt_values(
envelopes: &[SecureEnvelope],
partials: Option<&[Vec<PartialDecryptionData>]>,
threshold: &ThresholdEvalInput,
) -> Result<Vec<serde_json::Value>, EnclaveError> {
let partials = partials.ok_or_else(|| EnclaveError::MissingInput("threshold partials".to_string()))?;
if partials.len() != envelopes.len() {
return Err(EnclaveError::ThresholdFailed(format!(
"partials length {} does not match envelopes length {}",
partials.len(),
envelopes.len()
)));
}
let public_shares = public_shares(&threshold.public_shares)?;
envelopes
.iter()
.enumerate()
.map(|(i, envelope)| {
let partials = partials_for_envelope(&partials[i])?;
if (partials.len() as u32) < threshold.config.threshold {
return Err(EnclaveError::ThresholdFailed(format!(
"insufficient partials: got {}, need {}",
partials.len(),
threshold.config.threshold
)));
}
let enc: [u8; 32] = hex::decode(&envelope.enc)
.map_err(|e| EnclaveError::DecryptFailed(format!("invalid enc hex: {e}")))?
.as_slice()
.try_into()
.map_err(|_| EnclaveError::DecryptFailed("enc must be 32 bytes".to_string()))?;
let ciphertext = hex::decode(&envelope.ciphertext)
.map_err(|e| EnclaveError::DecryptFailed(format!("invalid ciphertext hex: {e}")))?;
let aad = compute_aad(&envelope.policy_client, envelope.chain_id)
.map_err(|e| EnclaveError::DecryptFailed(e.to_string()))?;
let pk_r = threshold_public_key(&threshold.public_shares, threshold.config)?;
let plaintext = dkg::threshold_decrypt(
&partials,
&enc,
&pk_r,
&ciphertext,
&aad,
&public_shares,
threshold.config.threshold,
)
.map_err(|e| EnclaveError::ThresholdFailed(format!("envelope {i}: {e}")))?;
serde_json::from_slice(&plaintext)
.map_err(|e| EnclaveError::DecryptFailed(format!("threshold plaintext is not json: {e}")))
})
.collect()
}
fn partials_for_envelope(wire: &[PartialDecryptionData]) -> Result<Vec<PartialDecryption>, EnclaveError> {
wire.iter()
.map(|pd| PartialDecryption::try_from(pd).map_err(EnclaveError::ThresholdFailed))
.collect()
}
fn public_shares(
shares: &[(u32, Vec<u8>)],
) -> Result<HashMap<u32, curve25519_dalek::edwards::EdwardsPoint>, EnclaveError> {
shares
.iter()
.map(|(index, bytes)| {
let bytes: [u8; 32] = bytes
.as_slice()
.try_into()
.map_err(|_| EnclaveError::ThresholdFailed("public share must be 32 bytes".to_string()))?;
let point = CompressedEdwardsY(bytes)
.decompress()
.ok_or_else(|| EnclaveError::ThresholdFailed("invalid public share".to_string()))?;
Ok((*index, point))
})
.collect()
}
fn threshold_public_key(
shares: &[(u32, Vec<u8>)],
config: crate::protocol::EnclaveThresholdConfig,
) -> Result<[u8; 32], EnclaveError> {
let shares = shares
.iter()
.map(|(index, bytes)| {
let bytes: [u8; 32] = bytes
.as_slice()
.try_into()
.map_err(|_| EnclaveError::ThresholdFailed("public share must be 32 bytes".to_string()))?;
Ok((*index, bytes))
})
.collect::<Result<Vec<_>, EnclaveError>>()?;
let ctx = dkg::threshold_context_from_public_shares(
&shares,
dkg::ThresholdConfig {
threshold: config.threshold,
total: config.total,
},
)
.map_err(|e| EnclaveError::ThresholdFailed(e.to_string()))?;
Ok(ctx.public_key.hpke_public_key)
}