divviup_client/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    missing_debug_implementations,
6    nonstandard_style
7)]
8#![warn(clippy::perf, clippy::cargo)]
9#![allow(clippy::cargo_common_metadata)]
10#![allow(clippy::multiple_crate_versions)]
11
12mod account;
13mod aggregator;
14mod api_token;
15mod collector_credentials;
16pub mod dp_strategy;
17mod membership;
18mod protocol;
19mod task;
20mod validation_errors;
21
22pub const CONTENT_TYPE: &str = "application/vnd.divviup+json;version=0.1";
23pub const DEFAULT_URL: &str = "https://api.divviup.org/";
24pub const USER_AGENT: &str = concat!("divviup-client/", env!("CARGO_PKG_VERSION"));
25
26use base64::{engine::general_purpose::STANDARD, Engine};
27use serde::{de::DeserializeOwned, Serialize};
28use serde_json::json;
29use std::{fmt::Display, future::Future, pin::Pin};
30use time::format_description::well_known::Rfc3339;
31use trillium_http::{HeaderName, HeaderValues};
32
33pub use account::Account;
34pub use aggregator::{Aggregator, CollectorAuthenticationToken, NewAggregator, Role};
35pub use api_token::ApiToken;
36pub use collector_credentials::CollectorCredential;
37pub use janus_messages::{
38    codec::{CodecError, Decode, Encode},
39    HpkeConfig, HpkePublicKey,
40};
41pub use membership::Membership;
42pub use num_bigint::BigUint;
43pub use num_rational::Ratio;
44pub use protocol::Protocol;
45pub use task::{Histogram, NewTask, SumVec, Task, Vdaf};
46pub use time::OffsetDateTime;
47pub use trillium_client;
48pub use trillium_client::Client;
49pub use trillium_client::Conn;
50pub use trillium_http::{HeaderValue, Headers, KnownHeaderName, Method, Status};
51pub use url::Url;
52pub use uuid::Uuid;
53pub use validation_errors::ValidationErrors;
54
55#[cfg(feature = "admin")]
56pub use aggregator::NewSharedAggregator;
57
58trait ErrInto<T, E1, E2> {
59    fn err_into(self) -> Result<T, E2>;
60}
61impl<T, E1, E2> ErrInto<T, E1, E2> for Result<T, E1>
62where
63    E2: From<E1>,
64{
65    fn err_into(self) -> Result<T, E2> {
66        self.map_err(Into::into)
67    }
68}
69
70#[derive(Debug, Clone)]
71pub struct DivviupClient(Client);
72
73impl DivviupClient {
74    pub fn new(token: impl Display, http_client: impl Into<Client>) -> Self {
75        Self(
76            http_client
77                .into()
78                .with_default_header(KnownHeaderName::UserAgent, USER_AGENT)
79                .with_default_header(KnownHeaderName::Accept, CONTENT_TYPE)
80                .with_default_header(KnownHeaderName::Authorization, format!("Bearer {token}"))
81                .with_base(DEFAULT_URL),
82        )
83    }
84
85    pub fn with_default_pool(mut self) -> Self {
86        self.0 = self.0.with_default_pool();
87        self
88    }
89
90    pub fn with_header(
91        mut self,
92        name: impl Into<HeaderName<'static>>,
93        value: impl Into<HeaderValues>,
94    ) -> Self {
95        self.insert_header(name, value);
96        self
97    }
98
99    pub fn insert_header(
100        &mut self,
101        name: impl Into<HeaderName<'static>>,
102        value: impl Into<HeaderValues>,
103    ) {
104        self.headers_mut().insert(name, value);
105    }
106
107    pub fn headers(&self) -> &Headers {
108        self.0.default_headers()
109    }
110
111    pub fn headers_mut(&mut self) -> &mut Headers {
112        self.0.default_headers_mut()
113    }
114
115    pub fn with_url(mut self, url: Url) -> Self {
116        self.set_url(url);
117        self
118    }
119
120    pub fn set_url(&mut self, url: Url) {
121        self.0.set_base(url).unwrap();
122    }
123
124    async fn get<T>(&self, path: &str) -> ClientResult<T>
125    where
126        T: DeserializeOwned,
127    {
128        self.0
129            .get(path)
130            .success_or_error()
131            .await?
132            .response_json()
133            .await
134            .err_into()
135    }
136
137    async fn patch<T>(&self, path: &str, body: &impl Serialize) -> ClientResult<T>
138    where
139        T: DeserializeOwned,
140    {
141        self.0
142            .patch(path)
143            .with_json_body(body)?
144            .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE)
145            .success_or_error()
146            .await?
147            .response_json()
148            .await
149            .err_into()
150    }
151
152    async fn post<T>(&self, path: &str, body: Option<&impl Serialize>) -> ClientResult<T>
153    where
154        T: DeserializeOwned,
155    {
156        let mut conn = self.0.post(path);
157
158        if let Some(body) = body {
159            conn = conn
160                .with_json_body(body)?
161                .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE);
162        }
163
164        conn.success_or_error()
165            .await?
166            .response_json()
167            .await
168            .err_into()
169    }
170
171    async fn delete(&self, path: &str) -> ClientResult {
172        let _ = self.0.delete(path).success_or_error().await?;
173        Ok(())
174    }
175
176    pub async fn accounts(&self) -> ClientResult<Vec<Account>> {
177        self.get("api/accounts").await
178    }
179
180    pub async fn rename_account(&self, account_id: Uuid, new_name: &str) -> ClientResult<Account> {
181        self.patch(
182            &format!("api/accounts/{account_id}"),
183            &json!({ "name": new_name }),
184        )
185        .await
186    }
187
188    pub async fn aggregator(&self, aggregator_id: Uuid) -> ClientResult<Aggregator> {
189        self.get(&format!("api/aggregators/{aggregator_id}")).await
190    }
191
192    pub async fn aggregators(&self, account_id: Uuid) -> ClientResult<Vec<Aggregator>> {
193        self.get(&format!("api/accounts/{account_id}/aggregators"))
194            .await
195    }
196
197    pub async fn create_aggregator(
198        &self,
199        account_id: Uuid,
200        aggregator: NewAggregator,
201    ) -> ClientResult<Aggregator> {
202        self.post(
203            &format!("api/accounts/{account_id}/aggregators"),
204            Some(&aggregator),
205        )
206        .await
207    }
208
209    pub async fn rename_aggregator(
210        &self,
211        aggregator_id: Uuid,
212        new_name: &str,
213    ) -> ClientResult<Aggregator> {
214        self.patch(
215            &format!("api/aggregators/{aggregator_id}"),
216            &json!({ "name": new_name }),
217        )
218        .await
219    }
220
221    pub async fn rotate_aggregator_bearer_token(
222        &self,
223        aggregator_id: Uuid,
224        new_bearer_token: &str,
225    ) -> ClientResult<Aggregator> {
226        self.patch(
227            &format!("api/aggregators/{aggregator_id}"),
228            &json!({ "bearer_token": new_bearer_token }),
229        )
230        .await
231    }
232
233    pub async fn update_aggregator_configuration(
234        &self,
235        aggregator_id: Uuid,
236    ) -> ClientResult<Aggregator> {
237        self.patch(&format!("api/aggregators/{aggregator_id}"), &json!({}))
238            .await
239    }
240
241    pub async fn delete_aggregator(&self, aggregator_id: Uuid) -> ClientResult {
242        self.delete(&format!("api/aggregators/{aggregator_id}"))
243            .await
244    }
245
246    pub async fn memberships(&self, account_id: Uuid) -> ClientResult<Vec<Membership>> {
247        self.get(&format!("api/accounts/{account_id}/memberships"))
248            .await
249    }
250
251    pub async fn delete_membership(&self, membership_id: Uuid) -> ClientResult {
252        self.delete(&format!("api/memberships/{membership_id}"))
253            .await
254    }
255
256    pub async fn create_membership(
257        &self,
258        account_id: Uuid,
259        email: &str,
260    ) -> ClientResult<Membership> {
261        self.post(
262            &format!("api/accounts/{account_id}/memberships"),
263            Some(&json!({ "user_email": email })),
264        )
265        .await
266    }
267
268    pub async fn tasks(&self, account_id: Uuid) -> ClientResult<Vec<Task>> {
269        self.get(&format!("api/accounts/{account_id}/tasks")).await
270    }
271
272    pub async fn task(&self, task_id: &str) -> ClientResult<Task> {
273        self.get(&format!("api/tasks/{task_id}")).await
274    }
275
276    pub async fn create_task(&self, account_id: Uuid, task: NewTask) -> ClientResult<Task> {
277        self.post(&format!("api/accounts/{account_id}/tasks"), Some(&task))
278            .await
279    }
280
281    pub async fn rename_task(&self, task_id: &str, new_name: &str) -> ClientResult<Task> {
282        self.patch(&format!("api/tasks/{task_id}"), &json!({"name": new_name}))
283            .await
284    }
285
286    pub async fn set_task_expiration(
287        &self,
288        task_id: &str,
289        expiration: Option<&OffsetDateTime>,
290    ) -> ClientResult<Task> {
291        self.patch(
292            &format!("api/tasks/{task_id}"),
293            &json!({
294                "expiration": expiration.map(|e| e.format(&Rfc3339)).transpose()?
295            }),
296        )
297        .await
298    }
299
300    pub async fn delete_task(&self, task_id: &str) -> ClientResult<()> {
301        self.delete(&format!("api/tasks/{task_id}")).await
302    }
303
304    pub async fn force_delete_task(&self, task_id: &str) -> ClientResult<()> {
305        self.delete(&format!("api/tasks/{task_id}?force=true"))
306            .await
307    }
308
309    pub async fn api_tokens(&self, account_id: Uuid) -> ClientResult<Vec<ApiToken>> {
310        self.get(&format!("api/accounts/{account_id}/api_tokens"))
311            .await
312    }
313
314    pub async fn create_api_token(&self, account_id: Uuid) -> ClientResult<ApiToken> {
315        self.post(
316            &format!("api/accounts/{account_id}/api_tokens"),
317            Option::<&()>::None,
318        )
319        .await
320    }
321
322    pub async fn delete_api_token(&self, api_token_id: Uuid) -> ClientResult {
323        self.delete(&format!("api/api_tokens/{api_token_id}")).await
324    }
325
326    pub async fn collector_credentials(
327        &self,
328        account_id: Uuid,
329    ) -> ClientResult<Vec<CollectorCredential>> {
330        self.get(&format!("api/accounts/{account_id}/collector_credentials"))
331            .await
332    }
333
334    pub async fn rename_collector_credential(
335        &self,
336        collector_credential_id: Uuid,
337        new_name: &str,
338    ) -> ClientResult<CollectorCredential> {
339        self.patch(
340            &format!("api/collector_credentials/{collector_credential_id}"),
341            &json!({"name": new_name}),
342        )
343        .await
344    }
345
346    pub async fn create_collector_credential(
347        &self,
348        account_id: Uuid,
349        hpke_config: &HpkeConfig,
350        name: Option<&str>,
351    ) -> ClientResult<CollectorCredential> {
352        self.post(
353            &format!("api/accounts/{account_id}/collector_credentials"),
354            Some(&json!({
355                "name": name,
356                "hpke_config": STANDARD.encode(hpke_config.get_encoded()?)
357            })),
358        )
359        .await
360    }
361
362    pub async fn delete_collector_credential(&self, collector_credential_id: Uuid) -> ClientResult {
363        self.delete(&format!(
364            "api/collector_credentials/{collector_credential_id}"
365        ))
366        .await
367    }
368
369    pub async fn shared_aggregators(&self) -> ClientResult<Vec<Aggregator>> {
370        self.get("api/aggregators").await
371    }
372}
373
374#[cfg(feature = "admin")]
375impl DivviupClient {
376    pub async fn create_account(&self, name: &str) -> ClientResult<Account> {
377        self.post("api/accounts", Some(&json!({ "name": name })))
378            .await
379    }
380
381    pub async fn create_shared_aggregator(
382        &self,
383        aggregator: NewSharedAggregator,
384    ) -> ClientResult<Aggregator> {
385        self.post("api/aggregators", Some(&aggregator)).await
386    }
387}
388
389pub type ClientResult<T = ()> = Result<T, Error>;
390
391#[derive(thiserror::Error, Debug)]
392pub enum Error {
393    #[error(transparent)]
394    Http(#[from] trillium_http::Error),
395
396    #[error(transparent)]
397    Client(#[from] trillium_client::ClientSerdeError),
398
399    #[error(transparent)]
400    Url(#[from] url::ParseError),
401
402    #[error(transparent)]
403    Json(#[from] serde_json::Error),
404
405    #[error("unexpected http status {method} {url} {status:?}: {body}")]
406    HttpStatusNotSuccess {
407        method: Method,
408        url: Url,
409        status: Option<Status>,
410        body: String,
411    },
412
413    #[error("Validation errors:\n{0}")]
414    ValidationErrors(ValidationErrors),
415
416    #[error(transparent)]
417    Codec(#[from] CodecError),
418
419    #[error("time formatting error: {0}")]
420    TimeFormat(#[from] time::error::Format),
421}
422
423pub trait ClientConnExt: Sized {
424    fn success_or_error(self)
425        -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>>;
426}
427impl ClientConnExt for Conn {
428    fn success_or_error(
429        self,
430    ) -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>> {
431        Box::pin(async move {
432            let mut error = match self.await?.success() {
433                Ok(conn) => return Ok(conn),
434                Err(error) => error,
435            };
436
437            let status = error.status();
438            if let Some(Status::BadRequest) = status {
439                let body = error.response_body().read_string().await?;
440                log::trace!("{body}");
441                Err(Error::ValidationErrors(serde_json::from_str(&body)?))
442            } else {
443                let url = error.url().clone();
444                let method = error.method();
445                let body = error.response_body().await?;
446                Err(Error::HttpStatusNotSuccess {
447                    method,
448                    url,
449                    status,
450                    body,
451                })
452            }
453        })
454    }
455}