use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::{
protocol::{EnclaveWireRequest, EnclaveWireResponse, MAX_FRAME_LEN},
EnclaveError,
};
pub async fn read_request<R>(reader: &mut R) -> Result<EnclaveWireRequest, EnclaveError>
where
R: AsyncRead + Unpin,
{
read_frame(reader).await
}
pub async fn write_request<W>(writer: &mut W, request: &EnclaveWireRequest) -> Result<(), EnclaveError>
where
W: AsyncWrite + Unpin,
{
write_frame(writer, request).await
}
pub async fn read_response<R>(reader: &mut R) -> Result<EnclaveWireResponse, EnclaveError>
where
R: AsyncRead + Unpin,
{
read_frame(reader).await
}
pub async fn write_response<W>(writer: &mut W, response: &EnclaveWireResponse) -> Result<(), EnclaveError>
where
W: AsyncWrite + Unpin,
{
write_frame(writer, response).await
}
async fn read_frame<R, T>(reader: &mut R) -> Result<T, EnclaveError>
where
R: AsyncRead + Unpin,
T: serde::de::DeserializeOwned,
{
let len = reader
.read_u32()
.await
.map_err(|e| EnclaveError::InvalidRequest(e.to_string()))? as usize;
if len > MAX_FRAME_LEN {
return Err(EnclaveError::InvalidRequest(format!("frame too large: {len}")));
}
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.await
.map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
bincode::serde::decode_from_slice(&buf, bincode::config::standard())
.map(|(msg, _)| msg)
.map_err(|e| EnclaveError::InvalidRequest(e.to_string()))
}
async fn write_frame<W, T>(writer: &mut W, msg: &T) -> Result<(), EnclaveError>
where
W: AsyncWrite + Unpin,
T: serde::Serialize,
{
let buf = bincode::serde::encode_to_vec(msg, bincode::config::standard())
.map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
if buf.len() > MAX_FRAME_LEN {
return Err(EnclaveError::InvalidRequest(format!("frame too large: {}", buf.len())));
}
writer
.write_u32(buf.len() as u32)
.await
.map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
writer
.write_all(&buf)
.await
.map_err(|e| EnclaveError::InvalidRequest(e.to_string()))?;
writer
.flush()
.await
.map_err(|e| EnclaveError::InvalidRequest(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::{EnclaveWireRequestBody, EnclaveWireResponseBody, ENCLAVE_PROTOCOL_VERSION};
#[tokio::test]
async fn request_round_trip_over_duplex() {
let (mut client, mut server) = tokio::io::duplex(1024);
let request = crate::protocol::EnclaveWireRequest {
version: ENCLAVE_PROTOCOL_VERSION,
request_id: 7,
body: EnclaveWireRequestBody::Health,
};
let writer = tokio::spawn(async move { write_request(&mut client, &request).await });
let decoded = read_request(&mut server).await.unwrap();
writer.await.unwrap().unwrap();
assert_eq!(decoded.request_id, 7);
assert!(matches!(decoded.body, EnclaveWireRequestBody::Health));
}
#[tokio::test]
async fn response_round_trip_over_duplex() {
let (mut client, mut server) = tokio::io::duplex(1024);
let response = crate::protocol::EnclaveWireResponse {
version: ENCLAVE_PROTOCOL_VERSION,
request_id: 9,
body: EnclaveWireResponseBody::Health,
};
let writer = tokio::spawn(async move { write_response(&mut client, &response).await });
let decoded = read_response(&mut server).await.unwrap();
writer.await.unwrap().unwrap();
assert_eq!(decoded.request_id, 9);
assert!(matches!(decoded.body, EnclaveWireResponseBody::Health));
}
#[tokio::test]
async fn frame_too_large_is_rejected() {
let (mut client, mut server) = tokio::io::duplex(8);
let writer = tokio::spawn(async move {
client.write_u32((MAX_FRAME_LEN + 1) as u32).await.unwrap();
client.flush().await.unwrap();
});
let error = read_request(&mut server).await.unwrap_err();
writer.await.unwrap();
assert!(matches!(error, EnclaveError::InvalidRequest(message) if message.contains("frame too large")));
}
#[tokio::test]
async fn partial_dh_request_with_bytes_round_trips() {
use crate::protocol::{EnclavePartialDhRequest, EnclaveWireRequest, EnclaveWireRequestBody};
let enc_points: Vec<Vec<u8>> = vec![vec![0xab; 32], vec![0xcd; 32]];
let req = EnclavePartialDhRequest {
task_id: alloy::primitives::FixedBytes::ZERO,
enc_points: enc_points.clone(),
peer_enclave_pubkeys: vec![([0x11; 32], [0x22; 32])],
};
let wire_req = EnclaveWireRequest {
version: ENCLAVE_PROTOCOL_VERSION,
request_id: 55,
body: EnclaveWireRequestBody::PartialDh(req),
};
let (mut client, mut server) = tokio::io::duplex(4096);
let writer = tokio::spawn(async move { write_request(&mut client, &wire_req).await });
let decoded = read_request(&mut server).await.unwrap();
writer.await.unwrap().unwrap();
assert_eq!(decoded.request_id, 55);
let EnclaveWireRequestBody::PartialDh(decoded_req) = decoded.body else {
panic!("expected PartialDh body");
};
assert_eq!(decoded_req.enc_points, enc_points);
}
#[tokio::test]
async fn init_request_with_bytes_round_trips() {
use crate::protocol::{EnclaveInitRequest, EnclaveWireRequest, EnclaveWireRequestBody};
let init = EnclaveInitRequest {
hpke_private_key: Some(vec![0x11; 32]),
threshold_keystore: Some(b"{\"key\":\"data\"}".to_vec()),
threshold_keystore_password: Some(b"hunter2".to_vec()),
kms_seed_ciphertext: None,
};
let wire_req = EnclaveWireRequest {
version: ENCLAVE_PROTOCOL_VERSION,
request_id: 77,
body: EnclaveWireRequestBody::Init(init.clone()),
};
let (mut client, mut server) = tokio::io::duplex(4096);
let writer = tokio::spawn(async move { write_request(&mut client, &wire_req).await });
let decoded = read_request(&mut server).await.unwrap();
writer.await.unwrap().unwrap();
assert_eq!(decoded.request_id, 77);
let EnclaveWireRequestBody::Init(decoded_init) = decoded.body else {
panic!("expected Init body");
};
assert_eq!(decoded_init.hpke_private_key, init.hpke_private_key);
assert_eq!(decoded_init.threshold_keystore, init.threshold_keystore);
assert_eq!(
decoded_init.threshold_keystore_password,
init.threshold_keystore_password
);
}
}