divviup_client/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    missing_debug_implementations,
6    nonstandard_style
7)]
8#![warn(clippy::perf, clippy::cargo)]
9#![allow(clippy::cargo_common_metadata)]
10#![allow(clippy::multiple_crate_versions)]
11
12mod account;
13mod aggregator;
14mod api_token;
15mod collector_credentials;
16pub mod dp_strategy;
17mod membership;
18mod protocol;
19mod task;
20mod validation_errors;
21
22pub const CONTENT_TYPE: &str = "application/vnd.divviup+json;version=0.1";
23pub const DEFAULT_URL: &str = "https://api.divviup.org/";
24pub const USER_AGENT: &str = concat!("divviup-client/", env!("CARGO_PKG_VERSION"));
25
26use base64::{engine::general_purpose::STANDARD, Engine};
27use serde::{de::DeserializeOwned, Serialize};
28use serde_json::json;
29use std::{fmt::Display, future::Future, pin::Pin};
30use time::format_description::well_known::Rfc3339;
31use trillium_http::{HeaderName, HeaderValues};
32
33pub use account::Account;
34pub use aggregator::{Aggregator, CollectorAuthenticationToken, NewAggregator, Role};
35pub use api_token::ApiToken;
36pub use collector_credentials::CollectorCredential;
37pub use janus_messages::{
38    codec::{CodecError, Decode, Encode},
39    HpkeConfig, HpkePublicKey,
40};
41pub use membership::Membership;
42pub use num_bigint::BigUint;
43pub use num_rational::Ratio;
44pub use protocol::Protocol;
45pub use task::{Histogram, NewTask, SumVec, Task, Vdaf};
46pub use time::OffsetDateTime;
47pub use trillium_client;
48pub use trillium_client::Client;
49pub use trillium_client::Conn;
50pub use trillium_http::{HeaderValue, Headers, KnownHeaderName, Method, Status};
51pub use url::Url;
52pub use uuid::Uuid;
53pub use validation_errors::ValidationErrors;
54
55#[cfg(feature = "admin")]
56pub use aggregator::NewSharedAggregator;
57
58trait ErrInto<T, E1, E2> {
59    fn err_into(self) -> Result<T, E2>;
60}
61impl<T, E1, E2> ErrInto<T, E1, E2> for Result<T, E1>
62where
63    E2: From<E1>,
64{
65    fn err_into(self) -> Result<T, E2> {
66        self.map_err(Into::into)
67    }
68}
69
70#[derive(Debug, Clone)]
71pub struct DivviupClient(Client);
72
73impl DivviupClient {
74    pub fn new(token: impl Display, http_client: impl Into<Client>) -> Self {
75        Self(
76            http_client
77                .into()
78                .with_default_header(KnownHeaderName::UserAgent, USER_AGENT)
79                .with_default_header(KnownHeaderName::Accept, CONTENT_TYPE)
80                .with_default_header(KnownHeaderName::Authorization, format!("Bearer {token}"))
81                .with_base(DEFAULT_URL),
82        )
83    }
84
85    pub fn with_default_pool(mut self) -> Self {
86        self.0 = self.0.with_default_pool();
87        self
88    }
89
90    pub fn with_header(
91        mut self,
92        name: impl Into<HeaderName<'static>>,
93        value: impl Into<HeaderValues>,
94    ) -> Self {
95        self.insert_header(name, value);
96        self
97    }
98
99    pub fn insert_header(
100        &mut self,
101        name: impl Into<HeaderName<'static>>,
102        value: impl Into<HeaderValues>,
103    ) {
104        self.headers_mut().insert(name, value);
105    }
106
107    pub fn headers(&self) -> &Headers {
108        self.0.default_headers()
109    }
110
111    pub fn headers_mut(&mut self) -> &mut Headers {
112        self.0.default_headers_mut()
113    }
114
115    pub fn with_url(mut self, url: Url) -> Self {
116        self.set_url(url);
117        self
118    }
119
120    pub fn set_url(&mut self, url: Url) {
121        self.0.set_base(url).unwrap();
122    }
123
124    async fn get<T>(&self, path: &str) -> ClientResult<T>
125    where
126        T: DeserializeOwned,
127    {
128        self.0
129            .get(path)
130            .success_or_error()
131            .await?
132            .response_json()
133            .await
134            .err_into()
135    }
136
137    async fn patch<T>(&self, path: &str, body: &impl Serialize) -> ClientResult<T>
138    where
139        T: DeserializeOwned,
140    {
141        self.0
142            .patch(path)
143            .with_json_body(body)?
144            .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE)
145            .success_or_error()
146            .await?
147            .response_json()
148            .await
149            .err_into()
150    }
151
152    async fn post<T>(&self, path: &str, body: Option<&impl Serialize>) -> ClientResult<T>
153    where
154        T: DeserializeOwned,
155    {
156        let mut conn = self.0.post(path);
157
158        if let Some(body) = body {
159            conn = conn
160                .with_json_body(body)?
161                .with_request_header(KnownHeaderName::ContentType, CONTENT_TYPE);
162        }
163
164        conn.success_or_error()
165            .await?
166            .response_json()
167            .await
168            .err_into()
169    }
170
171    async fn delete(&self, path: &str) -> ClientResult {
172        let _ = self.0.delete(path).success_or_error().await?;
173        Ok(())
174    }
175
176    pub async fn accounts(&self) -> ClientResult<Vec<Account>> {
177        self.get("api/accounts").await
178    }
179
180    pub async fn rename_account(&self, account_id: Uuid, new_name: &str) -> ClientResult<Account> {
181        self.patch(
182            &format!("api/accounts/{account_id}"),
183            &json!({ "name": new_name }),
184        )
185        .await
186    }
187
188    pub async fn aggregator(&self, aggregator_id: Uuid) -> ClientResult<Aggregator> {
189        self.get(&format!("api/aggregators/{aggregator_id}")).await
190    }
191
192    pub async fn aggregators(&self, account_id: Uuid) -> ClientResult<Vec<Aggregator>> {
193        self.get(&format!("api/accounts/{account_id}/aggregators"))
194            .await
195    }
196
197    pub async fn create_aggregator(
198        &self,
199        account_id: Uuid,
200        aggregator: NewAggregator,
201    ) -> ClientResult<Aggregator> {
202        self.post(
203            &format!("api/accounts/{account_id}/aggregators"),
204            Some(&aggregator),
205        )
206        .await
207    }
208
209    pub async fn rename_aggregator(
210        &self,
211        aggregator_id: Uuid,
212        new_name: &str,
213    ) -> ClientResult<Aggregator> {
214        self.patch(
215            &format!("api/aggregators/{aggregator_id}"),
216            &json!({ "name": new_name }),
217        )
218        .await
219    }
220
221    pub async fn rotate_aggregator_bearer_token(
222        &self,
223        aggregator_id: Uuid,
224        new_bearer_token: &str,
225    ) -> ClientResult<Aggregator> {
226        self.patch(
227            &format!("api/aggregators/{aggregator_id}"),
228            &json!({ "bearer_token": new_bearer_token }),
229        )
230        .await
231    }
232
233    pub async fn update_aggregator_configuration(
234        &self,
235        aggregator_id: Uuid,
236    ) -> ClientResult<Aggregator> {
237        self.patch(&format!("api/aggregators/{aggregator_id}"), &json!({}))
238            .await
239    }
240
241    pub async fn delete_aggregator(&self, aggregator_id: Uuid) -> ClientResult {
242        self.delete(&format!("api/aggregators/{aggregator_id}"))
243            .await
244    }
245
246    pub async fn memberships(&self, account_id: Uuid) -> ClientResult<Vec<Membership>> {
247        self.get(&format!("api/accounts/{account_id}/memberships"))
248            .await
249    }
250
251    pub async fn delete_membership(&self, membership_id: Uuid) -> ClientResult {
252        self.delete(&format!("api/memberships/{membership_id}"))
253            .await
254    }
255
256    pub async fn create_membership(
257        &self,
258        account_id: Uuid,
259        email: &str,
260    ) -> ClientResult<Membership> {
261        self.post(
262            &format!("api/accounts/{account_id}/memberships"),
263            Some(&json!({ "user_email": email })),
264        )
265        .await
266    }
267
268    pub async fn tasks(&self, account_id: Uuid) -> ClientResult<Vec<Task>> {
269        self.get(&format!("api/accounts/{account_id}/tasks")).await
270    }
271
272    pub async fn task(&self, task_id: &str) -> ClientResult<Task> {
273        self.get(&format!("api/tasks/{task_id}")).await
274    }
275
276    pub async fn create_task(&self, account_id: Uuid, task: NewTask) -> ClientResult<Task> {
277        self.post(&format!("api/accounts/{account_id}/tasks"), Some(&task))
278            .await
279    }
280
281    pub async fn task_collector_auth_tokens(
282        &self,
283        task_id: &str,
284    ) -> ClientResult<Vec<CollectorAuthenticationToken>> {
285        self.get(&format!("api/tasks/{task_id}/collector_auth_tokens"))
286            .await
287    }
288
289    pub async fn rename_task(&self, task_id: &str, new_name: &str) -> ClientResult<Task> {
290        self.patch(&format!("api/tasks/{task_id}"), &json!({"name": new_name}))
291            .await
292    }
293
294    pub async fn set_task_expiration(
295        &self,
296        task_id: &str,
297        expiration: Option<&OffsetDateTime>,
298    ) -> ClientResult<Task> {
299        self.patch(
300            &format!("api/tasks/{task_id}"),
301            &json!({
302                "expiration": expiration.map(|e| e.format(&Rfc3339)).transpose()?
303            }),
304        )
305        .await
306    }
307
308    pub async fn delete_task(&self, task_id: &str) -> ClientResult<()> {
309        self.delete(&format!("api/tasks/{task_id}")).await
310    }
311
312    pub async fn force_delete_task(&self, task_id: &str) -> ClientResult<()> {
313        self.delete(&format!("api/tasks/{task_id}?force=true"))
314            .await
315    }
316
317    pub async fn api_tokens(&self, account_id: Uuid) -> ClientResult<Vec<ApiToken>> {
318        self.get(&format!("api/accounts/{account_id}/api_tokens"))
319            .await
320    }
321
322    pub async fn create_api_token(&self, account_id: Uuid) -> ClientResult<ApiToken> {
323        self.post(
324            &format!("api/accounts/{account_id}/api_tokens"),
325            Option::<&()>::None,
326        )
327        .await
328    }
329
330    pub async fn delete_api_token(&self, api_token_id: Uuid) -> ClientResult {
331        self.delete(&format!("api/api_tokens/{api_token_id}")).await
332    }
333
334    pub async fn collector_credentials(
335        &self,
336        account_id: Uuid,
337    ) -> ClientResult<Vec<CollectorCredential>> {
338        self.get(&format!("api/accounts/{account_id}/collector_credentials"))
339            .await
340    }
341
342    pub async fn rename_collector_credential(
343        &self,
344        collector_credential_id: Uuid,
345        new_name: &str,
346    ) -> ClientResult<CollectorCredential> {
347        self.patch(
348            &format!("api/collector_credentials/{collector_credential_id}"),
349            &json!({"name": new_name}),
350        )
351        .await
352    }
353
354    pub async fn create_collector_credential(
355        &self,
356        account_id: Uuid,
357        hpke_config: &HpkeConfig,
358        name: Option<&str>,
359    ) -> ClientResult<CollectorCredential> {
360        self.post(
361            &format!("api/accounts/{account_id}/collector_credentials"),
362            Some(&json!({
363                "name": name,
364                "hpke_config": STANDARD.encode(hpke_config.get_encoded()?)
365            })),
366        )
367        .await
368    }
369
370    pub async fn delete_collector_credential(&self, collector_credential_id: Uuid) -> ClientResult {
371        self.delete(&format!(
372            "api/collector_credentials/{collector_credential_id}"
373        ))
374        .await
375    }
376
377    pub async fn shared_aggregators(&self) -> ClientResult<Vec<Aggregator>> {
378        self.get("api/aggregators").await
379    }
380}
381
382#[cfg(feature = "admin")]
383impl DivviupClient {
384    pub async fn create_account(&self, name: &str) -> ClientResult<Account> {
385        self.post("api/accounts", Some(&json!({ "name": name })))
386            .await
387    }
388
389    pub async fn create_shared_aggregator(
390        &self,
391        aggregator: NewSharedAggregator,
392    ) -> ClientResult<Aggregator> {
393        self.post("api/aggregators", Some(&aggregator)).await
394    }
395}
396
397pub type ClientResult<T = ()> = Result<T, Error>;
398
399#[derive(thiserror::Error, Debug)]
400pub enum Error {
401    #[error(transparent)]
402    Http(#[from] trillium_http::Error),
403
404    #[error(transparent)]
405    Client(#[from] trillium_client::ClientSerdeError),
406
407    #[error(transparent)]
408    Url(#[from] url::ParseError),
409
410    #[error(transparent)]
411    Json(#[from] serde_json::Error),
412
413    #[error("unexpected http status {method} {url} {status:?}: {body}")]
414    HttpStatusNotSuccess {
415        method: Method,
416        url: Url,
417        status: Option<Status>,
418        body: String,
419    },
420
421    #[error("Validation errors:\n{0}")]
422    ValidationErrors(ValidationErrors),
423
424    #[error(transparent)]
425    Codec(#[from] CodecError),
426
427    #[error("time formatting error: {0}")]
428    TimeFormat(#[from] time::error::Format),
429}
430
431pub trait ClientConnExt: Sized {
432    fn success_or_error(self)
433        -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>>;
434}
435impl ClientConnExt for Conn {
436    fn success_or_error(
437        self,
438    ) -> Pin<Box<dyn Future<Output = ClientResult<Self>> + Send + 'static>> {
439        Box::pin(async move {
440            let mut error = match self.await?.success() {
441                Ok(conn) => return Ok(conn),
442                Err(error) => error,
443            };
444
445            let status = error.status();
446            if let Some(Status::BadRequest) = status {
447                let body = error.response_body().read_string().await?;
448                log::trace!("{body}");
449                Err(Error::ValidationErrors(serde_json::from_str(&body)?))
450            } else {
451                let url = error.url().clone();
452                let method = error.method();
453                let body = error.response_body().await?;
454                Err(Error::HttpStatusNotSuccess {
455                    method,
456                    url,
457                    status,
458                    body,
459                })
460            }
461        })
462    }
463}