Skip to main content

rainy_sdk/
client.rs

1use crate::{
2    auth::AuthConfig,
3    error::{ApiErrorResponse, RainyError, Result},
4    models::*,
5    retry::{retry_with_backoff, RetryConfig},
6};
7use eventsource_stream::Eventsource;
8use futures::{Stream, StreamExt};
9use reqwest::{
10    header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT},
11    Client, Response,
12};
13use secrecy::ExposeSecret;
14use std::pin::Pin;
15use std::time::Instant;
16
17#[cfg(feature = "rate-limiting")]
18use governor::{
19    clock::DefaultClock,
20    state::{InMemoryState, NotKeyed},
21    Quota, RateLimiter,
22};
23
24/// The main client for interacting with the Rainy API.
25///
26/// `RainyClient` provides a convenient and high-level interface for making requests
27/// to the various endpoints of the Rainy API. It handles authentication, rate limiting,
28/// and retries automatically.
29///
30/// # Examples
31///
32/// ```rust,no_run
33/// use rainy_sdk::{RainyClient, Result};
34///
35/// #[tokio::main]
36/// async fn main() -> Result<()> {
37///     // Create a client using an API key from an environment variable
38///     let api_key = std::env::var("RAINY_API_KEY").expect("RAINY_API_KEY not set");
39///     let client = RainyClient::with_api_key(api_key)?;
40///
41///     // Use the client to make API calls
42///     let models = client.get_available_models().await?;
43///     println!("Available models: {:?}", models);
44///
45///     Ok(())
46/// }
47/// ```
48pub struct RainyClient {
49    /// The underlying `reqwest::Client` used for making HTTP requests.
50    client: Client,
51    /// The authentication configuration for the client.
52    auth_config: AuthConfig,
53    /// The retry configuration for handling failed requests.
54    retry_config: RetryConfig,
55
56    /// An optional rate limiter to control the request frequency.
57    /// This is only available when the `rate-limiting` feature is enabled.
58    #[cfg(feature = "rate-limiting")]
59    rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
60}
61
62impl RainyClient {
63    /// Creates a new `RainyClient` with the given API key.
64    ///
65    /// This is the simplest way to create a client. It uses default settings for the base URL,
66    /// timeout, and retries.
67    ///
68    /// # Arguments
69    ///
70    /// * `api_key` - Your Rainy API key.
71    ///
72    /// # Returns
73    ///
74    /// A `Result` containing the new `RainyClient` or a `RainyError` if initialization fails.
75    pub fn with_api_key(api_key: impl Into<String>) -> Result<Self> {
76        let auth_config = AuthConfig::new(api_key);
77        Self::with_config(auth_config)
78    }
79
80    /// Creates a new `RainyClient` with a custom `AuthConfig`.
81    ///
82    /// This allows for more advanced configuration, such as setting a custom base URL or timeout.
83    ///
84    /// # Arguments
85    ///
86    /// * `auth_config` - The authentication configuration to use.
87    ///
88    /// # Returns
89    ///
90    /// A `Result` containing the new `RainyClient` or a `RainyError` if initialization fails.
91    pub fn with_config(auth_config: AuthConfig) -> Result<Self> {
92        // Validate configuration
93        auth_config.validate()?;
94
95        // Build HTTP client
96        let mut headers = HeaderMap::new();
97        headers.insert(
98            AUTHORIZATION,
99            HeaderValue::from_str(&format!("Bearer {}", auth_config.api_key.expose_secret()))
100                .map_err(|e| RainyError::Authentication {
101                    code: "INVALID_API_KEY".to_string(),
102                    message: format!("Invalid API key format: {}", e),
103                    retryable: false,
104                })?,
105        );
106        headers.insert(
107            USER_AGENT,
108            HeaderValue::from_str(&auth_config.user_agent).map_err(|e| RainyError::Network {
109                message: format!("Invalid user agent: {}", e),
110                retryable: false,
111                source_error: None,
112            })?,
113        );
114
115        let client = Client::builder()
116            .use_rustls_tls()
117            .min_tls_version(reqwest::tls::Version::TLS_1_2)
118            .https_only(true)
119            .timeout(auth_config.timeout())
120            .default_headers(headers)
121            .build()
122            .map_err(|e| RainyError::Network {
123                message: format!("Failed to create HTTP client: {}", e),
124                retryable: false,
125                source_error: Some(e.to_string()),
126            })?;
127
128        let retry_config = RetryConfig::new(auth_config.max_retries);
129
130        #[cfg(feature = "rate-limiting")]
131        let rate_limiter = Some(RateLimiter::direct(Quota::per_second(
132            std::num::NonZeroU32::new(10).unwrap(),
133        )));
134
135        Ok(Self {
136            client,
137            auth_config,
138            retry_config,
139            #[cfg(feature = "rate-limiting")]
140            rate_limiter,
141        })
142    }
143
144    /// Sets a custom retry configuration for the client.
145    ///
146    /// This allows you to override the default retry behavior.
147    ///
148    /// # Arguments
149    ///
150    /// * `retry_config` - The new retry configuration.
151    ///
152    /// # Returns
153    ///
154    /// The `RainyClient` instance with the updated retry configuration.
155    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
156        self.retry_config = retry_config;
157        self
158    }
159
160    /// Retrieves the list of available models and providers from the API.
161    ///
162    /// # Returns
163    ///
164    /// A `Result` containing an `AvailableModels` struct on success, or a `RainyError` on failure.
165    pub async fn get_available_models(&self) -> Result<AvailableModels> {
166        let url = format!("{}/api/v1/models", self.auth_config.base_url);
167
168        let operation = || async {
169            let response = self.client.get(&url).send().await?;
170            self.handle_response(response).await
171        };
172
173        if self.auth_config.enable_retry {
174            retry_with_backoff(&self.retry_config, operation).await
175        } else {
176            operation().await
177        }
178    }
179
180    /// Creates a chat completion based on the provided request.
181    ///
182    /// # Arguments
183    ///
184    /// * `request` - A `ChatCompletionRequest` containing the model, messages, and other parameters.
185    ///
186    /// # Returns
187    ///
188    /// A `Result` containing a tuple of `(ChatCompletionResponse, RequestMetadata)` on success,
189    /// or a `RainyError` on failure.
190    pub async fn chat_completion(
191        &self,
192        request: ChatCompletionRequest,
193    ) -> Result<(ChatCompletionResponse, RequestMetadata)> {
194        #[cfg(feature = "rate-limiting")]
195        if let Some(ref limiter) = self.rate_limiter {
196            limiter.until_ready().await;
197        }
198
199        let url = format!("{}/api/v1/chat/completions", self.auth_config.base_url);
200        let start_time = Instant::now();
201
202        let operation = || async {
203            let response = self.client.post(&url).json(&request).send().await?;
204
205            let metadata = self.extract_metadata(&response, start_time);
206            let chat_response: ChatCompletionResponse = self.handle_response(response).await?;
207
208            Ok((chat_response, metadata))
209        };
210
211        if self.auth_config.enable_retry {
212            retry_with_backoff(&self.retry_config, operation).await
213        } else {
214            operation().await
215        }
216    }
217
218    /// Creates a streaming chat completion based on the provided request.
219    ///
220    /// # Arguments
221    ///
222    /// * `request` - A `ChatCompletionRequest` containing the model, messages, and other parameters.
223    ///
224    /// # Returns
225    ///
226    /// A `Result` containing a stream of `ChatCompletionChunk`s on success, or a `RainyError` on failure.
227    pub async fn chat_completion_stream(
228        &self,
229        mut request: ChatCompletionRequest,
230    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
231        // Ensure stream is set to true
232        request.stream = Some(true);
233
234        #[cfg(feature = "rate-limiting")]
235        if let Some(ref limiter) = self.rate_limiter {
236            limiter.until_ready().await;
237        }
238
239        let url = format!("{}/api/v1/chat/completions", self.auth_config.base_url);
240
241        // Note: Retries are more complex with streams, so we only retry the initial connection
242        let operation = || async {
243            let response = self
244                .client
245                .post(&url)
246                .json(&request)
247                .send()
248                .await
249                .map_err(|e| RainyError::Network {
250                    message: format!("Failed to send request: {}", e),
251                    retryable: true,
252                    source_error: Some(e.to_string()),
253                })?;
254
255            self.handle_stream_response(response).await
256        };
257
258        if self.auth_config.enable_retry {
259            retry_with_backoff(&self.retry_config, operation).await
260        } else {
261            operation().await
262        }
263    }
264
265    /// Creates a simple chat completion with a single user prompt.
266    ///
267    /// This is a convenience method for simple use cases where you only need to send a single
268    /// prompt to a model and get a text response.
269    ///
270    /// # Arguments
271    ///
272    /// * `model` - The name of the model to use for the completion.
273    /// * `prompt` - The user's prompt.
274    ///
275    /// # Returns
276    ///
277    /// A `Result` containing the `String` response from the model, or a `RainyError` on failure.
278    pub async fn simple_chat(
279        &self,
280        model: impl Into<String>,
281        prompt: impl Into<String>,
282    ) -> Result<String> {
283        let request = ChatCompletionRequest::new(model, vec![ChatMessage::user(prompt)]);
284
285        let (response, _) = self.chat_completion(request).await?;
286
287        Ok(response
288            .choices
289            .into_iter()
290            .next()
291            .map(|choice| choice.message.content)
292            .unwrap_or_default())
293    }
294
295    /// Handles the HTTP response, deserializing the body into a given type `T` on success,
296    /// or mapping the error to a `RainyError` on failure.
297    ///
298    /// This is an internal method used by the various endpoint functions.
299    pub(crate) async fn handle_response<T>(&self, response: Response) -> Result<T>
300    where
301        T: serde::de::DeserializeOwned,
302    {
303        let status = response.status();
304        let headers = response.headers().clone();
305        let request_id = headers
306            .get("x-request-id")
307            .and_then(|v| v.to_str().ok())
308            .map(String::from);
309
310        if status.is_success() {
311            let text = response.text().await?;
312            serde_json::from_str(&text).map_err(|e| RainyError::Serialization {
313                message: format!("Failed to parse response: {}", e),
314                source_error: Some(e.to_string()),
315            })
316        } else {
317            let text = response.text().await.unwrap_or_default();
318
319            // Try to parse structured error response
320            if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
321                let error = error_response.error;
322                self.map_api_error(error, status.as_u16(), request_id)
323            } else {
324                // Fallback to generic error
325                Err(RainyError::Api {
326                    code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
327                    message: if text.is_empty() {
328                        format!("HTTP {}", status.as_u16())
329                    } else {
330                        text
331                    },
332                    status_code: status.as_u16(),
333                    retryable: status.is_server_error(),
334                    request_id,
335                })
336            }
337        }
338    }
339
340    /// Handles the HTTP response for streaming requests.
341    pub(crate) async fn handle_stream_response(
342        &self,
343        response: Response,
344    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
345        let status = response.status();
346        let request_id = response
347            .headers()
348            .get("x-request-id")
349            .and_then(|v| v.to_str().ok())
350            .map(String::from);
351
352        if status.is_success() {
353            let stream = response
354                .bytes_stream()
355                .eventsource()
356                .map(move |event| match event {
357                    Ok(event) => {
358                        if event.data == "[DONE]" {
359                            return None;
360                        }
361
362                        match serde_json::from_str::<ChatCompletionChunk>(&event.data) {
363                            Ok(chunk) => Some(Ok(chunk)),
364                            Err(e) => Some(Err(RainyError::Serialization {
365                                message: format!("Failed to parse stream chunk: {}", e),
366                                source_error: Some(e.to_string()),
367                            })),
368                        }
369                    }
370                    Err(e) => Some(Err(RainyError::Network {
371                        message: format!("Stream error: {}", e),
372                        retryable: true,
373                        source_error: Some(e.to_string()),
374                    })),
375                })
376                .take_while(|x| futures::future::ready(x.is_some()))
377                .map(|x| x.unwrap());
378
379            Ok(Box::pin(stream))
380        } else {
381            let text = response.text().await.unwrap_or_default();
382
383            // Try to parse structured error response
384            if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
385                let error = error_response.error;
386                self.map_api_error(error, status.as_u16(), request_id)
387            } else {
388                Err(RainyError::Api {
389                    code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
390                    message: if text.is_empty() {
391                        format!("HTTP {}", status.as_u16())
392                    } else {
393                        text
394                    },
395                    status_code: status.as_u16(),
396                    retryable: status.is_server_error(),
397                    request_id,
398                })
399            }
400        }
401    }
402
403    /// Extracts request metadata from the HTTP response headers.
404    ///
405    /// This is an internal method.
406    fn extract_metadata(&self, response: &Response, start_time: Instant) -> RequestMetadata {
407        let headers = response.headers();
408
409        RequestMetadata {
410            response_time: Some(start_time.elapsed().as_millis() as u64),
411            provider: headers
412                .get("x-provider")
413                .and_then(|v| v.to_str().ok())
414                .map(String::from),
415            tokens_used: headers
416                .get("x-tokens-used")
417                .and_then(|v| v.to_str().ok())
418                .and_then(|s| s.parse().ok()),
419            credits_used: headers
420                .get("x-credits-used")
421                .and_then(|v| v.to_str().ok())
422                .and_then(|s| s.parse().ok()),
423            credits_remaining: headers
424                .get("x-credits-remaining")
425                .and_then(|v| v.to_str().ok())
426                .and_then(|s| s.parse().ok()),
427            request_id: headers
428                .get("x-request-id")
429                .and_then(|v| v.to_str().ok())
430                .map(String::from),
431        }
432    }
433
434    /// Maps a structured API error response to a `RainyError`.
435    ///
436    /// This is an internal method.
437    fn map_api_error<T>(
438        &self,
439        error: crate::error::ApiErrorDetails,
440        status_code: u16,
441        request_id: Option<String>,
442    ) -> Result<T> {
443        let retryable = error.retryable.unwrap_or(status_code >= 500);
444
445        let rainy_error = match error.code.as_str() {
446            "INVALID_API_KEY" | "EXPIRED_API_KEY" => RainyError::Authentication {
447                code: error.code,
448                message: error.message,
449                retryable: false,
450            },
451            "INSUFFICIENT_CREDITS" => {
452                // Extract credit info from details if available
453                let (current_credits, estimated_cost, reset_date) =
454                    if let Some(details) = error.details {
455                        let current = details
456                            .get("current_credits")
457                            .and_then(|v| v.as_f64())
458                            .unwrap_or(0.0);
459                        let cost = details
460                            .get("estimated_cost")
461                            .and_then(|v| v.as_f64())
462                            .unwrap_or(0.0);
463                        let reset = details
464                            .get("reset_date")
465                            .and_then(|v| v.as_str())
466                            .map(String::from);
467                        (current, cost, reset)
468                    } else {
469                        (0.0, 0.0, None)
470                    };
471
472                RainyError::InsufficientCredits {
473                    code: error.code,
474                    message: error.message,
475                    current_credits,
476                    estimated_cost,
477                    reset_date,
478                }
479            }
480            "RATE_LIMIT_EXCEEDED" => {
481                let retry_after = error
482                    .details
483                    .as_ref()
484                    .and_then(|d| d.get("retry_after"))
485                    .and_then(|v| v.as_u64());
486
487                RainyError::RateLimit {
488                    code: error.code,
489                    message: error.message,
490                    retry_after,
491                    current_usage: None,
492                }
493            }
494            "INVALID_REQUEST" | "MISSING_REQUIRED_FIELD" | "INVALID_MODEL" => {
495                RainyError::InvalidRequest {
496                    code: error.code,
497                    message: error.message,
498                    details: error.details,
499                }
500            }
501            "PROVIDER_ERROR" | "PROVIDER_UNAVAILABLE" => {
502                let provider = error
503                    .details
504                    .as_ref()
505                    .and_then(|d| d.get("provider"))
506                    .and_then(|v| v.as_str())
507                    .unwrap_or("unknown")
508                    .to_string();
509
510                RainyError::Provider {
511                    code: error.code,
512                    message: error.message,
513                    provider,
514                    retryable,
515                }
516            }
517            _ => RainyError::Api {
518                code: error.code,
519                message: error.message,
520                status_code,
521                retryable,
522                request_id: request_id.clone(),
523            },
524        };
525
526        Err(rainy_error)
527    }
528
529    /// Returns a reference to the current authentication configuration.
530    pub fn auth_config(&self) -> &AuthConfig {
531        &self.auth_config
532    }
533
534    /// Returns the base URL being used by the client.
535    pub fn base_url(&self) -> &str {
536        &self.auth_config.base_url
537    }
538
539    /// Returns a reference to the underlying `reqwest::Client`.
540    ///
541    /// This is intended for internal use by the endpoint modules.
542    pub(crate) fn http_client(&self) -> &Client {
543        &self.client
544    }
545
546    /// Retrieves the list of available models from the API.
547    ///
548    /// This method returns information about all models that are currently available
549    /// through the Rainy API, including their compatibility status and supported parameters.
550    ///
551    /// # Returns
552    ///
553    /// A `Result` containing a `AvailableModels` struct with model information.
554    ///
555    /// # Example
556    ///
557    /// ```rust,no_run
558    /// # use rainy_sdk::RainyClient;
559    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
560    /// let client = RainyClient::with_api_key("your-api-key")?;
561    /// let models = client.list_available_models().await?;
562    ///
563    /// println!("Total models: {}", models.total_models);
564    /// for (provider, model_list) in &models.providers {
565    ///     println!("Provider {}: {:?}", provider, model_list);
566    /// }
567    /// # Ok(())
568    /// # }
569    /// ```
570    pub async fn list_available_models(&self) -> Result<AvailableModels> {
571        let url = format!("{}/api/v1/models", self.auth_config.base_url);
572
573        let operation = || async {
574            let response = self.client.get(&url).send().await?;
575            self.handle_response(response).await
576        };
577
578        if self.auth_config.enable_retry {
579            retry_with_backoff(&self.retry_config, operation).await
580        } else {
581            operation().await
582        }
583    }
584
585    /// Retrieves the Cowork profile for the current user.
586    ///
587    /// This includes subscription plan details, usage statistics, and feature flags.
588    ///
589    /// # Returns
590    ///
591    /// A `Result` containing a `CoworkProfile` struct on success, or a `RainyError` on failure.
592    pub async fn get_cowork_profile(&self) -> Result<crate::cowork::CoworkProfile> {
593        let url = format!("{}/api/v1/cowork/profile", self.auth_config.base_url);
594
595        let operation = || async {
596            let response = self.client.get(&url).send().await?;
597            self.handle_response(response).await
598        };
599
600        if self.auth_config.enable_retry {
601            retry_with_backoff(&self.retry_config, operation).await
602        } else {
603            operation().await
604        }
605    }
606
607    // Legacy methods for backward compatibility
608
609    /// Makes a generic HTTP request to the API.
610    ///
611    /// This is an internal method kept for compatibility with endpoint implementations.
612    pub(crate) async fn make_request<T: serde::de::DeserializeOwned>(
613        &self,
614        method: reqwest::Method,
615        endpoint: &str,
616        body: Option<serde_json::Value>,
617    ) -> Result<T> {
618        #[cfg(feature = "rate-limiting")]
619        if let Some(ref limiter) = self.rate_limiter {
620            limiter.until_ready().await;
621        }
622
623        let url = format!("{}/api/v1{}", self.auth_config.base_url, endpoint);
624        let headers = self.auth_config.build_headers()?;
625
626        let mut request = self.client.request(method, &url).headers(headers);
627
628        if let Some(body) = body {
629            request = request.json(&body);
630        }
631
632        let response = request.send().await?;
633        self.handle_response(response).await
634    }
635}
636
637impl std::fmt::Debug for RainyClient {
638    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
639        f.debug_struct("RainyClient")
640            .field("base_url", &self.auth_config.base_url)
641            .field("timeout", &self.auth_config.timeout_seconds)
642            .field("max_retries", &self.retry_config.max_retries)
643            .finish()
644    }
645}