1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 missing_debug_implementations,
6 nonstandard_style
7)]
8#![warn(clippy::perf, clippy::cargo)]
9#![allow(clippy::cargo_common_metadata)]
10#![allow(clippy::multiple_crate_versions)]
11
12mod account;
13mod aggregator;
14mod api_token;
15mod collector_credentials;
16pub mod dp_strategy;
17mod membership;
18mod protocol;
19mod task;
20mod validation_errors;
21
22pub const CONTENT_TYPE: &str = "application/vnd.divviup+json;version=0.1";
23pub const DEFAULT_URL: &str = "https://api.divviup.org/";
24pub const USER_AGENT: &str = concat!("divviup-client/", env!("CARGO_PKG_VERSION"));
25
26use base64::{engine::general_purpose::STANDARD, Engine};
27use serde::{de::DeserializeOwned, Serialize};
28use serde_json::json;
29use std::{fmt::Display, future::Future, pin::Pin};
30use time::format_description::well_known::Rfc3339;
31use trillium_http::{HeaderName, HeaderValues};
32
33pub use account::Account;
34pub use aggregator::{Aggregator, CollectorAuthenticationToken, NewAggregator, Role};
35pub use api_token::ApiToken;
36pub use collector_credentials::CollectorCredential;
37pub use janus_messages::{
38 codec::{CodecError, Decode, Encode},
39 HpkeConfig, HpkePublicKey,
40};
41pub use membership::Membership;
42pub use num_bigint::BigUint;
43pub use num_rational::Ratio;
44pub use protocol::Protocol;
45pub use task::{Histogram, NewTask, SumVec, Task, Vdaf};
46pub use time::OffsetDateTime;
47pub use trillium_client;
48pub use trillium_client::Client;
49pub use trillium_client::Conn;
50pub use trillium_http::{HeaderValue, Headers, KnownHeaderName, Method, Status};
51pub use url::Url;
52pub use uuid::Uuid;
53pub use validation_errors::ValidationErrors;
54
55#[cfg(feature = "admin")]
56pub use aggregator::NewSharedAggregator;
57
58trait ErrInto<T, E1, E2> {
59 fn err_into(self) -> Result<T, E2>;
60}
61impl<T, E1, E2> ErrInto<T, E1, E2> for Result<T, E1>
62where
63 E2: From<E1>,
64{
65 fn err_into(self) -> Result<T, E2> {
66 self.map_err(Into::into)
67 }
68}
69
70#[derive(Debug, Clone)]
71pub struct DivviupClient(Client);
72
73impl DivviupClient {
74 pub fn new(token: impl Display, http_client: impl Into<Client>) -> Self {
75 Self(
76 http_client
77 .into()
78 .with_default_header(KnownHeaderName::UserAgent, USER_AGENT)
79 .with_default_header(KnownHeaderName::Accept, CONTENT_TYPE)
80 .with_default_header(KnownHeaderName::Authorization, format!("Bearer {token}"))
81 .with_base(DEFAULT_URL),
82 )
83 }
84
85 pub fn with_default_pool(mut self) -> Self {
86 self.0 = self.0.with_default_pool();
87 self
88 }
89
90 pub fn with_header(
91 mut self,
92 name: impl Into<HeaderName<'static>>,
93 value: impl Into<HeaderValues>,
94 ) -> Self {
95 self.insert_header(name, value);
96 self
97 }
98
99 pub fn insert_header(
100 &mut self,
101 name: impl Into<HeaderName<'static>>,
102 value: impl Into<HeaderValues>,
103 ) {
104 self.headers_mut().insert(name, value);
105 }
106
107 pub fn headers(&self) -> &Headers {
108 self.0.default_headers()
109 }
110
111 pub fn headers_mut(&mut self) -> &mut Headers {
112 self.0.default_headers_mut()
113 }
114
115 pub fn with_url(mut self, url: Url) -> Self {
116 self.set_url(url);
117 self
118 }
119
120 pub fn set_url(&mut self, url: Url) {
121 self.0.set_base(url).unwrap();
122 }
123
124 async fn get<T>(&self, path: &str) -> ClientResult<T>
125 where
126 T: DeserializeOwned,
127 {
128 self.0
129 .get(path)
130 .success_or_error()
131 .await?
132 .response_json()
133 .await
134 .err_into()
135 }
136
137 async fn patch<T>(&self, path: &str, body: &impl Serialize) -> ClientResult<T>
138 where
139 T: DeserializeOwned,
140 {
141 self.0
142 .patch(path)
143 .with_json_body(body)?
144 .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE)
145 .success_or_error()
146 .await?
147 .response_json()
148 .await
149 .err_into()
150 }
151
152 async fn post<T>(&self, path: &str, body: Option<&impl Serialize>) -> ClientResult<T>
153 where
154 T: DeserializeOwned,
155 {
156 let mut conn = self.0.post(path);
157
158 if let Some(body) = body {
159 conn = conn
160 .with_json_body(body)?
161 .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE);
162 }
163
164 conn.success_or_error()
165 .await?
166 .response_json()
167 .await
168 .err_into()
169 }
170
171 async fn delete(&self, path: &str) -> ClientResult {
172 let _ = self.0.delete(path).success_or_error().await?;
173 Ok(())
174 }
175
176 pub async fn accounts(&self) -> ClientResult<Vec<Account>> {
177 self.get("api/accounts").await
178 }
179
180 pub async fn rename_account(&self, account_id: Uuid, new_name: &str) -> ClientResult<Account> {
181 self.patch(
182 &format!("api/accounts/{account_id}"),
183 &json!({ "name": new_name }),
184 )
185 .await
186 }
187
188 pub async fn aggregator(&self, aggregator_id: Uuid) -> ClientResult<Aggregator> {
189 self.get(&format!("api/aggregators/{aggregator_id}")).await
190 }
191
192 pub async fn aggregators(&self, account_id: Uuid) -> ClientResult<Vec<Aggregator>> {
193 self.get(&format!("api/accounts/{account_id}/aggregators"))
194 .await
195 }
196
197 pub async fn create_aggregator(
198 &self,
199 account_id: Uuid,
200 aggregator: NewAggregator,
201 ) -> ClientResult<Aggregator> {
202 self.post(
203 &format!("api/accounts/{account_id}/aggregators"),
204 Some(&aggregator),
205 )
206 .await
207 }
208
209 pub async fn rename_aggregator(
210 &self,
211 aggregator_id: Uuid,
212 new_name: &str,
213 ) -> ClientResult<Aggregator> {
214 self.patch(
215 &format!("api/aggregators/{aggregator_id}"),
216 &json!({ "name": new_name }),
217 )
218 .await
219 }
220
221 pub async fn rotate_aggregator_bearer_token(
222 &self,
223 aggregator_id: Uuid,
224 new_bearer_token: &str,
225 ) -> ClientResult<Aggregator> {
226 self.patch(
227 &format!("api/aggregators/{aggregator_id}"),
228 &json!({ "bearer_token": new_bearer_token }),
229 )
230 .await
231 }
232
233 pub async fn update_aggregator_configuration(
234 &self,
235 aggregator_id: Uuid,
236 ) -> ClientResult<Aggregator> {
237 self.patch(&format!("api/aggregators/{aggregator_id}"), &json!({}))
238 .await
239 }
240
241 pub async fn delete_aggregator(&self, aggregator_id: Uuid) -> ClientResult {
242 self.delete(&format!("api/aggregators/{aggregator_id}"))
243 .await
244 }
245
246 pub async fn memberships(&self, account_id: Uuid) -> ClientResult<Vec<Membership>> {
247 self.get(&format!("api/accounts/{account_id}/memberships"))
248 .await
249 }
250
251 pub async fn delete_membership(&self, membership_id: Uuid) -> ClientResult {
252 self.delete(&format!("api/memberships/{membership_id}"))
253 .await
254 }
255
256 pub async fn create_membership(
257 &self,
258 account_id: Uuid,
259 email: &str,
260 ) -> ClientResult<Membership> {
261 self.post(
262 &format!("api/accounts/{account_id}/memberships"),
263 Some(&json!({ "user_email": email })),
264 )
265 .await
266 }
267
268 pub async fn tasks(&self, account_id: Uuid) -> ClientResult<Vec<Task>> {
269 self.get(&format!("api/accounts/{account_id}/tasks")).await
270 }
271
272 pub async fn task(&self, task_id: &str) -> ClientResult<Task> {
273 self.get(&format!("api/tasks/{task_id}")).await
274 }
275
276 pub async fn create_task(&self, account_id: Uuid, task: NewTask) -> ClientResult<Task> {
277 self.post(&format!("api/accounts/{account_id}/tasks"), Some(&task))
278 .await
279 }
280
281 pub async fn rename_task(&self, task_id: &str, new_name: &str) -> ClientResult<Task> {
282 self.patch(&format!("api/tasks/{task_id}"), &json!({"name": new_name}))
283 .await
284 }
285
286 pub async fn set_task_expiration(
287 &self,
288 task_id: &str,
289 expiration: Option<&OffsetDateTime>,
290 ) -> ClientResult<Task> {
291 self.patch(
292 &format!("api/tasks/{task_id}"),
293 &json!({
294 "expiration": expiration.map(|e| e.format(&Rfc3339)).transpose()?
295 }),
296 )
297 .await
298 }
299
300 pub async fn delete_task(&self, task_id: &str) -> ClientResult<()> {
301 self.delete(&format!("api/tasks/{task_id}")).await
302 }
303
304 pub async fn force_delete_task(&self, task_id: &str) -> ClientResult<()> {
305 self.delete(&format!("api/tasks/{task_id}?force=true"))
306 .await
307 }
308
309 pub async fn api_tokens(&self, account_id: Uuid) -> ClientResult<Vec<ApiToken>> {
310 self.get(&format!("api/accounts/{account_id}/api_tokens"))
311 .await
312 }
313
314 pub async fn create_api_token(&self, account_id: Uuid) -> ClientResult<ApiToken> {
315 self.post(
316 &format!("api/accounts/{account_id}/api_tokens"),
317 Option::<&()>::None,
318 )
319 .await
320 }
321
322 pub async fn delete_api_token(&self, api_token_id: Uuid) -> ClientResult {
323 self.delete(&format!("api/api_tokens/{api_token_id}")).await
324 }
325
326 pub async fn collector_credentials(
327 &self,
328 account_id: Uuid,
329 ) -> ClientResult<Vec<CollectorCredential>> {
330 self.get(&format!("api/accounts/{account_id}/collector_credentials"))
331 .await
332 }
333
334 pub async fn rename_collector_credential(
335 &self,
336 collector_credential_id: Uuid,
337 new_name: &str,
338 ) -> ClientResult<CollectorCredential> {
339 self.patch(
340 &format!("api/collector_credentials/{collector_credential_id}"),
341 &json!({"name": new_name}),
342 )
343 .await
344 }
345
346 pub async fn create_collector_credential(
347 &self,
348 account_id: Uuid,
349 hpke_config: &HpkeConfig,
350 name: Option<&str>,
351 ) -> ClientResult<CollectorCredential> {
352 self.post(
353 &format!("api/accounts/{account_id}/collector_credentials"),
354 Some(&json!({
355 "name": name,
356 "hpke_config": STANDARD.encode(hpke_config.get_encoded()?)
357 })),
358 )
359 .await
360 }
361
362 pub async fn delete_collector_credential(&self, collector_credential_id: Uuid) -> ClientResult {
363 self.delete(&format!(
364 "api/collector_credentials/{collector_credential_id}"
365 ))
366 .await
367 }
368
369 pub async fn shared_aggregators(&self) -> ClientResult<Vec<Aggregator>> {
370 self.get("api/aggregators").await
371 }
372}
373
374#[cfg(feature = "admin")]
375impl DivviupClient {
376 pub async fn create_account(&self, name: &str) -> ClientResult<Account> {
377 self.post("api/accounts", Some(&json!({ "name": name })))
378 .await
379 }
380
381 pub async fn create_shared_aggregator(
382 &self,
383 aggregator: NewSharedAggregator,
384 ) -> ClientResult<Aggregator> {
385 self.post("api/aggregators", Some(&aggregator)).await
386 }
387}
388
389pub type ClientResult<T = ()> = Result<T, Error>;
390
391#[derive(thiserror::Error, Debug)]
392pub enum Error {
393 #[error(transparent)]
394 Http(#[from] trillium_http::Error),
395
396 #[error(transparent)]
397 Client(#[from] trillium_client::ClientSerdeError),
398
399 #[error(transparent)]
400 Url(#[from] url::ParseError),
401
402 #[error(transparent)]
403 Json(#[from] serde_json::Error),
404
405 #[error("unexpected http status {method} {url} {status:?}: {body}")]
406 HttpStatusNotSuccess {
407 method: Method,
408 url: Url,
409 status: Option<Status>,
410 body: String,
411 },
412
413 #[error("Validation errors:\n{0}")]
414 ValidationErrors(ValidationErrors),
415
416 #[error(transparent)]
417 Codec(#[from] CodecError),
418
419 #[error("time formatting error: {0}")]
420 TimeFormat(#[from] time::error::Format),
421}
422
423pub trait ClientConnExt: Sized {
424 fn success_or_error(self)
425 -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>>;
426}
427impl ClientConnExt for Conn {
428 fn success_or_error(
429 self,
430 ) -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>> {
431 Box::pin(async move {
432 let mut error = match self.await?.success() {
433 Ok(conn) => return Ok(conn),
434 Err(error) => error,
435 };
436
437 let status = error.status();
438 if let Some(Status::BadRequest) = status {
439 let body = error.response_body().read_string().await?;
440 log::trace!("{body}");
441 Err(Error::ValidationErrors(serde_json::from_str(&body)?))
442 } else {
443 let url = error.url().clone();
444 let method = error.method();
445 let body = error.response_body().await?;
446 Err(Error::HttpStatusNotSuccess {
447 method,
448 url,
449 status,
450 body,
451 })
452 }
453 })
454 }
455}