newton-enclave 0.4.15

newton prover enclave compute
use newton_core::crypto::{HpkePrivateKey, SecureEnvelope};
use tracing::debug;

use crate::{
    error::EnclaveError,
    protocol::{EnclaveEnvelope, EnclaveEvalRequest},
    threshold,
};

#[derive(Debug)]
pub(crate) struct DecryptedRequest {
    pub(crate) ephemeral: Option<serde_json::Value>,
    pub(crate) identity: Vec<DecryptedEnvelope>,
    pub(crate) confidential: Vec<DecryptedEnvelope>,
}

#[derive(Debug)]
pub(crate) struct DecryptedEnvelope {
    pub(crate) domain: Option<alloy::primitives::FixedBytes<32>>,
    pub(crate) value: serde_json::Value,
}

/// decrypt all envelopes in an evaluation request.
///
/// when `threshold` is present: decrypts encrypted peer partials using this enclave's
/// ephemeral key, splits by type using envelope counts, then per-type threshold or direct
/// HPKE decryption. types with count 0 fall back to direct HPKE (mixed-mode: a task may
/// use threshold for identity but direct for ephemeral).
///
/// when `threshold` is absent: all envelopes use direct HPKE decryption.
pub(crate) fn decrypt_request(
    req: &EnclaveEvalRequest,
    hpke_sk: Option<&HpkePrivateKey>,
    own_operator_id: Option<&crate::protocol::EnclaveOperatorId>,
) -> Result<DecryptedRequest, EnclaveError> {
    if let Some(threshold) = req.threshold.as_ref() {
        let hpke_sk =
            hpke_sk.ok_or_else(|| EnclaveError::MissingInput("hpke key for threshold decrypt".to_string()))?;
        let own_id = own_operator_id
            .ok_or_else(|| EnclaveError::MissingInput("operator id for threshold decrypt".to_string()))?;

        // Decrypt all peer partials using our ephemeral private key.
        // Each peer encrypted the same partials vector to us; we collect all peers' partials.
        let all_peer_partials = threshold::decrypt_all_peer_partials(
            &threshold.encrypted_peer_partials,
            hpke_sk,
            own_id,
            &threshold.task_id,
        )?;

        // Split the flat partials by envelope type using the counts from ThresholdEvalInput.
        // Each peer contributed partials for [ephemeral..identity..confidential] envelopes
        // concatenated. We transpose: from per-peer to per-envelope.
        let (eph_partials, id_partials, conf_partials) = split_and_transpose_partials(
            &all_peer_partials,
            threshold.ephemeral_count,
            threshold.identity_count,
            threshold.confidential_count,
        )?;

        debug!(
            ephemeral_count = req.ephemeral_envelopes.len(),
            identity_count = req.identity_envelopes.len(),
            confidential_count = req.confidential_envelopes.len(),
            peer_count = all_peer_partials.len(),
            "decrypt_request: threshold decryption from encrypted peer partials"
        );

        // Per-type decryption: threshold when partials are available, direct HPKE fallback
        // when count is 0 for that type (the task may use threshold for identity but direct
        // for ephemeral — any combination of 0+ per type is valid).
        return Ok(DecryptedRequest {
            ephemeral: if !eph_partials.is_empty() {
                threshold::decrypt_envelopes(&req.ephemeral_envelopes, Some(&eph_partials), threshold)?
            } else {
                decrypt_ephemeral(&req.ephemeral_envelopes, hpke_sk)?
            },
            identity: if !id_partials.is_empty() {
                threshold::decrypt_domain_envelopes(&req.identity_envelopes, Some(&id_partials), threshold)?
            } else {
                decrypt_domain(&req.identity_envelopes, hpke_sk)?
            },
            confidential: if !conf_partials.is_empty() {
                threshold::decrypt_domain_envelopes(&req.confidential_envelopes, Some(&conf_partials), threshold)?
            } else {
                decrypt_domain(&req.confidential_envelopes, hpke_sk)?
            },
        });
    }

    debug!(
        ephemeral_count = req.ephemeral_envelopes.len(),
        identity_count = req.identity_envelopes.len(),
        confidential_count = req.confidential_envelopes.len(),
        "decrypt_request: direct HPKE decryption (no threshold)"
    );

    let hpke_sk = hpke_sk.ok_or_else(|| EnclaveError::MissingInput("hpke key".to_string()))?;
    Ok(DecryptedRequest {
        ephemeral: decrypt_ephemeral(&req.ephemeral_envelopes, hpke_sk)?,
        identity: decrypt_domain(&req.identity_envelopes, hpke_sk)?,
        confidential: decrypt_domain(&req.confidential_envelopes, hpke_sk)?,
    })
}

