1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 missing_debug_implementations,
6 nonstandard_style
7)]
8#![warn(clippy::perf, clippy::cargo)]
9#![allow(clippy::cargo_common_metadata)]
10#![allow(clippy::multiple_crate_versions)]
11
12mod account;
13mod aggregator;
14mod api_token;
15mod collector_credentials;
16pub mod dp_strategy;
17mod membership;
18mod protocol;
19mod task;
20mod validation_errors;
21
22pub const CONTENT_TYPE: &str = "application/vnd.divviup+json;version=0.1";
23pub const DEFAULT_URL: &str = "https://api.divviup.org/";
24pub const USER_AGENT: &str = concat!("divviup-client/", env!("CARGO_PKG_VERSION"));
25
26use base64::{engine::general_purpose::STANDARD, Engine};
27use serde::{de::DeserializeOwned, Serialize};
28use serde_json::json;
29use std::{fmt::Display, future::Future, pin::Pin};
30use time::format_description::well_known::Rfc3339;
31use trillium_http::{HeaderName, HeaderValues};
32
33pub use account::Account;
34pub use aggregator::{Aggregator, CollectorAuthenticationToken, NewAggregator, Role};
35pub use api_token::ApiToken;
36pub use collector_credentials::CollectorCredential;
37pub use janus_messages::{
38 codec::{CodecError, Decode, Encode},
39 HpkeConfig, HpkePublicKey,
40};
41pub use membership::Membership;
42pub use num_bigint::BigUint;
43pub use num_rational::Ratio;
44pub use protocol::Protocol;
45pub use task::{Histogram, NewTask, SumVec, Task, Vdaf};
46pub use time::OffsetDateTime;
47pub use trillium_client;
48pub use trillium_client::Client;
49pub use trillium_client::Conn;
50pub use trillium_http::{HeaderValue, Headers, KnownHeaderName, Method, Status};
51pub use url::Url;
52pub use uuid::Uuid;
53pub use validation_errors::ValidationErrors;
54
55#[cfg(feature = "admin")]
56pub use aggregator::NewSharedAggregator;
57
58trait ErrInto<T, E1, E2> {
59 fn err_into(self) -> Result<T, E2>;
60}
61impl<T, E1, E2> ErrInto<T, E1, E2> for Result<T, E1>
62where
63 E2: From<E1>,
64{
65 fn err_into(self) -> Result<T, E2> {
66 self.map_err(Into::into)
67 }
68}
69
70#[derive(Debug, Clone)]
71pub struct DivviupClient(Client);
72
73impl DivviupClient {
74 pub fn new(token: impl Display, http_client: impl Into<Client>) -> Self {
75 Self(
76 http_client
77 .into()
78 .with_default_header(KnownHeaderName::UserAgent, USER_AGENT)
79 .with_default_header(KnownHeaderName::Accept, CONTENT_TYPE)
80 .with_default_header(KnownHeaderName::Authorization, format!("Bearer {token}"))
81 .with_base(DEFAULT_URL),
82 )
83 }
84
85 pub fn with_default_pool(mut self) -> Self {
86 self.0 = self.0.with_default_pool();
87 self
88 }
89
90 pub fn with_header(
91 mut self,
92 name: impl Into<HeaderName<'static>>,
93 value: impl Into<HeaderValues>,
94 ) -> Self {
95 self.insert_header(name, value);
96 self
97 }
98
99 pub fn insert_header(
100 &mut self,
101 name: impl Into<HeaderName<'static>>,
102 value: impl Into<HeaderValues>,
103 ) {
104 self.headers_mut().insert(name, value);
105 }
106
107 pub fn headers(&self) -> &Headers {
108 self.0.default_headers()
109 }
110
111 pub fn headers_mut(&mut self) -> &mut Headers {
112 self.0.default_headers_mut()
113 }
114
115 pub fn with_url(mut self, url: Url) -> Self {
116 self.set_url(url);
117 self
118 }
119
120 pub fn set_url(&mut self, url: Url) {
121 self.0.set_base(url).unwrap();
122 }
123
124 async fn get<T>(&self, path: &str) -> ClientResult<T>
125 where
126 T: DeserializeOwned,
127 {
128 self.0
129 .get(path)
130 .success_or_error()
131 .await?
132 .response_json()
133 .await
134 .err_into()
135 }
136
137 async fn patch<T>(&self, path: &str, body: &impl Serialize) -> ClientResult<T>
138 where
139 T: DeserializeOwned,
140 {
141 self.0
142 .patch(path)
143 .with_json_body(body)?
144 .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE)
145 .success_or_error()
146 .await?
147 .response_json()
148 .await
149 .err_into()
150 }
151
152 async fn post<T>(&self, path: &str, body: Option<&impl Serialize>) -> ClientResult<T>
153 where
154 T: DeserializeOwned,
155 {
156 let mut conn = self.0.post(path);
157
158 if let Some(body) = body {
159 conn = conn
160 .with_json_body(body)?
161 .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE);
162 }
163
164 conn.success_or_error()
165 .await?
166 .response_json()
167 .await
168 .err_into()
169 }
170
171 async fn delete(&self, path: &str) -> ClientResult {
172 let _ = self.0.delete(path).success_or_error().await?;
173 Ok(())
174 }
175
176 pub async fn accounts(&self) -> ClientResult<Vec<Account>> {
177 self.get("api/accounts").await
178 }
179
180 pub async fn rename_account(&self, account_id: Uuid, new_name: &str) -> ClientResult<Account> {
181 self.patch(
182 &format!("api/accounts/{account_id}"),
183 &json!({ "name": new_name }),
184 )
185 .await
186 }
187
188 pub async fn aggregator(&self, aggregator_id: Uuid) -> ClientResult<Aggregator> {
189 self.get(&format!("api/aggregators/{aggregator_id}")).await
190 }
191
192 pub async fn aggregators(&self, account_id: Uuid) -> ClientResult<Vec<Aggregator>> {
193 self.get(&format!("api/accounts/{account_id}/aggregators"))
194 .await
195 }
196
197 pub async fn create_aggregator(
198 &self,
199 account_id: Uuid,
200 aggregator: NewAggregator,
201 ) -> ClientResult<Aggregator> {
202 self.post(
203 &format!("api/accounts/{account_id}/aggregators"),
204 Some(&aggregator),
205 )
206 .await
207 }
208
209 pub async fn rename_aggregator(
210 &self,
211 aggregator_id: Uuid,
212 new_name: &str,
213 ) -> ClientResult<Aggregator> {
214 self.patch(
215 &format!("api/aggregators/{aggregator_id}"),
216 &json!({ "name": new_name }),
217 )
218 .await
219 }
220
221 pub async fn rotate_aggregator_bearer_token(
222 &self,
223 aggregator_id: Uuid,
224 new_bearer_token: &str,
225 ) -> ClientResult<Aggregator> {
226 self.patch(
227 &format!("api/aggregators/{aggregator_id}"),
228 &json!({ "bearer_token": new_bearer_token }),
229 )
230 .await
231 }
232
233 pub async fn update_aggregator_configuration(
234 &self,
235 aggregator_id: Uuid,
236 ) -> ClientResult<Aggregator> {
237 self.patch(&format!("api/aggregators/{aggregator_id}"), &json!({}))
238 .await
239 }
240
241 pub async fn delete_aggregator(&self, aggregator_id: Uuid) -> ClientResult {
242 self.delete(&format!("api/aggregators/{aggregator_id}"))
243 .await
244 }
245
246 pub async fn memberships(&self, account_id: Uuid) -> ClientResult<Vec<Membership>> {
247 self.get(&format!("api/accounts/{account_id}/memberships"))
248 .await
249 }
250
251 pub async fn delete_membership(&self, membership_id: Uuid) -> ClientResult {
252 self.delete(&format!("api/memberships/{membership_id}"))
253 .await
254 }
255
256 pub async fn create_membership(
257 &self,
258 account_id: Uuid,
259 email: &str,
260 ) -> ClientResult<Membership> {
261 self.post(
262 &format!("api/accounts/{account_id}/memberships"),
263 Some(&json!({ "user_email": email })),
264 )
265 .await
266 }
267
268 pub async fn tasks(&self, account_id: Uuid) -> ClientResult<Vec<Task>> {
269 self.get(&format!("api/accounts/{account_id}/tasks")).await
270 }
271
272 pub async fn task(&self, task_id: &str) -> ClientResult<Task> {
273 self.get(&format!("api/tasks/{task_id}")).await
274 }
275
276 pub async fn create_task(&self, account_id: Uuid, task: NewTask) -> ClientResult<Task> {
277 self.post(&format!("api/accounts/{account_id}/tasks"), Some(&task))
278 .await
279 }
280
281 pub async fn task_collector_auth_tokens(
282 &self,
283 task_id: &str,
284 ) -> ClientResult<Vec<CollectorAuthenticationToken>> {
285 self.get(&format!("api/tasks/{task_id}/collector_auth_tokens"))
286 .await
287 }
288
289 pub async fn rename_task(&self, task_id: &str, new_name: &str) -> ClientResult<Task> {
290 self.patch(&format!("api/tasks/{task_id}"), &json!({"name": new_name}))
291 .await
292 }
293
294 pub async fn set_task_expiration(
295 &self,
296 task_id: &str,
297 expiration: Option<&OffsetDateTime>,
298 ) -> ClientResult<Task> {
299 self.patch(
300 &format!("api/tasks/{task_id}"),
301 &json!({
302 "expiration": expiration.map(|e| e.format(&Rfc3339)).transpose()?
303 }),
304 )
305 .await
306 }
307
308 pub async fn delete_task(&self, task_id: &str) -> ClientResult<()> {
309 self.delete(&format!("api/tasks/{task_id}")).await
310 }
311
312 pub async fn force_delete_task(&self, task_id: &str) -> ClientResult<()> {
313 self.delete(&format!("api/tasks/{task_id}?force=true"))
314 .await
315 }
316
317 pub async fn api_tokens(&self, account_id: Uuid) -> ClientResult<Vec<ApiToken>> {
318 self.get(&format!("api/accounts/{account_id}/api_tokens"))
319 .await
320 }
321
322 pub async fn create_api_token(&self, account_id: Uuid) -> ClientResult<ApiToken> {
323 self.post(
324 &format!("api/accounts/{account_id}/api_tokens"),
325 Option::<&()>::None,
326 )
327 .await
328 }
329
330 pub async fn delete_api_token(&self, api_token_id: Uuid) -> ClientResult {
331 self.delete(&format!("api/api_tokens/{api_token_id}")).await
332 }
333
334 pub async fn collector_credentials(
335 &self,
336 account_id: Uuid,
337 ) -> ClientResult<Vec<CollectorCredential>> {
338 self.get(&format!("api/accounts/{account_id}/collector_credentials"))
339 .await
340 }
341
342 pub async fn rename_collector_credential(
343 &self,
344 collector_credential_id: Uuid,
345 new_name: &str,
346 ) -> ClientResult<CollectorCredential> {
347 self.patch(
348 &format!("api/collector_credentials/{collector_credential_id}"),
349 &json!({"name": new_name}),
350 )
351 .await
352 }
353
354 pub async fn create_collector_credential(
355 &self,
356 account_id: Uuid,
357 hpke_config: &HpkeConfig,
358 name: Option<&str>,
359 ) -> ClientResult<CollectorCredential> {
360 self.post(
361 &format!("api/accounts/{account_id}/collector_credentials"),
362 Some(&json!({
363 "name": name,
364 "hpke_config": STANDARD.encode(hpke_config.get_encoded()?)
365 })),
366 )
367 .await
368 }
369
370 pub async fn delete_collector_credential(&self, collector_credential_id: Uuid) -> ClientResult {
371 self.delete(&format!(
372 "api/collector_credentials/{collector_credential_id}"
373 ))
374 .await
375 }
376
377 pub async fn shared_aggregators(&self) -> ClientResult<Vec<Aggregator>> {
378 self.get("api/aggregators").await
379 }
380}
381
382#[cfg(feature = "admin")]
383impl DivviupClient {
384 pub async fn create_account(&self, name: &str) -> ClientResult<Account> {
385 self.post("api/accounts", Some(&json!({ "name": name })))
386 .await
387 }
388
389 pub async fn create_shared_aggregator(
390 &self,
391 aggregator: NewSharedAggregator,
392 ) -> ClientResult<Aggregator> {
393 self.post("api/aggregators", Some(&aggregator)).await
394 }
395}
396
397pub type ClientResult<T = ()> = Result<T, Error>;
398
399#[derive(thiserror::Error, Debug)]
400pub enum Error {
401 #[error(transparent)]
402 Http(#[from] trillium_http::Error),
403
404 #[error(transparent)]
405 Client(#[from] trillium_client::ClientSerdeError),
406
407 #[error(transparent)]
408 Url(#[from] url::ParseError),
409
410 #[error(transparent)]
411 Json(#[from] serde_json::Error),
412
413 #[error("unexpected http status {method} {url} {status:?}: {body}")]
414 HttpStatusNotSuccess {
415 method: Method,
416 url: Url,
417 status: Option<Status>,
418 body: String,
419 },
420
421 #[error("Validation errors:\n{0}")]
422 ValidationErrors(ValidationErrors),
423
424 #[error(transparent)]
425 Codec(#[from] CodecError),
426
427 #[error("time formatting error: {0}")]
428 TimeFormat(#[from] time::error::Format),
429}
430
431pub trait ClientConnExt: Sized {
432 fn success_or_error(self)
433 -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>>;
434}
435impl ClientConnExt for Conn {
436 fn success_or_error(
437 self,
438 ) -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>> {
439 Box::pin(async move {
440 let mut error = match self.await?.success() {
441 Ok(conn) => return Ok(conn),
442 Err(error) => error,
443 };
444
445 let status = error.status();
446 if let Some(Status::BadRequest) = status {
447 let body = error.response_body().read_string().await?;
448 log::trace!("{body}");
449 Err(Error::ValidationErrors(serde_json::from_str(&body)?))
450 } else {
451 let url = error.url().clone();
452 let method = error.method();
453 let body = error.response_body().await?;
454 Err(Error::HttpStatusNotSuccess {
455 method,
456 url,
457 status,
458 body,
459 })
460 }
461 })
462 }
463}