newton-enclave 0.4.15

newton prover enclave compute
//! threshold HPKE decryption inside the enclave.
//!
//! this module implements the enclave side of newton's threshold HPKE decryption
//! (see `docs/THRESHOLD_DKG.md` for the full protocol). two operations:
//!
//! 1. `partial_dh` — computes partial DH outputs during the Prepare phase using the
//!    enclave-resident key share. called once per task; outputs are aggregated off-chain
//!    and returned in the Commit request.
//!
//! 2. `decrypt_envelopes` / `decrypt_domain_envelopes` — reconstructs the threshold-encrypted
//!    plaintext during Commit using collected partials and the public shares. the enclave
//!    never exposes the key share outside this module.
//!
//! this is the most cryptographically sensitive code in the enclave. any change here
//! should be reviewed against the DKG spec and accompanied by test coverage.

use std::collections::HashMap;

use curve25519_dalek::edwards::CompressedEdwardsY;
use newton_aggregator::PartialDecryptionData;
use newton_core::{
    crypto::{compute_aad, SecureEnvelope},
    dkg::{self, KeyShare, PartialDecryption},
};

use crate::{
    decrypt::DecryptedEnvelope,
    error::EnclaveError,
    protocol::{
        EnclaveEnvelope, EnclavePartialDhRequest, EnclavePartialDhResponse, EncryptedPartialDH, ThresholdEvalInput,
    },
};

pub(crate) fn partial_dh(
    key_share: Option<&KeyShare>,
    req: EnclavePartialDhRequest,
) -> Result<EnclavePartialDhResponse, EnclaveError> {
    let key_share = key_share.ok_or_else(|| EnclaveError::MissingInput("threshold key share".to_string()))?;

    if req.peer_enclave_pubkeys.is_empty() {
        return Err(EnclaveError::MissingInput(
            "peer_enclave_pubkeys is required for threshold partial DH".to_string(),
        ));
    }

    let partials = req
        .enc_points
        .iter()
        .map(|enc| {
            let enc: [u8; 32] = enc
                .as_slice()
                .try_into()
                .map_err(|_| EnclaveError::InvalidRequest("enc point must be 32 bytes".to_string()))?;
            let enc = dkg::montgomery_to_edwards(&enc).map_err(|e| EnclaveError::ThresholdFailed(e.to_string()))?;
            Ok(PartialDecryptionData::from(dkg::compute_partial_decryption(
                key_share.index,
                &key_share.secret_share,
                &enc,
            )))
        })
        .collect::<Result<Vec<_>, EnclaveError>>()?;

    let encrypted = encrypt_partials_per_peer(&partials, &req.peer_enclave_pubkeys, &req.task_id, key_share.index)?;
    Ok(EnclavePartialDhResponse {
        encrypted_partials: encrypted,
    })
}

/// Encrypt the full partials vector to each peer enclave's ephemeral X25519 pubkey.
/// Uses the same HPKE construction (X25519+HKDF-SHA256+ChaCha20Poly1305) as SecureEnvelope.
///
/// AAD is `sender_index(4 bytes) || task_id(32 bytes) || recipient_id(32 bytes)` to prevent:
/// - Cross-task replay (task_id binding)
/// - Blob swap between recipients (recipient_id binding)
/// - Sender impersonation (sender_index binding)
fn encrypt_partials_per_peer(
    partials: &[PartialDecryptionData],
    peer_pubkeys: &[(crate::protocol::EnclaveOperatorId, [u8; 32])],
    task_id: &newton_core::TaskId,
    sender_index: u32,
) -> Result<Vec<EncryptedPartialDH>, EnclaveError> {
    use newton_core::crypto::hpke::{encrypt, HpkePublicKey};

    let plaintext =
        serde_json::to_vec(partials).map_err(|e| EnclaveError::ThresholdFailed(format!("serialize partials: {e}")))?;

    peer_pubkeys
        .iter()
        .map(|(recipient_id, pubkey_bytes)| {
            let pk = HpkePublicKey::from_bytes(pubkey_bytes)
                .map_err(|e| EnclaveError::ThresholdFailed(format!("invalid peer pubkey: {e}")))?;
            let aad = partial_dh_aad(sender_index, task_id, recipient_id);
            let (encapped_key, ciphertext) = encrypt(&pk, &plaintext, &aad)
                .map_err(|e| EnclaveError::ThresholdFailed(format!("HPKE encrypt to peer failed: {e}")))?;
            Ok(EncryptedPartialDH {
                recipient: *recipient_id,
                sender_index,
                encapped_key,
                ciphertext,
            })
        })
        .collect()
}

