#![cfg_attr(docsrs, feature(doc_cfg))]
use backoff::ExponentialBackoff;
#[cfg(feature = "ohttp")]
use bhttp::{ControlData, Message, Mode};
use educe::Educe;
#[cfg(feature = "ohttp")]
use http::{header::ACCEPT, HeaderValue};
use http::{header::CONTENT_TYPE, StatusCode};
use itertools::Itertools;
use janus_core::{
hpke::{self, is_hpke_config_supported, HpkeApplicationInfo, Label},
http::{cached_resource::CachedResource, HttpErrorResponse},
retries::{http_request_exponential_backoff, retry_http_request},
time::{Clock, RealClock, TimeExt},
url_ensure_trailing_slash,
};
use janus_messages::{
Duration, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId,
ReportMetadata, Role, TaskId, Time,
};
#[cfg(feature = "ohttp")]
use ohttp::{ClientRequest, KeyConfig};
use prio::{codec::Encode, vdaf};
use rand::random;
#[cfg(feature = "ohttp")]
use std::io::Cursor;
use std::{convert::Infallible, fmt::Debug, sync::Arc, time::SystemTimeError};
use tokio::{sync::Mutex, try_join};
use url::Url;
#[cfg(test)]
mod tests;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("invalid parameter {0}")]
InvalidParameter(&'static str),
#[error("HTTP client error: {0}")]
HttpClient(#[from] reqwest::Error),
#[error("codec error: {0}")]
Codec(#[from] prio::codec::CodecError),
#[error("HTTP response status {0}")]
Http(Box<HttpErrorResponse>),
#[error("URL parse: {0}")]
Url(#[from] url::ParseError),
#[error("VDAF error: {0}")]
Vdaf(#[from] prio::vdaf::VdafError),
#[error("HPKE error: {0}")]
Hpke(#[from] janus_core::hpke::Error),
#[error("Cached resource error: {0}")]
CachedResource(#[from] janus_core::http::cached_resource::Error),
#[error("unexpected server response {0}")]
UnexpectedServerResponse(&'static str),
#[error("time conversion error: {0}")]
TimeConversion(#[from] SystemTimeError),
#[cfg(feature = "ohttp")]
#[error("OHTTP error: {0}")]
Ohttp(#[from] ohttp::Error),
#[cfg(feature = "ohttp")]
#[error("BHTTP error: {0}")]
Bhttp(#[from] bhttp::Error),
#[cfg(feature = "ohttp")]
#[error("No supported key configurations advertised by OHTTP gateway")]
OhttpNoSupportedKeyConfigs(Box<Vec<KeyConfig>>),
}
impl From<Infallible> for Error {
fn from(value: Infallible) -> Self {
match value {}
}
}
impl From<Result<HttpErrorResponse, reqwest::Error>> for Error {
fn from(result: Result<HttpErrorResponse, reqwest::Error>) -> Self {
match result {
Ok(http_error_response) => Error::Http(Box::new(http_error_response)),
Err(error) => error.into(),
}
}
}
static CLIENT_USER_AGENT: &str = concat!(
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION"),
"/",
"client"
);
#[cfg(feature = "ohttp")]
const OHTTP_KEYS_MEDIA_TYPE: &str = "application/ohttp-keys";
#[cfg(feature = "ohttp")]
const OHTTP_REQUEST_MEDIA_TYPE: &str = "message/ohttp-req";
#[cfg(feature = "ohttp")]
const OHTTP_RESPONSE_MEDIA_TYPE: &str = "message/ohttp-res";
#[derive(Clone, Educe)]
#[educe(Debug)]
struct ClientParameters {
task_id: TaskId,
#[educe(Debug(method(std::fmt::Display::fmt)))]
leader_aggregator_endpoint: Url,
#[educe(Debug(method(std::fmt::Display::fmt)))]
helper_aggregator_endpoint: Url,
time_precision: Duration,
http_request_retry_parameters: ExponentialBackoff,
}
impl ClientParameters {
pub fn new(
task_id: TaskId,
leader_aggregator_endpoint: Url,
helper_aggregator_endpoint: Url,
time_precision: Duration,
) -> Self {
Self {
task_id,
leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint),
helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint),
time_precision,
http_request_retry_parameters: http_request_exponential_backoff(),
}
}
fn aggregator_endpoint(&self, role: &Role) -> Result<&Url, Error> {
match role {
Role::Leader => Ok(&self.leader_aggregator_endpoint),
Role::Helper => Ok(&self.helper_aggregator_endpoint),
_ => Err(Error::InvalidParameter("role is not an aggregator")),
}
}
fn hpke_config_endpoint(&self, role: &Role) -> Result<Url, Error> {
Ok(self.aggregator_endpoint(role)?.join("hpke_config")?)
}
fn reports_resource_uri(&self, task_id: &TaskId) -> Result<Url, Error> {
Ok(self
.leader_aggregator_endpoint
.join(&format!("tasks/{task_id}/reports"))?)
}
}
#[tracing::instrument(err)]
#[cfg(feature = "ohttp")]
async fn ohttp_key_configs(
http_request_retry_parameters: ExponentialBackoff,
ohttp_config: &OhttpConfig,
http_client: &reqwest::Client,
) -> Result<Vec<KeyConfig>, Error> {
let keys_response = retry_http_request(http_request_retry_parameters, || async {
http_client
.get(ohttp_config.key_configs.clone())
.header(ACCEPT, OHTTP_KEYS_MEDIA_TYPE)
.send()
.await
})
.await?;
if !keys_response.status().is_success() {
return Err(Error::Http(Box::new(HttpErrorResponse::from(
keys_response.status(),
))));
}
if keys_response
.headers()
.get(CONTENT_TYPE)
.map(HeaderValue::as_bytes)
!= Some(OHTTP_KEYS_MEDIA_TYPE.as_bytes())
{
return Err(Error::UnexpectedServerResponse(
"content type wrong for OHTTP keys",
));
}
Ok(KeyConfig::decode_list(keys_response.body().as_ref())?)
}
pub fn default_http_client() -> Result<reqwest::Client, Error> {
Ok(reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.connect_timeout(std::time::Duration::from_secs(10))
.user_agent(CLIENT_USER_AGENT)
.build()?)
}
#[derive(Clone, Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "ohttp")))]
#[cfg(feature = "ohttp")]
pub struct OhttpConfig {
pub key_configs: Url,
pub relay: Url,
}
pub struct ClientBuilder<V: vdaf::Client<16>> {
parameters: ClientParameters,
vdaf: V,
leader_hpke_config: Option<HpkeConfig>,
helper_hpke_config: Option<HpkeConfig>,
#[cfg(feature = "ohttp")]
ohttp_config: Option<OhttpConfig>,
http_client: Option<reqwest::Client>,
}
impl<V: vdaf::Client<16>> ClientBuilder<V> {
pub fn new(
task_id: TaskId,
leader_aggregator_endpoint: Url,
helper_aggregator_endpoint: Url,
time_precision: Duration,
vdaf: V,
) -> Self {
Self {
parameters: ClientParameters::new(
task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
time_precision,
),
vdaf,
leader_hpke_config: None,
helper_hpke_config: None,
#[cfg(feature = "ohttp")]
ohttp_config: None,
http_client: None,
}
}
pub async fn build(self) -> Result<Client<V>, Error> {
let http_client = if let Some(http_client) = self.http_client {
http_client
} else {
default_http_client()?
};
let fetch_hpke_config = async |hpke_config, role| match hpke_config {
Some(hpke_config) => Ok(HpkeConfiguration::new_static(hpke_config)),
None => HpkeConfiguration::new(&self.parameters, role, &http_client).await,
};
let (leader_hpke_config, helper_hpke_config) = tokio::try_join!(
fetch_hpke_config(self.leader_hpke_config, &Role::Leader),
fetch_hpke_config(self.helper_hpke_config, &Role::Helper),
)?;
#[cfg(feature = "ohttp")]
let ohttp_config = if let Some(ohttp_config) = self.ohttp_config {
let key_configs = ohttp_key_configs(
self.parameters.http_request_retry_parameters.clone(),
&ohttp_config,
&http_client,
)
.await?;
Some((ohttp_config, key_configs))
} else {
None
};
Ok(Client {
#[cfg(feature = "ohttp")]
ohttp_config,
parameters: self.parameters,
vdaf: self.vdaf,
http_client,
leader_hpke_config: Arc::new(Mutex::new(leader_hpke_config)),
helper_hpke_config: Arc::new(Mutex::new(helper_hpke_config)),
})
}
#[deprecated(
note = "Use `ClientBuilder::with_leader_hpke_config`, `ClientBuilder::with_helper_hpke_config` and `ClientBuilder::build` instead"
)]
pub fn build_with_hpke_configs(
self,
leader_hpke_config: HpkeConfig,
helper_hpke_config: HpkeConfig,
) -> Result<Client<V>, Error> {
let http_client = if let Some(http_client) = self.http_client {
http_client
} else {
default_http_client()?
};
Ok(Client {
parameters: self.parameters,
vdaf: self.vdaf,
#[cfg(feature = "ohttp")]
ohttp_config: None,
http_client,
leader_hpke_config: Arc::new(Mutex::new(HpkeConfiguration::new_static(
leader_hpke_config,
))),
helper_hpke_config: Arc::new(Mutex::new(HpkeConfiguration::new_static(
helper_hpke_config,
))),
})
}
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = Some(http_client);
self
}
pub fn with_backoff(mut self, http_request_retry_parameters: ExponentialBackoff) -> Self {
self.parameters.http_request_retry_parameters = http_request_retry_parameters;
self
}
pub fn with_leader_hpke_config(mut self, hpke_config: HpkeConfig) -> Self {
self.leader_hpke_config = Some(hpke_config);
self
}
pub fn with_helper_hpke_config(mut self, hpke_config: HpkeConfig) -> Self {
self.helper_hpke_config = Some(hpke_config);
self
}
#[cfg(feature = "ohttp")]
#[cfg_attr(docsrs, doc(cfg(feature = "ohttp")))]
pub fn with_ohttp_config(mut self, ohttp_config: OhttpConfig) -> Self {
self.ohttp_config = Some(ohttp_config);
self
}
}
#[derive(Clone, Debug)]
pub struct Client<V: vdaf::Client<16>> {
parameters: ClientParameters,
vdaf: V,
#[cfg(feature = "ohttp")]
ohttp_config: Option<(OhttpConfig, Vec<KeyConfig>)>,
http_client: reqwest::Client,
leader_hpke_config: Arc<Mutex<HpkeConfiguration>>,
helper_hpke_config: Arc<Mutex<HpkeConfiguration>>,
}
impl<V: vdaf::Client<16>> Client<V> {
pub async fn new(
task_id: TaskId,
leader_aggregator_endpoint: Url,
helper_aggregator_endpoint: Url,
time_precision: Duration,
vdaf: V,
) -> Result<Self, Error> {
ClientBuilder::new(
task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
time_precision,
vdaf,
)
.build()
.await
}
#[deprecated(
note = "Use `ClientBuilder::with_leader_hpke_config`, `ClientBuilder::with_helper_hpke_config` and `ClientBuilder::build` instead"
)]
pub fn with_hpke_configs(
task_id: TaskId,
leader_aggregator_endpoint: Url,
helper_aggregator_endpoint: Url,
time_precision: Duration,
vdaf: V,
leader_hpke_config: HpkeConfig,
helper_hpke_config: HpkeConfig,
) -> Result<Self, Error> {
#[allow(deprecated)]
ClientBuilder::new(
task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
time_precision,
vdaf,
)
.build_with_hpke_configs(leader_hpke_config, helper_hpke_config)
}
pub fn builder(
task_id: TaskId,
leader_aggregator_endpoint: Url,
helper_aggregator_endpoint: Url,
time_precision: Duration,
vdaf: V,
) -> ClientBuilder<V> {
ClientBuilder::new(
task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
time_precision,
vdaf,
)
}
fn prepare_report(
&self,
measurement: &V::Measurement,
time: &Time,
leader_hpke_config: &HpkeConfig,
helper_hpke_config: &HpkeConfig,
) -> Result<Report, Error> {
let report_id: ReportId = random();
let (public_share, input_shares) = self.vdaf.shard(measurement, report_id.as_ref())?;
assert_eq!(input_shares.len(), 2);
let time = time
.to_batch_interval_start(&self.parameters.time_precision)
.map_err(|_| Error::InvalidParameter("couldn't round time down to time_precision"))?;
let report_metadata = ReportMetadata::new(report_id, time);
let encoded_public_share = public_share.get_encoded()?;
let (leader_encrypted_input_share, helper_encrypted_input_share) = [
(leader_hpke_config, &Role::Leader),
(helper_hpke_config, &Role::Helper),
]
.into_iter()
.zip(input_shares)
.map(|((hpke_config, receiver_role), input_share)| {
hpke::seal(
hpke_config,
&HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, receiver_role),
&PlaintextInputShare::new(
Vec::new(), input_share.get_encoded()?,
)
.get_encoded()?,
&InputShareAad::new(
self.parameters.task_id,
report_metadata.clone(),
encoded_public_share.clone(),
)
.get_encoded()?,
)
.map_err(Error::Hpke)
})
.collect_tuple()
.expect("iterator to yield two items");
Ok(Report::new(
report_metadata,
encoded_public_share,
leader_encrypted_input_share?,
helper_encrypted_input_share?,
))
}
#[tracing::instrument(skip(measurement), err)]
pub async fn upload(&self, measurement: &V::Measurement) -> Result<(), Error> {
self.upload_with_time(measurement, Clock::now(&RealClock::default()))
.await
}
#[tracing::instrument(skip(measurement), err)]
pub async fn upload_with_time<T>(
&self,
measurement: &V::Measurement,
time: T,
) -> Result<(), Error>
where
T: TryInto<Time> + Debug,
Error: From<<T as TryInto<Time>>::Error>,
{
let mut leader_hpke_config = self.leader_hpke_config.lock().await;
let mut helper_hpke_config = self.helper_hpke_config.lock().await;
let (leader_hpke_config, helper_hpke_config) =
try_join!(leader_hpke_config.get(), helper_hpke_config.get())?;
let report = self
.prepare_report(
measurement,
&time.try_into()?,
leader_hpke_config,
helper_hpke_config,
)?
.get_encoded()?;
let upload_endpoint = self
.parameters
.reports_resource_uri(&self.parameters.task_id)?;
#[cfg(feature = "ohttp")]
let upload_status = self.upload_with_ohttp(&upload_endpoint, &report).await?;
#[cfg(not(feature = "ohttp"))]
let upload_status = self.put_report(&upload_endpoint, &report).await?;
if !upload_status.is_success() {
return Err(Error::Http(Box::new(HttpErrorResponse::from(
upload_status,
))));
}
Ok(())
}
async fn put_report(
&self,
upload_endpoint: &Url,
request_body: &[u8],
) -> Result<StatusCode, Error> {
Ok(retry_http_request(
self.parameters.http_request_retry_parameters.clone(),
|| async {
self.http_client
.put(upload_endpoint.clone())
.header(CONTENT_TYPE, Report::MEDIA_TYPE)
.body(request_body.to_vec())
.send()
.await
},
)
.await?
.status())
}
#[cfg(feature = "ohttp")]
#[tracing::instrument(skip(self, request_body), err)]
async fn upload_with_ohttp(
&self,
upload_endpoint: &Url,
request_body: &[u8],
) -> Result<StatusCode, Error> {
let (ohttp_config, key_configs) =
if let Some((ohttp_config, key_configs)) = &self.ohttp_config {
(ohttp_config, key_configs)
} else {
return self.put_report(upload_endpoint, request_body).await;
};
let mut message = Message::request(
"PUT".into(),
upload_endpoint.scheme().into(),
upload_endpoint.authority().into(),
upload_endpoint.path().into(),
);
message.put_header(CONTENT_TYPE.as_str(), Report::MEDIA_TYPE);
message.write_content(request_body);
let mut request_buf = Vec::new();
message.write_bhttp(Mode::KnownLength, &mut request_buf)?;
let ohttp_request = key_configs
.iter()
.cloned()
.find_map(|mut key_config| ClientRequest::from_config(&mut key_config).ok())
.ok_or_else(|| Error::OhttpNoSupportedKeyConfigs(Box::new(key_configs.to_vec())))?;
let (encapsulated_request, ohttp_response) = ohttp_request.encapsulate(&request_buf)?;
let relay_response = retry_http_request(
self.parameters.http_request_retry_parameters.clone(),
|| async {
self.http_client
.post(ohttp_config.relay.clone())
.header(CONTENT_TYPE, OHTTP_REQUEST_MEDIA_TYPE)
.header(ACCEPT, OHTTP_RESPONSE_MEDIA_TYPE)
.body(encapsulated_request.clone())
.send()
.await
},
)
.await?;
if !relay_response.status().is_success() {
return Err(Error::Http(Box::new(HttpErrorResponse::from(
relay_response.status(),
))));
}
if relay_response
.headers()
.get(CONTENT_TYPE)
.map(HeaderValue::as_bytes)
!= Some(OHTTP_RESPONSE_MEDIA_TYPE.as_bytes())
{
return Err(Error::UnexpectedServerResponse(
"content type wrong for OHTTP response",
));
}
let decapsulated_response = ohttp_response.decapsulate(relay_response.body().as_ref())?;
let message = Message::read_bhttp(&mut Cursor::new(&decapsulated_response))?;
let status = if let ControlData::Response(status) = message.control() {
StatusCode::from_u16((*status).into()).map_err(|_| {
Error::UnexpectedServerResponse(
"status in decapsulated response is not valid HTTP status",
)
})?
} else {
return Err(Error::UnexpectedServerResponse(
"decapsulated response control data is not a response",
));
};
Ok(status)
}
}
#[derive(Debug, Clone)]
pub(crate) struct HpkeConfiguration {
hpke_config_list: CachedResource<HpkeConfigList>,
}
impl HpkeConfiguration {
pub(crate) async fn new(
client_parameters: &ClientParameters,
aggregator_role: &Role,
http_client: &reqwest::Client,
) -> Result<Self, Error> {
let mut hpke_config_url = client_parameters.hpke_config_endpoint(aggregator_role)?;
hpke_config_url.set_query(Some(&format!("task_id={}", client_parameters.task_id)));
Ok(Self {
hpke_config_list: CachedResource::new(
hpke_config_url,
HpkeConfigList::MEDIA_TYPE,
http_client,
client_parameters.http_request_retry_parameters.clone(),
)
.await?,
})
}
pub(crate) fn new_static(hpke_configuration: HpkeConfig) -> Self {
Self {
hpke_config_list: CachedResource::Static(HpkeConfigList::new(vec![hpke_configuration])),
}
}
pub(crate) async fn get(&mut self) -> Result<&HpkeConfig, Error> {
let hpke_config_list = self.hpke_config_list.resource().await?;
if hpke_config_list.hpke_configs().is_empty() {
return Err(Error::UnexpectedServerResponse(
"aggregator provided empty HpkeConfigList",
));
}
let mut first_error = None;
for config in hpke_config_list.hpke_configs() {
match is_hpke_config_supported(config) {
Ok(()) => return Ok(config),
Err(e) => {
if first_error.is_none() {
first_error = Some(e);
}
}
}
}
Err(first_error.unwrap().into())
}
}