newton-enclave 0.4.15

newton prover enclave compute
//! length-prefixed framing over vsock (a stream protocol with no built-in message boundaries).
//!
//! framing format: 4-byte big-endian length prefix (via tokio `read_u32`/`write_u32`) followed
//! by a bincode-encoded payload. max frame size is `MAX_FRAME_LEN` (16 MiB).
//!
//! **field-order stability:** bincode is positional, not self-describing. any change to struct
//! field order or types in the request/response types is a wire-breaking change and must be
//! accompanied by a `ENCLAVE_PROTOCOL_VERSION` bump. callers check the version before
//! deserializing, so a mismatch surfaces as a clean version error rather than corrupt data.
//!
//! both the enclave binary and the operator vsock client depend on this module. the protocol
//! is one request per connection — the enclave reads one frame, responds, then closes.

use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use crate::{
    protocol::{EnclaveWireRequest, EnclaveWireResponse, MAX_FRAME_LEN},
    EnclaveError,
};

/// read one framed enclave request.
pub async fn read_request<R>(reader: &mut R) -> Result<EnclaveWireRequest, EnclaveError>
where
    R: AsyncRead + Unpin,
{
    read_frame(reader).await
}

/// write one framed enclave request.
pub async fn write_request<W>(writer: &mut W, request: &EnclaveWireRequest) -> Result<(), EnclaveError>
where
    W: AsyncWrite + Unpin,
{
    write_frame(writer, request).await
}

/// read one framed enclave response.
pub async fn read_response<R>(reader: &mut R) -> Result<EnclaveWireResponse, EnclaveError>
where
    R: AsyncRead + Unpin,
{
    read_frame(reader).await
}

/// write one framed enclave response.
pub async fn write_response<W>(writer: &mut W, response: &EnclaveWireResponse) -> Result<(), EnclaveError>
where
    W: AsyncWrite + Unpin,
{
    write_frame(writer, response).await
}

async fn read_frame<R, T>(reader: &mut R) -> Result<T, EnclaveError>
where
    R: AsyncRead + Unpin,
    T: serde::de::DeserializeOwned,
{
    let len = reader
        .read_u32()
        .await
        .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))? as usize;
    if len > MAX_FRAME_LEN {
        return Err(EnclaveError::InvalidRequest(format!("frame too large: {len}")));
    }

    let mut buf = vec![0; len];
    reader
        .read_exact(&mut buf)
        .await
        .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
    bincode::serde::decode_from_slice(&buf, bincode::config::standard())
        .map(|(msg, _)| msg)
        .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))
}

