use crate::{
common::*,
crypto::{self, LocalKeyPair, RemotePublicKey},
error::*,
Cryptographer,
};
use byteorder::{BigEndian, ByteOrder};
const ECE_AES128GCM_MIN_RS: u32 = 18;
const ECE_AES128GCM_HEADER_LENGTH: usize = 21;
pub(crate) const ECE_AES128GCM_PAD_SIZE: usize = 1;
const ECE_WEBPUSH_AES128GCM_IKM_INFO_PREFIX: &str = "WebPush: info\0";
const ECE_WEBPUSH_AES128GCM_IKM_INFO_LENGTH: usize = 144;
const ECE_WEBPUSH_IKM_LENGTH: usize = 32;
const ECE_AES128GCM_KEY_INFO: &str = "Content-Encoding: aes128gcm\0";
const ECE_AES128GCM_NONCE_INFO: &str = "Content-Encoding: nonce\0";
pub(crate) fn encrypt(
local_prv_key: &dyn LocalKeyPair,
remote_pub_key: &dyn RemotePublicKey,
auth_secret: &[u8],
plaintext: &[u8],
mut params: WebPushParams,
) -> Result<Vec<u8>> {
let cryptographer = crypto::holder::get_cryptographer();
if plaintext.is_empty() {
return Err(Error::ZeroPlaintext);
}
let salt = params.take_or_generate_salt(cryptographer)?;
let (key, nonce) = derive_key_and_nonce(
cryptographer,
EceMode::Encrypt,
local_prv_key,
remote_pub_key,
auth_secret,
&salt,
)?;
let keyid = local_prv_key.pub_as_raw()?;
if keyid.len() != ECE_WEBPUSH_PUBLIC_KEY_LENGTH {
return Err(Error::InvalidKeyLength);
}
let header = Header {
salt: &salt,
rs: params.rs,
keyid: &keyid,
};
let records = split_into_records(plaintext, params.pad_length, params.rs as usize)?;
let mut ciphertext = vec![0; header.encoded_size() + records.total_ciphertext_size()];
let mut offset = 0;
offset += header.write_into(&mut ciphertext);
for record in records {
offset += record.encrypt_into(cryptographer, &key, &nonce, &mut ciphertext[offset..])?;
}
assert!(offset == ciphertext.len());
Ok(ciphertext)
}
pub(crate) fn decrypt(
local_prv_key: &dyn LocalKeyPair,
auth_secret: &[u8],
ciphertext: &[u8],
) -> Result<Vec<u8>> {
let cryptographer = crypto::holder::get_cryptographer();
if ciphertext.is_empty() {
return Err(Error::ZeroCiphertext);
}
let mut output = Vec::<u8>::with_capacity(ciphertext.len());
let header = Header::read_from(ciphertext)?;
if ciphertext.len() == header.encoded_size() {
return Err(Error::ZeroCiphertext);
}
if header.keyid.len() != ECE_WEBPUSH_PUBLIC_KEY_LENGTH {
return Err(Error::InvalidKeyLength);
}
let remote_pub_key = cryptographer.import_public_key(header.keyid)?;
let (key, nonce) = derive_key_and_nonce(
cryptographer,
EceMode::Decrypt,
local_prv_key,
&*remote_pub_key,
auth_secret,
header.salt,
)?;
let mut plaintext_buffer = vec![0u8; (header.rs as usize) - ECE_TAG_LENGTH];
let records = ciphertext[header.encoded_size()..].chunks(header.rs as usize);
let mut seen_final_record = false;
for (sequence_number, ciphertext) in records.enumerate() {
if seen_final_record {
return Err(Error::DecryptPadding);
}
let record = PlaintextRecord::decrypt_from(
cryptographer,
&key,
&nonce,
sequence_number,
ciphertext,
plaintext_buffer.as_mut_slice(),
)?;
if record.is_final {
seen_final_record = true;
}
output.extend(record.plaintext)
}
if !seen_final_record {
return Err(Error::DecryptTruncated);
}
Ok(output)
}
pub(crate) struct Header<'a> {
salt: &'a [u8],
rs: u32,
keyid: &'a [u8],
}
impl<'a> Header<'a> {
fn read_from(input: &'a [u8]) -> Result<Header<'a>> {
if input.len() < ECE_AES128GCM_HEADER_LENGTH {
return Err(Error::HeaderTooShort);
}
let keyid_len = input[ECE_AES128GCM_HEADER_LENGTH - 1] as usize;
if input.len() < ECE_AES128GCM_HEADER_LENGTH + keyid_len {
return Err(Error::HeaderTooShort);
}
let salt = &input[0..ECE_SALT_LENGTH];
let rs = BigEndian::read_u32(&input[ECE_SALT_LENGTH..]);
if rs < ECE_AES128GCM_MIN_RS {
return Err(Error::InvalidRecordSize);
}
let keyid = &input[ECE_AES128GCM_HEADER_LENGTH..ECE_AES128GCM_HEADER_LENGTH + keyid_len];
Ok(Header { salt, rs, keyid })
}
pub fn write_into(&self, output: &mut [u8]) -> usize {
output[0..ECE_SALT_LENGTH].copy_from_slice(self.salt);
BigEndian::write_u32(&mut output[ECE_SALT_LENGTH..], self.rs);
output[ECE_AES128GCM_HEADER_LENGTH - 1] = self.keyid.len() as u8;
output[ECE_AES128GCM_HEADER_LENGTH..ECE_AES128GCM_HEADER_LENGTH + self.keyid.len()]
.copy_from_slice(self.keyid);
self.encoded_size()
}
pub fn encoded_size(&self) -> usize {
ECE_AES128GCM_HEADER_LENGTH + self.keyid.len()
}
}
struct PlaintextRecord<'a> {
plaintext: &'a [u8],
padding: usize,
sequence_number: usize,
is_final: bool,
}
impl<'a> PlaintextRecord<'a> {
pub(crate) fn decrypt_from(
cryptographer: &dyn Cryptographer,
key: &[u8],
nonce: &[u8],
sequence_number: usize,
ciphertext: &[u8],
plaintext_buffer: &'a mut [u8],
) -> Result<Self> {
if ciphertext.len() <= ECE_TAG_LENGTH {
return Err(Error::BlockTooShort);
}
let iv = generate_iv_for_record(nonce, sequence_number);
let padded_plaintext = cryptographer.aes_gcm_128_decrypt(key, &iv, ciphertext)?;
let padding_delimiter_idx = padded_plaintext
.iter()
.rposition(|&b| b != 0u8)
.ok_or(Error::DecryptPadding)?;
let is_final = match padded_plaintext[padding_delimiter_idx] {
1 => false,
2 => true,
_ => return Err(Error::DecryptPadding),
};
plaintext_buffer[0..padding_delimiter_idx]
.copy_from_slice(&padded_plaintext[0..padding_delimiter_idx]);
Ok(PlaintextRecord {
plaintext: &plaintext_buffer[0..padding_delimiter_idx],
padding: padded_plaintext.len() - padding_delimiter_idx,
sequence_number,
is_final,
})
}
pub(crate) fn encrypt_into(
&self,
cryptographer: &dyn Cryptographer,
key: &[u8],
nonce: &[u8],
output: &mut [u8],
) -> Result<usize> {
let padded_plaintext_len = self.plaintext.len() + self.padding;
output[0..self.plaintext.len()].copy_from_slice(self.plaintext);
assert!(self.padding >= 1);
output[self.plaintext.len()] = if self.is_final { 2 } else { 1 };
output[self.plaintext.len() + 1..padded_plaintext_len].fill(0);
let iv = generate_iv_for_record(nonce, self.sequence_number);
let ciphertext =
cryptographer.aes_gcm_128_encrypt(key, &iv, &output[0..padded_plaintext_len])?;
output[0..ciphertext.len()].copy_from_slice(&ciphertext);
Ok(ciphertext.len())
}
}
fn split_into_records(
plaintext: &[u8],
pad_length: usize,
rs: usize,
) -> Result<PlaintextRecordIterator<'_>> {
if rs < ECE_AES128GCM_MIN_RS as usize {
return Err(Error::InvalidRecordSize);
}
let rs = rs - ECE_TAG_LENGTH;
let mut min_num_records = plaintext.len() / (rs - 1);
if plaintext.len() % (rs - 1) != 0 {
min_num_records += 1;
}
let pad_length = std::cmp::max(pad_length, min_num_records);
let total_size = plaintext.len() + pad_length;
let mut num_records = total_size / rs;
let size_of_final_record = total_size % rs;
if size_of_final_record > 0 {
num_records += 1;
}
assert!(
num_records >= min_num_records,
"record chunking error: we miscalculated the minimum number of records ({} < {})",
num_records,
min_num_records,
);
let plaintext_per_record = plaintext.len() / num_records;
let mut extra_plaintext = plaintext.len() % num_records;
if size_of_final_record > 0 && plaintext_per_record > size_of_final_record - 1 {
extra_plaintext += plaintext_per_record - (size_of_final_record - 1)
}
Ok(PlaintextRecordIterator {
plaintext,
pad_length,
plaintext_per_record,
extra_plaintext,
rs,
sequence_number: 0,
num_records,
total_size,
})
}
struct PlaintextRecordIterator<'a> {
plaintext: &'a [u8],
pad_length: usize,
plaintext_per_record: usize,
extra_plaintext: usize,
total_size: usize,
rs: usize,
num_records: usize,
sequence_number: usize,
}
impl<'a> PlaintextRecordIterator<'a> {
pub(crate) fn total_ciphertext_size(&self) -> usize {
self.total_size + self.num_records * ECE_TAG_LENGTH
}
}
impl<'a> Iterator for PlaintextRecordIterator<'a> {
type Item = PlaintextRecord<'a>;
fn next(&mut self) -> Option<Self::Item> {
let records_remaining = self.num_records - self.sequence_number;
if records_remaining == 0 {
assert!(
self.plaintext.is_empty(),
"record chunking error: the plaintext was not fully consumed"
);
assert!(
self.extra_plaintext == 0,
"record chunking error: the extra plaintext was not fully consumed"
);
assert!(
self.pad_length == 0,
"record chunking error: the padding was not fully consumed"
);
return None;
}
let mut plaintext_share = self.plaintext_per_record;
if plaintext_share > self.plaintext.len() {
assert!(
records_remaining == 1,
"record chunking error: the plaintext was consumed too early"
);
plaintext_share = self.plaintext.len();
} else {
if self.extra_plaintext > 0 {
let mut extra_share = self.extra_plaintext / (records_remaining - 1);
if self.extra_plaintext % (records_remaining - 1) != 0 {
extra_share += 1;
}
plaintext_share += extra_share;
self.extra_plaintext -= extra_share;
}
}
let plaintext = &self.plaintext[0..plaintext_share];
self.plaintext = &self.plaintext[plaintext_share..];
let padding_share = std::cmp::min(self.pad_length, self.rs - plaintext_share);
self.pad_length -= padding_share;
assert!(
padding_share > 0,
"record chunking error: the padding was consumed too early"
);
let sequence_number = self.sequence_number;
self.sequence_number += 1;
let is_final = self.sequence_number == self.num_records;
assert!(
is_final || plaintext.len() + padding_share == self.rs,
"record chunking error: non-final record is too short"
);
Some(PlaintextRecord {
plaintext,
padding: padding_share,
sequence_number,
is_final,
})
}
}
fn derive_key_and_nonce(
cryptographer: &dyn Cryptographer,
ece_mode: EceMode,
local_prv_key: &dyn LocalKeyPair,
remote_pub_key: &dyn RemotePublicKey,
auth_secret: &[u8],
salt: &[u8],
) -> Result<KeyAndNonce> {
if auth_secret.len() != ECE_WEBPUSH_AUTH_SECRET_LENGTH {
return Err(Error::InvalidAuthSecret);
}
if salt.len() != ECE_SALT_LENGTH {
return Err(Error::InvalidSalt);
}
let shared_secret = cryptographer.compute_ecdh_secret(remote_pub_key, local_prv_key)?;
let raw_remote_pub_key = remote_pub_key.as_raw()?;
let raw_local_pub_key = local_prv_key.pub_as_raw()?;
let ikm_info = match ece_mode {
EceMode::Encrypt => generate_info(&raw_remote_pub_key, &raw_local_pub_key),
EceMode::Decrypt => generate_info(&raw_local_pub_key, &raw_remote_pub_key),
}?;
let ikm = cryptographer.hkdf_sha256(
auth_secret,
&shared_secret,
&ikm_info,
ECE_WEBPUSH_IKM_LENGTH,
)?;
let key = cryptographer.hkdf_sha256(
salt,
&ikm,
ECE_AES128GCM_KEY_INFO.as_bytes(),
ECE_AES_KEY_LENGTH,
)?;
let nonce = cryptographer.hkdf_sha256(
salt,
&ikm,
ECE_AES128GCM_NONCE_INFO.as_bytes(),
ECE_NONCE_LENGTH,
)?;
Ok((key, nonce))
}
fn generate_info(
raw_recv_pub_key: &[u8],
raw_sender_pub_key: &[u8],
) -> Result<[u8; ECE_WEBPUSH_AES128GCM_IKM_INFO_LENGTH]> {
let mut info = [0u8; ECE_WEBPUSH_AES128GCM_IKM_INFO_LENGTH];
let prefix = ECE_WEBPUSH_AES128GCM_IKM_INFO_PREFIX.as_bytes();
let mut offset = prefix.len();
info[0..offset].copy_from_slice(prefix);
info[offset..offset + ECE_WEBPUSH_PUBLIC_KEY_LENGTH].copy_from_slice(raw_recv_pub_key);
offset += ECE_WEBPUSH_PUBLIC_KEY_LENGTH;
info[offset..].copy_from_slice(raw_sender_pub_key);
Ok(info)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_into_records_17_0_20() {
let records = split_into_records(&[0u8; 17], 0, 20 + ECE_TAG_LENGTH)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(records.len(), 1);
assert_eq!(records[0].plaintext.len(), 17);
assert_eq!(records[0].padding, 1);
assert_eq!(records[0].sequence_number, 0);
assert!(records[0].is_final);
}
#[test]
fn test_split_into_records_15_0_6() {
let records = split_into_records(&[0u8; 15], 0, 6 + ECE_TAG_LENGTH)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(records.len(), 3);
assert_eq!(records[0].plaintext.len(), 5);
assert_eq!(records[0].padding, 1);
assert_eq!(records[0].sequence_number, 0);
assert!(!records[0].is_final);
assert_eq!(records[1].plaintext.len(), 5);
assert_eq!(records[1].padding, 1);
assert_eq!(records[1].sequence_number, 1);
assert!(!records[1].is_final);
assert_eq!(records[2].plaintext.len(), 5);
assert_eq!(records[2].padding, 1);
assert_eq!(records[2].sequence_number, 2);
assert!(records[2].is_final);
}
fn split_and_summarize(payload_len: usize, padding: usize, rs: usize) -> Vec<(usize, usize)> {
split_into_records(&vec![0u8; payload_len], padding, rs + ECE_TAG_LENGTH)
.unwrap()
.map(|record| (record.plaintext.len(), record.padding))
.collect()
}
#[test]
fn test_split_into_records_8_2_3() {
assert_eq!(
split_and_summarize(8, 2, 3),
vec![(2, 1), (2, 1), (2, 1), (2, 1)]
);
}
#[test]
fn test_split_into_records_8_0_8() {
assert_eq!(split_and_summarize(8, 0, 8), vec![(7, 1), (1, 1)]);
}
#[test]
fn test_split_into_records_24_6_8() {
assert_eq!(
split_and_summarize(24, 6, 8),
vec![(7, 1), (6, 2), (6, 2), (5, 1)]
);
}
#[test]
fn test_split_into_records_8_6_3() {
assert_eq!(
split_and_summarize(8, 6, 3),
vec![(2, 1), (2, 1), (2, 1), (1, 2), (1, 1)]
);
}
#[test]
fn test_split_into_records_3_25_8() {
assert_eq!(
split_and_summarize(3, 25, 8),
vec![(1, 7), (1, 7), (1, 7), (0, 4)]
);
}
#[test]
fn test_split_into_records_3_35_8() {
assert_eq!(
split_and_summarize(3, 35, 8),
vec![(1, 7), (1, 7), (1, 7), (0, 8), (0, 6)]
);
}
#[test]
fn test_split_into_records_19_6_8() {
assert_eq!(
split_and_summarize(19, 6, 8),
vec![(7, 1), (6, 2), (6, 2), (0, 1)]
);
}
}