1use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE, USER_AGENT};
2use reqwest::{Client, Method, Response};
3use serde::de::DeserializeOwned;
4use serde::Serialize;
5use std::time::Duration;
6use url::Url;
7
8use crate::constants::{
9 API_BASE_URL, API_PATH_PREFIX, API_SANDBOX_BASE_URL, DEFAULT_TIMEOUT_SECONDS, USER_AGENT as UA,
10};
11use crate::credentials::Credentials;
12use crate::error::{Error, Result};
13use crate::rest::{AccountsApi, ConvertApi, DataApi, FeesApi, FuturesApi, OrdersApi, PaymentMethodsApi, PerpetualsApi, PortfoliosApi, ProductsApi, PublicApi};
14use crate::jwt::generate_jwt;
15use crate::rate_limit::RateLimiter;
16
17#[derive(Debug, Clone)]
19pub struct RestClientBuilder {
20 credentials: Option<Credentials>,
21 sandbox: bool,
22 timeout: Duration,
23 rate_limiting: bool,
24}
25
26impl Default for RestClientBuilder {
27 fn default() -> Self {
28 Self::new()
29 }
30}
31
32impl RestClientBuilder {
33 pub fn new() -> Self {
35 Self {
36 credentials: None,
37 sandbox: false,
38 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECONDS),
39 rate_limiting: false,
40 }
41 }
42
43 pub fn credentials(mut self, credentials: Credentials) -> Self {
47 self.credentials = Some(credentials);
48 self
49 }
50
51 pub fn sandbox(mut self, enabled: bool) -> Self {
55 self.sandbox = enabled;
56 self
57 }
58
59 pub fn timeout(mut self, timeout: Duration) -> Self {
63 self.timeout = timeout;
64 self
65 }
66
67 pub fn rate_limiting(mut self, enabled: bool) -> Self {
72 self.rate_limiting = enabled;
73 self
74 }
75
76 pub fn build(self) -> Result<RestClient> {
78 let base_url = if self.sandbox {
79 API_SANDBOX_BASE_URL
80 } else {
81 API_BASE_URL
82 };
83
84 let http_client = Client::builder()
85 .timeout(self.timeout)
86 .build()
87 .map_err(|e| Error::config(format!("Failed to create HTTP client: {}", e)))?;
88
89 let rate_limiter = if self.rate_limiting {
90 Some(RateLimiter::for_private_rest())
91 } else {
92 None
93 };
94
95 Ok(RestClient {
96 http_client,
97 base_url: base_url.to_string(),
98 credentials: self.credentials,
99 rate_limiter,
100 })
101 }
102}
103
104#[derive(Clone)]
106pub struct RestClient {
107 http_client: Client,
108 base_url: String,
109 credentials: Option<Credentials>,
110 rate_limiter: Option<RateLimiter>,
111}
112
113impl RestClient {
114 pub fn builder() -> RestClientBuilder {
116 RestClientBuilder::new()
117 }
118
119 pub fn has_credentials(&self) -> bool {
121 self.credentials.is_some()
122 }
123
124 pub fn accounts(&self) -> AccountsApi<'_> {
140 AccountsApi::new(self)
141 }
142
143 pub fn products(&self) -> ProductsApi<'_> {
159 ProductsApi::new(self)
160 }
161
162 pub fn public(&self) -> PublicApi<'_> {
179 PublicApi::new(self)
180 }
181
182 pub fn orders(&self) -> OrdersApi<'_> {
198 OrdersApi::new(self)
199 }
200
201 pub fn fees(&self) -> FeesApi<'_> {
218 FeesApi::new(self)
219 }
220
221 pub fn data(&self) -> DataApi<'_> {
238 DataApi::new(self)
239 }
240
241 pub fn payment_methods(&self) -> PaymentMethodsApi<'_> {
260 PaymentMethodsApi::new(self)
261 }
262
263 pub fn portfolios(&self) -> PortfoliosApi<'_> {
282 PortfoliosApi::new(self)
283 }
284
285 pub fn convert(&self) -> ConvertApi<'_> {
302 ConvertApi::new(self)
303 }
304
305 pub fn perpetuals(&self) -> PerpetualsApi<'_> {
321 PerpetualsApi::new(self)
322 }
323
324 pub fn futures(&self) -> FuturesApi<'_> {
340 FuturesApi::new(self)
341 }
342
343 pub fn base_url(&self) -> &str {
345 &self.base_url
346 }
347
348 fn build_url(&self, endpoint: &str) -> Result<Url> {
350 let path = format!("{}{}", API_PATH_PREFIX, endpoint);
351 let url_str = format!("{}{}", self.base_url, path);
352 Url::parse(&url_str).map_err(Error::Url)
353 }
354
355 fn build_auth_headers(&self, method: &str, path: &str) -> Result<HeaderMap> {
357 let mut headers = HeaderMap::new();
358
359 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
360 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
361 headers.insert(
362 USER_AGENT,
363 HeaderValue::from_static(UA),
364 );
365
366 if let Some(ref credentials) = self.credentials {
367 let jwt = generate_jwt(credentials, method, path)?;
368 let auth_value = format!("Bearer {}", jwt);
369 headers.insert(
370 AUTHORIZATION,
371 HeaderValue::from_str(&auth_value)
372 .map_err(|e| Error::request(format!("Invalid auth header: {}", e)))?,
373 );
374 }
375
376 Ok(headers)
377 }
378
379 pub async fn get<T: DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
381 self.request::<(), T>(Method::GET, endpoint, None).await
382 }
383
384 pub async fn get_with_query<Q: Serialize, T: DeserializeOwned>(
386 &self,
387 endpoint: &str,
388 query: &Q,
389 ) -> Result<T> {
390 self.request_with_query::<Q, (), T>(Method::GET, endpoint, Some(query), None)
391 .await
392 }
393
394 pub async fn post<B: Serialize, T: DeserializeOwned>(
396 &self,
397 endpoint: &str,
398 body: &B,
399 ) -> Result<T> {
400 self.request(Method::POST, endpoint, Some(body)).await
401 }
402
403 pub async fn put<B: Serialize, T: DeserializeOwned>(
405 &self,
406 endpoint: &str,
407 body: &B,
408 ) -> Result<T> {
409 self.request(Method::PUT, endpoint, Some(body)).await
410 }
411
412 pub async fn delete<T: DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
414 self.request::<(), T>(Method::DELETE, endpoint, None).await
415 }
416
417 async fn request<B: Serialize, T: DeserializeOwned>(
419 &self,
420 method: Method,
421 endpoint: &str,
422 body: Option<&B>,
423 ) -> Result<T> {
424 self.request_with_query::<(), B, T>(method, endpoint, None, body)
425 .await
426 }
427
428 async fn request_with_query<Q: Serialize, B: Serialize, T: DeserializeOwned>(
430 &self,
431 method: Method,
432 endpoint: &str,
433 query: Option<&Q>,
434 body: Option<&B>,
435 ) -> Result<T> {
436 if let Some(ref limiter) = self.rate_limiter {
438 limiter.acquire().await;
439 }
440
441 let mut url = self.build_url(endpoint)?;
442
443 if let Some(q) = query {
445 let query_string = serde_urlencoded::to_string(q)
446 .map_err(|e| Error::request(format!("Failed to encode query: {}", e)))?;
447 if !query_string.is_empty() {
448 url.set_query(Some(&query_string));
449 }
450 }
451
452 let path = if let Some(q) = url.query() {
454 format!("{}?{}", url.path(), q)
455 } else {
456 url.path().to_string()
457 };
458
459 let headers = self.build_auth_headers(method.as_str(), &path)?;
460
461 let mut request = self.http_client.request(method, url).headers(headers);
462
463 if let Some(b) = body {
464 request = request.json(b);
465 }
466
467 let response = request
468 .send()
469 .await
470 .map_err(Error::Http)?;
471
472 self.handle_response(response).await
473 }
474
475 pub async fn public_get<T: DeserializeOwned>(&self, endpoint: &str) -> Result<T> {
477 self.public_request::<(), T>(Method::GET, endpoint, None)
478 .await
479 }
480
481 pub async fn public_get_with_query<Q: Serialize, T: DeserializeOwned>(
483 &self,
484 endpoint: &str,
485 query: &Q,
486 ) -> Result<T> {
487 self.public_request_with_query::<Q, (), T>(Method::GET, endpoint, Some(query), None)
488 .await
489 }
490
491 async fn public_request<B: Serialize, T: DeserializeOwned>(
493 &self,
494 method: Method,
495 endpoint: &str,
496 body: Option<&B>,
497 ) -> Result<T> {
498 self.public_request_with_query::<(), B, T>(method, endpoint, None, body)
499 .await
500 }
501
502 async fn public_request_with_query<Q: Serialize, B: Serialize, T: DeserializeOwned>(
504 &self,
505 method: Method,
506 endpoint: &str,
507 query: Option<&Q>,
508 body: Option<&B>,
509 ) -> Result<T> {
510 if let Some(ref limiter) = self.rate_limiter {
512 limiter.acquire().await;
513 }
514
515 let mut url = self.build_url(endpoint)?;
516
517 if let Some(q) = query {
518 let query_string = serde_urlencoded::to_string(q)
519 .map_err(|e| Error::request(format!("Failed to encode query: {}", e)))?;
520 if !query_string.is_empty() {
521 url.set_query(Some(&query_string));
522 }
523 }
524
525 let mut headers = HeaderMap::new();
526 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
527 headers.insert(ACCEPT, HeaderValue::from_static("application/json"));
528 headers.insert(USER_AGENT, HeaderValue::from_static(UA));
529
530 let mut request = self.http_client.request(method, url).headers(headers);
531
532 if let Some(b) = body {
533 request = request.json(b);
534 }
535
536 let response = request.send().await.map_err(Error::Http)?;
537
538 self.handle_response(response).await
539 }
540
541 async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T> {
543 let status = response.status();
544
545 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
547 let retry_after = response
548 .headers()
549 .get("retry-after")
550 .and_then(|v| v.to_str().ok())
551 .and_then(|s| s.parse::<u64>().ok())
552 .map(Duration::from_secs);
553
554 return Err(Error::RateLimited { retry_after });
555 }
556
557 let body = response.text().await.map_err(Error::Http)?;
558
559 if !status.is_success() {
561 let message = serde_json::from_str::<serde_json::Value>(&body)
563 .ok()
564 .and_then(|v| {
565 v.get("message")
566 .or_else(|| v.get("error"))
567 .or_else(|| v.get("error_description"))
568 .and_then(|m| m.as_str())
569 .map(String::from)
570 })
571 .unwrap_or_else(|| format!("HTTP {} error", status.as_u16()));
572
573 return Err(Error::api(status.as_u16(), message, Some(body)));
574 }
575
576 serde_json::from_str(&body).map_err(|e| {
578 Error::parse(
579 format!("Failed to parse response: {}", e),
580 Some(body),
581 )
582 })
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_builder_defaults() {
592 let builder = RestClientBuilder::new();
593 assert!(builder.credentials.is_none());
594 assert!(!builder.sandbox);
595 }
596
597 #[test]
598 fn test_builder_sandbox() {
599 let client = RestClient::builder().sandbox(true).build().unwrap();
600 assert_eq!(client.base_url(), API_SANDBOX_BASE_URL);
601 }
602
603 #[test]
604 fn test_builder_production() {
605 let client = RestClient::builder().sandbox(false).build().unwrap();
606 assert_eq!(client.base_url(), API_BASE_URL);
607 }
608
609 #[test]
610 fn test_build_url() {
611 let client = RestClient::builder().build().unwrap();
612 let url = client.build_url("/accounts").unwrap();
613 assert_eq!(
614 url.as_str(),
615 "https://api.coinbase.com/api/v3/brokerage/accounts"
616 );
617 }
618}