newton-enclave 0.4.16

newton prover enclave compute
//! Enclave-resident privacy compute.

mod decrypt;
mod domain;
mod error;
/// KMS-attested HPKE key derivation.
pub mod kms;
/// enclave protocol messages.
pub mod protocol;
mod threshold;
/// WASM data provider executor for enclave-resident execution.
#[cfg(feature = "wasm")]
pub mod wasm;
/// framed enclave transport helpers.
pub mod wire;

pub use error::EnclaveError;
pub use newton_core::crypto::{EnclaveOperatorId, EncryptedPartialDH};
pub use protocol::{
    EnclaveEnvelope, EnclaveEvalRequest, EnclaveEvalResponse, EnclaveInitRequest, EnclavePartialDhRequest,
    EnclavePartialDhResponse, EnclaveThresholdConfig, EnclaveWireRequest, EnclaveWireRequestBody, EnclaveWireResponse,
    EnclaveWireResponseBody, GetAttestationRequest, GetAttestationResponse, PrepareEvalRequest, PrepareEvalResponse,
    ThresholdEvalInput, WasmPluginError, WasmPluginInput, WasmPluginOutput,
};

/// enclave key state.
#[derive(Debug, Default)]
pub struct EnclaveState {
    hpke_sk: Option<newton_core::crypto::HpkePrivateKey>,
    threshold_key_share: Option<newton_core::dkg::KeyShare>,
    /// operator id for this enclave instance, used as AAD for peer partial decryption.
    operator_id: Option<protocol::EnclaveOperatorId>,
}

impl EnclaveState {
    /// build enclave state from operator-held privacy keys.
    pub fn new(
        hpke_sk: Option<newton_core::crypto::HpkePrivateKey>,
        threshold_key_share: Option<newton_core::dkg::KeyShare>,
    ) -> Self {
        Self {
            hpke_sk,
            threshold_key_share,
            operator_id: None,
        }
    }

    /// set the operator id for this enclave (used as AAD for peer partial decryption).
    pub fn set_operator_id(&mut self, id: protocol::EnclaveOperatorId) {
        self.operator_id = Some(id);
    }

    /// initialize enclave state from centralized hpke key bytes (legacy Phase 1 path).
    pub fn init_hpke_private_key(&mut self, key: &[u8]) -> Result<(), EnclaveError> {
        self.hpke_sk = Some(
            newton_core::crypto::HpkePrivateKey::from_bytes(key)
                .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?,
        );
        Ok(())
    }

    /// initialize HPKE key via KMS-attested seed derivation.
    ///
    /// the caller (enclave binary) is responsible for decrypting the KMS seed
    /// using the Nitro attestation document. this method receives the zeroizable
    /// seed and derives the HPKE keypair deterministically via HKDF-SHA256.
    pub fn init_from_kms_seed(&mut self, seed: &zeroize::Zeroizing<[u8; 32]>) -> Result<(), EnclaveError> {
        self.hpke_sk = Some(kms::derive_hpke_from_seed(seed)?);
        Ok(())
    }

    /// initialize threshold key share from an encrypted FROST keystore.
    pub fn init_threshold_keystore(&mut self, keystore: &[u8], password: &[u8]) -> Result<(), EnclaveError> {
        let keystore: newton_core::dkg::keystore::ThresholdKeystore = serde_json::from_slice(keystore)
            .map_err(|e| EnclaveError::InvalidRequest(format!("invalid threshold keystore JSON: {e}")))?;
        let plaintext = zeroize::Zeroizing::new(
            newton_core::dkg::keystore::decrypt_keystore(&keystore, password)
                .map_err(|e| EnclaveError::DecryptFailed(format!("threshold keystore decrypt failed: {e}")))?,
        );
        let serializable: newton_core::dkg::types::SerializableKeyShare = serde_json::from_slice(&plaintext)
            .map_err(|e| EnclaveError::InvalidRequest(format!("invalid threshold key share: {e}")))?;
        self.threshold_key_share = Some(
            serializable
                .to_key_share()
                .map_err(|e| EnclaveError::InvalidRequest(format!("invalid threshold key share: {e}")))?,
        );
        Ok(())
    }

    /// return enclave-held hpke public key bytes.
    pub fn hpke_public_key(&self) -> Result<Vec<u8>, EnclaveError> {
        let sk = self
            .hpke_sk
            .as_ref()
            .ok_or_else(|| EnclaveError::MissingInput("hpke key".to_string()))?;
        Ok(sk.to_public_key().to_bytes().to_vec())
    }

