orb_billing/
client.rs

1// Copyright Materialize, Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use async_stream::try_stream;
17use futures_core::Stream;
18use reqwest::{Method, RequestBuilder, Url};
19use serde::de::DeserializeOwned;
20use serde::Deserialize;
21
22use crate::config::ListParams;
23use crate::error::ApiError;
24use crate::{ClientBuilder, ClientConfig, Error};
25
26pub mod customers;
27pub mod events;
28pub mod invoices;
29pub mod marketplaces;
30pub mod plans;
31pub mod subscriptions;
32pub mod taxes;
33
34/// An API client for Orb.
35///
36/// The API client is designed to be wrapped in an [`Arc`] and used from
37/// multiple threads simultaneously.
38///
39/// [`Arc`]: std::sync::Arc
40#[derive(Debug)]
41pub struct Client {
42    pub(crate) inner: reqwest::Client,
43    pub(crate) api_key: String,
44    pub(crate) endpoint: Url,
45}
46
47impl Client {
48    /// Creates a new `Client` from its required configuration parameters.
49    pub fn new(config: ClientConfig) -> Client {
50        ClientBuilder::default().build(config)
51    }
52
53    /// Creates a builder for a `Client` that allows for customization of
54    /// optional parameters.
55    pub fn builder() -> ClientBuilder {
56        ClientBuilder::default()
57    }
58
59    fn build_request<P>(&self, method: Method, path: P) -> RequestBuilder
60    where
61        P: IntoIterator,
62        P::Item: AsRef<str>,
63    {
64        let mut url = self.endpoint.clone();
65        url.path_segments_mut()
66            .expect("builder validated URL can be a base")
67            .extend(path);
68        self.inner.request(method, url).bearer_auth(&self.api_key)
69    }
70
71    async fn send_request<T>(&self, req: RequestBuilder) -> Result<T, Error>
72    where
73        T: DeserializeOwned,
74    {
75        #[derive(Deserialize)]
76        struct ErrorResponse {
77            title: String,
78            #[serde(default)]
79            detail: Option<String>,
80            #[serde(default)]
81            validation_errors: Vec<String>,
82        }
83
84        let res = req.send().await?;
85        let status_code = res.status();
86        if status_code.is_success() {
87            Ok(res.json().await?)
88        } else {
89            let res_body = res.text().await?;
90            match serde_json::from_str::<ErrorResponse>(&res_body) {
91                Ok(e) => Err(Error::Api(ApiError {
92                    status_code,
93                    title: e.title,
94                    detail: e.detail,
95                    validation_errors: e.validation_errors,
96                })),
97                Err(e) => {
98                    eprintln!("There's been an API error! {e:?} from {res_body:?}");
99                    Err(Error::Api(ApiError {
100                        status_code,
101                        title: "decoding failure".into(),
102                        detail: Some("unable to decode API response as JSON".into()),
103                        validation_errors: vec![],
104                    }))
105                }
106            }
107        }
108    }
109
110    fn stream_paginated_request<'a, T>(
111        &'a self,
112        params: &ListParams,
113        req: RequestBuilder,
114    ) -> impl Stream<Item = Result<T, Error>> + 'a
115    where
116        T: DeserializeOwned + 'a,
117    {
118        #[derive(Deserialize)]
119        struct Paginated<T> {
120            data: Vec<T>,
121            pagination_metadata: PaginationMetadata,
122        }
123
124        #[derive(Deserialize)]
125        struct PaginationMetadata {
126            next_cursor: Option<String>,
127        }
128
129        let req = req.query(&[("limit", params.page_size)]);
130        try_stream! {
131            let mut cursor = None;
132            loop {
133                let mut current_req = req.try_clone().expect("request is clonable");
134                if let Some(cursor) = cursor {
135                    current_req = current_req.query(&[("cursor", cursor)]);
136                }
137                let res: Paginated<T> = self.send_request(current_req).await?;
138                for datum in res.data {
139                    yield datum;
140                }
141                match res.pagination_metadata.next_cursor {
142                    None => break,
143                    Some(next_cursor) => cursor = Some(next_cursor),
144                }
145            }
146        }
147    }
148}