use std::ops::{Deref, DerefMut};
use borsh::{BorshDeserialize, BorshSerialize};
use qos_nsm::types::NsmResponse;
use serde::{Serialize, de::DeserializeOwned};
use crate::protocol::{
ProtocolError,
services::{
boot::{Approval, VersionedManifestEnvelope},
genesis::{GenesisOutput, GenesisSet},
},
};
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct JsonBytes<T>(T);
impl<T> JsonBytes<T> {
#[must_use]
pub fn new(value: T) -> Self {
Self(value)
}
#[must_use]
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for JsonBytes<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for JsonBytes<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<T> BorshSerialize for JsonBytes<T>
where
T: Serialize,
{
fn serialize<W: borsh::io::Write>(
&self,
writer: &mut W,
) -> borsh::io::Result<()> {
let bytes =
serde_json::to_vec(&self.0).map_err(borsh::io::Error::other)?;
BorshSerialize::serialize(&bytes, writer)
}
}
impl<T> BorshDeserialize for JsonBytes<T>
where
T: DeserializeOwned,
{
fn deserialize_reader<R: borsh::io::Read>(
reader: &mut R,
) -> borsh::io::Result<Self> {
let bytes = Vec::<u8>::deserialize_reader(reader)?;
let value =
serde_json::from_slice(&bytes).map_err(borsh::io::Error::other)?;
Ok(Self(value))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProtocolMsgEncoding {
Json,
Borsh,
}
#[derive(
Debug,
PartialEq,
borsh::BorshSerialize,
borsh::BorshDeserialize,
serde::Serialize,
serde::Deserialize,
)]
#[serde(rename_all = "camelCase")]
pub enum ProtocolMsg {
ProtocolErrorResponse(ProtocolError),
StatusRequest,
StatusResponse(super::ProtocolPhase),
BootStandardRequest {
manifest_envelope: Box<VersionedManifestEnvelope>,
#[serde(with = "qos_hex::serde")]
pivot: Vec<u8>,
},
BootStandardResponse {
nsm_response: NsmResponse,
},
BootGenesisRequest {
set: GenesisSet,
#[serde(
default,
skip_serializing_if = "Option::is_none",
with = "qos_hex::serde::option"
)]
dr_key: Option<Vec<u8>>,
},
BootGenesisResponse {
nsm_response: NsmResponse,
genesis_output: Box<GenesisOutput>,
},
ProvisionRequest {
#[serde(with = "qos_hex::serde")]
share: Vec<u8>,
approval: Approval,
},
ProvisionResponse {
reconstructed: bool,
},
ProxyRequest {
#[serde(with = "qos_hex::serde")]
data: Vec<u8>,
},
ProxyResponse {
#[serde(with = "qos_hex::serde")]
data: Vec<u8>,
},
LiveAttestationDocRequest,
LiveAttestationDocResponse {
nsm_response: NsmResponse,
#[serde(default, skip_serializing_if = "Option::is_none")]
manifest_envelope: Option<Box<VersionedManifestEnvelope>>,
},
BootKeyForwardRequest {
manifest_envelope: Box<VersionedManifestEnvelope>,
#[serde(with = "qos_hex::serde")]
pivot: Vec<u8>,
},
BootKeyForwardResponse {
nsm_response: NsmResponse,
},
ExportKeyRequest {
manifest_envelope: Box<VersionedManifestEnvelope>,
#[serde(with = "qos_hex::serde")]
cose_sign1_attestation_doc: Vec<u8>,
},
ExportKeyResponse {
#[serde(with = "qos_hex::serde")]
encrypted_quorum_key: Vec<u8>,
#[serde(with = "qos_hex::serde")]
signature: Vec<u8>,
},
InjectKeyRequest {
#[serde(with = "qos_hex::serde")]
encrypted_quorum_key: Vec<u8>,
#[serde(with = "qos_hex::serde")]
signature: Vec<u8>,
},
InjectKeyResponse,
ManifestEnvelopeRequest,
ManifestEnvelopeResponse {
#[serde(default, skip_serializing_if = "Option::is_none")]
manifest_envelope: Box<Option<VersionedManifestEnvelope>>,
},
VersionRequest,
VersionResponse {
version: String,
commit: String,
},
#[serde(skip)]
BootStandardJsonEnvelopeRequest {
manifest_envelope: Box<JsonBytes<VersionedManifestEnvelope>>,
pivot: Vec<u8>,
},
}
impl ProtocolMsg {
pub fn from_wire(
bytes: &[u8],
) -> Result<(Self, ProtocolMsgEncoding), ProtocolError> {
if let Ok(msg) = qos_json::from_slice(bytes) {
return Ok((msg, ProtocolMsgEncoding::Json));
}
<Self as borsh::BorshDeserialize>::try_from_slice(bytes)
.map(|msg| (msg, ProtocolMsgEncoding::Borsh))
.map_err(|_| ProtocolError::ProtocolMsgDeserialization)
}
pub fn from_wire_any(bytes: &[u8]) -> Result<Self, ProtocolError> {
Self::from_wire(bytes).map(|(msg, _)| msg)
}
pub fn to_wire(
&self,
encoding: ProtocolMsgEncoding,
) -> Result<Vec<u8>, ProtocolError> {
match encoding {
ProtocolMsgEncoding::Json => {
qos_json::to_vec(self).map_err(|_| ProtocolError::InvalidMsg)
}
ProtocolMsgEncoding::Borsh => {
borsh::to_vec(self).map_err(|_| ProtocolError::InvalidMsg)
}
}
}
pub fn to_json_wire(&self) -> Result<Vec<u8>, ProtocolError> {
self.to_wire(ProtocolMsgEncoding::Json)
}
pub fn to_borsh_wire(&self) -> Result<Vec<u8>, ProtocolError> {
self.to_wire(ProtocolMsgEncoding::Borsh)
}
}
impl std::fmt::Display for ProtocolMsg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ProtocolErrorResponse(_) => {
write!(f, "ProtocolErrorResponse")
}
Self::StatusRequest => write!(f, "StatusRequest"),
Self::StatusResponse(_) => {
write!(f, "StatusResponse")
}
Self::BootStandardRequest { .. } => {
write!(f, "BootStandardRequest")
}
Self::BootStandardJsonEnvelopeRequest { .. } => {
write!(f, "BootStandardJsonEnvelopeRequest")
}
Self::BootStandardResponse { .. } => {
write!(f, "BootStandardResponse")
}
Self::BootGenesisRequest { .. } => {
write!(f, "BootGenesisRequest")
}
Self::BootGenesisResponse { .. } => {
write!(f, "BootGenesisResponse")
}
Self::ProvisionRequest { .. } => {
write!(f, "ProvisionRequest")
}
Self::ProvisionResponse { reconstructed } => {
write!(
f,
"ProvisionResponse{{ reconstructed: {reconstructed} }}"
)
}
Self::ProxyRequest { .. } => {
write!(f, "ProxyRequest")
}
Self::ProxyResponse { .. } => {
write!(f, "ProxyResponse")
}
Self::LiveAttestationDocRequest { .. } => {
write!(f, "LiveAttestationDocRequest")
}
Self::LiveAttestationDocResponse { .. } => {
write!(f, "LiveAttestationDocResponse")
}
Self::BootKeyForwardRequest { .. } => {
write!(f, "BootKeyForwardRequest")
}
Self::BootKeyForwardResponse { nsm_response } => match nsm_response
{
NsmResponse::Attestation { .. } => write!(
f,
"BootKeyForwardResponse {{ nsm_response: Attestation }}"
),
NsmResponse::Error(ecode) => write!(
f,
"BootKeyForwardResponse {{ nsm_response: Error({ecode:?}) }}"
),
_ => write!(
f,
"BootKeyForwardResponse {{ nsm_response: Other }}" ),
},
Self::ExportKeyRequest { .. } => {
write!(f, "ExportKeyRequest")
}
Self::ExportKeyResponse { .. } => {
write!(f, "ExportKeyResponse")
}
Self::InjectKeyRequest { .. } => {
write!(f, "InjectKeyRequest")
}
Self::InjectKeyResponse { .. } => {
write!(f, "InjectKeyResponse")
}
Self::ManifestEnvelopeRequest { .. } => {
write!(f, "ManifestEnvelopeRequest")
}
Self::ManifestEnvelopeResponse { .. } => {
write!(f, "ManifestEnvelopeResponse")
}
Self::VersionRequest => write!(f, "VersionRequest"),
Self::VersionResponse { version, commit } => {
write!(
f,
"VersionResponse{{ version: {version}, commit: {commit} }}"
)
}
}
}
}
#[cfg(test)]
mod test {
use borsh::BorshDeserialize;
use std::collections::BTreeSet;
use super::*;
use crate::protocol::services::boot::{
Manifest, ManifestEnvelope, ManifestEnvelopeV2, ManifestSet,
ManifestV2, ManifestVersion, Namespace, NitroConfig, PatchSet,
PivotConfig, PivotConfigV2, PivotEnv, RestartPolicy, ShareSet,
};
#[test]
fn boot_genesis_response_deserialize() {
let nsm_response = NsmResponse::LockPCR;
let vec = borsh::to_vec(&nsm_response).unwrap();
let test = NsmResponse::try_from_slice(&vec).unwrap();
assert_eq!(nsm_response, test);
let genesis_response = ProtocolMsg::BootGenesisResponse {
nsm_response,
genesis_output: Box::new(GenesisOutput {
quorum_key: vec![3, 2, 1],
member_outputs: vec![],
recovery_permutations: vec![],
threshold: 2,
dr_key_wrapped_quorum_key: None,
quorum_key_hash: [22; 64],
test_message_ciphertext: vec![],
test_message_signature: vec![],
test_message: vec![],
}),
};
let vec = borsh::to_vec(&genesis_response).unwrap();
let test = ProtocolMsg::try_from_slice(&vec).unwrap();
assert_eq!(test, genesis_response);
}
#[test]
fn version_response_round_trip() {
let msg = ProtocolMsg::VersionResponse {
version: "0.5.0".to_string(),
commit: "abc1234".to_string(),
};
let vec = borsh::to_vec(&msg).unwrap();
let decoded = ProtocolMsg::try_from_slice(&vec).unwrap();
assert_eq!(msg, decoded);
}
#[test]
fn version_request_round_trip() {
let msg = ProtocolMsg::VersionRequest;
let vec = borsh::to_vec(&msg).unwrap();
let decoded = ProtocolMsg::try_from_slice(&vec).unwrap();
assert_eq!(msg, decoded);
}
#[test]
fn json_wire_round_trips_numeric_protocol_payloads() {
let msg = ProtocolMsg::BootGenesisResponse {
nsm_response: NsmResponse::DescribeNSM {
version_major: 1,
version_minor: 2,
version_patch: 3,
module_id: "module".to_string(),
max_pcrs: 32,
locked_pcrs: BTreeSet::from([0, 1, 2]),
digest: qos_nsm::types::NsmDigest::SHA384,
},
genesis_output: Box::new(GenesisOutput {
quorum_key: vec![3, 2, 1],
member_outputs: vec![],
recovery_permutations: vec![],
threshold: 2,
dr_key_wrapped_quorum_key: None,
quorum_key_hash: [22; 64],
test_message_ciphertext: vec![],
test_message_signature: vec![],
test_message: vec![],
}),
};
let encoded = msg.to_json_wire().unwrap();
let (decoded, encoding) = ProtocolMsg::from_wire(&encoded).unwrap();
assert_eq!(encoding, ProtocolMsgEncoding::Json);
assert_eq!(decoded, msg);
}
#[test]
fn v2_manifest_envelope_is_json_wire_only() {
let manifest = ManifestV2 {
version: ManifestVersion::V2,
namespace: Namespace {
name: "test".to_string(),
nonce: 1,
quorum_key: vec![7; 33],
},
pivot: PivotConfigV2 {
hash: [9; 32],
restart: RestartPolicy::Never,
bridge_config: vec![],
debug_mode: false,
args: vec![],
env: PivotEnv::new(),
},
manifest_set: ManifestSet { threshold: 1, members: vec![] },
share_set: ShareSet { threshold: 1, members: vec![] },
enclave: NitroConfig {
pcr0: vec![0; 48],
pcr1: vec![1; 48],
pcr2: vec![2; 48],
pcr3: vec![3; 48],
aws_root_certificate: vec![],
qos_commit: "commit".to_string(),
},
};
let msg = ProtocolMsg::BootStandardRequest {
manifest_envelope: Box::new(VersionedManifestEnvelope::V2(
ManifestEnvelopeV2 {
manifest,
manifest_set_approvals: vec![],
share_set_approvals: vec![],
},
)),
pivot: vec![],
};
let encoded = msg.to_json_wire().unwrap();
let (decoded, encoding) = ProtocolMsg::from_wire(&encoded).unwrap();
assert_eq!(encoding, ProtocolMsgEncoding::Json);
assert_eq!(decoded, msg);
assert!(msg.to_borsh_wire().is_err());
}
#[test]
fn boot_standard_json_envelope_request_is_borsh_only() {
let envelope = VersionedManifestEnvelope::V2(ManifestEnvelopeV2 {
manifest: ManifestV2 {
version: ManifestVersion::V2,
namespace: Namespace {
name: "test".to_string(),
nonce: 1,
quorum_key: vec![7; 33],
},
pivot: PivotConfigV2 {
hash: [9; 32],
restart: RestartPolicy::Never,
bridge_config: vec![],
debug_mode: false,
args: vec![],
env: PivotEnv::new(),
},
manifest_set: ManifestSet { threshold: 1, members: vec![] },
share_set: ShareSet { threshold: 1, members: vec![] },
enclave: NitroConfig {
pcr0: vec![0; 48],
pcr1: vec![1; 48],
pcr2: vec![2; 48],
pcr3: vec![3; 48],
aws_root_certificate: vec![],
qos_commit: "commit".to_string(),
},
},
manifest_set_approvals: vec![],
share_set_approvals: vec![],
});
let msg = ProtocolMsg::BootStandardJsonEnvelopeRequest {
manifest_envelope: Box::new(JsonBytes::new(envelope)),
pivot: vec![1, 2, 3, 4],
};
let encoded = msg.to_borsh_wire().unwrap();
let (decoded, encoding) = ProtocolMsg::from_wire(&encoded).unwrap();
assert_eq!(encoding, ProtocolMsgEncoding::Borsh);
assert_eq!(decoded, msg);
assert!(msg.to_json_wire().is_err());
}
#[test]
fn borsh_variant_discriminants_preserve_legacy_boot_standard_request() {
let msg = ProtocolMsg::BootStandardRequest {
manifest_envelope: Box::new(VersionedManifestEnvelope::V1(
ManifestEnvelope {
manifest: Manifest {
namespace: Namespace {
name: "test".to_string(),
nonce: 1,
quorum_key: vec![7; 33],
},
pivot: PivotConfig {
hash: [9; 32],
restart: RestartPolicy::Never,
args: vec![],
bridge_config: vec![],
debug_mode: false,
},
enclave: NitroConfig {
pcr0: vec![0; 48],
pcr1: vec![1; 48],
pcr2: vec![2; 48],
pcr3: vec![3; 48],
aws_root_certificate: vec![],
qos_commit: "commit".to_string(),
},
manifest_set: ManifestSet {
threshold: 1,
members: vec![],
},
share_set: ShareSet { threshold: 1, members: vec![] },
patch_set: PatchSet { threshold: 0, members: vec![] },
},
manifest_set_approvals: vec![],
share_set_approvals: vec![],
},
)),
pivot: vec![1, 2, 3, 4],
};
let encoded = msg.to_borsh_wire().unwrap();
assert_eq!(encoded.first().copied(), Some(3));
}
}