lighter_rust/client/
api_client.rs1use crate::config::Config;
2use crate::error::{LighterError, Result};
3use reqwest::{Client, Method, Response};
4use serde::de::DeserializeOwned;
5use serde::Serialize;
6use std::time::Duration;
7use tokio::time::sleep;
8use tracing::{debug, error, warn};
9use url::Url;
10
11#[derive(Debug, Clone)]
12pub struct ApiClient {
13 client: Client,
14 config: Config,
15}
16
17impl ApiClient {
18 pub fn new(config: Config) -> Result<Self> {
19 let client = Client::builder()
20 .timeout(Duration::from_secs(config.timeout_secs))
21 .pool_max_idle_per_host(10) .pool_idle_timeout(Duration::from_secs(90)) .tcp_keepalive(Duration::from_secs(60)) .tcp_nodelay(true) .http2_prior_knowledge() .connection_verbose(false)
27 .build()
28 .map_err(|e| LighterError::Http(Box::new(e)))?;
29
30 Ok(Self { client, config })
31 }
32
33 pub async fn get<T>(&self, endpoint: &str) -> Result<T>
34 where
35 T: DeserializeOwned,
36 {
37 self.request(Method::GET, endpoint, None::<()>).await
38 }
39
40 pub async fn post<T, B>(&self, endpoint: &str, body: Option<B>) -> Result<T>
41 where
42 T: DeserializeOwned,
43 B: Serialize + Clone,
44 {
45 self.request(Method::POST, endpoint, body).await
46 }
47
48 pub async fn put<T, B>(&self, endpoint: &str, body: Option<B>) -> Result<T>
49 where
50 T: DeserializeOwned,
51 B: Serialize + Clone,
52 {
53 self.request(Method::PUT, endpoint, body).await
54 }
55
56 pub async fn delete<T>(&self, endpoint: &str) -> Result<T>
57 where
58 T: DeserializeOwned,
59 {
60 self.request(Method::DELETE, endpoint, None::<()>).await
61 }
62
63 async fn request<T, B>(&self, method: Method, endpoint: &str, body: Option<B>) -> Result<T>
64 where
65 T: DeserializeOwned,
66 B: Serialize + Clone,
67 {
68 let url = self.build_url(endpoint)?;
69 let mut retries = 0;
70 let max_retries = self.config.max_retries;
71
72 loop {
73 let mut request_builder = self.client.request(method.clone(), url.clone());
74
75 if let Some(api_key) = &self.config.api_key {
76 request_builder =
77 request_builder.header("Authorization", format!("Bearer {}", api_key));
78 }
79
80 request_builder = request_builder.header("Content-Type", "application/json");
81 request_builder = request_builder.header("User-Agent", "lighter-rust/0.1.0");
82
83 if let Some(ref body) = body {
84 request_builder = request_builder.json(body);
85 }
86
87 debug!("Sending {} request to {}", method, url);
88
89 match request_builder.send().await {
90 Ok(response) => {
91 let status = response.status();
92
93 if status.as_u16() == 429 || (status.is_server_error() && retries < max_retries)
95 {
96 retries += 1;
97 let delay = self.calculate_backoff_delay(retries);
98
99 warn!(
100 "Request failed with status {}. Retrying in {:?} (attempt {}/{})",
101 status, delay, retries, max_retries
102 );
103
104 sleep(delay).await;
105 continue;
106 }
107
108 return self.handle_response(response).await;
109 }
110 Err(e) if retries < max_retries => {
111 retries += 1;
112 let delay = self.calculate_backoff_delay(retries);
113
114 warn!(
115 "Request failed: {}. Retrying in {:?} (attempt {}/{})",
116 e, delay, retries, max_retries
117 );
118
119 sleep(delay).await;
120 continue;
121 }
122 Err(e) => {
123 error!("Request failed after {} retries: {}", max_retries, e);
124 return Err(LighterError::Http(Box::new(e)));
125 }
126 }
127 }
128 }
129
130 fn calculate_backoff_delay(&self, retry_count: u32) -> Duration {
131 let base_delay_ms = 100;
133 let max_delay_ms = 10000; let delay_ms = std::cmp::min(base_delay_ms * 2_u64.pow(retry_count - 1), max_delay_ms);
136
137 let jitter = (delay_ms as f64 * 0.25 * rand::random::<f64>()) as u64;
139 let final_delay = if rand::random::<bool>() {
140 delay_ms + jitter
141 } else {
142 delay_ms.saturating_sub(jitter)
143 };
144
145 Duration::from_millis(final_delay)
146 }
147
148 async fn handle_response<T>(&self, response: Response) -> Result<T>
149 where
150 T: DeserializeOwned,
151 {
152 let status = response.status();
153 let body = response
154 .text()
155 .await
156 .map_err(|e| LighterError::Http(Box::new(e)))?;
157
158 if status.is_success() {
159 serde_json::from_str(&body).map_err(LighterError::Json)
160 } else {
161 match status.as_u16() {
162 429 => Err(LighterError::RateLimit),
163 401 => Err(LighterError::Auth("Unauthorized".to_string())),
164 _ => {
165 let error_message = serde_json::from_str::<serde_json::Value>(&body)
166 .ok()
167 .and_then(|v| v.get("message").and_then(|m| m.as_str().map(String::from)))
168 .unwrap_or_else(|| body);
169
170 Err(LighterError::Api {
171 status: status.as_u16(),
172 message: error_message,
173 })
174 }
175 }
176 }
177 }
178
179 fn build_url(&self, endpoint: &str) -> Result<Url> {
180 let endpoint = endpoint.trim_start_matches('/');
181 self.config
182 .base_url
183 .join(endpoint)
184 .map_err(|e| LighterError::Config(format!("Invalid endpoint URL: {}", e)))
185 }
186}