1use crate::{
2 auth::AuthConfig,
3 error::{ApiErrorResponse, RainyError, Result},
4 models::*,
5 retry::{retry_with_backoff, RetryConfig},
6};
7use eventsource_stream::Eventsource;
8use futures::{Stream, StreamExt};
9use reqwest::{
10 header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT},
11 Client, Response,
12};
13use secrecy::ExposeSecret;
14use serde::Deserialize;
15use std::pin::Pin;
16use std::time::Instant;
17
18#[cfg(feature = "rate-limiting")]
19use governor::{
20 clock::DefaultClock,
21 state::{InMemoryState, NotKeyed},
22 Quota, RateLimiter,
23};
24
25pub struct RainyClient {
50 client: Client,
52 auth_config: AuthConfig,
54 retry_config: RetryConfig,
56
57 #[cfg(feature = "rate-limiting")]
60 rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
61}
62
63impl RainyClient {
64 pub(crate) fn root_url(&self, path: &str) -> String {
65 let normalized = if path.starts_with('/') {
66 path.to_string()
67 } else {
68 format!("/{path}")
69 };
70 format!(
71 "{}{}",
72 self.auth_config.base_url.trim_end_matches('/'),
73 normalized
74 )
75 }
76
77 pub(crate) fn api_v1_url(&self, path: &str) -> String {
78 let normalized = if path.starts_with('/') {
79 path.to_string()
80 } else {
81 format!("/{path}")
82 };
83 format!(
84 "{}/api/v1{}",
85 self.auth_config.base_url.trim_end_matches('/'),
86 normalized
87 )
88 }
89
90 pub fn with_api_key(api_key: impl Into<String>) -> Result<Self> {
103 let auth_config = AuthConfig::new(api_key);
104 Self::with_config(auth_config)
105 }
106
107 pub fn with_config(auth_config: AuthConfig) -> Result<Self> {
119 auth_config.validate()?;
121
122 let mut headers = HeaderMap::new();
124 headers.insert(
125 AUTHORIZATION,
126 HeaderValue::from_str(&format!("Bearer {}", auth_config.api_key.expose_secret()))
127 .map_err(|e| RainyError::Authentication {
128 code: "INVALID_API_KEY".to_string(),
129 message: format!("Invalid API key format: {}", e),
130 retryable: false,
131 })?,
132 );
133 headers.insert(
134 USER_AGENT,
135 HeaderValue::from_str(&auth_config.user_agent).map_err(|e| RainyError::Network {
136 message: format!("Invalid user agent: {}", e),
137 retryable: false,
138 source_error: None,
139 })?,
140 );
141
142 let client = Client::builder()
143 .use_rustls_tls()
144 .min_tls_version(reqwest::tls::Version::TLS_1_2)
145 .https_only(true)
146 .timeout(auth_config.timeout())
147 .default_headers(headers)
148 .build()
149 .map_err(|e| RainyError::Network {
150 message: format!("Failed to create HTTP client: {}", e),
151 retryable: false,
152 source_error: Some(e.to_string()),
153 })?;
154
155 let retry_config = RetryConfig::new(auth_config.max_retries);
156
157 #[cfg(feature = "rate-limiting")]
158 let rate_limiter = Some(RateLimiter::direct(Quota::per_second(
159 std::num::NonZeroU32::new(10).unwrap(),
160 )));
161
162 Ok(Self {
163 client,
164 auth_config,
165 retry_config,
166 #[cfg(feature = "rate-limiting")]
167 rate_limiter,
168 })
169 }
170
171 pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
183 self.retry_config = retry_config;
184 self
185 }
186
187 pub async fn get_available_models(&self) -> Result<AvailableModels> {
193 #[derive(Deserialize)]
194 struct ModelListItem {
195 id: String,
196 }
197 #[derive(Deserialize)]
198 struct ModelsData {
199 data: Vec<ModelListItem>,
200 }
201 #[derive(Deserialize)]
202 struct Envelope {
203 data: ModelsData,
204 }
205
206 let url = self.api_v1_url("/models");
207
208 let operation = || async {
209 let response = self.client.get(&url).send().await?;
210 let envelope: Envelope = self.handle_response(response).await?;
211
212 let mut providers = std::collections::HashMap::<String, Vec<String>>::new();
213 for item in envelope.data.data {
214 let provider = item
215 .id
216 .split_once('/')
217 .map(|(p, _)| p.to_string())
218 .unwrap_or_else(|| "rainy".to_string());
219 providers.entry(provider).or_default().push(item.id);
220 }
221
222 let total_models = providers.values().map(std::vec::Vec::len).sum();
223 let mut active_providers = providers.keys().cloned().collect::<Vec<_>>();
224 active_providers.sort();
225
226 Ok(AvailableModels {
227 providers,
228 total_models,
229 active_providers,
230 })
231 };
232
233 if self.auth_config.enable_retry {
234 retry_with_backoff(&self.retry_config, operation).await
235 } else {
236 operation().await
237 }
238 }
239
240 pub async fn chat_completion(
251 &self,
252 request: ChatCompletionRequest,
253 ) -> Result<(ChatCompletionResponse, RequestMetadata)> {
254 #[cfg(feature = "rate-limiting")]
255 if let Some(ref limiter) = self.rate_limiter {
256 limiter.until_ready().await;
257 }
258
259 let url = self.api_v1_url("/chat/completions");
260 let start_time = Instant::now();
261
262 let operation = || async {
263 let response = self.client.post(&url).json(&request).send().await?;
264
265 let metadata = self.extract_metadata(&response, start_time);
266 let chat_response: ChatCompletionResponse = self.handle_response(response).await?;
267
268 Ok((chat_response, metadata))
269 };
270
271 if self.auth_config.enable_retry {
272 retry_with_backoff(&self.retry_config, operation).await
273 } else {
274 operation().await
275 }
276 }
277
278 pub async fn chat_completion_stream(
288 &self,
289 mut request: ChatCompletionRequest,
290 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
291 request.stream = Some(true);
293
294 #[cfg(feature = "rate-limiting")]
295 if let Some(ref limiter) = self.rate_limiter {
296 limiter.until_ready().await;
297 }
298
299 let url = self.api_v1_url("/chat/completions");
300
301 let operation = || async {
303 let response = self
304 .client
305 .post(&url)
306 .json(&request)
307 .send()
308 .await
309 .map_err(|e| RainyError::Network {
310 message: format!("Failed to send request: {}", e),
311 retryable: true,
312 source_error: Some(e.to_string()),
313 })?;
314
315 self.handle_stream_response(response).await
316 };
317
318 if self.auth_config.enable_retry {
319 retry_with_backoff(&self.retry_config, operation).await
320 } else {
321 operation().await
322 }
323 }
324
325 pub async fn simple_chat(
339 &self,
340 model: impl Into<String>,
341 prompt: impl Into<String>,
342 ) -> Result<String> {
343 let request = ChatCompletionRequest::new(model, vec![ChatMessage::user(prompt)]);
344
345 let (response, _) = self.chat_completion(request).await?;
346
347 Ok(response
348 .choices
349 .into_iter()
350 .next()
351 .map(|choice| choice.message.content)
352 .unwrap_or_default())
353 }
354
355 pub(crate) async fn handle_response<T>(&self, response: Response) -> Result<T>
360 where
361 T: serde::de::DeserializeOwned,
362 {
363 let status = response.status();
364 let headers = response.headers().clone();
365 let request_id = headers
366 .get("x-request-id")
367 .and_then(|v| v.to_str().ok())
368 .map(String::from);
369
370 if status.is_success() {
371 let text = response.text().await?;
372 serde_json::from_str(&text).map_err(|e| RainyError::Serialization {
373 message: format!("Failed to parse response: {}", e),
374 source_error: Some(e.to_string()),
375 })
376 } else {
377 let text = response.text().await.unwrap_or_default();
378
379 if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
381 let error = error_response.error;
382 self.map_api_error(error, status.as_u16(), request_id)
383 } else {
384 Err(RainyError::Api {
386 code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
387 message: if text.is_empty() {
388 format!("HTTP {}", status.as_u16())
389 } else {
390 text
391 },
392 status_code: status.as_u16(),
393 retryable: status.is_server_error(),
394 request_id,
395 })
396 }
397 }
398 }
399
400 pub(crate) async fn handle_stream_response(
402 &self,
403 response: Response,
404 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
405 let status = response.status();
406 let request_id = response
407 .headers()
408 .get("x-request-id")
409 .and_then(|v| v.to_str().ok())
410 .map(String::from);
411
412 if status.is_success() {
413 let stream = response
414 .bytes_stream()
415 .eventsource()
416 .map(move |event| match event {
417 Ok(event) => {
418 if event.data == "[DONE]" {
419 return None;
420 }
421
422 match serde_json::from_str::<ChatCompletionChunk>(&event.data) {
423 Ok(chunk) => Some(Ok(chunk)),
424 Err(e) => Some(Err(RainyError::Serialization {
425 message: format!("Failed to parse stream chunk: {}", e),
426 source_error: Some(e.to_string()),
427 })),
428 }
429 }
430 Err(e) => Some(Err(RainyError::Network {
431 message: format!("Stream error: {}", e),
432 retryable: true,
433 source_error: Some(e.to_string()),
434 })),
435 })
436 .take_while(|x| futures::future::ready(x.is_some()))
437 .map(|x| x.unwrap());
438
439 Ok(Box::pin(stream))
440 } else {
441 let text = response.text().await.unwrap_or_default();
442
443 if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
445 let error = error_response.error;
446 self.map_api_error(error, status.as_u16(), request_id)
447 } else {
448 Err(RainyError::Api {
449 code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
450 message: if text.is_empty() {
451 format!("HTTP {}", status.as_u16())
452 } else {
453 text
454 },
455 status_code: status.as_u16(),
456 retryable: status.is_server_error(),
457 request_id,
458 })
459 }
460 }
461 }
462
463 fn extract_metadata(&self, response: &Response, start_time: Instant) -> RequestMetadata {
467 let headers = response.headers();
468
469 RequestMetadata {
470 response_time: Some(start_time.elapsed().as_millis() as u64),
471 provider: headers
472 .get("x-provider")
473 .and_then(|v| v.to_str().ok())
474 .map(String::from),
475 tokens_used: headers
476 .get("x-tokens-used")
477 .and_then(|v| v.to_str().ok())
478 .and_then(|s| s.parse().ok()),
479 credits_used: headers
480 .get("x-credits-used")
481 .and_then(|v| v.to_str().ok())
482 .and_then(|s| s.parse().ok()),
483 credits_remaining: headers
484 .get("x-credits-remaining")
485 .and_then(|v| v.to_str().ok())
486 .and_then(|s| s.parse().ok()),
487 request_id: headers
488 .get("x-request-id")
489 .and_then(|v| v.to_str().ok())
490 .map(String::from),
491 }
492 }
493
494 fn map_api_error<T>(
498 &self,
499 error: crate::error::ApiErrorDetails,
500 status_code: u16,
501 request_id: Option<String>,
502 ) -> Result<T> {
503 let retryable = error.retryable.unwrap_or(status_code >= 500);
504
505 let rainy_error = match error.code.as_str() {
506 "INVALID_API_KEY" | "EXPIRED_API_KEY" => RainyError::Authentication {
507 code: error.code,
508 message: error.message,
509 retryable: false,
510 },
511 "INSUFFICIENT_CREDITS" => {
512 let (current_credits, estimated_cost, reset_date) =
514 if let Some(details) = error.details {
515 let current = details
516 .get("current_credits")
517 .and_then(|v| v.as_f64())
518 .unwrap_or(0.0);
519 let cost = details
520 .get("estimated_cost")
521 .and_then(|v| v.as_f64())
522 .unwrap_or(0.0);
523 let reset = details
524 .get("reset_date")
525 .and_then(|v| v.as_str())
526 .map(String::from);
527 (current, cost, reset)
528 } else {
529 (0.0, 0.0, None)
530 };
531
532 RainyError::InsufficientCredits {
533 code: error.code,
534 message: error.message,
535 current_credits,
536 estimated_cost,
537 reset_date,
538 }
539 }
540 "RATE_LIMIT_EXCEEDED" => {
541 let retry_after = error
542 .details
543 .as_ref()
544 .and_then(|d| d.get("retry_after"))
545 .and_then(|v| v.as_u64());
546
547 RainyError::RateLimit {
548 code: error.code,
549 message: error.message,
550 retry_after,
551 current_usage: None,
552 }
553 }
554 "INVALID_REQUEST" | "MISSING_REQUIRED_FIELD" | "INVALID_MODEL" => {
555 RainyError::InvalidRequest {
556 code: error.code,
557 message: error.message,
558 details: error.details,
559 }
560 }
561 "PROVIDER_ERROR" | "PROVIDER_UNAVAILABLE" => {
562 let provider = error
563 .details
564 .as_ref()
565 .and_then(|d| d.get("provider"))
566 .and_then(|v| v.as_str())
567 .unwrap_or("unknown")
568 .to_string();
569
570 RainyError::Provider {
571 code: error.code,
572 message: error.message,
573 provider,
574 retryable,
575 }
576 }
577 _ => RainyError::Api {
578 code: error.code,
579 message: error.message,
580 status_code,
581 retryable,
582 request_id: request_id.clone(),
583 },
584 };
585
586 Err(rainy_error)
587 }
588
589 pub fn auth_config(&self) -> &AuthConfig {
591 &self.auth_config
592 }
593
594 pub fn base_url(&self) -> &str {
596 &self.auth_config.base_url
597 }
598
599 pub(crate) fn http_client(&self) -> &Client {
603 &self.client
604 }
605
606 pub async fn list_available_models(&self) -> Result<AvailableModels> {
631 self.get_available_models().await
632 }
633
634 #[cfg(feature = "cowork")]
642 #[deprecated(
643 note = "Cowork endpoints are legacy and not supported by Rainy API v3. Migrate to v3 session/org endpoints."
644 )]
645 pub async fn get_cowork_profile(&self) -> Result<crate::cowork::CoworkProfile> {
646 let url = self.api_v1_url("/cowork/profile");
647
648 let operation = || async {
649 let response = self.client.get(&url).send().await?;
650 self.handle_response(response).await
651 };
652
653 if self.auth_config.enable_retry {
654 retry_with_backoff(&self.retry_config, operation).await
655 } else {
656 operation().await
657 }
658 }
659
660 pub(crate) async fn make_request<T: serde::de::DeserializeOwned>(
666 &self,
667 method: reqwest::Method,
668 endpoint: &str,
669 body: Option<serde_json::Value>,
670 ) -> Result<T> {
671 #[cfg(feature = "rate-limiting")]
672 if let Some(ref limiter) = self.rate_limiter {
673 limiter.until_ready().await;
674 }
675
676 let url = self.api_v1_url(endpoint);
677 let headers = self.auth_config.build_headers()?;
678
679 let mut request = self.client.request(method, &url).headers(headers);
680
681 if let Some(body) = body {
682 request = request.json(&body);
683 }
684
685 let response = request.send().await?;
686 self.handle_response(response).await
687 }
688}
689
690impl std::fmt::Debug for RainyClient {
691 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
692 f.debug_struct("RainyClient")
693 .field("base_url", &self.auth_config.base_url)
694 .field("timeout", &self.auth_config.timeout_seconds)
695 .field("max_retries", &self.retry_config.max_retries)
696 .finish()
697 }
698}