/// Compose structured AAD for per-peer partial DH encryption/decryption.
/// Format: `sender_index(4 bytes LE) || task_id(32 bytes) || recipient_id(32 bytes)`
fn partial_dh_aad(
    sender_index: u32,
    task_id: &newton_core::TaskId,
    recipient_id: &crate::protocol::EnclaveOperatorId,
) -> Vec<u8> {
    let mut aad = Vec::with_capacity(4 + 32 + 32);
    aad.extend_from_slice(&sender_index.to_le_bytes());
    aad.extend_from_slice(task_id.as_slice());
    aad.extend_from_slice(recipient_id.as_slice());
    aad
}

/// Decrypt all peers' encrypted partial DH blobs using this enclave's ephemeral private key.
/// Returns one `Vec<PartialDecryptionData>` per peer (each peer's partials cover all envelopes).
///
/// AAD is reconstructed as `sender_index || task_id || recipient_id` to match encryption.
pub(crate) fn decrypt_all_peer_partials(
    encrypted: &[EncryptedPartialDH],
    hpke_sk: &newton_core::crypto::HpkePrivateKey,
    own_operator_id: &crate::protocol::EnclaveOperatorId,
    task_id: &newton_core::TaskId,
) -> Result<Vec<Vec<PartialDecryptionData>>, EnclaveError> {
    encrypted
        .iter()
        .filter(|e| e.recipient == *own_operator_id)
        .map(|blob| {
            let aad = partial_dh_aad(blob.sender_index, task_id, own_operator_id);
            let plaintext = newton_core::crypto::hpke::decrypt(hpke_sk, &blob.encapped_key, &blob.ciphertext, &aad)
                .map_err(|e| {
                    EnclaveError::DecryptFailed(format!(
                        "peer partial decrypt failed (likely enclave key rotation between phases): {e}"
                    ))
                })?;

            serde_json::from_slice::<Vec<PartialDecryptionData>>(&plaintext)
                .map_err(|e| EnclaveError::DecryptFailed(format!("peer partial deserialize failed: {e}")))
        })
        .collect()
}

pub(crate) fn decrypt_envelopes(
    envelopes: &[SecureEnvelope],
    partials: Option<&[Vec<PartialDecryptionData>]>,
    threshold: &ThresholdEvalInput,
) -> Result<Option<serde_json::Value>, EnclaveError> {
    if envelopes.is_empty() {
        return Ok(None);
    }

    let values = decrypt_values(envelopes, partials, threshold)?;
    Ok(Some(serde_json::json!({ "inline": values })))
}

pub(crate) fn decrypt_domain_envelopes(
    envelopes: &[EnclaveEnvelope],
    partials: Option<&[Vec<PartialDecryptionData>]>,
    threshold: &ThresholdEvalInput,
) -> Result<Vec<DecryptedEnvelope>, EnclaveError> {
    let secure_envelopes = envelopes.iter().map(|e| e.envelope.clone()).collect::<Vec<_>>();
    let values = decrypt_values(&secure_envelopes, partials, threshold)?;
    Ok(envelopes
        .iter()
        .zip(values)
        .map(|(item, value)| DecryptedEnvelope {
            domain: item.domain,
            value,
        })
        .collect())
}

