use qos_nsm::NsmProvider;
use super::{
error::ProtocolError, msg::ProtocolMsg, services::provision::SecretBuilder,
};
use crate::handles::Handles;
#[derive(
Debug,
Copy,
PartialEq,
Eq,
Clone,
borsh::BorshSerialize,
borsh::BorshDeserialize,
serde::Serialize,
serde::Deserialize,
)]
pub enum ProtocolPhase {
#[serde(rename = "UnrecoverableError", alias = "unrecoverableError")]
UnrecoverableError,
#[serde(
rename = "WaitingForBootInstruction",
alias = "waitingForBootInstruction"
)]
WaitingForBootInstruction,
#[serde(rename = "GenesisBooted", alias = "genesisBooted")]
GenesisBooted,
#[serde(
rename = "WaitingForQuorumShards",
alias = "waitingForQuorumShards"
)]
WaitingForQuorumShards,
#[serde(rename = "QuorumKeyProvisioned", alias = "quorumKeyProvisioned")]
QuorumKeyProvisioned,
#[serde(
rename = "WaitingForForwardedKey",
alias = "waitingForForwardedKey"
)]
WaitingForForwardedKey,
}
type ProtocolRouteResponse = Option<Result<ProtocolMsg, ProtocolMsg>>;
type ProtocolRouteHandler =
dyn Fn(&ProtocolMsg, &mut ProtocolState) -> ProtocolRouteResponse;
struct ProtocolRoute {
handler: Box<ProtocolRouteHandler>,
ok_phase: ProtocolPhase, err_phase: ProtocolPhase, }
impl ProtocolRoute {
pub fn try_msg(
&self,
msg: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
let resp = (self.handler)(msg, state);
if let Some(Ok(ProtocolMsg::ProvisionResponse { reconstructed })) = resp
&& !reconstructed
{
return resp;
}
let transition = match resp {
None => None,
Some(ref result) => match result {
Ok(_) => Some(self.ok_phase),
Err(_) => Some(self.err_phase),
},
};
if let Some(phase) = transition
&& let Err(e) = state.transition(phase)
{
return Some(Err(ProtocolMsg::ProtocolErrorResponse(e)));
}
resp
}
pub fn status(current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::status),
current_phase,
current_phase,
)
}
pub fn version(current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::version),
current_phase,
current_phase,
)
}
pub fn manifest_envelope(current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::manifest_envelope),
current_phase,
current_phase,
)
}
pub fn live_attestation_doc(current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::live_attestation_doc),
current_phase,
current_phase,
)
}
pub fn boot_genesis(_current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::boot_genesis),
ProtocolPhase::GenesisBooted,
ProtocolPhase::UnrecoverableError,
)
}
pub fn boot_standard(_current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::boot_standard),
ProtocolPhase::WaitingForQuorumShards,
ProtocolPhase::UnrecoverableError,
)
}
pub fn boot_key_forward(_current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::boot_key_forward),
ProtocolPhase::WaitingForForwardedKey,
ProtocolPhase::UnrecoverableError,
)
}
pub fn provision(_current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::provision),
ProtocolPhase::QuorumKeyProvisioned,
ProtocolPhase::UnrecoverableError,
)
}
pub fn export_key(current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::export_key),
current_phase,
current_phase,
)
}
pub fn inject_key(_current_phase: ProtocolPhase) -> Self {
ProtocolRoute::new(
Box::new(handlers::inject_key),
ProtocolPhase::QuorumKeyProvisioned,
ProtocolPhase::UnrecoverableError,
)
}
fn new(
handler: Box<ProtocolRouteHandler>,
ok_phase: ProtocolPhase,
err_phase: ProtocolPhase,
) -> Self {
ProtocolRoute { handler, ok_phase, err_phase }
}
}
pub(crate) struct ProtocolState {
pub provisioner: SecretBuilder,
pub attestor: Box<dyn NsmProvider>,
pub handles: Handles,
phase: ProtocolPhase,
}
impl ProtocolState {
pub fn new(
attestor: Box<dyn NsmProvider>,
handles: Handles,
#[allow(unused)] test_only_init_phase_override: Option<ProtocolPhase>,
) -> Self {
let provisioner = SecretBuilder::new();
#[cfg(any(feature = "mock", test))]
let init_phase = if let Some(phase) = test_only_init_phase_override {
phase
} else {
ProtocolPhase::WaitingForBootInstruction
};
#[cfg(not(any(feature = "mock", test)))]
let init_phase = ProtocolPhase::WaitingForBootInstruction;
Self { attestor, provisioner, phase: init_phase, handles }
}
pub fn get_phase(&self) -> ProtocolPhase {
self.phase
}
pub fn handle_msg_response(
&mut self,
msg_req: &ProtocolMsg,
) -> ProtocolMsg {
for route in &self.routes() {
match route.try_msg(msg_req, self) {
None => (),
Some(result) => match result {
Ok(msg_resp) | Err(msg_resp) => {
return msg_resp;
}
},
}
}
let err = ProtocolError::NoMatchingRoute(self.phase);
ProtocolMsg::ProtocolErrorResponse(err)
}
#[allow(clippy::too_many_lines)]
fn routes(&self) -> Vec<ProtocolRoute> {
#[allow(clippy::match_same_arms)]
match self.phase {
ProtocolPhase::UnrecoverableError => {
vec![
ProtocolRoute::status(self.phase),
ProtocolRoute::version(self.phase),
ProtocolRoute::manifest_envelope(self.phase),
ProtocolRoute::live_attestation_doc(self.phase),
]
}
ProtocolPhase::GenesisBooted => {
vec![
ProtocolRoute::status(self.phase),
ProtocolRoute::version(self.phase),
]
}
ProtocolPhase::WaitingForBootInstruction => vec![
ProtocolRoute::status(self.phase),
ProtocolRoute::version(self.phase),
ProtocolRoute::manifest_envelope(self.phase),
ProtocolRoute::boot_genesis(self.phase),
ProtocolRoute::boot_standard(self.phase),
ProtocolRoute::boot_key_forward(self.phase),
],
ProtocolPhase::WaitingForQuorumShards => {
vec![
ProtocolRoute::status(self.phase),
ProtocolRoute::version(self.phase),
ProtocolRoute::live_attestation_doc(self.phase),
ProtocolRoute::manifest_envelope(self.phase),
ProtocolRoute::provision(self.phase),
]
}
ProtocolPhase::QuorumKeyProvisioned => {
vec![
ProtocolRoute::status(self.phase),
ProtocolRoute::version(self.phase),
ProtocolRoute::live_attestation_doc(self.phase),
ProtocolRoute::manifest_envelope(self.phase),
ProtocolRoute::export_key(self.phase),
]
}
ProtocolPhase::WaitingForForwardedKey => {
vec![
ProtocolRoute::status(self.phase),
ProtocolRoute::version(self.phase),
ProtocolRoute::live_attestation_doc(self.phase),
ProtocolRoute::manifest_envelope(self.phase),
ProtocolRoute::inject_key(self.phase),
]
}
}
}
pub fn transition(
&mut self,
next: ProtocolPhase,
) -> Result<(), ProtocolError> {
if self.phase == next {
return Ok(());
}
#[allow(clippy::match_same_arms)]
let transitions = match self.phase {
ProtocolPhase::UnrecoverableError => vec![],
ProtocolPhase::WaitingForBootInstruction => vec![
ProtocolPhase::UnrecoverableError,
ProtocolPhase::GenesisBooted,
ProtocolPhase::WaitingForQuorumShards,
ProtocolPhase::WaitingForForwardedKey,
],
ProtocolPhase::GenesisBooted => {
vec![ProtocolPhase::UnrecoverableError]
}
ProtocolPhase::WaitingForQuorumShards => {
vec![
ProtocolPhase::UnrecoverableError,
ProtocolPhase::QuorumKeyProvisioned,
]
}
ProtocolPhase::QuorumKeyProvisioned => {
vec![ProtocolPhase::UnrecoverableError]
}
ProtocolPhase::WaitingForForwardedKey => {
vec![
ProtocolPhase::UnrecoverableError,
ProtocolPhase::QuorumKeyProvisioned,
]
}
};
if !transitions.contains(&next) {
let prev = self.phase;
self.phase = ProtocolPhase::UnrecoverableError;
return Err(ProtocolError::InvalidStateTransition(prev, next));
}
self.phase = next;
Ok(())
}
}
mod handlers {
use super::ProtocolRouteResponse;
use crate::protocol::{
ProtocolState,
msg::ProtocolMsg,
services::{
attestation, boot, genesis, key, key::EncryptedQuorumKey, provision,
},
};
pub(super) fn status(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::StatusRequest = req {
Some(Ok(ProtocolMsg::StatusResponse(state.get_phase())))
} else {
None
}
}
pub(super) fn version(
req: &ProtocolMsg,
_state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::VersionRequest = req {
Some(Ok(ProtocolMsg::VersionResponse {
version: env!("CARGO_PKG_VERSION").to_string(),
commit: env!("QOS_GIT_COMMIT").to_string(),
}))
} else {
None
}
}
pub(super) fn manifest_envelope(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::ManifestEnvelopeRequest = req {
Some(Ok(ProtocolMsg::ManifestEnvelopeResponse {
manifest_envelope: Box::new(
state.handles.get_manifest_envelope().ok(),
),
}))
} else {
None
}
}
pub(super) fn provision(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::ProvisionRequest { share, approval } = req {
let result = provision::provision(share, approval.clone(), state)
.map(|reconstructed| ProtocolMsg::ProvisionResponse {
reconstructed,
})
.map_err(ProtocolMsg::ProtocolErrorResponse);
Some(result)
} else {
None
}
}
pub(super) fn boot_standard(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
let (manifest_envelope, pivot) = match req {
ProtocolMsg::BootStandardRequest { manifest_envelope, pivot } => {
((**manifest_envelope).clone(), pivot)
}
ProtocolMsg::BootStandardJsonEnvelopeRequest {
manifest_envelope,
pivot,
} => (manifest_envelope.as_ref().clone().into_inner(), pivot),
_ => return None,
};
let result = boot::boot_standard(state, manifest_envelope, pivot)
.map(|nsm_response| ProtocolMsg::BootStandardResponse {
nsm_response,
})
.map_err(ProtocolMsg::ProtocolErrorResponse);
Some(result)
}
pub(super) fn boot_genesis(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::BootGenesisRequest { set, dr_key } = req {
let result = genesis::boot_genesis(state, set, dr_key.clone())
.map(|(genesis_output, nsm_response)| {
ProtocolMsg::BootGenesisResponse {
nsm_response,
genesis_output: Box::new(genesis_output),
}
})
.map_err(ProtocolMsg::ProtocolErrorResponse);
Some(result)
} else {
None
}
}
pub(super) fn live_attestation_doc(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::LiveAttestationDocRequest = req {
let result = attestation::live_attestation_doc(state)
.map(|nsm_response| ProtocolMsg::LiveAttestationDocResponse {
nsm_response,
manifest_envelope: state
.handles
.get_manifest_envelope()
.ok()
.map(Box::new),
})
.map_err(ProtocolMsg::ProtocolErrorResponse);
Some(result)
} else {
None
}
}
pub(super) fn boot_key_forward(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::BootKeyForwardRequest { manifest_envelope, pivot } =
req
{
let result = key::boot_key_forward(state, manifest_envelope, pivot)
.map(|nsm_response| ProtocolMsg::BootKeyForwardResponse {
nsm_response,
})
.map_err(ProtocolMsg::ProtocolErrorResponse);
Some(result)
} else {
None
}
}
pub(super) fn export_key(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::ExportKeyRequest {
manifest_envelope,
cose_sign1_attestation_doc,
} = req
{
let result = key::export_key(
state,
manifest_envelope,
cose_sign1_attestation_doc,
)
.map(|key| {
let EncryptedQuorumKey { encrypted_quorum_key, signature } =
key;
ProtocolMsg::ExportKeyResponse {
encrypted_quorum_key,
signature,
}
})
.map_err(ProtocolMsg::ProtocolErrorResponse);
Some(result)
} else {
None
}
}
pub(super) fn inject_key(
req: &ProtocolMsg,
state: &mut ProtocolState,
) -> ProtocolRouteResponse {
if let ProtocolMsg::InjectKeyRequest {
encrypted_quorum_key,
signature,
} = req
{
let result = key::inject_key(
state,
EncryptedQuorumKey {
encrypted_quorum_key: encrypted_quorum_key.clone(),
signature: signature.clone(),
},
)
.map(|()| ProtocolMsg::InjectKeyResponse)
.map_err(ProtocolMsg::ProtocolErrorResponse);
Some(result)
} else {
None
}
}
}