1#![cfg_attr(docsrs, feature(doc_cfg))]
41
42use backoff::ExponentialBackoff;
43#[cfg(feature = "ohttp")]
44use bhttp::{ControlData, Message, Mode};
45use educe::Educe;
46#[cfg(feature = "ohttp")]
47use http::{header::ACCEPT, HeaderValue};
48use http::{header::CONTENT_TYPE, StatusCode};
49use itertools::Itertools;
50use janus_core::{
51 hpke::{self, is_hpke_config_supported, HpkeApplicationInfo, Label},
52 http::HttpErrorResponse,
53 retries::{http_request_exponential_backoff, retry_http_request},
54 time::{Clock, RealClock, TimeExt},
55 url_ensure_trailing_slash,
56};
57use janus_messages::{
58 Duration, HpkeConfig, HpkeConfigList, InputShareAad, PlaintextInputShare, Report, ReportId,
59 ReportMetadata, Role, TaskId, Time,
60};
61#[cfg(feature = "ohttp")]
62use ohttp::{ClientRequest, KeyConfig};
63use prio::{
64 codec::{Decode, Encode},
65 vdaf,
66};
67use rand::random;
68#[cfg(feature = "ohttp")]
69use std::io::Cursor;
70use std::{convert::Infallible, fmt::Debug, time::SystemTimeError};
71use url::Url;
72
73#[cfg(test)]
74mod tests;
75
76#[derive(Debug, thiserror::Error)]
77pub enum Error {
78 #[error("invalid parameter {0}")]
79 InvalidParameter(&'static str),
80 #[error("HTTP client error: {0}")]
81 HttpClient(#[from] reqwest::Error),
82 #[error("codec error: {0}")]
83 Codec(#[from] prio::codec::CodecError),
84 #[error("HTTP response status {0}")]
85 Http(Box<HttpErrorResponse>),
86 #[error("URL parse: {0}")]
87 Url(#[from] url::ParseError),
88 #[error("VDAF error: {0}")]
89 Vdaf(#[from] prio::vdaf::VdafError),
90 #[error("HPKE error: {0}")]
91 Hpke(#[from] janus_core::hpke::Error),
92 #[error("unexpected server response {0}")]
93 UnexpectedServerResponse(&'static str),
94 #[error("time conversion error: {0}")]
95 TimeConversion(#[from] SystemTimeError),
96 #[cfg(feature = "ohttp")]
97 #[error("OHTTP error: {0}")]
98 Ohttp(#[from] ohttp::Error),
99 #[cfg(feature = "ohttp")]
100 #[error("BHTTP error: {0}")]
101 Bhttp(#[from] bhttp::Error),
102 #[cfg(feature = "ohttp")]
103 #[error("No supported key configurations advertised by OHTTP gateway")]
104 OhttpNoSupportedKeyConfigs(Box<Vec<KeyConfig>>),
105}
106
107impl From<Infallible> for Error {
108 fn from(value: Infallible) -> Self {
109 match value {}
110 }
111}
112
113impl From<Result<HttpErrorResponse, reqwest::Error>> for Error {
114 fn from(result: Result<HttpErrorResponse, reqwest::Error>) -> Self {
115 match result {
116 Ok(http_error_response) => Error::Http(Box::new(http_error_response)),
117 Err(error) => error.into(),
118 }
119 }
120}
121
122static CLIENT_USER_AGENT: &str = concat!(
123 env!("CARGO_PKG_NAME"),
124 "/",
125 env!("CARGO_PKG_VERSION"),
126 "/",
127 "client"
128);
129
130#[cfg(feature = "ohttp")]
131const OHTTP_KEYS_MEDIA_TYPE: &str = "application/ohttp-keys";
132#[cfg(feature = "ohttp")]
133const OHTTP_REQUEST_MEDIA_TYPE: &str = "message/ohttp-req";
134#[cfg(feature = "ohttp")]
135const OHTTP_RESPONSE_MEDIA_TYPE: &str = "message/ohttp-res";
136
137#[derive(Clone, Educe)]
139#[educe(Debug)]
140struct ClientParameters {
141 task_id: TaskId,
143 #[educe(Debug(method(std::fmt::Display::fmt)))]
145 leader_aggregator_endpoint: Url,
146 #[educe(Debug(method(std::fmt::Display::fmt)))]
148 helper_aggregator_endpoint: Url,
149 time_precision: Duration,
152 http_request_retry_parameters: ExponentialBackoff,
154}
155
156impl ClientParameters {
157 pub fn new(
159 task_id: TaskId,
160 leader_aggregator_endpoint: Url,
161 helper_aggregator_endpoint: Url,
162 time_precision: Duration,
163 ) -> Self {
164 Self {
165 task_id,
166 leader_aggregator_endpoint: url_ensure_trailing_slash(leader_aggregator_endpoint),
167 helper_aggregator_endpoint: url_ensure_trailing_slash(helper_aggregator_endpoint),
168 time_precision,
169 http_request_retry_parameters: http_request_exponential_backoff(),
170 }
171 }
172
173 fn aggregator_endpoint(&self, role: &Role) -> Result<&Url, Error> {
176 match role {
177 Role::Leader => Ok(&self.leader_aggregator_endpoint),
178 Role::Helper => Ok(&self.helper_aggregator_endpoint),
179 _ => Err(Error::InvalidParameter("role is not an aggregator")),
180 }
181 }
182
183 fn hpke_config_endpoint(&self, role: &Role) -> Result<Url, Error> {
188 Ok(self.aggregator_endpoint(role)?.join("hpke_config")?)
189 }
190
191 fn reports_resource_uri(&self, task_id: &TaskId) -> Result<Url, Error> {
193 Ok(self
194 .leader_aggregator_endpoint
195 .join(&format!("tasks/{task_id}/reports"))?)
196 }
197}
198
199#[tracing::instrument(err)]
202async fn aggregator_hpke_config(
203 hpke_config: Option<HpkeConfig>,
204 client_parameters: &ClientParameters,
205 aggregator_role: &Role,
206 http_client: &reqwest::Client,
207) -> Result<HpkeConfig, Error> {
208 if let Some(hpke_config) = hpke_config {
209 return Ok(hpke_config);
210 }
211
212 let mut request_url = client_parameters.hpke_config_endpoint(aggregator_role)?;
213 request_url.set_query(Some(&format!("task_id={}", client_parameters.task_id)));
214 let hpke_config_response = retry_http_request(
215 client_parameters.http_request_retry_parameters.clone(),
216 || async { http_client.get(request_url.clone()).send().await },
217 )
218 .await?;
219 let status = hpke_config_response.status();
220 if !status.is_success() {
221 return Err(Error::Http(Box::new(HttpErrorResponse::from(status))));
222 }
223
224 let hpke_configs = HpkeConfigList::get_decoded(hpke_config_response.body())?;
225
226 if hpke_configs.hpke_configs().is_empty() {
227 return Err(Error::UnexpectedServerResponse(
228 "aggregator provided empty HpkeConfigList",
229 ));
230 }
231
232 let mut first_error = None;
234 for config in hpke_configs.hpke_configs() {
235 match is_hpke_config_supported(config) {
236 Ok(()) => return Ok(config.clone()),
237 Err(e) => {
238 if first_error.is_none() {
239 first_error = Some(e);
240 }
241 }
242 }
243 }
244 Err(first_error.unwrap().into())
247}
248
249#[tracing::instrument(err)]
251#[cfg(feature = "ohttp")]
252async fn ohttp_key_configs(
253 http_request_retry_parameters: ExponentialBackoff,
254 ohttp_config: &OhttpConfig,
255 http_client: &reqwest::Client,
256) -> Result<Vec<KeyConfig>, Error> {
257 let keys_response = retry_http_request(http_request_retry_parameters, || async {
259 http_client
260 .get(ohttp_config.key_configs.clone())
261 .header(ACCEPT, OHTTP_KEYS_MEDIA_TYPE)
262 .send()
263 .await
264 })
265 .await?;
266
267 if !keys_response.status().is_success() {
268 return Err(Error::Http(Box::new(HttpErrorResponse::from(
269 keys_response.status(),
270 ))));
271 }
272
273 if keys_response
274 .headers()
275 .get(CONTENT_TYPE)
276 .map(HeaderValue::as_bytes)
277 != Some(OHTTP_KEYS_MEDIA_TYPE.as_bytes())
278 {
279 return Err(Error::UnexpectedServerResponse(
280 "content type wrong for OHTTP keys",
281 ));
282 }
283
284 Ok(KeyConfig::decode_list(keys_response.body().as_ref())?)
285}
286
287pub fn default_http_client() -> Result<reqwest::Client, Error> {
289 Ok(reqwest::Client::builder()
290 .timeout(std::time::Duration::from_secs(30))
293 .connect_timeout(std::time::Duration::from_secs(10))
294 .user_agent(CLIENT_USER_AGENT)
295 .build()?)
296}
297
298#[derive(Clone, Debug)]
300#[cfg_attr(docsrs, doc(cfg(feature = "ohttp")))]
301#[cfg(feature = "ohttp")]
302pub struct OhttpConfig {
303 pub key_configs: Url,
308
309 pub relay: Url,
311}
312
313pub struct ClientBuilder<V: vdaf::Client<16>> {
315 parameters: ClientParameters,
316 vdaf: V,
317 leader_hpke_config: Option<HpkeConfig>,
318 helper_hpke_config: Option<HpkeConfig>,
319 #[cfg(feature = "ohttp")]
320 ohttp_config: Option<OhttpConfig>,
321 http_client: Option<reqwest::Client>,
322}
323
324impl<V: vdaf::Client<16>> ClientBuilder<V> {
325 pub fn new(
327 task_id: TaskId,
328 leader_aggregator_endpoint: Url,
329 helper_aggregator_endpoint: Url,
330 time_precision: Duration,
331 vdaf: V,
332 ) -> Self {
333 Self {
334 parameters: ClientParameters::new(
335 task_id,
336 leader_aggregator_endpoint,
337 helper_aggregator_endpoint,
338 time_precision,
339 ),
340 vdaf,
341 leader_hpke_config: None,
342 helper_hpke_config: None,
343 #[cfg(feature = "ohttp")]
344 ohttp_config: None,
345 http_client: None,
346 }
347 }
348
349 pub async fn build(self) -> Result<Client<V>, Error> {
352 let http_client = if let Some(http_client) = self.http_client {
353 http_client
354 } else {
355 default_http_client()?
356 };
357 let (leader_hpke_config, helper_hpke_config) = tokio::try_join!(
359 aggregator_hpke_config(
360 self.leader_hpke_config,
361 &self.parameters,
362 &Role::Leader,
363 &http_client
364 ),
365 aggregator_hpke_config(
366 self.helper_hpke_config,
367 &self.parameters,
368 &Role::Helper,
369 &http_client
370 ),
371 )?;
372
373 #[cfg(feature = "ohttp")]
374 let ohttp_config = if let Some(ohttp_config) = self.ohttp_config {
375 let key_configs = ohttp_key_configs(
376 self.parameters.http_request_retry_parameters.clone(),
377 &ohttp_config,
378 &http_client,
379 )
380 .await?;
381 Some((ohttp_config, key_configs))
382 } else {
383 None
384 };
385
386 Ok(Client {
387 #[cfg(feature = "ohttp")]
388 ohttp_config,
389 parameters: self.parameters,
390 vdaf: self.vdaf,
391 http_client,
392 leader_hpke_config,
393 helper_hpke_config,
394 })
395 }
396
397 #[deprecated(
405 note = "Use `ClientBuilder::with_leader_hpke_config`, `ClientBuilder::with_helper_hpke_config` and `ClientBuilder::build` instead"
406 )]
407 pub fn build_with_hpke_configs(
408 self,
409 leader_hpke_config: HpkeConfig,
410 helper_hpke_config: HpkeConfig,
411 ) -> Result<Client<V>, Error> {
412 let http_client = if let Some(http_client) = self.http_client {
413 http_client
414 } else {
415 default_http_client()?
416 };
417 Ok(Client {
418 parameters: self.parameters,
419 vdaf: self.vdaf,
420 #[cfg(feature = "ohttp")]
421 ohttp_config: None,
422 http_client,
423 leader_hpke_config,
424 helper_hpke_config,
425 })
426 }
427
428 pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
430 self.http_client = Some(http_client);
431 self
432 }
433
434 pub fn with_backoff(mut self, http_request_retry_parameters: ExponentialBackoff) -> Self {
436 self.parameters.http_request_retry_parameters = http_request_retry_parameters;
437 self
438 }
439
440 pub fn with_leader_hpke_config(mut self, hpke_config: HpkeConfig) -> Self {
443 self.leader_hpke_config = Some(hpke_config);
444 self
445 }
446
447 pub fn with_helper_hpke_config(mut self, hpke_config: HpkeConfig) -> Self {
450 self.helper_hpke_config = Some(hpke_config);
451 self
452 }
453
454 #[cfg(feature = "ohttp")]
489 #[cfg_attr(docsrs, doc(cfg(feature = "ohttp")))]
490 pub fn with_ohttp_config(mut self, ohttp_config: OhttpConfig) -> Self {
491 self.ohttp_config = Some(ohttp_config);
492 self
493 }
494}
495
496#[derive(Clone, Debug)]
498pub struct Client<V: vdaf::Client<16>> {
499 parameters: ClientParameters,
500 vdaf: V,
501 #[cfg(feature = "ohttp")]
502 ohttp_config: Option<(OhttpConfig, Vec<KeyConfig>)>,
503 http_client: reqwest::Client,
504 leader_hpke_config: HpkeConfig,
505 helper_hpke_config: HpkeConfig,
506}
507
508impl<V: vdaf::Client<16>> Client<V> {
509 pub async fn new(
511 task_id: TaskId,
512 leader_aggregator_endpoint: Url,
513 helper_aggregator_endpoint: Url,
514 time_precision: Duration,
515 vdaf: V,
516 ) -> Result<Self, Error> {
517 ClientBuilder::new(
518 task_id,
519 leader_aggregator_endpoint,
520 helper_aggregator_endpoint,
521 time_precision,
522 vdaf,
523 )
524 .build()
525 .await
526 }
527
528 #[deprecated(
536 note = "Use `ClientBuilder::with_leader_hpke_config`, `ClientBuilder::with_helper_hpke_config` and `ClientBuilder::build` instead"
537 )]
538 pub fn with_hpke_configs(
539 task_id: TaskId,
540 leader_aggregator_endpoint: Url,
541 helper_aggregator_endpoint: Url,
542 time_precision: Duration,
543 vdaf: V,
544 leader_hpke_config: HpkeConfig,
545 helper_hpke_config: HpkeConfig,
546 ) -> Result<Self, Error> {
547 #[allow(deprecated)]
548 ClientBuilder::new(
549 task_id,
550 leader_aggregator_endpoint,
551 helper_aggregator_endpoint,
552 time_precision,
553 vdaf,
554 )
555 .build_with_hpke_configs(leader_hpke_config, helper_hpke_config)
556 }
557
558 pub fn builder(
561 task_id: TaskId,
562 leader_aggregator_endpoint: Url,
563 helper_aggregator_endpoint: Url,
564 time_precision: Duration,
565 vdaf: V,
566 ) -> ClientBuilder<V> {
567 ClientBuilder::new(
568 task_id,
569 leader_aggregator_endpoint,
570 helper_aggregator_endpoint,
571 time_precision,
572 vdaf,
573 )
574 }
575
576 fn prepare_report(&self, measurement: &V::Measurement, time: &Time) -> Result<Report, Error> {
579 let report_id: ReportId = random();
580 let (public_share, input_shares) = self.vdaf.shard(measurement, report_id.as_ref())?;
581 assert_eq!(input_shares.len(), 2); let time = time
584 .to_batch_interval_start(&self.parameters.time_precision)
585 .map_err(|_| Error::InvalidParameter("couldn't round time down to time_precision"))?;
586 let report_metadata = ReportMetadata::new(report_id, time);
587 let encoded_public_share = public_share.get_encoded()?;
588
589 let (leader_encrypted_input_share, helper_encrypted_input_share) = [
590 (&self.leader_hpke_config, &Role::Leader),
591 (&self.helper_hpke_config, &Role::Helper),
592 ]
593 .into_iter()
594 .zip(input_shares)
595 .map(|((hpke_config, receiver_role), input_share)| {
596 hpke::seal(
597 hpke_config,
598 &HpkeApplicationInfo::new(&Label::InputShare, &Role::Client, receiver_role),
599 &PlaintextInputShare::new(
600 Vec::new(), input_share.get_encoded()?,
602 )
603 .get_encoded()?,
604 &InputShareAad::new(
605 self.parameters.task_id,
606 report_metadata.clone(),
607 encoded_public_share.clone(),
608 )
609 .get_encoded()?,
610 )
611 .map_err(Error::Hpke)
612 })
613 .collect_tuple()
614 .expect("iterator to yield two items"); Ok(Report::new(
617 report_metadata,
618 encoded_public_share,
619 leader_encrypted_input_share?,
620 helper_encrypted_input_share?,
621 ))
622 }
623
624 #[tracing::instrument(skip(measurement), err)]
629 pub async fn upload(&self, measurement: &V::Measurement) -> Result<(), Error> {
630 self.upload_with_time(measurement, Clock::now(&RealClock::default()))
631 .await
632 }
633
634 #[tracing::instrument(skip(measurement), err)]
663 pub async fn upload_with_time<T>(
664 &self,
665 measurement: &V::Measurement,
666 time: T,
667 ) -> Result<(), Error>
668 where
669 T: TryInto<Time> + Debug,
670 Error: From<<T as TryInto<Time>>::Error>,
671 {
672 let report = self
673 .prepare_report(measurement, &time.try_into()?)?
674 .get_encoded()?;
675 let upload_endpoint = self
676 .parameters
677 .reports_resource_uri(&self.parameters.task_id)?;
678
679 #[cfg(feature = "ohttp")]
680 let upload_status = self.upload_with_ohttp(&upload_endpoint, &report).await?;
681 #[cfg(not(feature = "ohttp"))]
682 let upload_status = self.put_report(&upload_endpoint, &report).await?;
683
684 if !upload_status.is_success() {
685 return Err(Error::Http(Box::new(HttpErrorResponse::from(
686 upload_status,
687 ))));
688 }
689
690 Ok(())
691 }
692
693 async fn put_report(
694 &self,
695 upload_endpoint: &Url,
696 request_body: &[u8],
697 ) -> Result<StatusCode, Error> {
698 Ok(retry_http_request(
699 self.parameters.http_request_retry_parameters.clone(),
700 || async {
701 self.http_client
702 .put(upload_endpoint.clone())
703 .header(CONTENT_TYPE, Report::MEDIA_TYPE)
704 .body(request_body.to_vec())
705 .send()
706 .await
707 },
708 )
709 .await?
710 .status())
711 }
712
713 #[cfg(feature = "ohttp")]
716 #[tracing::instrument(skip(self, request_body), err)]
717 async fn upload_with_ohttp(
718 &self,
719 upload_endpoint: &Url,
720 request_body: &[u8],
721 ) -> Result<StatusCode, Error> {
722 let (ohttp_config, key_configs) =
723 if let Some((ohttp_config, key_configs)) = &self.ohttp_config {
724 (ohttp_config, key_configs)
725 } else {
726 return self.put_report(upload_endpoint, request_body).await;
727 };
728
729 let mut message = Message::request(
731 "PUT".into(),
732 upload_endpoint.scheme().into(),
733 upload_endpoint.authority().into(),
734 upload_endpoint.path().into(),
735 );
736 message.put_header(CONTENT_TYPE.as_str(), Report::MEDIA_TYPE);
737 message.write_content(request_body);
738
739 let mut request_buf = Vec::new();
741 message.write_bhttp(Mode::KnownLength, &mut request_buf)?;
742
743 let ohttp_request = key_configs
745 .iter()
746 .cloned()
747 .find_map(|mut key_config| ClientRequest::from_config(&mut key_config).ok())
748 .ok_or_else(|| Error::OhttpNoSupportedKeyConfigs(Box::new(key_configs.to_vec())))?;
749
750 let (encapsulated_request, ohttp_response) = ohttp_request.encapsulate(&request_buf)?;
751
752 let relay_response = retry_http_request(
753 self.parameters.http_request_retry_parameters.clone(),
754 || async {
755 self.http_client
756 .post(ohttp_config.relay.clone())
757 .header(CONTENT_TYPE, OHTTP_REQUEST_MEDIA_TYPE)
758 .header(ACCEPT, OHTTP_RESPONSE_MEDIA_TYPE)
759 .body(encapsulated_request.clone())
760 .send()
761 .await
762 },
763 )
764 .await?;
765
766 if !relay_response.status().is_success() {
769 return Err(Error::Http(Box::new(HttpErrorResponse::from(
770 relay_response.status(),
771 ))));
772 }
773
774 if relay_response
775 .headers()
776 .get(CONTENT_TYPE)
777 .map(HeaderValue::as_bytes)
778 != Some(OHTTP_RESPONSE_MEDIA_TYPE.as_bytes())
779 {
780 return Err(Error::UnexpectedServerResponse(
781 "content type wrong for OHTTP response",
782 ));
783 }
784
785 let decapsulated_response = ohttp_response.decapsulate(relay_response.body().as_ref())?;
786 let message = Message::read_bhttp(&mut Cursor::new(&decapsulated_response))?;
787 let status = if let ControlData::Response(status) = message.control() {
788 StatusCode::from_u16((*status).into()).map_err(|_| {
789 Error::UnexpectedServerResponse(
790 "status in decapsulated response is not valid HTTP status",
791 )
792 })?
793 } else {
794 return Err(Error::UnexpectedServerResponse(
795 "decapsulated response control data is not a response",
796 ));
797 };
798
799 Ok(status)
800 }
801}