use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use derivative::Derivative;
use hpke_dispatch::{HpkeError, Kem, Keypair};
use janus_messages::{
HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, Role,
};
use serde::{
de::{self, Visitor},
Deserialize, Serialize, Serializer,
};
use std::{
fmt::{self, Debug},
str::FromStr,
};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("HPKE error: {0}")]
Hpke(#[from] HpkeError),
#[error("invalid HPKE configuration: {0}")]
InvalidConfiguration(&'static str),
#[error("unsupported KEM")]
UnsupportedKem,
}
pub fn is_hpke_config_supported(config: &HpkeConfig) -> Result<(), Error> {
hpke_dispatch_config_from_hpke_config(config)?;
Ok(())
}
fn hpke_dispatch_config_from_hpke_config(
config: &HpkeConfig,
) -> Result<hpke_dispatch::Config, Error> {
Ok(hpke_dispatch::Config {
aead: u16::from(*config.aead_id())
.try_into()
.map_err(|_| Error::InvalidConfiguration("did not recognize aead"))?,
kdf: u16::from(*config.kdf_id())
.try_into()
.map_err(|_| Error::InvalidConfiguration("did not recognize kdf"))?,
kem: u16::from(*config.kem_id())
.try_into()
.map_err(|_| Error::InvalidConfiguration("did not recognize kem"))?,
})
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Label {
InputShare,
AggregateShare,
}
impl Label {
pub fn as_bytes(&self) -> &'static [u8] {
match self {
Self::InputShare => b"dap-07 input share",
Self::AggregateShare => b"dap-07 aggregate share",
}
}
}
#[derive(Clone, Debug)]
pub struct HpkeApplicationInfo(Vec<u8>);
impl HpkeApplicationInfo {
pub fn new(label: &Label, sender_role: &Role, recipient_role: &Role) -> Self {
Self(
[
label.as_bytes(),
&[*sender_role as u8],
&[*recipient_role as u8],
]
.concat(),
)
}
}
#[derive(Clone, Derivative, PartialEq, Eq)]
#[derivative(Debug)]
pub struct HpkePrivateKey(#[derivative(Debug = "ignore")] Vec<u8>);
impl HpkePrivateKey {
pub fn new(bytes: Vec<u8>) -> Self {
Self(bytes)
}
}
impl AsRef<[u8]> for HpkePrivateKey {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl From<Vec<u8>> for HpkePrivateKey {
fn from(v: Vec<u8>) -> Self {
Self::new(v)
}
}
impl FromStr for HpkePrivateKey {
type Err = hex::FromHexError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(HpkePrivateKey(hex::decode(s)?))
}
}
impl Serialize for HpkePrivateKey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = URL_SAFE_NO_PAD.encode(self.as_ref());
serializer.serialize_str(&encoded)
}
}
struct HpkePrivateKeyVisitor;
impl<'de> Visitor<'de> for HpkePrivateKeyVisitor {
type Value = HpkePrivateKey;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a base64url-encoded string")
}
fn visit_str<E>(self, value: &str) -> Result<HpkePrivateKey, E>
where
E: de::Error,
{
let decoded = URL_SAFE_NO_PAD
.decode(value)
.map_err(|_| E::custom("invalid base64url value"))?;
Ok(HpkePrivateKey::new(decoded))
}
}
impl<'de> Deserialize<'de> for HpkePrivateKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(HpkePrivateKeyVisitor)
}
}
pub fn seal(
recipient_config: &HpkeConfig,
application_info: &HpkeApplicationInfo,
plaintext: &[u8],
associated_data: &[u8],
) -> Result<HpkeCiphertext, Error> {
let output = hpke_dispatch_config_from_hpke_config(recipient_config)?.base_mode_seal(
recipient_config.public_key().as_ref(),
&application_info.0,
plaintext,
associated_data,
)?;
Ok(HpkeCiphertext::new(
*recipient_config.id(),
output.encapped_key,
output.ciphertext,
))
}
pub fn open(
recipient_keypair: &HpkeKeypair,
application_info: &HpkeApplicationInfo,
ciphertext: &HpkeCiphertext,
associated_data: &[u8],
) -> Result<Vec<u8>, Error> {
hpke_dispatch_config_from_hpke_config(recipient_keypair.config())?
.base_mode_open(
&recipient_keypair.private_key().0,
ciphertext.encapsulated_key(),
&application_info.0,
ciphertext.payload(),
associated_data,
)
.map_err(Into::into)
}
pub fn generate_hpke_config_and_private_key(
hpke_config_id: HpkeConfigId,
kem_id: HpkeKemId,
kdf_id: HpkeKdfId,
aead_id: HpkeAeadId,
) -> Result<HpkeKeypair, Error> {
let Keypair {
private_key,
public_key,
} = match kem_id {
HpkeKemId::X25519HkdfSha256 => Kem::X25519HkdfSha256.gen_keypair(),
HpkeKemId::P256HkdfSha256 => Kem::DhP256HkdfSha256.gen_keypair(),
_ => return Err(Error::UnsupportedKem),
};
Ok(HpkeKeypair::new(
HpkeConfig::new(
hpke_config_id,
kem_id,
kdf_id,
aead_id,
HpkePublicKey::from(public_key),
),
HpkePrivateKey::new(private_key),
))
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct HpkeKeypair {
config: HpkeConfig,
private_key: HpkePrivateKey, }
impl HpkeKeypair {
pub fn new(config: HpkeConfig, private_key: HpkePrivateKey) -> HpkeKeypair {
HpkeKeypair {
config,
private_key,
}
}
pub fn config(&self) -> &HpkeConfig {
&self.config
}
pub fn private_key(&self) -> &HpkePrivateKey {
&self.private_key
}
}
#[allow(deprecated)]
pub use deprecated::DivviUpHpkeConfig;
mod deprecated {
#![allow(deprecated)]
use super::{Error, HpkeKeypair, HpkePrivateKey};
use hpke_dispatch::{Aead, Kdf, Kem};
use janus_messages::{HpkeConfig, HpkeConfigId, HpkePublicKey};
use serde::Deserialize;
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[deprecated = "use CollectorCredential instead"]
pub struct DivviUpHpkeConfig {
id: HpkeConfigId,
kem: Kem,
kdf: Kdf,
aead: Aead,
public_key: HpkePublicKey,
private_key: HpkePrivateKey,
}
impl TryFrom<DivviUpHpkeConfig> for HpkeKeypair {
type Error = Error;
fn try_from(value: DivviUpHpkeConfig) -> Result<Self, Self::Error> {
Ok(Self::new(
HpkeConfig::new(
value.id,
(value.kem as u16).into(),
(value.kdf as u16).into(),
(value.aead as u16).into(),
value.public_key,
),
value.private_key,
))
}
}
}
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_util {
use super::{generate_hpke_config_and_private_key, HpkeKeypair};
use janus_messages::{HpkeAeadId, HpkeConfigId, HpkeKdfId, HpkeKemId};
use rand::random;
pub const SAMPLE_DIVVIUP_HPKE_CONFIG: &str = r#"{
"aead": "AesGcm128",
"id": 66,
"kdf": "Sha256",
"kem": "X25519HkdfSha256",
"private_key": "uKkTvzKLfYNUPZcoKI7hV64zS06OWgBkbivBL4Sw4mo",
"public_key": "CcDghts2boltt9GQtBUxdUsVR83SCVYHikcGh33aVlU"
}
"#;
pub fn generate_test_hpke_config_and_private_key() -> HpkeKeypair {
generate_hpke_config_and_private_key(
HpkeConfigId::from(random::<u8>()),
HpkeKemId::X25519HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
)
.unwrap()
}
pub fn generate_test_hpke_config_and_private_key_with_id(id: u8) -> HpkeKeypair {
generate_hpke_config_and_private_key(
HpkeConfigId::from(id),
HpkeKemId::X25519HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
)
.unwrap()
}
}
#[cfg(test)]
mod tests {
use super::{test_util::generate_test_hpke_config_and_private_key, HpkeApplicationInfo, Label};
#[allow(deprecated)]
use crate::hpke::{
open, seal, test_util::SAMPLE_DIVVIUP_HPKE_CONFIG, DivviUpHpkeConfig, HpkeKeypair,
HpkePrivateKey,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use hpke_dispatch::{Kem, Keypair};
use janus_messages::{
HpkeAeadId, HpkeCiphertext, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey,
Role,
};
use serde::Deserialize;
use std::collections::HashSet;
#[test]
fn exchange_message() {
let hpke_keypair = generate_test_hpke_config_and_private_key();
let application_info =
HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader);
let message = b"a message that is secret";
let associated_data = b"message associated data";
let ciphertext = seal(
hpke_keypair.config(),
&application_info,
message,
associated_data,
)
.unwrap();
let plaintext = open(
&hpke_keypair,
&application_info,
&ciphertext,
associated_data,
)
.unwrap();
assert_eq!(plaintext, message);
}
#[test]
fn wrong_private_key() {
let hpke_keypair = generate_test_hpke_config_and_private_key();
let application_info =
HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader);
let message = b"a message that is secret";
let associated_data = b"message associated data";
let ciphertext = seal(
hpke_keypair.config(),
&application_info,
message,
associated_data,
)
.unwrap();
let wrong_hpke_keypair = generate_test_hpke_config_and_private_key();
open(
&wrong_hpke_keypair,
&application_info,
&ciphertext,
associated_data,
)
.unwrap_err();
}
#[test]
fn wrong_application_info() {
let hpke_keypair = generate_test_hpke_config_and_private_key();
let application_info =
HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader);
let message = b"a message that is secret";
let associated_data = b"message associated data";
let ciphertext = seal(
hpke_keypair.config(),
&application_info,
message,
associated_data,
)
.unwrap();
let wrong_application_info =
HpkeApplicationInfo::new(&Label::AggregateShare, &Role::Client, &Role::Leader);
open(
&hpke_keypair,
&wrong_application_info,
&ciphertext,
associated_data,
)
.unwrap_err();
}
#[test]
fn wrong_associated_data() {
let hpke_keypair = generate_test_hpke_config_and_private_key();
let application_info =
HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader);
let message = b"a message that is secret";
let associated_data = b"message associated data";
let ciphertext = seal(
hpke_keypair.config(),
&application_info,
message,
associated_data,
)
.unwrap();
let wrong_associated_data = b"wrong associated data";
open(
&hpke_keypair,
&application_info,
&ciphertext,
wrong_associated_data,
)
.unwrap_err();
}
fn round_trip_check(kem_id: HpkeKemId, kdf_id: HpkeKdfId, aead_id: HpkeAeadId) {
const ASSOCIATED_DATA: &[u8] = b"round trip test associated data";
const MESSAGE: &[u8] = b"round trip test message";
let kem = Kem::try_from(u16::from(kem_id)).unwrap();
let Keypair {
private_key,
public_key,
} = kem.gen_keypair();
let hpke_config = HpkeConfig::new(
HpkeConfigId::from(0),
kem_id,
kdf_id,
aead_id,
HpkePublicKey::from(public_key),
);
let hpke_private_key = HpkePrivateKey::new(private_key);
let hpke_keypair = HpkeKeypair::new(hpke_config, hpke_private_key);
let application_info =
HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, &Role::Leader);
let ciphertext = seal(
hpke_keypair.config(),
&application_info,
MESSAGE,
ASSOCIATED_DATA,
)
.unwrap();
let plaintext = open(
&hpke_keypair,
&application_info,
&ciphertext,
ASSOCIATED_DATA,
)
.unwrap();
assert_eq!(plaintext, MESSAGE);
}
#[test]
fn round_trip_all_algorithms() {
for kem_id in [HpkeKemId::P256HkdfSha256, HpkeKemId::X25519HkdfSha256] {
for kdf_id in [
HpkeKdfId::HkdfSha256,
HpkeKdfId::HkdfSha384,
HpkeKdfId::HkdfSha512,
] {
for aead_id in [HpkeAeadId::Aes128Gcm, HpkeAeadId::Aes256Gcm] {
round_trip_check(kem_id, kdf_id, aead_id)
}
}
}
}
#[derive(Deserialize)]
struct EncryptionRecord {
#[serde(with = "hex")]
aad: Vec<u8>,
#[serde(with = "hex")]
ct: Vec<u8>,
#[serde(with = "hex")]
nonce: Vec<u8>,
#[serde(with = "hex")]
pt: Vec<u8>,
}
#[derive(Deserialize)]
struct TestVector {
mode: u16,
kem_id: u16,
kdf_id: u16,
aead_id: u16,
#[serde(with = "hex")]
info: Vec<u8>,
#[serde(with = "hex")]
enc: Vec<u8>,
#[serde(with = "hex", rename = "pkRm")]
serialized_public_key: Vec<u8>,
#[serde(with = "hex", rename = "skRm")]
serialized_private_key: Vec<u8>,
#[serde(with = "hex")]
base_nonce: Vec<u8>,
encryptions: Vec<EncryptionRecord>,
}
#[test]
fn decrypt_test_vectors() {
let test_vectors: Vec<TestVector> =
serde_json::from_str(include_str!("test-vectors.json")).unwrap();
let mut algorithms_tested = HashSet::new();
for test_vector in test_vectors {
if test_vector.mode != 0 {
continue;
}
let kem_id = match HpkeKemId::from(test_vector.kem_id) {
kem_id @ HpkeKemId::P256HkdfSha256 | kem_id @ HpkeKemId::X25519HkdfSha256 => kem_id,
_ => {
continue;
}
};
let kdf_id = test_vector.kdf_id.into();
if test_vector.aead_id == 0xffff {
continue;
}
let aead_id = test_vector.aead_id.into();
for encryption in test_vector.encryptions {
if encryption.nonce != test_vector.base_nonce {
continue;
}
let hpke_config = HpkeConfig::new(
HpkeConfigId::from(0),
kem_id,
kdf_id,
aead_id,
HpkePublicKey::from(test_vector.serialized_public_key.clone()),
);
let hpke_private_key = HpkePrivateKey(test_vector.serialized_private_key.clone());
let hpke_keypair = HpkeKeypair::new(hpke_config, hpke_private_key);
let application_info = HpkeApplicationInfo(test_vector.info.clone());
let ciphertext = HpkeCiphertext::new(
HpkeConfigId::from(0),
test_vector.enc.clone(),
encryption.ct,
);
let plaintext = open(
&hpke_keypair,
&application_info,
&ciphertext,
&encryption.aad,
)
.unwrap();
assert_eq!(plaintext, encryption.pt);
algorithms_tested.insert((
u16::from(kem_id),
u16::from(kdf_id),
u16::from(aead_id),
));
}
}
assert_eq!(algorithms_tested.len(), 12);
}
#[test]
#[allow(deprecated)]
fn deserialize_divviup_api_hpke_config() {
let deserialized: DivviUpHpkeConfig =
serde_json::from_str(SAMPLE_DIVVIUP_HPKE_CONFIG).unwrap();
let hpke_keypair = HpkeKeypair::try_from(deserialized).unwrap();
assert_eq!(
hpke_keypair,
HpkeKeypair::new(
HpkeConfig::new(
HpkeConfigId::from(66),
HpkeKemId::X25519HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
HpkePublicKey::from(
URL_SAFE_NO_PAD
.decode("CcDghts2boltt9GQtBUxdUsVR83SCVYHikcGh33aVlU")
.unwrap()
),
),
HpkePrivateKey::from(
URL_SAFE_NO_PAD
.decode("uKkTvzKLfYNUPZcoKI7hV64zS06OWgBkbivBL4Sw4mo")
.unwrap()
)
),
);
}
}