use newton_core::crypto::{HpkePrivateKey, SecureEnvelope};
use tracing::debug;
use crate::{
error::EnclaveError,
protocol::{EnclaveEnvelope, EnclaveEvalRequest},
threshold,
};
#[derive(Debug)]
pub(crate) struct DecryptedRequest {
pub(crate) ephemeral: Option<serde_json::Value>,
pub(crate) identity: Vec<DecryptedEnvelope>,
pub(crate) confidential: Vec<DecryptedEnvelope>,
}
#[derive(Debug)]
pub(crate) struct DecryptedEnvelope {
pub(crate) domain: Option<alloy::primitives::FixedBytes<32>>,
pub(crate) value: serde_json::Value,
}
pub(crate) fn decrypt_request(
req: &EnclaveEvalRequest,
hpke_sk: Option<&HpkePrivateKey>,
own_operator_id: Option<&crate::protocol::EnclaveOperatorId>,
) -> Result<DecryptedRequest, EnclaveError> {
if let Some(threshold) = req.threshold.as_ref() {
let hpke_sk =
hpke_sk.ok_or_else(|| EnclaveError::MissingInput("hpke key for threshold decrypt".to_string()))?;
let own_id = own_operator_id
.ok_or_else(|| EnclaveError::MissingInput("operator id for threshold decrypt".to_string()))?;
let all_peer_partials = threshold::decrypt_all_peer_partials(
&threshold.encrypted_peer_partials,
hpke_sk,
own_id,
&threshold.task_id,
)?;
let (eph_partials, id_partials, conf_partials) = split_and_transpose_partials(
&all_peer_partials,
threshold.ephemeral_count,
threshold.identity_count,
threshold.confidential_count,
)?;
debug!(
ephemeral_count = req.ephemeral_envelopes.len(),
identity_count = req.identity_envelopes.len(),
confidential_count = req.confidential_envelopes.len(),
peer_count = all_peer_partials.len(),
"decrypt_request: threshold decryption from encrypted peer partials"
);
return Ok(DecryptedRequest {
ephemeral: if !eph_partials.is_empty() {
threshold::decrypt_envelopes(&req.ephemeral_envelopes, Some(&eph_partials), threshold)?
} else {
decrypt_ephemeral(&req.ephemeral_envelopes, hpke_sk)?
},
identity: if !id_partials.is_empty() {
threshold::decrypt_domain_envelopes(&req.identity_envelopes, Some(&id_partials), threshold)?
} else {
decrypt_domain(&req.identity_envelopes, hpke_sk)?
},
confidential: if !conf_partials.is_empty() {
threshold::decrypt_domain_envelopes(&req.confidential_envelopes, Some(&conf_partials), threshold)?
} else {
decrypt_domain(&req.confidential_envelopes, hpke_sk)?
},
});
}
debug!(
ephemeral_count = req.ephemeral_envelopes.len(),
identity_count = req.identity_envelopes.len(),
confidential_count = req.confidential_envelopes.len(),
"decrypt_request: direct HPKE decryption (no threshold)"
);
let hpke_sk = hpke_sk.ok_or_else(|| EnclaveError::MissingInput("hpke key".to_string()))?;
Ok(DecryptedRequest {
ephemeral: decrypt_ephemeral(&req.ephemeral_envelopes, hpke_sk)?,
identity: decrypt_domain(&req.identity_envelopes, hpke_sk)?,
confidential: decrypt_domain(&req.confidential_envelopes, hpke_sk)?,
})
}
type PartialsByType = (
Vec<Vec<newton_aggregator::PartialDecryptionData>>,
Vec<Vec<newton_aggregator::PartialDecryptionData>>,
Vec<Vec<newton_aggregator::PartialDecryptionData>>,
);
fn split_and_transpose_partials(
all_peer_partials: &[Vec<newton_aggregator::PartialDecryptionData>],
eph_count: usize,
id_count: usize,
conf_count: usize,
) -> Result<PartialsByType, EnclaveError> {
let total = eph_count + id_count + conf_count;
if total == 0 {
return Ok((vec![], vec![], vec![]));
}
for (peer_idx, peer_partials) in all_peer_partials.iter().enumerate() {
if peer_partials.len() != total {
return Err(EnclaveError::ThresholdFailed(format!(
"peer {peer_idx} has {} partials, expected {total}",
peer_partials.len()
)));
}
}
let mut eph = vec![Vec::with_capacity(all_peer_partials.len()); eph_count];
let mut id = vec![Vec::with_capacity(all_peer_partials.len()); id_count];
let mut conf = vec![Vec::with_capacity(all_peer_partials.len()); conf_count];
for peer_partials in all_peer_partials {
for (env_idx, partial) in peer_partials[..eph_count].iter().enumerate() {
eph[env_idx].push(partial.clone());
}
for (env_idx, partial) in peer_partials[eph_count..eph_count + id_count].iter().enumerate() {
id[env_idx].push(partial.clone());
}
for (env_idx, partial) in peer_partials[eph_count + id_count..].iter().enumerate() {
conf[env_idx].push(partial.clone());
}
}
Ok((eph, id, conf))
}
fn decrypt_ephemeral(
envelopes: &[SecureEnvelope],
hpke_sk: &HpkePrivateKey,
) -> Result<Option<serde_json::Value>, EnclaveError> {
if envelopes.is_empty() {
return Ok(None);
}
let values = envelopes
.iter()
.enumerate()
.map(|(i, envelope)| {
decrypt_value(envelope, hpke_sk).map_err(|e| EnclaveError::DecryptFailed(format!("ephemeral {i}: {e}")))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Some(serde_json::json!({ "inline": values })))
}
fn decrypt_domain(
envelopes: &[EnclaveEnvelope],
hpke_sk: &HpkePrivateKey,
) -> Result<Vec<DecryptedEnvelope>, EnclaveError> {
envelopes
.iter()
.enumerate()
.map(|(i, item)| {
let value = decrypt_value(&item.envelope, hpke_sk)
.map_err(|e| EnclaveError::DecryptFailed(format!("domain envelope {i}: {e}")))?;
Ok(DecryptedEnvelope {
domain: item.domain,
value,
})
})
.collect()
}
pub(crate) fn decrypt_value(
envelope: &SecureEnvelope,
hpke_sk: &HpkePrivateKey,
) -> Result<serde_json::Value, EnclaveError> {
let plaintext = envelope
.open(hpke_sk)
.map_err(|e| EnclaveError::DecryptFailed(e.to_string()))?;
serde_json::from_slice(&plaintext).map_err(|e| EnclaveError::DecryptFailed(format!("plaintext is not json: {e}")))
}