1use crate::{
2 auth::AuthConfig,
3 error::{ApiErrorResponse, RainyError, Result},
4 models::*,
5 retry::{retry_with_backoff, RetryConfig},
6};
7use eventsource_stream::Eventsource;
8use futures::{Stream, StreamExt};
9use reqwest::{
10 header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT},
11 Client, Response,
12};
13use secrecy::ExposeSecret;
14use std::pin::Pin;
15use std::time::Instant;
16
17#[cfg(feature = "rate-limiting")]
18use governor::{
19 clock::DefaultClock,
20 state::{InMemoryState, NotKeyed},
21 Quota, RateLimiter,
22};
23
24pub struct RainyClient {
49 client: Client,
51 auth_config: AuthConfig,
53 retry_config: RetryConfig,
55
56 #[cfg(feature = "rate-limiting")]
59 rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
60}
61
62impl RainyClient {
63 pub fn with_api_key(api_key: impl Into<String>) -> Result<Self> {
76 let auth_config = AuthConfig::new(api_key);
77 Self::with_config(auth_config)
78 }
79
80 pub fn with_config(auth_config: AuthConfig) -> Result<Self> {
92 auth_config.validate()?;
94
95 let mut headers = HeaderMap::new();
97 headers.insert(
98 AUTHORIZATION,
99 HeaderValue::from_str(&format!("Bearer {}", auth_config.api_key.expose_secret()))
100 .map_err(|e| RainyError::Authentication {
101 code: "INVALID_API_KEY".to_string(),
102 message: format!("Invalid API key format: {}", e),
103 retryable: false,
104 })?,
105 );
106 headers.insert(
107 USER_AGENT,
108 HeaderValue::from_str(&auth_config.user_agent).map_err(|e| RainyError::Network {
109 message: format!("Invalid user agent: {}", e),
110 retryable: false,
111 source_error: None,
112 })?,
113 );
114
115 let client = Client::builder()
116 .use_rustls_tls()
117 .min_tls_version(reqwest::tls::Version::TLS_1_2)
118 .https_only(true)
119 .timeout(auth_config.timeout())
120 .default_headers(headers)
121 .build()
122 .map_err(|e| RainyError::Network {
123 message: format!("Failed to create HTTP client: {}", e),
124 retryable: false,
125 source_error: Some(e.to_string()),
126 })?;
127
128 let retry_config = RetryConfig::new(auth_config.max_retries);
129
130 #[cfg(feature = "rate-limiting")]
131 let rate_limiter = Some(RateLimiter::direct(Quota::per_second(
132 std::num::NonZeroU32::new(10).unwrap(),
133 )));
134
135 Ok(Self {
136 client,
137 auth_config,
138 retry_config,
139 #[cfg(feature = "rate-limiting")]
140 rate_limiter,
141 })
142 }
143
144 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
156 self.retry_config = retry_config;
157 self
158 }
159
160 pub async fn get_available_models(&self) -> Result<AvailableModels> {
166 let url = format!("{}/api/v1/models", self.auth_config.base_url);
167
168 let operation = || async {
169 let response = self.client.get(&url).send().await?;
170 self.handle_response(response).await
171 };
172
173 if self.auth_config.enable_retry {
174 retry_with_backoff(&self.retry_config, operation).await
175 } else {
176 operation().await
177 }
178 }
179
180 pub async fn chat_completion(
191 &self,
192 request: ChatCompletionRequest,
193 ) -> Result<(ChatCompletionResponse, RequestMetadata)> {
194 #[cfg(feature = "rate-limiting")]
195 if let Some(ref limiter) = self.rate_limiter {
196 limiter.until_ready().await;
197 }
198
199 let url = format!("{}/api/v1/chat/completions", self.auth_config.base_url);
200 let start_time = Instant::now();
201
202 let operation = || async {
203 let response = self.client.post(&url).json(&request).send().await?;
204
205 let metadata = self.extract_metadata(&response, start_time);
206 let chat_response: ChatCompletionResponse = self.handle_response(response).await?;
207
208 Ok((chat_response, metadata))
209 };
210
211 if self.auth_config.enable_retry {
212 retry_with_backoff(&self.retry_config, operation).await
213 } else {
214 operation().await
215 }
216 }
217
218 pub async fn chat_completion_stream(
228 &self,
229 mut request: ChatCompletionRequest,
230 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
231 request.stream = Some(true);
233
234 #[cfg(feature = "rate-limiting")]
235 if let Some(ref limiter) = self.rate_limiter {
236 limiter.until_ready().await;
237 }
238
239 let url = format!("{}/api/v1/chat/completions", self.auth_config.base_url);
240
241 let operation = || async {
243 let response = self
244 .client
245 .post(&url)
246 .json(&request)
247 .send()
248 .await
249 .map_err(|e| RainyError::Network {
250 message: format!("Failed to send request: {}", e),
251 retryable: true,
252 source_error: Some(e.to_string()),
253 })?;
254
255 self.handle_stream_response(response).await
256 };
257
258 if self.auth_config.enable_retry {
259 retry_with_backoff(&self.retry_config, operation).await
260 } else {
261 operation().await
262 }
263 }
264
265 pub async fn simple_chat(
279 &self,
280 model: impl Into<String>,
281 prompt: impl Into<String>,
282 ) -> Result<String> {
283 let request = ChatCompletionRequest::new(model, vec![ChatMessage::user(prompt)]);
284
285 let (response, _) = self.chat_completion(request).await?;
286
287 Ok(response
288 .choices
289 .into_iter()
290 .next()
291 .map(|choice| choice.message.content)
292 .unwrap_or_default())
293 }
294
295 pub(crate) async fn handle_response<T>(&self, response: Response) -> Result<T>
300 where
301 T: serde::de::DeserializeOwned,
302 {
303 let status = response.status();
304 let headers = response.headers().clone();
305 let request_id = headers
306 .get("x-request-id")
307 .and_then(|v| v.to_str().ok())
308 .map(String::from);
309
310 if status.is_success() {
311 let text = response.text().await?;
312 serde_json::from_str(&text).map_err(|e| RainyError::Serialization {
313 message: format!("Failed to parse response: {}", e),
314 source_error: Some(e.to_string()),
315 })
316 } else {
317 let text = response.text().await.unwrap_or_default();
318
319 if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
321 let error = error_response.error;
322 self.map_api_error(error, status.as_u16(), request_id)
323 } else {
324 Err(RainyError::Api {
326 code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
327 message: if text.is_empty() {
328 format!("HTTP {}", status.as_u16())
329 } else {
330 text
331 },
332 status_code: status.as_u16(),
333 retryable: status.is_server_error(),
334 request_id,
335 })
336 }
337 }
338 }
339
340 pub(crate) async fn handle_stream_response(
342 &self,
343 response: Response,
344 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
345 let status = response.status();
346 let request_id = response
347 .headers()
348 .get("x-request-id")
349 .and_then(|v| v.to_str().ok())
350 .map(String::from);
351
352 if status.is_success() {
353 let stream = response
354 .bytes_stream()
355 .eventsource()
356 .map(move |event| match event {
357 Ok(event) => {
358 if event.data == "[DONE]" {
359 return None;
360 }
361
362 match serde_json::from_str::<ChatCompletionChunk>(&event.data) {
363 Ok(chunk) => Some(Ok(chunk)),
364 Err(e) => Some(Err(RainyError::Serialization {
365 message: format!("Failed to parse stream chunk: {}", e),
366 source_error: Some(e.to_string()),
367 })),
368 }
369 }
370 Err(e) => Some(Err(RainyError::Network {
371 message: format!("Stream error: {}", e),
372 retryable: true,
373 source_error: Some(e.to_string()),
374 })),
375 })
376 .take_while(|x| futures::future::ready(x.is_some()))
377 .map(|x| x.unwrap());
378
379 Ok(Box::pin(stream))
380 } else {
381 let text = response.text().await.unwrap_or_default();
382
383 if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
385 let error = error_response.error;
386 self.map_api_error(error, status.as_u16(), request_id)
387 } else {
388 Err(RainyError::Api {
389 code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
390 message: if text.is_empty() {
391 format!("HTTP {}", status.as_u16())
392 } else {
393 text
394 },
395 status_code: status.as_u16(),
396 retryable: status.is_server_error(),
397 request_id,
398 })
399 }
400 }
401 }
402
403 fn extract_metadata(&self, response: &Response, start_time: Instant) -> RequestMetadata {
407 let headers = response.headers();
408
409 RequestMetadata {
410 response_time: Some(start_time.elapsed().as_millis() as u64),
411 provider: headers
412 .get("x-provider")
413 .and_then(|v| v.to_str().ok())
414 .map(String::from),
415 tokens_used: headers
416 .get("x-tokens-used")
417 .and_then(|v| v.to_str().ok())
418 .and_then(|s| s.parse().ok()),
419 credits_used: headers
420 .get("x-credits-used")
421 .and_then(|v| v.to_str().ok())
422 .and_then(|s| s.parse().ok()),
423 credits_remaining: headers
424 .get("x-credits-remaining")
425 .and_then(|v| v.to_str().ok())
426 .and_then(|s| s.parse().ok()),
427 request_id: headers
428 .get("x-request-id")
429 .and_then(|v| v.to_str().ok())
430 .map(String::from),
431 }
432 }
433
434 fn map_api_error<T>(
438 &self,
439 error: crate::error::ApiErrorDetails,
440 status_code: u16,
441 request_id: Option<String>,
442 ) -> Result<T> {
443 let retryable = error.retryable.unwrap_or(status_code >= 500);
444
445 let rainy_error = match error.code.as_str() {
446 "INVALID_API_KEY" | "EXPIRED_API_KEY" => RainyError::Authentication {
447 code: error.code,
448 message: error.message,
449 retryable: false,
450 },
451 "INSUFFICIENT_CREDITS" => {
452 let (current_credits, estimated_cost, reset_date) =
454 if let Some(details) = error.details {
455 let current = details
456 .get("current_credits")
457 .and_then(|v| v.as_f64())
458 .unwrap_or(0.0);
459 let cost = details
460 .get("estimated_cost")
461 .and_then(|v| v.as_f64())
462 .unwrap_or(0.0);
463 let reset = details
464 .get("reset_date")
465 .and_then(|v| v.as_str())
466 .map(String::from);
467 (current, cost, reset)
468 } else {
469 (0.0, 0.0, None)
470 };
471
472 RainyError::InsufficientCredits {
473 code: error.code,
474 message: error.message,
475 current_credits,
476 estimated_cost,
477 reset_date,
478 }
479 }
480 "RATE_LIMIT_EXCEEDED" => {
481 let retry_after = error
482 .details
483 .as_ref()
484 .and_then(|d| d.get("retry_after"))
485 .and_then(|v| v.as_u64());
486
487 RainyError::RateLimit {
488 code: error.code,
489 message: error.message,
490 retry_after,
491 current_usage: None,
492 }
493 }
494 "INVALID_REQUEST" | "MISSING_REQUIRED_FIELD" | "INVALID_MODEL" => {
495 RainyError::InvalidRequest {
496 code: error.code,
497 message: error.message,
498 details: error.details,
499 }
500 }
501 "PROVIDER_ERROR" | "PROVIDER_UNAVAILABLE" => {
502 let provider = error
503 .details
504 .as_ref()
505 .and_then(|d| d.get("provider"))
506 .and_then(|v| v.as_str())
507 .unwrap_or("unknown")
508 .to_string();
509
510 RainyError::Provider {
511 code: error.code,
512 message: error.message,
513 provider,
514 retryable,
515 }
516 }
517 _ => RainyError::Api {
518 code: error.code,
519 message: error.message,
520 status_code,
521 retryable,
522 request_id: request_id.clone(),
523 },
524 };
525
526 Err(rainy_error)
527 }
528
529 pub fn auth_config(&self) -> &AuthConfig {
531 &self.auth_config
532 }
533
534 pub fn base_url(&self) -> &str {
536 &self.auth_config.base_url
537 }
538
539 pub(crate) fn http_client(&self) -> &Client {
543 &self.client
544 }
545
546 pub async fn list_available_models(&self) -> Result<AvailableModels> {
571 let url = format!("{}/api/v1/models", self.auth_config.base_url);
572
573 let operation = || async {
574 let response = self.client.get(&url).send().await?;
575 self.handle_response(response).await
576 };
577
578 if self.auth_config.enable_retry {
579 retry_with_backoff(&self.retry_config, operation).await
580 } else {
581 operation().await
582 }
583 }
584
585 pub async fn get_cowork_profile(&self) -> Result<crate::cowork::CoworkProfile> {
593 let url = format!("{}/api/v1/cowork/profile", self.auth_config.base_url);
594
595 let operation = || async {
596 let response = self.client.get(&url).send().await?;
597 self.handle_response(response).await
598 };
599
600 if self.auth_config.enable_retry {
601 retry_with_backoff(&self.retry_config, operation).await
602 } else {
603 operation().await
604 }
605 }
606
607 pub(crate) async fn make_request<T: serde::de::DeserializeOwned>(
613 &self,
614 method: reqwest::Method,
615 endpoint: &str,
616 body: Option<serde_json::Value>,
617 ) -> Result<T> {
618 #[cfg(feature = "rate-limiting")]
619 if let Some(ref limiter) = self.rate_limiter {
620 limiter.until_ready().await;
621 }
622
623 let url = format!("{}/api/v1{}", self.auth_config.base_url, endpoint);
624 let headers = self.auth_config.build_headers()?;
625
626 let mut request = self.client.request(method, &url).headers(headers);
627
628 if let Some(body) = body {
629 request = request.json(&body);
630 }
631
632 let response = request.send().await?;
633 self.handle_response(response).await
634 }
635}
636
637impl std::fmt::Debug for RainyClient {
638 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639 f.debug_struct("RainyClient")
640 .field("base_url", &self.auth_config.base_url)
641 .field("timeout", &self.auth_config.timeout_seconds)
642 .field("max_retries", &self.retry_config.max_retries)
643 .finish()
644 }
645}