use self::query_type::{FixedSize, QueryType, TimeInterval};
use anyhow::anyhow;
use base64::{display::Base64Display, engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use derivative::Derivative;
use num_enum::{FromPrimitive, IntoPrimitive, TryFromPrimitive};
use prio::{
codec::{
decode_u16_items, decode_u32_items, encode_u16_items, encode_u32_items, CodecError, Decode,
Encode,
},
topology::ping_pong::PingPongMessage,
};
use rand::{distributions::Standard, prelude::Distribution, Rng};
use serde::{
de::{self, Visitor},
Deserialize, Serialize, Serializer,
};
use std::{
fmt::{self, Debug, Display, Formatter},
io::{Cursor, Read},
num::TryFromIntError,
str,
str::FromStr,
time::{SystemTime, SystemTimeError},
};
pub use prio::codec;
pub mod problem_type;
pub mod query_type;
pub mod taskprov;
#[cfg(test)]
mod tests;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("{0}")]
InvalidParameter(&'static str),
#[error("{0}")]
IllegalTimeArithmetic(&'static str),
#[error("base64 decode failure: {0}")]
Base64Decode(#[from] base64::DecodeError),
}
#[derive(Clone, PartialEq, Eq)]
pub struct Url(Vec<u8>);
impl Url {
const MAX_LEN: usize = 2usize.pow(16) - 1;
}
impl Encode for Url {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
encode_u16_items(bytes, &(), &self.0)
}
fn encoded_len(&self) -> Option<usize> {
Some(2 + self.0.len())
}
}
impl Decode for Url {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Url::try_from(decode_u16_items(&(), bytes)?.as_ref())
}
}
impl Debug for Url {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
str::from_utf8(&self.0).map_err(|_| std::fmt::Error)?
)
}
}
impl Display for Url {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
str::from_utf8(&self.0).map_err(|_| std::fmt::Error)?
)
}
}
impl TryFrom<&[u8]> for Url {
type Error = CodecError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.is_empty() {
Err(CodecError::Other(
anyhow!("Url must be at least 1 byte long").into(),
))
} else if value.len() > Url::MAX_LEN {
Err(CodecError::Other(
anyhow!("Url must be less than {} bytes long", Url::MAX_LEN).into(),
))
} else if !value.iter().all(|i: &u8| i.is_ascii()) {
Err(CodecError::Other(
anyhow!("Url must be ASCII encoded").into(),
))
} else {
Ok(Self(Vec::from(value)))
}
}
}
impl TryFrom<&Url> for url::Url {
type Error = url::ParseError;
fn try_from(value: &Url) -> Result<Self, Self::Error> {
url::Url::parse(str::from_utf8(&value.0).unwrap())
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Duration(u64);
impl Duration {
pub const ZERO: Duration = Duration::from_seconds(0);
pub const fn from_seconds(seconds: u64) -> Self {
Self(seconds)
}
pub fn as_seconds(&self) -> u64 {
self.0
}
}
impl Encode for Duration {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.0.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
self.0.encoded_len()
}
}
impl Decode for Duration {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(Self(u64::decode(bytes)?))
}
}
impl Display for Duration {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{} seconds", self.0)
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Time(u64);
impl Time {
pub const fn from_seconds_since_epoch(timestamp: u64) -> Self {
Self(timestamp)
}
pub fn as_seconds_since_epoch(&self) -> u64 {
self.0
}
}
impl Display for Time {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Encode for Time {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.0.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
self.0.encoded_len()
}
}
impl Decode for Time {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(Self(u64::decode(bytes)?))
}
}
impl TryFrom<SystemTime> for Time {
type Error = SystemTimeError;
fn try_from(time: SystemTime) -> Result<Self, Self::Error> {
let duration = time.duration_since(SystemTime::UNIX_EPOCH)?;
Ok(Time::from_seconds_since_epoch(duration.as_secs()))
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct Interval {
start: Time,
duration: Duration,
}
impl Interval {
pub const EMPTY: Self = Self {
start: Time::from_seconds_since_epoch(0),
duration: Duration::ZERO,
};
pub fn new(start: Time, duration: Duration) -> Result<Self, Error> {
start
.0
.checked_add(duration.0)
.ok_or(Error::IllegalTimeArithmetic("duration overflows time"))?;
Ok(Self { start, duration })
}
pub fn start(&self) -> &Time {
&self.start
}
pub fn duration(&self) -> &Duration {
&self.duration
}
}
impl Encode for Interval {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.start.encode(bytes)?;
self.duration.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(self.start.encoded_len()? + self.duration.encoded_len()?)
}
}
impl Decode for Interval {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let start = Time::decode(bytes)?;
let duration = Duration::decode(bytes)?;
Self::new(start, duration).map_err(|e| CodecError::Other(Box::new(e)))
}
}
impl Display for Interval {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "start: {} duration: {}", self.start, self.duration)
}
}
#[derive(Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct BatchId([u8; Self::LEN]);
impl BatchId {
pub const LEN: usize = 32;
}
impl From<[u8; Self::LEN]> for BatchId {
fn from(batch_id: [u8; Self::LEN]) -> Self {
Self(batch_id)
}
}
impl<'a> TryFrom<&'a [u8]> for BatchId {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
Error::InvalidParameter("byte slice has incorrect length for BatchId")
})?))
}
}
impl FromStr for BatchId {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(URL_SAFE_NO_PAD.decode(s)?.as_ref())
}
}
impl AsRef<[u8; Self::LEN]> for BatchId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
}
}
impl Debug for BatchId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"BatchId({})",
Base64Display::new(&self.0, &URL_SAFE_NO_PAD)
)
}
}
impl Display for BatchId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Base64Display::new(&self.0, &URL_SAFE_NO_PAD))
}
}
impl Encode for BatchId {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.extend_from_slice(&self.0);
Ok(())
}
fn encoded_len(&self) -> Option<usize> {
Some(Self::LEN)
}
}
impl Decode for BatchId {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut batch_id = [0; Self::LEN];
bytes.read_exact(&mut batch_id)?;
Ok(Self(batch_id))
}
}
impl Distribution<BatchId> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> BatchId {
BatchId(rng.gen())
}
}
#[derive(Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct ReportId([u8; Self::LEN]);
impl ReportId {
pub const LEN: usize = 16;
}
impl From<[u8; Self::LEN]> for ReportId {
fn from(report_id: [u8; Self::LEN]) -> Self {
Self(report_id)
}
}
impl<'a> TryFrom<&'a [u8]> for ReportId {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
Error::InvalidParameter("byte slice has incorrect length for ReportId")
})?))
}
}
impl AsRef<[u8; Self::LEN]> for ReportId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
}
}
impl Debug for ReportId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"ReportId({})",
Base64Display::new(&self.0, &URL_SAFE_NO_PAD)
)
}
}
impl Display for ReportId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Base64Display::new(&self.0, &URL_SAFE_NO_PAD))
}
}
impl Encode for ReportId {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.extend_from_slice(&self.0);
Ok(())
}
fn encoded_len(&self) -> Option<usize> {
Some(Self::LEN)
}
}
impl Decode for ReportId {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut report_id = [0; Self::LEN];
bytes.read_exact(&mut report_id)?;
Ok(Self(report_id))
}
}
impl FromStr for ReportId {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(URL_SAFE_NO_PAD.decode(s)?.as_ref())
}
}
impl Distribution<ReportId> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReportId {
ReportId(rng.gen())
}
}
#[derive(Copy, Clone, Debug, Default, Hash, PartialEq, Eq)]
pub struct ReportIdChecksum([u8; Self::LEN]);
impl ReportIdChecksum {
pub const LEN: usize = 32;
}
impl From<[u8; Self::LEN]> for ReportIdChecksum {
fn from(checksum: [u8; Self::LEN]) -> Self {
Self(checksum)
}
}
impl<'a> TryFrom<&'a [u8]> for ReportIdChecksum {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
Error::InvalidParameter("byte slice has incorrect length for ReportIdChecksum")
})?))
}
}
impl AsRef<[u8]> for ReportIdChecksum {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl AsMut<[u8]> for ReportIdChecksum {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.0
}
}
impl Display for ReportIdChecksum {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", hex::encode(self.0))
}
}
impl Encode for ReportIdChecksum {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.extend_from_slice(&self.0);
Ok(())
}
fn encoded_len(&self) -> Option<usize> {
Some(Self::LEN)
}
}
impl Decode for ReportIdChecksum {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut checksum = Self::default();
bytes.read_exact(&mut checksum.0)?;
Ok(checksum)
}
}
#[cfg(feature = "test-util")]
impl Distribution<ReportIdChecksum> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReportIdChecksum {
ReportIdChecksum(rng.gen())
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, TryFromPrimitive, Serialize, Deserialize)]
#[repr(u8)]
pub enum Role {
Collector = 0,
Client = 1,
Leader = 2,
Helper = 3,
}
impl Role {
pub fn is_aggregator(&self) -> bool {
matches!(self, Role::Leader | Role::Helper)
}
pub fn index(&self) -> Option<usize> {
match self {
Role::Leader => Some(0),
Role::Helper => Some(1),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Collector => "collector",
Self::Client => "client",
Self::Leader => "leader",
Self::Helper => "helper",
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("unknown role {0}")]
pub struct RoleParseError(String);
impl FromStr for Role {
type Err = RoleParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"collector" => Ok(Self::Collector),
"client" => Ok(Self::Client),
"leader" => Ok(Self::Leader),
"helper" => Ok(Self::Helper),
_ => Err(RoleParseError(s.to_owned())),
}
}
}
impl Encode for Role {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
(*self as u8).encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(1)
}
}
impl Decode for Role {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let val = u8::decode(bytes)?;
Self::try_from(val)
.map_err(|_| CodecError::Other(anyhow!("unexpected Role value {}", val).into()))
}
}
impl Display for Role {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct HpkeConfigId(u8);
impl Display for HpkeConfigId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Encode for HpkeConfigId {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.0.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
self.0.encoded_len()
}
}
impl Decode for HpkeConfigId {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(Self(u8::decode(bytes)?))
}
}
impl From<u8> for HpkeConfigId {
fn from(value: u8) -> HpkeConfigId {
HpkeConfigId(value)
}
}
impl From<HpkeConfigId> for u8 {
fn from(id: HpkeConfigId) -> u8 {
id.0
}
}
impl Distribution<HpkeConfigId> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> HpkeConfigId {
HpkeConfigId(rng.gen())
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct TaskId([u8; Self::LEN]);
impl TaskId {
pub const LEN: usize = 32;
}
impl Debug for TaskId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"TaskId({})",
Base64Display::new(&self.0, &URL_SAFE_NO_PAD)
)
}
}
impl Display for TaskId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Base64Display::new(&self.0, &URL_SAFE_NO_PAD))
}
}
impl Encode for TaskId {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
bytes.extend_from_slice(&self.0);
Ok(())
}
fn encoded_len(&self) -> Option<usize> {
Some(Self::LEN)
}
}
impl Decode for TaskId {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut decoded = [0u8; Self::LEN];
bytes.read_exact(&mut decoded)?;
Ok(Self(decoded))
}
}
impl From<[u8; Self::LEN]> for TaskId {
fn from(task_id: [u8; Self::LEN]) -> Self {
Self(task_id)
}
}
impl<'a> TryFrom<&'a [u8]> for TaskId {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
Error::InvalidParameter("byte slice has incorrect length for TaskId")
})?))
}
}
impl AsRef<[u8; Self::LEN]> for TaskId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
}
}
impl FromStr for TaskId {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(URL_SAFE_NO_PAD.decode(s)?.as_ref())
}
}
impl Distribution<TaskId> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> TaskId {
TaskId(rng.gen())
}
}
impl Serialize for TaskId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = URL_SAFE_NO_PAD.encode(self.as_ref());
serializer.serialize_str(&encoded)
}
}
struct TaskIdVisitor;
impl<'de> Visitor<'de> for TaskIdVisitor {
type Value = TaskId;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a base64url-encoded string that decodes to 32 bytes")
}
fn visit_str<E>(self, value: &str) -> Result<TaskId, E>
where
E: de::Error,
{
let decoded = URL_SAFE_NO_PAD
.decode(value)
.map_err(|_| E::custom("invalid base64url value"))?;
TaskId::try_from(decoded.as_slice()).map_err(|e| E::custom(e))
}
}
impl<'de> Deserialize<'de> for TaskId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(TaskIdVisitor)
}
}
#[derive(
Clone, Copy, Debug, PartialEq, Eq, FromPrimitive, IntoPrimitive, Serialize, Deserialize,
)]
#[repr(u16)]
#[non_exhaustive]
pub enum HpkeKemId {
P256HkdfSha256 = 0x0010,
P384HkdfSha384 = 0x0011,
P521HkdfSha512 = 0x0012,
X25519HkdfSha256 = 0x0020,
X448HkdfSha512 = 0x0021,
#[num_enum(catch_all)]
Other(u16),
}
impl Encode for HpkeKemId {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
u16::from(*self).encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(2)
}
}
impl Decode for HpkeKemId {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let val = u16::decode(bytes)?;
Ok(Self::from(val))
}
}
#[derive(
Clone, Copy, Debug, PartialEq, Eq, FromPrimitive, IntoPrimitive, Serialize, Deserialize,
)]
#[repr(u16)]
#[non_exhaustive]
pub enum HpkeKdfId {
HkdfSha256 = 0x0001,
HkdfSha384 = 0x0002,
HkdfSha512 = 0x0003,
#[num_enum(catch_all)]
Other(u16),
}
impl Encode for HpkeKdfId {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
u16::from(*self).encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(2)
}
}
impl Decode for HpkeKdfId {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let val = u16::decode(bytes)?;
Ok(Self::from(val))
}
}
#[derive(
Clone, Copy, Debug, PartialEq, Eq, FromPrimitive, IntoPrimitive, Serialize, Deserialize,
)]
#[repr(u16)]
#[non_exhaustive]
pub enum HpkeAeadId {
Aes128Gcm = 0x0001,
Aes256Gcm = 0x0002,
ChaCha20Poly1305 = 0x0003,
#[num_enum(catch_all)]
Other(u16),
}
impl Encode for HpkeAeadId {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
u16::from(*self).encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(2)
}
}
impl Decode for HpkeAeadId {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let val = u16::decode(bytes)?;
Ok(Self::from(val))
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Extension {
extension_type: ExtensionType,
extension_data: Vec<u8>,
}
impl Extension {
pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
Extension {
extension_type,
extension_data,
}
}
pub fn extension_type(&self) -> &ExtensionType {
&self.extension_type
}
pub fn extension_data(&self) -> &[u8] {
&self.extension_data
}
}
impl Encode for Extension {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.extension_type.encode(bytes)?;
encode_u16_items(bytes, &(), &self.extension_data)
}
fn encoded_len(&self) -> Option<usize> {
Some(self.extension_type.encoded_len()? + 2 + self.extension_data.len())
}
}
impl Decode for Extension {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let extension_type = ExtensionType::decode(bytes)?;
let extension_data = decode_u16_items(&(), bytes)?;
Ok(Self {
extension_type,
extension_data,
})
}
}
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, TryFromPrimitive)]
#[repr(u16)]
#[non_exhaustive]
pub enum ExtensionType {
Tbd = 0,
Taskprov = 0xFF00,
}
impl Encode for ExtensionType {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
(*self as u16).encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(2)
}
}
impl Decode for ExtensionType {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let val = u16::decode(bytes)?;
Self::try_from(val).map_err(|_| {
CodecError::Other(anyhow!("unexpected ExtensionType value {}", val).into())
})
}
}
#[derive(Clone, Derivative, Eq, PartialEq)]
#[derivative(Debug)]
pub struct HpkeCiphertext {
config_id: HpkeConfigId,
#[derivative(Debug = "ignore")]
encapsulated_key: Vec<u8>,
#[derivative(Debug = "ignore")]
payload: Vec<u8>,
}
impl HpkeCiphertext {
pub fn new(
config_id: HpkeConfigId,
encapsulated_key: Vec<u8>,
payload: Vec<u8>,
) -> HpkeCiphertext {
HpkeCiphertext {
config_id,
encapsulated_key,
payload,
}
}
pub fn config_id(&self) -> &HpkeConfigId {
&self.config_id
}
pub fn encapsulated_key(&self) -> &[u8] {
&self.encapsulated_key
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
}
impl Encode for HpkeCiphertext {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.config_id.encode(bytes)?;
encode_u16_items(bytes, &(), &self.encapsulated_key)?;
encode_u32_items(bytes, &(), &self.payload)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.config_id.encoded_len()?
+ 2
+ self.encapsulated_key.len()
+ 4
+ self.payload.len(),
)
}
}
impl Decode for HpkeCiphertext {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let config_id = HpkeConfigId::decode(bytes)?;
let encapsulated_key = decode_u16_items(&(), bytes)?;
let payload = decode_u32_items(&(), bytes)?;
Ok(Self {
config_id,
encapsulated_key,
payload,
})
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct HpkePublicKey(Vec<u8>);
impl From<Vec<u8>> for HpkePublicKey {
fn from(key: Vec<u8>) -> Self {
Self(key)
}
}
impl AsRef<[u8]> for HpkePublicKey {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Encode for HpkePublicKey {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
encode_u16_items(bytes, &(), &self.0)
}
fn encoded_len(&self) -> Option<usize> {
Some(2 + self.0.len())
}
}
impl Decode for HpkePublicKey {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let key = decode_u16_items(&(), bytes)?;
Ok(Self(key))
}
}
impl Debug for HpkePublicKey {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "HpkePublicKey({})", self)
}
}
impl Display for HpkePublicKey {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", Base64Display::new(&self.0, &URL_SAFE_NO_PAD))
}
}
impl FromStr for HpkePublicKey {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self::from(URL_SAFE_NO_PAD.decode(s)?))
}
}
impl Serialize for HpkePublicKey {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let encoded = URL_SAFE_NO_PAD.encode(self.as_ref());
serializer.serialize_str(&encoded)
}
}
struct HpkePublicKeyVisitor;
impl<'de> Visitor<'de> for HpkePublicKeyVisitor {
type Value = HpkePublicKey;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a base64url-encoded string")
}
fn visit_str<E>(self, value: &str) -> Result<HpkePublicKey, E>
where
E: de::Error,
{
let decoded = URL_SAFE_NO_PAD
.decode(value)
.map_err(|_| E::custom("invalid base64url value"))?;
Ok(HpkePublicKey::from(decoded))
}
}
impl<'de> Deserialize<'de> for HpkePublicKey {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_str(HpkePublicKeyVisitor)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct HpkeConfig {
id: HpkeConfigId,
kem_id: HpkeKemId,
kdf_id: HpkeKdfId,
aead_id: HpkeAeadId,
public_key: HpkePublicKey,
}
impl HpkeConfig {
pub fn new(
id: HpkeConfigId,
kem_id: HpkeKemId,
kdf_id: HpkeKdfId,
aead_id: HpkeAeadId,
public_key: HpkePublicKey,
) -> HpkeConfig {
HpkeConfig {
id,
kem_id,
kdf_id,
aead_id,
public_key,
}
}
pub fn id(&self) -> &HpkeConfigId {
&self.id
}
pub fn kem_id(&self) -> &HpkeKemId {
&self.kem_id
}
pub fn kdf_id(&self) -> &HpkeKdfId {
&self.kdf_id
}
pub fn aead_id(&self) -> &HpkeAeadId {
&self.aead_id
}
pub fn public_key(&self) -> &HpkePublicKey {
&self.public_key
}
}
impl Encode for HpkeConfig {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.id.encode(bytes)?;
self.kem_id.encode(bytes)?;
self.kdf_id.encode(bytes)?;
self.aead_id.encode(bytes)?;
self.public_key.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.id.encoded_len()?
+ self.kem_id.encoded_len()?
+ self.kdf_id.encoded_len()?
+ self.aead_id.encoded_len()?
+ self.public_key.encoded_len()?,
)
}
}
impl Decode for HpkeConfig {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let id = HpkeConfigId::decode(bytes)?;
let kem_id = HpkeKemId::decode(bytes)?;
let kdf_id = HpkeKdfId::decode(bytes)?;
let aead_id = HpkeAeadId::decode(bytes)?;
let public_key = HpkePublicKey::decode(bytes)?;
Ok(Self {
id,
kem_id,
kdf_id,
aead_id,
public_key,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct HpkeConfigList(Vec<HpkeConfig>);
impl HpkeConfigList {
pub const MEDIA_TYPE: &'static str = "application/dap-hpke-config-list";
pub fn new(hpke_configs: Vec<HpkeConfig>) -> Self {
Self(hpke_configs)
}
pub fn hpke_configs(&self) -> &[HpkeConfig] {
&self.0
}
}
impl Encode for HpkeConfigList {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
encode_u16_items(bytes, &(), &self.0)
}
fn encoded_len(&self) -> Option<usize> {
let mut length = 2;
for hpke_config in self.0.iter() {
length += hpke_config.encoded_len()?;
}
Some(length)
}
}
impl Decode for HpkeConfigList {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(Self(decode_u16_items(&(), bytes)?))
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ReportMetadata {
report_id: ReportId,
time: Time,
}
impl ReportMetadata {
pub fn new(report_id: ReportId, time: Time) -> Self {
Self { report_id, time }
}
pub fn id(&self) -> &ReportId {
&self.report_id
}
pub fn time(&self) -> &Time {
&self.time
}
}
impl Encode for ReportMetadata {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.report_id.encode(bytes)?;
self.time.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(self.report_id.encoded_len()? + self.time.encoded_len()?)
}
}
impl Decode for ReportMetadata {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let report_id = ReportId::decode(bytes)?;
let time = Time::decode(bytes)?;
Ok(Self { report_id, time })
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PlaintextInputShare {
extensions: Vec<Extension>,
payload: Vec<u8>,
}
impl PlaintextInputShare {
pub fn new(extensions: Vec<Extension>, payload: Vec<u8>) -> Self {
Self {
extensions,
payload,
}
}
pub fn extensions(&self) -> &[Extension] {
&self.extensions
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
}
impl Encode for PlaintextInputShare {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
encode_u16_items(bytes, &(), &self.extensions)?;
encode_u32_items(bytes, &(), &self.payload)
}
fn encoded_len(&self) -> Option<usize> {
let mut length = 2;
for extension in self.extensions.iter() {
length += extension.encoded_len()?;
}
length += 4;
length += self.payload.len();
Some(length)
}
}
impl Decode for PlaintextInputShare {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let extensions = decode_u16_items(&(), bytes)?;
let payload = decode_u32_items(&(), bytes)?;
Ok(Self {
extensions,
payload,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Report {
metadata: ReportMetadata,
public_share: Vec<u8>,
leader_encrypted_input_share: HpkeCiphertext,
helper_encrypted_input_share: HpkeCiphertext,
}
impl Report {
pub const MEDIA_TYPE: &'static str = "application/dap-report";
pub fn new(
metadata: ReportMetadata,
public_share: Vec<u8>,
leader_encrypted_input_share: HpkeCiphertext,
helper_encrypted_input_share: HpkeCiphertext,
) -> Self {
Self {
metadata,
public_share,
leader_encrypted_input_share,
helper_encrypted_input_share,
}
}
pub fn metadata(&self) -> &ReportMetadata {
&self.metadata
}
pub fn public_share(&self) -> &[u8] {
&self.public_share
}
pub fn leader_encrypted_input_share(&self) -> &HpkeCiphertext {
&self.leader_encrypted_input_share
}
pub fn helper_encrypted_input_share(&self) -> &HpkeCiphertext {
&self.helper_encrypted_input_share
}
}
impl Encode for Report {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.metadata.encode(bytes)?;
encode_u32_items(bytes, &(), &self.public_share)?;
self.leader_encrypted_input_share.encode(bytes)?;
self.helper_encrypted_input_share.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
let mut length = self.metadata.encoded_len()?;
length += 4;
length += self.public_share.len();
length += self.leader_encrypted_input_share.encoded_len()?;
length += self.helper_encrypted_input_share.encoded_len()?;
Some(length)
}
}
impl Decode for Report {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let metadata = ReportMetadata::decode(bytes)?;
let public_share = decode_u32_items(&(), bytes)?;
let leader_encrypted_input_share = HpkeCiphertext::decode(bytes)?;
let helper_encrypted_input_share = HpkeCiphertext::decode(bytes)?;
Ok(Self {
metadata,
public_share,
leader_encrypted_input_share,
helper_encrypted_input_share,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FixedSizeQuery {
ByBatchId { batch_id: BatchId },
CurrentBatch,
}
impl Encode for FixedSizeQuery {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
match self {
FixedSizeQuery::ByBatchId { batch_id } => {
0u8.encode(bytes)?;
batch_id.encode(bytes)
}
FixedSizeQuery::CurrentBatch => 1u8.encode(bytes),
}
}
fn encoded_len(&self) -> Option<usize> {
match self {
FixedSizeQuery::ByBatchId { batch_id } => Some(1 + batch_id.encoded_len()?),
FixedSizeQuery::CurrentBatch => Some(1),
}
}
}
impl Decode for FixedSizeQuery {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let query_type = u8::decode(bytes)?;
match query_type {
0 => {
let batch_id = BatchId::decode(bytes)?;
Ok(FixedSizeQuery::ByBatchId { batch_id })
}
1 => Ok(FixedSizeQuery::CurrentBatch),
_ => Err(CodecError::Other(
anyhow!("unexpected FixedSizeQueryType value {}", query_type).into(),
)),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Query<Q: QueryType> {
query_body: Q::QueryBody,
}
impl<Q: QueryType> Query<Q> {
pub fn new(query_body: Q::QueryBody) -> Self {
Self { query_body }
}
pub fn query_body(&self) -> &Q::QueryBody {
&self.query_body
}
}
impl Query<TimeInterval> {
pub fn new_time_interval(batch_interval: Interval) -> Self {
Self::new(batch_interval)
}
pub fn batch_interval(&self) -> &Interval {
self.query_body()
}
}
impl Query<FixedSize> {
pub fn new_fixed_size(fixed_size_query: FixedSizeQuery) -> Self {
Self::new(fixed_size_query)
}
pub fn fixed_size_query(&self) -> &FixedSizeQuery {
self.query_body()
}
}
impl<Q: QueryType> Encode for Query<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
Q::CODE.encode(bytes)?;
self.query_body.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(1 + self.query_body.encoded_len()?)
}
}
impl<Q: QueryType> Decode for Query<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
query_type::Code::decode_expecting_value(bytes, Q::CODE)?;
let query_body = Q::QueryBody::decode(bytes)?;
Ok(Self { query_body })
}
}
#[derive(Clone, Derivative, PartialEq, Eq)]
#[derivative(Debug)]
pub struct CollectionReq<Q: QueryType> {
query: Query<Q>,
#[derivative(Debug = "ignore")]
aggregation_parameter: Vec<u8>,
}
impl<Q: QueryType> CollectionReq<Q> {
pub const MEDIA_TYPE: &'static str = "application/dap-collect-req";
pub fn new(query: Query<Q>, aggregation_parameter: Vec<u8>) -> Self {
Self {
query,
aggregation_parameter,
}
}
pub fn query(&self) -> &Query<Q> {
&self.query
}
pub fn aggregation_parameter(&self) -> &[u8] {
&self.aggregation_parameter
}
}
impl<Q: QueryType> Encode for CollectionReq<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.query.encode(bytes)?;
encode_u32_items(bytes, &(), &self.aggregation_parameter)
}
fn encoded_len(&self) -> Option<usize> {
Some(self.query.encoded_len()? + 4 + self.aggregation_parameter.len())
}
}
impl<Q: QueryType> Decode for CollectionReq<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let query = Query::decode(bytes)?;
let aggregation_parameter = decode_u32_items(&(), bytes)?;
Ok(Self {
query,
aggregation_parameter,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PartialBatchSelector<Q: QueryType> {
batch_identifier: Q::PartialBatchIdentifier,
}
impl<Q: QueryType> PartialBatchSelector<Q> {
pub fn new(batch_identifier: Q::PartialBatchIdentifier) -> Self {
Self { batch_identifier }
}
pub fn batch_identifier(&self) -> &Q::PartialBatchIdentifier {
&self.batch_identifier
}
}
impl PartialBatchSelector<TimeInterval> {
pub fn new_time_interval() -> Self {
Self::new(())
}
}
impl PartialBatchSelector<FixedSize> {
pub fn new_fixed_size(batch_id: BatchId) -> Self {
Self::new(batch_id)
}
pub fn batch_id(&self) -> &BatchId {
self.batch_identifier()
}
}
impl<Q: QueryType> Encode for PartialBatchSelector<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
Q::CODE.encode(bytes)?;
self.batch_identifier.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(1 + self.batch_identifier.encoded_len()?)
}
}
impl<Q: QueryType> Decode for PartialBatchSelector<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
query_type::Code::decode_expecting_value(bytes, Q::CODE)?;
let batch_identifier = Q::PartialBatchIdentifier::decode(bytes)?;
Ok(Self { batch_identifier })
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CollectionJobId([u8; Self::LEN]);
impl CollectionJobId {
pub const LEN: usize = 16;
}
impl AsRef<[u8; Self::LEN]> for CollectionJobId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
}
}
impl TryFrom<&[u8]> for CollectionJobId {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
Error::InvalidParameter("byte slice has incorrect length for CollectionId")
})?))
}
}
impl FromStr for CollectionJobId {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(URL_SAFE_NO_PAD.decode(s)?.as_ref())
}
}
impl Debug for CollectionJobId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"CollectionId({})",
Base64Display::new(&self.0, &URL_SAFE_NO_PAD)
)
}
}
impl Display for CollectionJobId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Base64Display::new(&self.0, &URL_SAFE_NO_PAD))
}
}
impl Distribution<CollectionJobId> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> CollectionJobId {
CollectionJobId(rng.gen())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Collection<Q: QueryType> {
partial_batch_selector: PartialBatchSelector<Q>,
report_count: u64,
interval: Interval,
leader_encrypted_agg_share: HpkeCiphertext,
helper_encrypted_agg_share: HpkeCiphertext,
}
impl<Q: QueryType> Collection<Q> {
pub const MEDIA_TYPE: &'static str = "application/dap-collection";
pub fn new(
partial_batch_selector: PartialBatchSelector<Q>,
report_count: u64,
interval: Interval,
leader_encrypted_agg_share: HpkeCiphertext,
helper_encrypted_agg_share: HpkeCiphertext,
) -> Self {
Self {
partial_batch_selector,
report_count,
interval,
leader_encrypted_agg_share,
helper_encrypted_agg_share,
}
}
pub fn partial_batch_selector(&self) -> &PartialBatchSelector<Q> {
&self.partial_batch_selector
}
pub fn report_count(&self) -> u64 {
self.report_count
}
pub fn interval(&self) -> &Interval {
&self.interval
}
pub fn leader_encrypted_aggregate_share(&self) -> &HpkeCiphertext {
&self.leader_encrypted_agg_share
}
pub fn helper_encrypted_aggregate_share(&self) -> &HpkeCiphertext {
&self.helper_encrypted_agg_share
}
}
impl<Q: QueryType> Encode for Collection<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.partial_batch_selector.encode(bytes)?;
self.report_count.encode(bytes)?;
self.interval.encode(bytes)?;
self.leader_encrypted_agg_share.encode(bytes)?;
self.helper_encrypted_agg_share.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.partial_batch_selector.encoded_len()?
+ self.report_count.encoded_len()?
+ self.interval.encoded_len()?
+ self.leader_encrypted_agg_share.encoded_len()?
+ self.helper_encrypted_agg_share.encoded_len()?,
)
}
}
impl<Q: QueryType> Decode for Collection<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let partial_batch_selector = PartialBatchSelector::decode(bytes)?;
let report_count = u64::decode(bytes)?;
let interval = Interval::decode(bytes)?;
let leader_encrypted_agg_share = HpkeCiphertext::decode(bytes)?;
let helper_encrypted_agg_share = HpkeCiphertext::decode(bytes)?;
Ok(Self {
partial_batch_selector,
report_count,
interval,
leader_encrypted_agg_share,
helper_encrypted_agg_share,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct InputShareAad {
task_id: TaskId,
metadata: ReportMetadata,
public_share: Vec<u8>,
}
impl InputShareAad {
pub fn new(task_id: TaskId, metadata: ReportMetadata, public_share: Vec<u8>) -> Self {
Self {
task_id,
metadata,
public_share,
}
}
pub fn task_id(&self) -> &TaskId {
&self.task_id
}
pub fn metadata(&self) -> &ReportMetadata {
&self.metadata
}
pub fn public_share(&self) -> &[u8] {
&self.public_share
}
}
impl Encode for InputShareAad {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.task_id.encode(bytes)?;
self.metadata.encode(bytes)?;
encode_u32_items(bytes, &(), &self.public_share)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.task_id.encoded_len()?
+ self.metadata.encoded_len()?
+ 4
+ self.public_share.len(),
)
}
}
impl Decode for InputShareAad {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let task_id = TaskId::decode(bytes)?;
let metadata = ReportMetadata::decode(bytes)?;
let public_share = decode_u32_items(&(), bytes)?;
Ok(Self {
task_id,
metadata,
public_share,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AggregateShareAad<Q: QueryType> {
task_id: TaskId,
aggregation_parameter: Vec<u8>,
batch_selector: BatchSelector<Q>,
}
impl<Q: QueryType> AggregateShareAad<Q> {
pub fn new(
task_id: TaskId,
aggregation_parameter: Vec<u8>,
batch_selector: BatchSelector<Q>,
) -> Self {
Self {
task_id,
aggregation_parameter,
batch_selector,
}
}
pub fn task_id(&self) -> &TaskId {
&self.task_id
}
pub fn aggregation_parameter(&self) -> &[u8] {
&self.aggregation_parameter
}
pub fn batch_selector(&self) -> &BatchSelector<Q> {
&self.batch_selector
}
}
impl<Q: QueryType> Encode for AggregateShareAad<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.task_id.encode(bytes)?;
encode_u32_items(bytes, &(), &self.aggregation_parameter)?;
self.batch_selector.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.task_id.encoded_len()?
+ 4
+ self.aggregation_parameter.len()
+ self.batch_selector.encoded_len()?,
)
}
}
impl<Q: QueryType> Decode for AggregateShareAad<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let task_id = TaskId::decode(bytes)?;
let aggregation_parameter = decode_u32_items(&(), bytes)?;
let batch_selector = BatchSelector::decode(bytes)?;
Ok(Self {
task_id,
aggregation_parameter,
batch_selector,
})
}
}
#[derive(Derivative, Clone, PartialEq, Eq)]
#[derivative(Debug)]
pub struct ReportShare {
metadata: ReportMetadata,
#[derivative(Debug = "ignore")]
public_share: Vec<u8>,
encrypted_input_share: HpkeCiphertext,
}
impl ReportShare {
pub fn new(
metadata: ReportMetadata,
public_share: Vec<u8>,
encrypted_input_share: HpkeCiphertext,
) -> Self {
Self {
metadata,
public_share,
encrypted_input_share,
}
}
pub fn metadata(&self) -> &ReportMetadata {
&self.metadata
}
pub fn public_share(&self) -> &[u8] {
&self.public_share
}
pub fn encrypted_input_share(&self) -> &HpkeCiphertext {
&self.encrypted_input_share
}
}
impl Encode for ReportShare {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.metadata.encode(bytes)?;
encode_u32_items(bytes, &(), &self.public_share)?;
self.encrypted_input_share.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.metadata.encoded_len()?
+ 4
+ self.public_share.len()
+ self.encrypted_input_share.encoded_len()?,
)
}
}
impl Decode for ReportShare {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let metadata = ReportMetadata::decode(bytes)?;
let public_share = decode_u32_items(&(), bytes)?;
let encrypted_input_share = HpkeCiphertext::decode(bytes)?;
Ok(Self {
metadata,
public_share,
encrypted_input_share,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PrepareInit {
report_share: ReportShare,
message: PingPongMessage,
}
impl PrepareInit {
pub fn new(report_share: ReportShare, message: PingPongMessage) -> Self {
Self {
report_share,
message,
}
}
pub fn report_share(&self) -> &ReportShare {
&self.report_share
}
pub fn message(&self) -> &PingPongMessage {
&self.message
}
}
impl Encode for PrepareInit {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.report_share.encode(bytes)?;
let encoded_message = self.message.get_encoded()?;
encode_u32_items(bytes, &(), &encoded_message)
}
fn encoded_len(&self) -> Option<usize> {
Some(self.report_share.encoded_len()? + 4 + self.message.encoded_len()?)
}
}
impl Decode for PrepareInit {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let report_share = ReportShare::decode(bytes)?;
let message_bytes = decode_u32_items(&(), bytes)?;
let message = PingPongMessage::get_decoded(&message_bytes)?;
Ok(Self {
report_share,
message,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PrepareResp {
report_id: ReportId,
result: PrepareStepResult,
}
impl PrepareResp {
pub fn new(report_id: ReportId, result: PrepareStepResult) -> Self {
Self { report_id, result }
}
pub fn report_id(&self) -> &ReportId {
&self.report_id
}
pub fn result(&self) -> &PrepareStepResult {
&self.result
}
}
impl Encode for PrepareResp {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.report_id.encode(bytes)?;
self.result.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(self.report_id.encoded_len()? + self.result.encoded_len()?)
}
}
impl Decode for PrepareResp {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let report_id = ReportId::decode(bytes)?;
let result = PrepareStepResult::decode(bytes)?;
Ok(Self { report_id, result })
}
}
#[derive(Clone, Derivative, PartialEq, Eq)]
#[derivative(Debug)]
pub enum PrepareStepResult {
Continue {
#[derivative(Debug = "ignore")]
message: PingPongMessage,
},
Finished,
Reject(PrepareError),
}
impl Encode for PrepareStepResult {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
match self {
Self::Continue { message: prep_msg } => {
0u8.encode(bytes)?;
let encoded_prep_msg = prep_msg.get_encoded()?;
encode_u32_items(bytes, &(), &encoded_prep_msg)
}
Self::Finished => 1u8.encode(bytes),
Self::Reject(error) => {
2u8.encode(bytes)?;
error.encode(bytes)
}
}
}
fn encoded_len(&self) -> Option<usize> {
match self {
Self::Continue { message: prep_msg } => Some(1 + 4 + prep_msg.encoded_len()?),
Self::Finished => Some(1),
Self::Reject(error) => Some(1 + error.encoded_len()?),
}
}
}
impl Decode for PrepareStepResult {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let val = u8::decode(bytes)?;
Ok(match val {
0 => {
let prep_msg_bytes = decode_u32_items(&(), bytes)?;
let prep_msg = PingPongMessage::get_decoded(&prep_msg_bytes)?;
Self::Continue { message: prep_msg }
}
1 => Self::Finished,
2 => Self::Reject(PrepareError::decode(bytes)?),
_ => return Err(CodecError::UnexpectedValue),
})
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, TryFromPrimitive)]
#[repr(u8)]
pub enum PrepareError {
BatchCollected = 0,
ReportReplayed = 1,
ReportDropped = 2,
HpkeUnknownConfigId = 3,
HpkeDecryptError = 4,
VdafPrepError = 5,
BatchSaturated = 6,
TaskExpired = 7,
InvalidMessage = 8,
ReportTooEarly = 9,
}
impl Encode for PrepareError {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
(*self as u8).encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(1)
}
}
impl Decode for PrepareError {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let val = u8::decode(bytes)?;
Self::try_from(val).map_err(|_| {
CodecError::Other(anyhow!("unexpected ReportShareError value {}", val).into())
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PrepareContinue {
report_id: ReportId,
message: PingPongMessage,
}
impl PrepareContinue {
pub fn new(report_id: ReportId, message: PingPongMessage) -> Self {
Self { report_id, message }
}
pub fn report_id(&self) -> &ReportId {
&self.report_id
}
pub fn message(&self) -> &PingPongMessage {
&self.message
}
}
impl Encode for PrepareContinue {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.report_id.encode(bytes)?;
let encoded_message = self.message.get_encoded()?;
encode_u32_items(bytes, &(), &encoded_message)
}
fn encoded_len(&self) -> Option<usize> {
Some(self.report_id.encoded_len()? + 4 + self.message.encoded_len()?)
}
}
impl Decode for PrepareContinue {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let report_id = ReportId::decode(bytes)?;
let message_bytes = decode_u32_items(&(), bytes)?;
let message = PingPongMessage::get_decoded(&message_bytes)?;
Ok(Self { report_id, message })
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct AggregationJobId([u8; Self::LEN]);
impl AggregationJobId {
pub const LEN: usize = 16;
}
impl From<[u8; Self::LEN]> for AggregationJobId {
fn from(aggregation_job_id: [u8; Self::LEN]) -> Self {
Self(aggregation_job_id)
}
}
impl<'a> TryFrom<&'a [u8]> for AggregationJobId {
type Error = Error;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|_| {
Error::InvalidParameter("byte slice has incorrect length for AggregationJobId")
})?))
}
}
impl AsRef<[u8; Self::LEN]> for AggregationJobId {
fn as_ref(&self) -> &[u8; Self::LEN] {
&self.0
}
}
impl FromStr for AggregationJobId {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::try_from(URL_SAFE_NO_PAD.decode(s)?.as_ref())
}
}
impl Debug for AggregationJobId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AggregationJobId({})",
Base64Display::new(&self.0, &URL_SAFE_NO_PAD)
)
}
}
impl Display for AggregationJobId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", Base64Display::new(&self.0, &URL_SAFE_NO_PAD))
}
}
impl Distribution<AggregationJobId> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> AggregationJobId {
AggregationJobId(rng.gen())
}
}
#[derive(Clone, Derivative, PartialEq, Eq)]
#[derivative(Debug)]
pub struct AggregationJobInitializeReq<Q: QueryType> {
#[derivative(Debug = "ignore")]
aggregation_parameter: Vec<u8>,
partial_batch_selector: PartialBatchSelector<Q>,
prepare_inits: Vec<PrepareInit>,
}
impl<Q: QueryType> AggregationJobInitializeReq<Q> {
pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-init-req";
pub fn new(
aggregation_parameter: Vec<u8>,
partial_batch_selector: PartialBatchSelector<Q>,
prepare_inits: Vec<PrepareInit>,
) -> Self {
Self {
aggregation_parameter,
partial_batch_selector,
prepare_inits,
}
}
pub fn aggregation_parameter(&self) -> &[u8] {
&self.aggregation_parameter
}
pub fn batch_selector(&self) -> &PartialBatchSelector<Q> {
&self.partial_batch_selector
}
pub fn prepare_inits(&self) -> &[PrepareInit] {
&self.prepare_inits
}
}
impl<Q: QueryType> Encode for AggregationJobInitializeReq<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
encode_u32_items(bytes, &(), &self.aggregation_parameter)?;
self.partial_batch_selector.encode(bytes)?;
encode_u32_items(bytes, &(), &self.prepare_inits)
}
fn encoded_len(&self) -> Option<usize> {
let mut length = 4 + self.aggregation_parameter.len();
length += self.partial_batch_selector.encoded_len()?;
length += 4;
for prepare_init in &self.prepare_inits {
length += prepare_init.encoded_len()?;
}
Some(length)
}
}
impl<Q: QueryType> Decode for AggregationJobInitializeReq<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let aggregation_parameter = decode_u32_items(&(), bytes)?;
let partial_batch_selector = PartialBatchSelector::decode(bytes)?;
let prepare_inits = decode_u32_items(&(), bytes)?;
Ok(Self {
aggregation_parameter,
partial_batch_selector,
prepare_inits,
})
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct AggregationJobStep(u16);
impl AggregationJobStep {
pub fn increment(&self) -> Self {
Self(self.0 + 1)
}
}
impl Display for AggregationJobStep {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl Encode for AggregationJobStep {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.0.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
self.0.encoded_len()
}
}
impl Decode for AggregationJobStep {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
Ok(Self(u16::decode(bytes)?))
}
}
impl From<u16> for AggregationJobStep {
fn from(value: u16) -> Self {
Self(value)
}
}
impl From<AggregationJobStep> for u16 {
fn from(value: AggregationJobStep) -> Self {
value.0
}
}
impl TryFrom<i32> for AggregationJobStep {
type Error = TryFromIntError;
fn try_from(value: i32) -> Result<Self, Self::Error> {
Ok(AggregationJobStep(u16::try_from(value)?))
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AggregationJobContinueReq {
step: AggregationJobStep,
prepare_continues: Vec<PrepareContinue>,
}
impl AggregationJobContinueReq {
pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-continue-req";
pub fn new(step: AggregationJobStep, prepare_continues: Vec<PrepareContinue>) -> Self {
Self {
step,
prepare_continues,
}
}
pub fn step(&self) -> AggregationJobStep {
self.step
}
pub fn prepare_steps(&self) -> &[PrepareContinue] {
&self.prepare_continues
}
}
impl Encode for AggregationJobContinueReq {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.step.encode(bytes)?;
encode_u32_items(bytes, &(), &self.prepare_continues)
}
fn encoded_len(&self) -> Option<usize> {
let mut length = self.step.encoded_len()?;
length += 4;
for prepare_continue in self.prepare_continues.iter() {
length += prepare_continue.encoded_len()?;
}
Some(length)
}
}
impl Decode for AggregationJobContinueReq {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let step = AggregationJobStep::decode(bytes)?;
let prepare_continues = decode_u32_items(&(), bytes)?;
Ok(Self::new(step, prepare_continues))
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AggregationJobResp {
prepare_resps: Vec<PrepareResp>,
}
impl AggregationJobResp {
pub const MEDIA_TYPE: &'static str = "application/dap-aggregation-job-resp";
pub fn new(prepare_resps: Vec<PrepareResp>) -> Self {
Self { prepare_resps }
}
pub fn prepare_resps(&self) -> &[PrepareResp] {
&self.prepare_resps
}
}
impl Encode for AggregationJobResp {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
encode_u32_items(bytes, &(), &self.prepare_resps)
}
fn encoded_len(&self) -> Option<usize> {
let mut length = 4;
for prepare_resp in self.prepare_resps.iter() {
length += prepare_resp.encoded_len()?;
}
Some(length)
}
}
impl Decode for AggregationJobResp {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let prepare_resps = decode_u32_items(&(), bytes)?;
Ok(Self { prepare_resps })
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct BatchSelector<Q: QueryType> {
batch_identifier: Q::BatchIdentifier,
}
impl<Q: QueryType> BatchSelector<Q> {
pub fn new(batch_identifier: Q::BatchIdentifier) -> Self {
Self { batch_identifier }
}
pub fn batch_identifier(&self) -> &Q::BatchIdentifier {
&self.batch_identifier
}
}
impl BatchSelector<TimeInterval> {
pub fn new_time_interval(batch_interval: Interval) -> Self {
Self::new(batch_interval)
}
pub fn batch_interval(&self) -> &Interval {
self.batch_identifier()
}
}
impl BatchSelector<FixedSize> {
pub fn new_fixed_size(batch_id: BatchId) -> Self {
Self::new(batch_id)
}
pub fn batch_id(&self) -> &BatchId {
self.batch_identifier()
}
}
impl<Q: QueryType> Encode for BatchSelector<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
Q::CODE.encode(bytes)?;
self.batch_identifier.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(1 + self.batch_identifier.encoded_len()?)
}
}
impl<Q: QueryType> Decode for BatchSelector<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
query_type::Code::decode_expecting_value(bytes, Q::CODE)?;
let batch_identifier = Q::BatchIdentifier::decode(bytes)?;
Ok(Self { batch_identifier })
}
}
#[derive(Clone, Derivative, PartialEq, Eq)]
#[derivative(Debug)]
pub struct AggregateShareReq<Q: QueryType> {
batch_selector: BatchSelector<Q>,
#[derivative(Debug = "ignore")]
aggregation_parameter: Vec<u8>,
report_count: u64,
checksum: ReportIdChecksum,
}
impl<Q: QueryType> AggregateShareReq<Q> {
pub const MEDIA_TYPE: &'static str = "application/dap-aggregate-share-req";
pub fn new(
batch_selector: BatchSelector<Q>,
aggregation_parameter: Vec<u8>,
report_count: u64,
checksum: ReportIdChecksum,
) -> Self {
Self {
batch_selector,
aggregation_parameter,
report_count,
checksum,
}
}
pub fn batch_selector(&self) -> &BatchSelector<Q> {
&self.batch_selector
}
pub fn aggregation_parameter(&self) -> &[u8] {
&self.aggregation_parameter
}
pub fn report_count(&self) -> u64 {
self.report_count
}
pub fn checksum(&self) -> &ReportIdChecksum {
&self.checksum
}
}
impl<Q: QueryType> Encode for AggregateShareReq<Q> {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.batch_selector.encode(bytes)?;
encode_u32_items(bytes, &(), &self.aggregation_parameter)?;
self.report_count.encode(bytes)?;
self.checksum.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.batch_selector.encoded_len()?
+ 4
+ self.aggregation_parameter.len()
+ self.report_count.encoded_len()?
+ self.checksum.encoded_len()?,
)
}
}
impl<Q: QueryType> Decode for AggregateShareReq<Q> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let batch_selector = BatchSelector::decode(bytes)?;
let aggregation_parameter = decode_u32_items(&(), bytes)?;
let report_count = u64::decode(bytes)?;
let checksum = ReportIdChecksum::decode(bytes)?;
Ok(Self {
batch_selector,
aggregation_parameter,
report_count,
checksum,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AggregateShare {
encrypted_aggregate_share: HpkeCiphertext,
}
impl AggregateShare {
pub const MEDIA_TYPE: &'static str = "application/dap-aggregate-share";
pub fn new(encrypted_aggregate_share: HpkeCiphertext) -> Self {
Self {
encrypted_aggregate_share,
}
}
pub fn encrypted_aggregate_share(&self) -> &HpkeCiphertext {
&self.encrypted_aggregate_share
}
}
impl Encode for AggregateShare {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.encrypted_aggregate_share.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
self.encrypted_aggregate_share.encoded_len()
}
}
impl Decode for AggregateShare {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let encrypted_aggregate_share = HpkeCiphertext::decode(bytes)?;
Ok(Self {
encrypted_aggregate_share,
})
}
}
#[cfg(test)]
pub(crate) fn roundtrip_encoding<T>(vals_and_encodings: &[(T, &str)])
where
T: Encode + Decode + Debug + Eq,
{
struct Wrapper<T>(T);
impl<T: PartialEq> PartialEq for Wrapper<T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T: Eq> Eq for Wrapper<T> {}
impl<T: Debug> Debug for Wrapper<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:02x?}", &self.0)
}
}
for (val, hex_encoding) in vals_and_encodings {
let mut encoded_val = Vec::new();
val.encode(&mut encoded_val).unwrap();
let expected = Wrapper(hex::decode(hex_encoding).unwrap());
let encoded_val = Wrapper(encoded_val);
pretty_assertions::assert_eq!(
encoded_val,
expected,
"Couldn't roundtrip (encoded value differs): {val:?}"
);
let decoded_val = T::get_decoded(&encoded_val.0).unwrap();
pretty_assertions::assert_eq!(
&decoded_val,
val,
"Couldn't roundtrip (decoded value differs): {val:?}"
);
pretty_assertions::assert_eq!(
encoded_val.0.len(),
val.encoded_len().expect("No encoded length hint"),
"Encoded length hint is incorrect: {val:?}"
)
}
}