1use crate::{
2 auth::AuthConfig,
3 error::{ApiErrorResponse, RainyError, Result},
4 models::*,
5 retry::{retry_with_backoff, RetryConfig},
6};
7use reqwest::{
8 header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT},
9 Client, Response,
10};
11use secrecy::ExposeSecret;
12use std::time::Instant;
13
14#[cfg(feature = "rate-limiting")]
15use governor::{
16 clock::DefaultClock,
17 state::{InMemoryState, NotKeyed},
18 Quota, RateLimiter,
19};
20
21pub struct RainyClient {
46 client: Client,
48 auth_config: AuthConfig,
50 retry_config: RetryConfig,
52
53 #[cfg(feature = "rate-limiting")]
56 rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
57}
58
59impl RainyClient {
60 pub fn with_api_key(api_key: impl Into<String>) -> Result<Self> {
73 let auth_config = AuthConfig::new(api_key);
74 Self::with_config(auth_config)
75 }
76
77 pub fn with_config(auth_config: AuthConfig) -> Result<Self> {
89 auth_config.validate()?;
91
92 let mut headers = HeaderMap::new();
94 headers.insert(
95 AUTHORIZATION,
96 HeaderValue::from_str(&format!("Bearer {}", auth_config.api_key.expose_secret()))
97 .map_err(|e| RainyError::Authentication {
98 code: "INVALID_API_KEY".to_string(),
99 message: format!("Invalid API key format: {}", e),
100 retryable: false,
101 })?,
102 );
103 headers.insert(
104 USER_AGENT,
105 HeaderValue::from_str(&auth_config.user_agent).map_err(|e| RainyError::Network {
106 message: format!("Invalid user agent: {}", e),
107 retryable: false,
108 source_error: None,
109 })?,
110 );
111
112 let client = Client::builder()
113 .use_rustls_tls()
114 .min_tls_version(reqwest::tls::Version::TLS_1_2)
115 .https_only(true)
116 .timeout(auth_config.timeout())
117 .default_headers(headers)
118 .build()
119 .map_err(|e| RainyError::Network {
120 message: format!("Failed to create HTTP client: {}", e),
121 retryable: false,
122 source_error: Some(e.to_string()),
123 })?;
124
125 let retry_config = RetryConfig::new(auth_config.max_retries);
126
127 #[cfg(feature = "rate-limiting")]
128 let rate_limiter = Some(RateLimiter::direct(Quota::per_second(
129 std::num::NonZeroU32::new(10).unwrap(),
130 )));
131
132 Ok(Self {
133 client,
134 auth_config,
135 retry_config,
136 #[cfg(feature = "rate-limiting")]
137 rate_limiter,
138 })
139 }
140
141 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
153 self.retry_config = retry_config;
154 self
155 }
156
157 pub async fn get_available_models(&self) -> Result<AvailableModels> {
163 let url = format!("{}/api/v1/models", self.auth_config.base_url);
164
165 let operation = || async {
166 let response = self.client.get(&url).send().await?;
167 self.handle_response(response).await
168 };
169
170 if self.auth_config.enable_retry {
171 retry_with_backoff(&self.retry_config, operation).await
172 } else {
173 operation().await
174 }
175 }
176
177 pub async fn chat_completion(
188 &self,
189 request: ChatCompletionRequest,
190 ) -> Result<(ChatCompletionResponse, RequestMetadata)> {
191 #[cfg(feature = "rate-limiting")]
192 if let Some(ref limiter) = self.rate_limiter {
193 limiter.until_ready().await;
194 }
195
196 let url = format!("{}/api/v1/chat/completions", self.auth_config.base_url);
197 let start_time = Instant::now();
198
199 let operation = || async {
200 let response = self.client.post(&url).json(&request).send().await?;
201
202 let metadata = self.extract_metadata(&response, start_time);
203 let chat_response: ChatCompletionResponse = self.handle_response(response).await?;
204
205 Ok((chat_response, metadata))
206 };
207
208 if self.auth_config.enable_retry {
209 retry_with_backoff(&self.retry_config, operation).await
210 } else {
211 operation().await
212 }
213 }
214
215 pub async fn simple_chat(
229 &self,
230 model: impl Into<String>,
231 prompt: impl Into<String>,
232 ) -> Result<String> {
233 let request = ChatCompletionRequest::new(model, vec![ChatMessage::user(prompt)]);
234
235 let (response, _) = self.chat_completion(request).await?;
236
237 Ok(response
238 .choices
239 .into_iter()
240 .next()
241 .map(|choice| choice.message.content)
242 .unwrap_or_default())
243 }
244
245 pub(crate) async fn handle_response<T>(&self, response: Response) -> Result<T>
250 where
251 T: serde::de::DeserializeOwned,
252 {
253 let status = response.status();
254 let headers = response.headers().clone();
255 let request_id = headers
256 .get("x-request-id")
257 .and_then(|v| v.to_str().ok())
258 .map(String::from);
259
260 if status.is_success() {
261 let text = response.text().await?;
262 serde_json::from_str(&text).map_err(|e| RainyError::Serialization {
263 message: format!("Failed to parse response: {}", e),
264 source_error: Some(e.to_string()),
265 })
266 } else {
267 let text = response.text().await.unwrap_or_default();
268
269 if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
271 let error = error_response.error;
272 self.map_api_error(error, status.as_u16(), request_id)
273 } else {
274 Err(RainyError::Api {
276 code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
277 message: if text.is_empty() {
278 format!("HTTP {}", status.as_u16())
279 } else {
280 text
281 },
282 status_code: status.as_u16(),
283 retryable: status.is_server_error(),
284 request_id,
285 })
286 }
287 }
288 }
289
290 fn extract_metadata(&self, response: &Response, start_time: Instant) -> RequestMetadata {
294 let headers = response.headers();
295
296 RequestMetadata {
297 response_time: Some(start_time.elapsed().as_millis() as u64),
298 provider: headers
299 .get("x-provider")
300 .and_then(|v| v.to_str().ok())
301 .map(String::from),
302 tokens_used: headers
303 .get("x-tokens-used")
304 .and_then(|v| v.to_str().ok())
305 .and_then(|s| s.parse().ok()),
306 credits_used: headers
307 .get("x-credits-used")
308 .and_then(|v| v.to_str().ok())
309 .and_then(|s| s.parse().ok()),
310 credits_remaining: headers
311 .get("x-credits-remaining")
312 .and_then(|v| v.to_str().ok())
313 .and_then(|s| s.parse().ok()),
314 request_id: headers
315 .get("x-request-id")
316 .and_then(|v| v.to_str().ok())
317 .map(String::from),
318 }
319 }
320
321 fn map_api_error<T>(
325 &self,
326 error: crate::error::ApiErrorDetails,
327 status_code: u16,
328 request_id: Option<String>,
329 ) -> Result<T> {
330 let retryable = error.retryable.unwrap_or(status_code >= 500);
331
332 let rainy_error = match error.code.as_str() {
333 "INVALID_API_KEY" | "EXPIRED_API_KEY" => RainyError::Authentication {
334 code: error.code,
335 message: error.message,
336 retryable: false,
337 },
338 "INSUFFICIENT_CREDITS" => {
339 let (current_credits, estimated_cost, reset_date) =
341 if let Some(details) = error.details {
342 let current = details
343 .get("current_credits")
344 .and_then(|v| v.as_f64())
345 .unwrap_or(0.0);
346 let cost = details
347 .get("estimated_cost")
348 .and_then(|v| v.as_f64())
349 .unwrap_or(0.0);
350 let reset = details
351 .get("reset_date")
352 .and_then(|v| v.as_str())
353 .map(String::from);
354 (current, cost, reset)
355 } else {
356 (0.0, 0.0, None)
357 };
358
359 RainyError::InsufficientCredits {
360 code: error.code,
361 message: error.message,
362 current_credits,
363 estimated_cost,
364 reset_date,
365 }
366 }
367 "RATE_LIMIT_EXCEEDED" => {
368 let retry_after = error
369 .details
370 .as_ref()
371 .and_then(|d| d.get("retry_after"))
372 .and_then(|v| v.as_u64());
373
374 RainyError::RateLimit {
375 code: error.code,
376 message: error.message,
377 retry_after,
378 current_usage: None,
379 }
380 }
381 "INVALID_REQUEST" | "MISSING_REQUIRED_FIELD" | "INVALID_MODEL" => {
382 RainyError::InvalidRequest {
383 code: error.code,
384 message: error.message,
385 details: error.details,
386 }
387 }
388 "PROVIDER_ERROR" | "PROVIDER_UNAVAILABLE" => {
389 let provider = error
390 .details
391 .as_ref()
392 .and_then(|d| d.get("provider"))
393 .and_then(|v| v.as_str())
394 .unwrap_or("unknown")
395 .to_string();
396
397 RainyError::Provider {
398 code: error.code,
399 message: error.message,
400 provider,
401 retryable,
402 }
403 }
404 _ => RainyError::Api {
405 code: error.code,
406 message: error.message,
407 status_code,
408 retryable,
409 request_id: request_id.clone(),
410 },
411 };
412
413 Err(rainy_error)
414 }
415
416 pub fn auth_config(&self) -> &AuthConfig {
418 &self.auth_config
419 }
420
421 pub fn base_url(&self) -> &str {
423 &self.auth_config.base_url
424 }
425
426 pub(crate) fn http_client(&self) -> &Client {
430 &self.client
431 }
432
433 pub async fn list_available_models(&self) -> Result<AvailableModels> {
458 let url = format!("{}/api/v1/models", self.auth_config.base_url);
459
460 let operation = || async {
461 let response = self.client.get(&url).send().await?;
462 self.handle_response(response).await
463 };
464
465 if self.auth_config.enable_retry {
466 retry_with_backoff(&self.retry_config, operation).await
467 } else {
468 operation().await
469 }
470 }
471
472 pub(crate) async fn make_request<T: serde::de::DeserializeOwned>(
478 &self,
479 method: reqwest::Method,
480 endpoint: &str,
481 body: Option<serde_json::Value>,
482 ) -> Result<T> {
483 #[cfg(feature = "rate-limiting")]
484 if let Some(ref limiter) = self.rate_limiter {
485 limiter.until_ready().await;
486 }
487
488 let url = format!("{}/api/v1{}", self.auth_config.base_url, endpoint);
489 let headers = self.auth_config.build_headers()?;
490
491 let mut request = self.client.request(method, &url).headers(headers);
492
493 if let Some(body) = body {
494 request = request.json(&body);
495 }
496
497 let response = request.send().await?;
498 self.handle_response(response).await
499 }
500}
501
502impl std::fmt::Debug for RainyClient {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 f.debug_struct("RainyClient")
505 .field("base_url", &self.auth_config.base_url)
506 .field("timeout", &self.auth_config.timeout_seconds)
507 .field("max_retries", &self.retry_config.max_retries)
508 .finish()
509 }
510}