koios_sdk/client/
mod.rs

1//! Client implementation for the Koios API
2//!
3//! This module provides the main client implementation for interacting with the Koios API.
4//! It includes the core [`Client`] struct, along with the builder pattern for configuration
5//! via [`ClientBuilder`].
6//!
7//! # Examples
8//!
9//! Basic usage with default configuration (Mainnet):
10//!
11//! ```rust,no_run
12//! use koios_sdk::Client;
13//!
14//! #[tokio::main]
15//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
16//!     let client = Client::new()?;
17//!     let tip = client.get_tip().await?;
18//!     println!("Current block: {:?}", tip);
19//!     Ok(())
20//! }
21//! ```
22//!
23//! Using the builder pattern with network selection and custom configuration:
24//!
25//! ```rust,no_run
26//! use koios_sdk::{ClientBuilder, Network};
27//! use std::time::Duration;
28//!
29//! #[tokio::main]
30//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
31//!     let client = ClientBuilder::new()
32//!         .network(Network::Preprod)
33//!         .auth_token("your-jwt-token")
34//!         .timeout(Duration::from_secs(60))
35//!         .build()?;
36//!     
37//!     Ok(())
38//! }
39//! ```
40//!
41//! Custom base URL (overrides network setting):
42//!
43//! ```rust,no_run
44//! use koios_sdk::ClientBuilder;
45//!
46//! #[tokio::main]
47//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
48//!     let client = ClientBuilder::new()
49//!         .base_url("https://custom.api.com")
50//!         .build()?;
51//!     
52//!     Ok(())
53//! }
54//! ```
55
56mod builder;
57mod config;
58mod rate_limit;
59
60pub use builder::ClientBuilder;
61pub use config::{AuthConfig, Config};
62use rate_limit::RateLimiter;
63
64use crate::error::{Error, Result};
65use reqwest::{Client as ReqwestClient, RequestBuilder, StatusCode};
66use serde::{de::DeserializeOwned, Serialize};
67use std::sync::Arc;
68use tokio::sync::RwLock;
69
70/// Client for interacting with the Koios API
71#[derive(Debug, Clone)]
72pub struct Client {
73    http_client: ReqwestClient,
74    base_url: String,
75    auth: Arc<RwLock<Option<AuthConfig>>>,
76    rate_limiter: Arc<RateLimiter>,
77}
78
79impl Client {
80    /// Create a new client with default configuration
81    pub fn new() -> Result<Self> {
82        Self::builder().build()
83    }
84
85    /// Create a new client builder
86    pub fn builder() -> ClientBuilder {
87        ClientBuilder::new()
88    }
89
90    /// Create a new client with custom configuration
91    pub(crate) fn with_config(config: Config) -> Result<Self> {
92        let http_client = ReqwestClient::builder().timeout(config.timeout).build()?;
93
94        let base_url = config.base_url();
95
96        Ok(Self {
97            http_client,
98            base_url,
99            auth: Arc::new(RwLock::new(config.auth)),
100            rate_limiter: Arc::new(RateLimiter::new(config.rate_limit)),
101        })
102    }
103
104    /// Get the base URL of the client
105    pub fn base_url(&self) -> &str {
106        &self.base_url
107    }
108
109    /// Get the authentication token if set (async version)
110    pub async fn get_auth_token(&self) -> Option<String> {
111        let auth = self.auth.read().await;
112        auth.as_ref().and_then(|auth| {
113            if auth.is_valid() {
114                Some(auth.token.clone())
115            } else {
116                None
117            }
118        })
119    }
120
121    /// Get the authentication token if set (for compatibility)
122    pub fn auth_token(&self) -> Option<String> {
123        futures::executor::block_on(self.get_auth_token())
124    }
125
126    /// Set authentication token with optional expiry
127    pub async fn set_auth_token(&self, token: String) {
128        let mut auth = self.auth.write().await;
129        *auth = Some(AuthConfig::new(token));
130    }
131
132    /// Set authentication token with expiry
133    pub async fn set_auth_token_with_expiry(
134        &self,
135        token: String,
136        expiry: chrono::DateTime<chrono::Utc>,
137    ) {
138        let mut auth = self.auth.write().await;
139        *auth = Some(AuthConfig::with_expiry(token, expiry));
140    }
141
142    /// Clear authentication token
143    pub async fn clear_auth_token(&self) {
144        let mut auth = self.auth.write().await;
145        *auth = None;
146    }
147
148    /// Check if client has valid authentication
149    pub async fn has_valid_auth(&self) -> bool {
150        let auth = self.auth.read().await;
151        auth.as_ref().map_or(false, |auth| auth.is_valid())
152    }
153
154    /// Build request with authentication if available
155    async fn build_request(&self, request: RequestBuilder) -> RequestBuilder {
156        let auth = self.auth.read().await;
157        if let Some(auth) = auth.as_ref() {
158            if auth.is_valid() {
159                return request.header("Authorization", format!("Bearer {}", auth.token));
160            }
161        }
162        request
163    }
164
165    /// Make a GET request to the API
166    pub(crate) async fn get<T>(&self, endpoint: &str) -> Result<T>
167    where
168        T: DeserializeOwned,
169    {
170        // Apply rate limiting
171        self.rate_limiter
172            .check()
173            .await
174            .map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
175
176        let url = format!("{}{}", self.base_url, endpoint);
177        let request = self.http_client.get(&url);
178        let request = self.build_request(request).await;
179
180        let response = request.send().await?;
181
182        match response.status() {
183            StatusCode::OK => Ok(response.json().await?),
184            status => {
185                let message = response
186                    .text()
187                    .await
188                    .unwrap_or_else(|_| "Unknown error".to_string());
189                Err(Error::Api {
190                    status: status.as_u16(),
191                    message,
192                })
193            }
194        }
195    }
196
197    /// Make a POST request to the API
198    pub(crate) async fn post<T, B>(&self, endpoint: &str, body: &B) -> Result<T>
199    where
200        T: DeserializeOwned,
201        B: Serialize,
202    {
203        // Apply rate limiting
204        self.rate_limiter
205            .check()
206            .await
207            .map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
208
209        let url = format!("{}{}", self.base_url, endpoint);
210        let request = self.http_client.post(&url).json(body);
211        let request = self.build_request(request).await;
212
213        let response = request.send().await?;
214        let status = response.status();
215        let text = response.text().await?;
216
217        // Parse JSON-RPC error response
218        if !status.is_success() {
219            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
220                if let Some(error) = json.get("error") {
221                    if let Some(message) = error.get("message").and_then(|m| m.as_str()) {
222                        return Err(Error::Api {
223                            status: status.as_u16(),
224                            message: message.to_string(),
225                        });
226                    }
227                }
228            }
229            return Err(Error::Api {
230                status: status.as_u16(),
231                message: text,
232            });
233        }
234
235        // Parse successful response
236        match serde_json::from_str(&text) {
237            Ok(value) => Ok(value),
238            Err(e) => Err(Error::Json(e)),
239        }
240    }
241
242    /// Make a POST request with CBOR data to the API
243    pub(crate) async fn post_cbor<T>(&self, endpoint: &str, data: &[u8]) -> Result<T>
244    where
245        T: DeserializeOwned,
246    {
247        // Apply rate limiting
248        self.rate_limiter
249            .check()
250            .await
251            .map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
252
253        let url = format!("{}{}", self.base_url, endpoint);
254        let request = self
255            .http_client
256            .post(&url)
257            .header("Content-Type", "application/cbor")
258            .body(data.to_vec());
259
260        let request = self.build_request(request).await;
261
262        let response = request.send().await?;
263
264        match response.status() {
265            StatusCode::OK | StatusCode::ACCEPTED => Ok(response.json().await?),
266            status => {
267                let message = response
268                    .text()
269                    .await
270                    .unwrap_or_else(|_| "Unknown error".to_string());
271                Err(Error::Api {
272                    status: status.as_u16(),
273                    message,
274                })
275            }
276        }
277    }
278
279    /// Check if response indicates a rate limit error
280    fn is_rate_limit_error(status: StatusCode, _text: &str) -> Option<u64> {
281        if status == StatusCode::TOO_MANY_REQUESTS {
282            // Try to parse retry-after header or default to 60 seconds
283            return Some(60);
284        }
285        None
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use serde_json::json;
293    use wiremock::matchers::{header, method, path};
294    use wiremock::{Mock, MockServer, ResponseTemplate};
295
296    #[tokio::test]
297    async fn test_get_request() {
298        let mock_server = MockServer::start().await;
299        let client = Client::builder()
300            .base_url(mock_server.uri())
301            .build()
302            .unwrap();
303
304        let mock_response = json!({
305            "data": "test"
306        });
307
308        Mock::given(method("GET"))
309            .and(path("/test"))
310            .respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
311            .mount(&mock_server)
312            .await;
313
314        let response: serde_json::Value = client.get("/test").await.unwrap();
315        assert_eq!(response, mock_response);
316    }
317
318    #[tokio::test]
319    async fn test_post_request() {
320        let mock_server = MockServer::start().await;
321        let client = Client::builder()
322            .base_url(mock_server.uri())
323            .build()
324            .unwrap();
325
326        let request_body = json!({
327            "test": "data"
328        });
329
330        let mock_response = json!({
331            "result": "success"
332        });
333
334        Mock::given(method("POST"))
335            .and(path("/test"))
336            .respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
337            .mount(&mock_server)
338            .await;
339
340        let response: serde_json::Value = client.post("/test", &request_body).await.unwrap();
341        assert_eq!(response, mock_response);
342    }
343
344    #[tokio::test]
345    async fn test_auth_token() {
346        let mock_server = MockServer::start().await;
347        let client = Client::builder()
348            .base_url(mock_server.uri())
349            .build()
350            .unwrap();
351
352        client.set_auth_token("test-token".to_string()).await;
353        assert!(client.has_valid_auth().await);
354        assert_eq!(
355            client.get_auth_token().await,
356            Some("test-token".to_string())
357        );
358
359        let mock_response = json!({
360            "data": "test"
361        });
362
363        Mock::given(method("GET"))
364            .and(path("/test"))
365            .and(header("Authorization", "Bearer test-token"))
366            .respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
367            .mount(&mock_server)
368            .await;
369
370        let response: serde_json::Value = client.get("/test").await.unwrap();
371        assert_eq!(response, mock_response);
372    }
373
374    #[tokio::test]
375    async fn test_error_handling() {
376        let mock_server = MockServer::start().await;
377        let client = Client::builder()
378            .base_url(mock_server.uri())
379            .build()
380            .unwrap();
381
382        Mock::given(method("GET"))
383            .and(path("/test"))
384            .respond_with(ResponseTemplate::new(404).set_body_string("Not Found"))
385            .mount(&mock_server)
386            .await;
387
388        let error = client.get::<serde_json::Value>("/test").await.unwrap_err();
389        match error {
390            Error::Api { status, message } => {
391                assert_eq!(status, 404);
392                assert_eq!(message, "Not Found");
393            }
394            _ => panic!("Expected API error"),
395        }
396    }
397
398    #[tokio::test]
399    async fn test_rate_limit() {
400        let mock_server = MockServer::start().await;
401        let client = Client::builder()
402            .base_url(mock_server.uri())
403            .build()
404            .unwrap();
405
406        Mock::given(method("GET"))
407            .and(path("/test"))
408            .respond_with(ResponseTemplate::new(429).set_body_string("Too Many Requests"))
409            .mount(&mock_server)
410            .await;
411
412        let error = client.get::<serde_json::Value>("/test").await.unwrap_err();
413        match error {
414            Error::Api { status, message } => {
415                assert_eq!(status, 429);
416                assert_eq!(message, "Too Many Requests");
417            }
418            _ => panic!("Expected rate limit error"),
419        }
420    }
421}