fn decrypt_values(
    envelopes: &[SecureEnvelope],
    partials: Option<&[Vec<PartialDecryptionData>]>,
    threshold: &ThresholdEvalInput,
) -> Result<Vec<serde_json::Value>, EnclaveError> {
    let partials = partials.ok_or_else(|| EnclaveError::MissingInput("threshold partials".to_string()))?;
    if partials.len() != envelopes.len() {
        return Err(EnclaveError::ThresholdFailed(format!(
            "partials length {} does not match envelopes length {}",
            partials.len(),
            envelopes.len()
        )));
    }

    let public_shares = public_shares(&threshold.public_shares)?;
    envelopes
        .iter()
        .enumerate()
        .map(|(i, envelope)| {
            let partials = partials_for_envelope(&partials[i])?;
            if (partials.len() as u32) < threshold.config.threshold {
                return Err(EnclaveError::ThresholdFailed(format!(
                    "insufficient partials: got {}, need {}",
                    partials.len(),
                    threshold.config.threshold
                )));
            }

            let enc: [u8; 32] = hex::decode(&envelope.enc)
                .map_err(|e| EnclaveError::DecryptFailed(format!("invalid enc hex: {e}")))?
                .as_slice()
                .try_into()
                .map_err(|_| EnclaveError::DecryptFailed("enc must be 32 bytes".to_string()))?;
            let ciphertext = hex::decode(&envelope.ciphertext)
                .map_err(|e| EnclaveError::DecryptFailed(format!("invalid ciphertext hex: {e}")))?;
            let aad = compute_aad(&envelope.policy_client, envelope.chain_id)
                .map_err(|e| EnclaveError::DecryptFailed(e.to_string()))?;
            let pk_r = threshold_public_key(&threshold.public_shares, threshold.config)?;

            let plaintext = dkg::threshold_decrypt(
                &partials,
                &enc,
                &pk_r,
                &ciphertext,
                &aad,
                &public_shares,
                threshold.config.threshold,
            )
            .map_err(|e| EnclaveError::ThresholdFailed(format!("envelope {i}: {e}")))?;
            serde_json::from_slice(&plaintext)
                .map_err(|e| EnclaveError::DecryptFailed(format!("threshold plaintext is not json: {e}")))
        })
        .collect()
}

fn partials_for_envelope(wire: &[PartialDecryptionData]) -> Result<Vec<PartialDecryption>, EnclaveError> {
    wire.iter()
        .map(|pd| PartialDecryption::try_from(pd).map_err(EnclaveError::ThresholdFailed))
        .collect()
}

fn public_shares(
    shares: &[(u32, Vec<u8>)],
) -> Result<HashMap<u32, curve25519_dalek::edwards::EdwardsPoint>, EnclaveError> {
    shares
        .iter()
        .map(|(index, bytes)| {
            let bytes: [u8; 32] = bytes
                .as_slice()
                .try_into()
                .map_err(|_| EnclaveError::ThresholdFailed("public share must be 32 bytes".to_string()))?;
            let point = CompressedEdwardsY(bytes)
                .decompress()
                .ok_or_else(|| EnclaveError::ThresholdFailed("invalid public share".to_string()))?;
            Ok((*index, point))
        })
        .collect()
}

fn threshold_public_key(
    shares: &[(u32, Vec<u8>)],
    config: crate::protocol::EnclaveThresholdConfig,
) -> Result<[u8; 32], EnclaveError> {
    let shares = shares
        .iter()
        .map(|(index, bytes)| {
            let bytes: [u8; 32] = bytes
                .as_slice()
                .try_into()
                .map_err(|_| EnclaveError::ThresholdFailed("public share must be 32 bytes".to_string()))?;
            Ok((*index, bytes))
        })
        .collect::<Result<Vec<_>, EnclaveError>>()?;
    let ctx = dkg::threshold_context_from_public_shares(
        &shares,
        dkg::ThresholdConfig {
            threshold: config.threshold,
            total: config.total,
        },
    )
    .map_err(|e| EnclaveError::ThresholdFailed(e.to_string()))?;
    Ok(ctx.public_key.hpke_public_key)
}