async fn write_frame<W, T>(writer: &mut W, msg: &T) -> Result<(), EnclaveError>
where
    W: AsyncWrite + Unpin,
    T: serde::Serialize,
{
    let buf = bincode::serde::encode_to_vec(msg, bincode::config::standard())
        .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
    if buf.len() > MAX_FRAME_LEN {
        return Err(EnclaveError::InvalidRequest(format!("frame too large: {}", buf.len())));
    }

    writer
        .write_u32(buf.len() as u32)
        .await
        .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
    writer
        .write_all(&buf)
        .await
        .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
    writer
        .flush()
        .await
        .map_err(|e| EnclaveError::InvalidRequest(e.to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::protocol::{EnclaveWireRequestBody, EnclaveWireResponseBody, ENCLAVE_PROTOCOL_VERSION};

    #[tokio::test]
    async fn request_round_trip_over_duplex() {
        let (mut client, mut server) = tokio::io::duplex(1024);
        let request = crate::protocol::EnclaveWireRequest {
            version: ENCLAVE_PROTOCOL_VERSION,
            request_id: 7,
            body: EnclaveWireRequestBody::Health,
        };

        let writer = tokio::spawn(async move { write_request(&mut client, &request).await });
        let decoded = read_request(&mut server).await.unwrap();
        writer.await.unwrap().unwrap();

        assert_eq!(decoded.request_id, 7);
        assert!(matches!(decoded.body, EnclaveWireRequestBody::Health));
    }

    #[tokio::test]
    async fn response_round_trip_over_duplex() {
        let (mut client, mut server) = tokio::io::duplex(1024);
        let response = crate::protocol::EnclaveWireResponse {
            version: ENCLAVE_PROTOCOL_VERSION,
            request_id: 9,
            body: EnclaveWireResponseBody::Health,
        };

        let writer = tokio::spawn(async move { write_response(&mut client, &response).await });
        let decoded = read_response(&mut server).await.unwrap();
        writer.await.unwrap().unwrap();

        assert_eq!(decoded.request_id, 9);
        assert!(matches!(decoded.body, EnclaveWireResponseBody::Health));
    }

    #[tokio::test]
    async fn frame_too_large_is_rejected() {
        let (mut client, mut server) = tokio::io::duplex(8);
        let writer = tokio::spawn(async move {
            client.write_u32((MAX_FRAME_LEN + 1) as u32).await.unwrap();
            client.flush().await.unwrap();
        });
        let error = read_request(&mut server).await.unwrap_err();
        writer.await.unwrap();

        assert!(matches!(error, EnclaveError::InvalidRequest(message) if message.contains("frame too large")));
    }

    /// verify that alloy Bytes / FixedBytes fields in the partial-dh request round-trip
    /// through bincode. bincode is positional — this guards against silent serde mismatches
    /// for types with variable-length byte sequences.
    #[tokio::test]
    async fn partial_dh_request_with_bytes_round_trips() {
        use crate::protocol::{EnclavePartialDhRequest, EnclaveWireRequest, EnclaveWireRequestBody};

        let enc_points: Vec<Vec<u8>> = vec![vec![0xab; 32], vec![0xcd; 32]];
        let req = EnclavePartialDhRequest {
            task_id: alloy::primitives::FixedBytes::ZERO,
            enc_points: enc_points.clone(),
            peer_enclave_pubkeys: vec![([0x11; 32], [0x22; 32])],
        };
        let wire_req = EnclaveWireRequest {
            version: ENCLAVE_PROTOCOL_VERSION,
            request_id: 55,
            body: EnclaveWireRequestBody::PartialDh(req),
        };

        let (mut client, mut server) = tokio::io::duplex(4096);
        let writer = tokio::spawn(async move { write_request(&mut client, &wire_req).await });
        let decoded = read_request(&mut server).await.unwrap();
        writer.await.unwrap().unwrap();

        assert_eq!(decoded.request_id, 55);
        let EnclaveWireRequestBody::PartialDh(decoded_req) = decoded.body else {
            panic!("expected PartialDh body");
        };
        assert_eq!(decoded_req.enc_points, enc_points);
    }

    /// verify that EnclaveInitRequest with optional Bytes fields round-trips.
    #[tokio::test]
    async fn init_request_with_bytes_round_trips() {
        use crate::protocol::{EnclaveInitRequest, EnclaveWireRequest, EnclaveWireRequestBody};

        let init = EnclaveInitRequest {
            hpke_private_key: Some(vec![0x11; 32]),
            threshold_keystore: Some(b"{\"key\":\"data\"}".to_vec()),
            threshold_keystore_password: Some(b"hunter2".to_vec()),
            kms_seed_ciphertext: None,
        };
        let wire_req = EnclaveWireRequest {
            version: ENCLAVE_PROTOCOL_VERSION,
            request_id: 77,
            body: EnclaveWireRequestBody::Init(init.clone()),
        };

        let (mut client, mut server) = tokio::io::duplex(4096);
        let writer = tokio::spawn(async move { write_request(&mut client, &wire_req).await });
        let decoded = read_request(&mut server).await.unwrap();
        writer.await.unwrap().unwrap();

        assert_eq!(decoded.request_id, 77);
        let EnclaveWireRequestBody::Init(decoded_init) = decoded.body else {
            panic!("expected Init body");
        };
        assert_eq!(decoded_init.hpke_private_key, init.hpke_private_key);
        assert_eq!(decoded_init.threshold_keystore, init.threshold_keystore);
        assert_eq!(
            decoded_init.threshold_keystore_password,
            init.threshold_keystore_password
        );
    }
}