type PartialsByType = (
    Vec<Vec<newton_aggregator::PartialDecryptionData>>,
    Vec<Vec<newton_aggregator::PartialDecryptionData>>,
    Vec<Vec<newton_aggregator::PartialDecryptionData>>,
);

/// Transpose per-peer partials into per-envelope partials, split by type.
///
/// Input: `all_peer_partials[peer_idx]` = flat Vec of partials for all envelopes from one peer.
/// Output: `(eph[env_idx][peer_idx], id[env_idx][peer_idx], conf[env_idx][peer_idx])`
fn split_and_transpose_partials(
    all_peer_partials: &[Vec<newton_aggregator::PartialDecryptionData>],
    eph_count: usize,
    id_count: usize,
    conf_count: usize,
) -> Result<PartialsByType, EnclaveError> {
    let total = eph_count + id_count + conf_count;
    if total == 0 {
        return Ok((vec![], vec![], vec![]));
    }

    for (peer_idx, peer_partials) in all_peer_partials.iter().enumerate() {
        if peer_partials.len() != total {
            return Err(EnclaveError::ThresholdFailed(format!(
                "peer {peer_idx} has {} partials, expected {total}",
                peer_partials.len()
            )));
        }
    }

    let mut eph = vec![Vec::with_capacity(all_peer_partials.len()); eph_count];
    let mut id = vec![Vec::with_capacity(all_peer_partials.len()); id_count];
    let mut conf = vec![Vec::with_capacity(all_peer_partials.len()); conf_count];

    for peer_partials in all_peer_partials {
        for (env_idx, partial) in peer_partials[..eph_count].iter().enumerate() {
            eph[env_idx].push(partial.clone());
        }
        for (env_idx, partial) in peer_partials[eph_count..eph_count + id_count].iter().enumerate() {
            id[env_idx].push(partial.clone());
        }
        for (env_idx, partial) in peer_partials[eph_count + id_count..].iter().enumerate() {
            conf[env_idx].push(partial.clone());
        }
    }

    Ok((eph, id, conf))
}

fn decrypt_ephemeral(
    envelopes: &[SecureEnvelope],
    hpke_sk: &HpkePrivateKey,
) -> Result<Option<serde_json::Value>, EnclaveError> {
    if envelopes.is_empty() {
        return Ok(None);
    }

    let values = envelopes
        .iter()
        .enumerate()
        .map(|(i, envelope)| {
            decrypt_value(envelope, hpke_sk).map_err(|e| EnclaveError::DecryptFailed(format!("ephemeral {i}: {e}")))
        })
        .collect::<Result<Vec<_>, _>>()?;

    Ok(Some(serde_json::json!({ "inline": values })))
}

fn decrypt_domain(
    envelopes: &[EnclaveEnvelope],
    hpke_sk: &HpkePrivateKey,
) -> Result<Vec<DecryptedEnvelope>, EnclaveError> {
    envelopes
        .iter()
        .enumerate()
        .map(|(i, item)| {
            let value = decrypt_value(&item.envelope, hpke_sk)
                .map_err(|e| EnclaveError::DecryptFailed(format!("domain envelope {i}: {e}")))?;
            Ok(DecryptedEnvelope {
                domain: item.domain,
                value,
            })
        })
        .collect()
}

pub(crate) fn decrypt_value(
    envelope: &SecureEnvelope,
    hpke_sk: &HpkePrivateKey,
) -> Result<serde_json::Value, EnclaveError> {
    let plaintext = envelope
        .open(hpke_sk)
        .map_err(|e| EnclaveError::DecryptFailed(e.to_string()))?;
    serde_json::from_slice(&plaintext).map_err(|e| EnclaveError::DecryptFailed(format!("plaintext is not json: {e}")))
}