    /// return a reference to the enclave-held HPKE private key.
    /// callers: LoopbackEnclave (operator crate) and enclave WASM executor.
    /// does not widen plaintext exposure — returns a reference, not a copy.
    pub fn hpke_sk_ref(&self) -> Option<&newton_core::crypto::HpkePrivateKey> {
        self.hpke_sk.as_ref()
    }

    /// return whether a threshold share is resident in enclave memory.
    pub fn has_threshold_key_share(&self) -> bool {
        self.threshold_key_share.is_some()
    }

    /// decrypt secret-bearing inputs and evaluate the policy.
    pub fn evaluate(&self, req: EnclaveEvalRequest) -> Result<EnclaveEvalResponse, EnclaveError> {
        let decrypted = decrypt::decrypt_request(&req, self.hpke_sk.as_ref(), self.operator_id.as_ref())?;
        let mut domain_data = Vec::new();
        domain_data.extend(domain::identity_data(
            &decrypted.identity,
            req.initialization_timestamp,
        )?);
        domain_data.extend(domain::confidential_data(&decrypted.confidential)?);

        let additional_data = domain::merge_additional(req.proof_data, decrypted.ephemeral);
        let result = newton_core::common::evaluate_task_with_resolved_policy(
            serde_json::json!(req.intent),
            &req.policy_task_data,
            req.resolved_policy,
            domain_data,
            additional_data,
        )
        .map_err(|e| EnclaveError::PolicyEvalFailed(e.to_string()))?;

        Ok(EnclaveEvalResponse {
            verified: *result.result.as_bool().unwrap_or(&false),
        })
    }

    /// compute threshold partial dh values.
    pub fn partial_dh(&self, req: EnclavePartialDhRequest) -> Result<EnclavePartialDhResponse, EnclaveError> {
        threshold::partial_dh(self.threshold_key_share.as_ref(), req)
    }
}

#[cfg(test)]
mod tests {
    use alloy::primitives::FixedBytes;
    use newton_aggregator::PartialDecryptionData;
    use newton_core::crypto::{generate_keypair, hpke::HpkePublicKey, SecureEnvelope};

    use super::*;

    const POLICY_CLIENT: &str = "0x0000000000000000000000000000000000000001";
    const CHAIN_ID: u64 = 31337;

