Skip to main content

rust_ynab/ynab/
client.rs

1use crate::ynab::errors::{Error, ErrorResponse};
2use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
3use std::fmt;
4use std::num::NonZeroU32;
5use std::sync::Arc;
6use std::time::Duration;
7
8/// Client is the YNAB API client. Use Client::new() to create one.
9pub struct Client {
10    base_url: reqwest::Url,
11    http_client: reqwest::Client,
12    limiter: Option<Arc<DefaultDirectRateLimiter>>,
13    api_key: String,
14    timeout: Option<Duration>,
15}
16
17impl fmt::Debug for Client {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        f.debug_struct("Client")
20            .field("base_url", &self.base_url)
21            .field("api_key", &"[redacted]")
22            .finish()
23    }
24}
25
26impl Client {
27    /// Creates a new Client with the given Personal Access Token.
28    pub fn new(api_key: impl Into<String>) -> Result<Self, Error> {
29        let api_key = api_key.into();
30        let http_client = Self::build_http_client(&api_key, None)?;
31        Ok(Self {
32            base_url: reqwest::Url::parse("https://api.ynab.com/v1").unwrap(),
33            http_client,
34            limiter: None,
35            api_key,
36            timeout: None,
37        })
38    }
39
40    fn build_http_client(
41        api_key: &str,
42        timeout: Option<Duration>,
43    ) -> Result<reqwest::Client, Error> {
44        let mut headers = reqwest::header::HeaderMap::new();
45        headers.insert(
46            reqwest::header::AUTHORIZATION,
47            format!("Bearer {}", api_key)
48                .parse()
49                .expect("api key must be valid ASCII"),
50        );
51        let mut builder = reqwest::Client::builder().default_headers(headers);
52        if let Some(t) = timeout {
53            builder = builder.timeout(t);
54        }
55        builder.build().map_err(Into::into)
56    }
57
58    /// Sets the request timeout. Rebuilds the underlying HTTP client.
59    pub fn with_timeout(mut self, timeout: Duration) -> Result<Self, Error> {
60        self.http_client = Self::build_http_client(&self.api_key, Some(timeout))?;
61        self.timeout = Some(timeout);
62        Ok(self)
63    }
64
65    /// Overrides the base URL. Primarily useful for testing.
66    pub fn with_base_url(mut self, base_url: impl AsRef<str>) -> Result<Self, Error> {
67        self.base_url = reqwest::Url::parse(base_url.as_ref())?;
68        Ok(self)
69    }
70
71    /// Configures a token bucket rate limiter on the client.
72    /// The YNAB API enforces a rolling window of 200 requests per hour.
73    ///
74    /// `requests_per_hour` is the total allowed requests per hour.
75    /// `burst_volume` optionally allows a number of requests to be made immediately
76    /// before throttling begins. The effective sustained rate becomes
77    /// `requests_per_hour - burst_volume` to account for burst consumption.
78    /// If `None`, no burst is allowed and the full rate is sustained evenly.
79    ///
80    /// Example: `with_rate_limiter(200, Some(10))` allows 10 immediate requests,
81    /// then throttles to 190 requests per hour.
82    pub fn with_rate_limiter(
83        mut self,
84        requests_per_hour: usize,
85        burst_volume: Option<usize>,
86    ) -> Result<Self, Error> {
87        let requests = NonZeroU32::new(requests_per_hour as u32)
88            .ok_or_else(|| Error::InvalidRateLimit("requests_per_hour must be non-zero".into()))?;
89
90        let quota = match burst_volume {
91            None => Quota::per_hour(requests),
92            Some(burst) => {
93                let effective = (requests_per_hour as u32)
94                    .checked_sub(burst as u32)
95                    .ok_or_else(|| {
96                        Error::InvalidRateLimit(
97                            "requests_per_hour must be greater than burst_volume".into(),
98                        )
99                    })?;
100                let effective_rate = NonZeroU32::new(effective).ok_or_else(|| {
101                    Error::InvalidRateLimit(
102                        "requests_per_hour - burst_volume must be non-zero".into(),
103                    )
104                })?;
105                let burst = NonZeroU32::new(burst as u32).ok_or_else(|| {
106                    Error::InvalidRateLimit("burst_volume must be non-zero".into())
107                })?;
108                Quota::per_hour(effective_rate).allow_burst(burst)
109            }
110        };
111
112        self.limiter = Some(Arc::new(RateLimiter::direct(quota)));
113        Ok(self)
114    }
115
116    pub(crate) async fn get<T: serde::de::DeserializeOwned, Q: serde::ser::Serialize + ?Sized>(
117        &self,
118        endpoint: &str,
119        params: Option<&Q>,
120    ) -> Result<T, Error> {
121        if let Some(limiter) = &self.limiter {
122            limiter.until_ready().await;
123        }
124
125        let mut url = self.base_url.clone();
126        url.path_segments_mut()
127            .expect("base URL must be a valid base")
128            .extend(endpoint.split('/'));
129
130        let mut builder = self.http_client.get(url);
131        if let Some(p) = params {
132            builder = builder.query(p);
133        }
134        let res = builder.send().await?;
135        let status = res.status();
136
137        if !status.is_success() {
138            let err_body: ErrorResponse = res.json().await?;
139            return Err(Error::new_api_error(status, err_body.error));
140        }
141
142        res.json().await.map_err(Into::into)
143    }
144
145    pub(crate) async fn post<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
146        &self,
147        endpoint: &str,
148        body: B,
149    ) -> Result<T, Error> {
150        if let Some(limiter) = &self.limiter {
151            limiter.until_ready().await;
152        }
153        let mut url = self.base_url.clone();
154        url.path_segments_mut()
155            .expect("base URL must be a valid base")
156            .extend(endpoint.split('/'));
157
158        let res = self.http_client.post(url).json(&body).send().await?;
159        let status = res.status();
160
161        if !status.is_success() {
162            let err_body: ErrorResponse = res.json().await?;
163            return Err(Error::new_api_error(status, err_body.error));
164        }
165
166        res.json().await.map_err(Into::into)
167    }
168
169    pub(crate) async fn patch<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
170        &self,
171        endpoint: &str,
172        body: B,
173    ) -> Result<T, Error> {
174        if let Some(limiter) = &self.limiter {
175            limiter.until_ready().await;
176        }
177        let mut url = self.base_url.clone();
178        url.path_segments_mut()
179            .expect("base URL must be a valid base")
180            .extend(endpoint.split('/'));
181
182        let res = self.http_client.patch(url).json(&body).send().await?;
183        let status = res.status();
184
185        if !status.is_success() {
186            let err_body: ErrorResponse = res.json().await?;
187            return Err(Error::new_api_error(status, err_body.error));
188        }
189
190        res.json().await.map_err(Into::into)
191    }
192
193    pub(crate) async fn put<T: serde::de::DeserializeOwned, B: serde::ser::Serialize>(
194        &self,
195        endpoint: &str,
196        body: B,
197    ) -> Result<T, Error> {
198        if let Some(limiter) = &self.limiter {
199            limiter.until_ready().await;
200        }
201        let mut url = self.base_url.clone();
202        url.path_segments_mut()
203            .expect("base URL must be a valid base")
204            .extend(endpoint.split('/'));
205
206        let res = self.http_client.put(url).json(&body).send().await?;
207        let status = res.status();
208
209        if !status.is_success() {
210            let err_body: ErrorResponse = res.json().await?;
211            return Err(Error::new_api_error(status, err_body.error));
212        }
213
214        res.json().await.map_err(Into::into)
215    }
216
217    pub(crate) async fn delete<T: serde::de::DeserializeOwned>(
218        &self,
219        endpoint: &str,
220    ) -> Result<T, Error> {
221        if let Some(limiter) = &self.limiter {
222            limiter.until_ready().await;
223        }
224        let mut url = self.base_url.clone();
225        url.path_segments_mut()
226            .expect("base URL must be a valid base")
227            .extend(endpoint.split('/'));
228
229        let res = self.http_client.delete(url).send().await?;
230        let status = res.status();
231
232        if !status.is_success() {
233            let err_body: ErrorResponse = res.json().await?;
234            return Err(Error::new_api_error(status, err_body.error));
235        }
236
237        res.json().await.map_err(Into::into)
238    }
239}