    #[test]
    fn direct_envelope_decrypts_inside_enclave() {
        let (sk, pk) = generate_keypair();
        let envelope = SecureEnvelope::seal(br#"{"amount":7}"#, POLICY_CLIENT, CHAIN_ID, &pk, &[0x11; 32]).unwrap();

        let value = decrypt::decrypt_value(&envelope, &sk).unwrap();
        assert_eq!(value, serde_json::json!({ "amount": 7 }));
    }

    #[test]
    fn corrupt_envelope_fails_closed() {
        let (sk, pk) = generate_keypair();
        let mut envelope = SecureEnvelope::seal(br#"{"amount":7}"#, POLICY_CLIENT, CHAIN_ID, &pk, &[0x11; 32]).unwrap();
        // Flip bits in the first byte to guarantee corruption regardless of original value
        let first_byte = u8::from_str_radix(&envelope.ciphertext[0..2], 16).unwrap();
        let corrupted = format!("{:02x}", first_byte ^ 0xff);
        envelope.ciphertext.replace_range(0..2, &corrupted);

        assert!(decrypt::decrypt_value(&envelope, &sk).is_err());
    }

    #[test]
    fn partial_dh_encrypts_per_peer_and_round_trips() {
        let (_public_key, _commitment, shares) =
            newton_core::dkg::generate_shares(newton_core::dkg::ThresholdConfig { threshold: 2, total: 3 }).unwrap();
        let key_share = shares[0].clone();
        let (_sk, pk) = generate_keypair();
        let envelope = SecureEnvelope::seal(br#"{"ok":true}"#, POLICY_CLIENT, CHAIN_ID, &pk, &[0x11; 32]).unwrap();
        let enc = hex::decode(envelope.enc).unwrap();

        // Generate peer enclave keypairs
        let (peer_sk, peer_pk) = generate_keypair();
        let peer_operator_id: newton_core::crypto::EnclaveOperatorId = [0xAA; 32];
        let task_id = FixedBytes::ZERO;

        let resp = EnclaveState::new(None, Some(key_share))
            .partial_dh(EnclavePartialDhRequest {
                task_id,
                enc_points: vec![enc],
                peer_enclave_pubkeys: vec![(peer_operator_id, peer_pk.to_bytes().try_into().unwrap())],
            })
            .unwrap();

        assert_eq!(resp.encrypted_partials.len(), 1);
        assert_eq!(resp.encrypted_partials[0].recipient, peer_operator_id);

        // Verify round-trip: decrypt the blob as the peer
        let decrypted = crate::threshold::decrypt_all_peer_partials(
            &resp.encrypted_partials,
            &peer_sk,
            &peer_operator_id,
            &task_id,
        )
        .unwrap();
        assert_eq!(decrypted.len(), 1);
        assert_eq!(decrypted[0].len(), 1);
        assert!(decrypted[0][0].operator_index > 0);
    }

    #[test]
    fn threshold_identity_with_encrypted_peer_partials_decrypts() {
        let (hpke_sk, hpke_pk) = generate_keypair();
        let (threshold_public_key, _commitment, shares) =
            newton_core::dkg::generate_shares(newton_core::dkg::ThresholdConfig { threshold: 2, total: 3 }).unwrap();
        let public_shares: Vec<(u32, Vec<u8>)> = shares
            .iter()
            .map(|share| (share.index, share.public_share.compress().to_bytes().to_vec()))
            .collect();

        let identity_plaintext = br#"{"name":"alice"}"#;
        let direct_plaintext = br#"{"amount":7}"#;
        let task_id = FixedBytes::ZERO;

        let threshold_hpke_pk = HpkePublicKey::from_bytes(&threshold_public_key.hpke_public_key).unwrap();
        let identity_envelope = SecureEnvelope::seal(
            identity_plaintext,
            POLICY_CLIENT,
            CHAIN_ID,
            &threshold_hpke_pk,
            &[0x22; 32],
        )
        .unwrap();
        let direct_envelope =
            SecureEnvelope::seal(direct_plaintext, POLICY_CLIENT, CHAIN_ID, &hpke_pk, &[0x33; 32]).unwrap();

        // Simulate 2 operators computing encrypted partials for the identity envelope.
        // Each operator encrypts its partial to the recipient (our enclave).
        let recipient_operator_id: newton_core::crypto::EnclaveOperatorId = [0xBB; 32];
        let identity_enc = hex::decode(&identity_envelope.enc).unwrap();

        let mut all_encrypted: Vec<newton_core::crypto::EncryptedPartialDH> = Vec::new();
        for share in &shares[..2] {
            let resp = EnclaveState::new(None, Some(share.clone()))
                .partial_dh(EnclavePartialDhRequest {
                    task_id,
                    enc_points: vec![identity_enc.clone()],
                    peer_enclave_pubkeys: vec![(recipient_operator_id, hpke_pk.to_bytes().try_into().unwrap())],
                })
                .unwrap();
            all_encrypted.extend(resp.encrypted_partials);
        }

        let req = EnclaveEvalRequest {
            task_id,
            policy_client: POLICY_CLIENT.parse().unwrap(),
            intent: Default::default(),
            intent_signature: Default::default(),
            policy_task_data: Default::default(),
            resolved_policy: newton_core::common::ResolvedPolicyInputs {
                policy_config: Default::default(),
                entrypoint: "allow".to_string(),
                schema: serde_json::json!({}),
            },
            initialization_timestamp: 0,
            proof_data: None,
            ephemeral_envelopes: vec![direct_envelope],
            identity_envelopes: vec![EnclaveEnvelope {
                ref_id: None,
                domain: None,
                envelope: identity_envelope,
            }],
            confidential_envelopes: vec![],
            threshold: Some(ThresholdEvalInput {
                task_id,
                encrypted_peer_partials: all_encrypted,
                ephemeral_count: 0,
                identity_count: 1,
                confidential_count: 0,
                public_shares,
                config: EnclaveThresholdConfig { threshold: 2, total: 3 },
            }),
        };

        let mut state = EnclaveState::new(Some(hpke_sk.clone()), None);
        state.set_operator_id(recipient_operator_id);
        let decrypted = decrypt::decrypt_request(&req, Some(&hpke_sk), Some(&recipient_operator_id)).unwrap();
        assert_eq!(
            decrypted.ephemeral,
            Some(serde_json::json!({ "inline": [{ "amount": 7 }] }))
        );
        assert_eq!(decrypted.identity.len(), 1);
        assert_eq!(decrypted.identity[0].value, serde_json::json!({ "name": "alice" }));
    }
}