Skip to main content

rainy_sdk/
client.rs

1use crate::{
2    auth::AuthConfig,
3    error::{ApiErrorResponse, RainyError, Result},
4    models::*,
5    retry::{retry_with_backoff, RetryConfig},
6};
7use eventsource_stream::Eventsource;
8use futures::{Stream, StreamExt};
9use reqwest::{
10    header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT},
11    Client, Response,
12};
13use secrecy::ExposeSecret;
14use serde::Deserialize;
15use std::pin::Pin;
16use std::time::Instant;
17
18#[cfg(feature = "rate-limiting")]
19use governor::{
20    clock::DefaultClock,
21    state::{InMemoryState, NotKeyed},
22    Quota, RateLimiter,
23};
24
25/// The main client for interacting with the Rainy API.
26///
27/// `RainyClient` provides a convenient and high-level interface for making requests
28/// to the various endpoints of the Rainy API. It handles authentication, rate limiting,
29/// and retries automatically.
30///
31/// # Examples
32///
33/// ```rust,no_run
34/// use rainy_sdk::{RainyClient, Result};
35///
36/// #[tokio::main]
37/// async fn main() -> Result<()> {
38///     // Create a client using an API key from an environment variable
39///     let api_key = std::env::var("RAINY_API_KEY").expect("RAINY_API_KEY not set");
40///     let client = RainyClient::with_api_key(api_key)?;
41///
42///     // Use the client to make API calls
43///     let models = client.get_available_models().await?;
44///     println!("Available models: {:?}", models);
45///
46///     Ok(())
47/// }
48/// ```
49pub struct RainyClient {
50    /// The underlying `reqwest::Client` used for making HTTP requests.
51    client: Client,
52    /// The authentication configuration for the client.
53    auth_config: AuthConfig,
54    /// The retry configuration for handling failed requests.
55    retry_config: RetryConfig,
56
57    /// An optional rate limiter to control the request frequency.
58    /// This is only available when the `rate-limiting` feature is enabled.
59    #[cfg(feature = "rate-limiting")]
60    rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
61}
62
63impl RainyClient {
64    pub(crate) fn root_url(&self, path: &str) -> String {
65        let normalized = if path.starts_with('/') {
66            path.to_string()
67        } else {
68            format!("/{path}")
69        };
70        format!(
71            "{}{}",
72            self.auth_config.base_url.trim_end_matches('/'),
73            normalized
74        )
75    }
76
77    pub(crate) fn api_v1_url(&self, path: &str) -> String {
78        let normalized = if path.starts_with('/') {
79            path.to_string()
80        } else {
81            format!("/{path}")
82        };
83        format!(
84            "{}/api/v1{}",
85            self.auth_config.base_url.trim_end_matches('/'),
86            normalized
87        )
88    }
89
90    /// Creates a new `RainyClient` with the given API key.
91    ///
92    /// This is the simplest way to create a client. It uses default settings for the base URL,
93    /// timeout, and retries.
94    ///
95    /// # Arguments
96    ///
97    /// * `api_key` - Your Rainy API key.
98    ///
99    /// # Returns
100    ///
101    /// A `Result` containing the new `RainyClient` or a `RainyError` if initialization fails.
102    pub fn with_api_key(api_key: impl Into<String>) -> Result<Self> {
103        let auth_config = AuthConfig::new(api_key);
104        Self::with_config(auth_config)
105    }
106
107    /// Creates a new `RainyClient` with a custom `AuthConfig`.
108    ///
109    /// This allows for more advanced configuration, such as setting a custom base URL or timeout.
110    ///
111    /// # Arguments
112    ///
113    /// * `auth_config` - The authentication configuration to use.
114    ///
115    /// # Returns
116    ///
117    /// A `Result` containing the new `RainyClient` or a `RainyError` if initialization fails.
118    pub fn with_config(auth_config: AuthConfig) -> Result<Self> {
119        // Validate configuration
120        auth_config.validate()?;
121
122        // Build HTTP client
123        let mut headers = HeaderMap::new();
124        headers.insert(
125            AUTHORIZATION,
126            HeaderValue::from_str(&format!("Bearer {}", auth_config.api_key.expose_secret()))
127                .map_err(|e| RainyError::Authentication {
128                    code: "INVALID_API_KEY".to_string(),
129                    message: format!("Invalid API key format: {}", e),
130                    retryable: false,
131                })?,
132        );
133        headers.insert(
134            USER_AGENT,
135            HeaderValue::from_str(&auth_config.user_agent).map_err(|e| RainyError::Network {
136                message: format!("Invalid user agent: {}", e),
137                retryable: false,
138                source_error: None,
139            })?,
140        );
141
142        let client = Client::builder()
143            .use_rustls_tls()
144            .min_tls_version(reqwest::tls::Version::TLS_1_2)
145            .https_only(true)
146            .timeout(auth_config.timeout())
147            .default_headers(headers)
148            .build()
149            .map_err(|e| RainyError::Network {
150                message: format!("Failed to create HTTP client: {}", e),
151                retryable: false,
152                source_error: Some(e.to_string()),
153            })?;
154
155        let retry_config = RetryConfig::new(auth_config.max_retries);
156
157        #[cfg(feature = "rate-limiting")]
158        let rate_limiter = Some(RateLimiter::direct(Quota::per_second(
159            std::num::NonZeroU32::new(10).unwrap(),
160        )));
161
162        Ok(Self {
163            client,
164            auth_config,
165            retry_config,
166            #[cfg(feature = "rate-limiting")]
167            rate_limiter,
168        })
169    }
170
171    /// Sets a custom retry configuration for the client.
172    ///
173    /// This allows you to override the default retry behavior.
174    ///
175    /// # Arguments
176    ///
177    /// * `retry_config` - The new retry configuration.
178    ///
179    /// # Returns
180    ///
181    /// The `RainyClient` instance with the updated retry configuration.
182    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
183        self.retry_config = retry_config;
184        self
185    }
186
187    /// Retrieves the list of available models and providers from the API.
188    ///
189    /// # Returns
190    ///
191    /// A `Result` containing an `AvailableModels` struct on success, or a `RainyError` on failure.
192    pub async fn get_available_models(&self) -> Result<AvailableModels> {
193        #[derive(Deserialize)]
194        struct ModelListItem {
195            id: String,
196        }
197        #[derive(Deserialize)]
198        struct ModelsData {
199            data: Vec<ModelListItem>,
200        }
201        #[derive(Deserialize)]
202        struct Envelope {
203            data: ModelsData,
204        }
205
206        let url = self.api_v1_url("/models");
207
208        let operation = || async {
209            let response = self.client.get(&url).send().await?;
210            let envelope: Envelope = self.handle_response(response).await?;
211
212            let mut providers = std::collections::HashMap::<String, Vec<String>>::new();
213            for item in envelope.data.data {
214                let provider = item
215                    .id
216                    .split_once('/')
217                    .map(|(p, _)| p.to_string())
218                    .unwrap_or_else(|| "rainy".to_string());
219                providers.entry(provider).or_default().push(item.id);
220            }
221
222            let total_models = providers.values().map(std::vec::Vec::len).sum();
223            let mut active_providers = providers.keys().cloned().collect::<Vec<_>>();
224            active_providers.sort();
225
226            Ok(AvailableModels {
227                providers,
228                total_models,
229                active_providers,
230            })
231        };
232
233        if self.auth_config.enable_retry {
234            retry_with_backoff(&self.retry_config, operation).await
235        } else {
236            operation().await
237        }
238    }
239
240    /// Creates a chat completion based on the provided request.
241    ///
242    /// # Arguments
243    ///
244    /// * `request` - A `ChatCompletionRequest` containing the model, messages, and other parameters.
245    ///
246    /// # Returns
247    ///
248    /// A `Result` containing a tuple of `(ChatCompletionResponse, RequestMetadata)` on success,
249    /// or a `RainyError` on failure.
250    pub async fn chat_completion(
251        &self,
252        request: ChatCompletionRequest,
253    ) -> Result<(ChatCompletionResponse, RequestMetadata)> {
254        #[cfg(feature = "rate-limiting")]
255        if let Some(ref limiter) = self.rate_limiter {
256            limiter.until_ready().await;
257        }
258
259        let url = self.api_v1_url("/chat/completions");
260        let start_time = Instant::now();
261
262        let operation = || async {
263            let response = self.client.post(&url).json(&request).send().await?;
264
265            let metadata = self.extract_metadata(&response, start_time);
266            let chat_response: ChatCompletionResponse = self.handle_response(response).await?;
267
268            Ok((chat_response, metadata))
269        };
270
271        if self.auth_config.enable_retry {
272            retry_with_backoff(&self.retry_config, operation).await
273        } else {
274            operation().await
275        }
276    }
277
278    /// Creates a streaming chat completion based on the provided request.
279    ///
280    /// # Arguments
281    ///
282    /// * `request` - A `ChatCompletionRequest` containing the model, messages, and other parameters.
283    ///
284    /// # Returns
285    ///
286    /// A `Result` containing a stream of `ChatCompletionChunk`s on success, or a `RainyError` on failure.
287    pub async fn chat_completion_stream(
288        &self,
289        mut request: ChatCompletionRequest,
290    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
291        // Ensure stream is set to true
292        request.stream = Some(true);
293
294        #[cfg(feature = "rate-limiting")]
295        if let Some(ref limiter) = self.rate_limiter {
296            limiter.until_ready().await;
297        }
298
299        let url = self.api_v1_url("/chat/completions");
300
301        // Note: Retries are more complex with streams, so we only retry the initial connection
302        let operation = || async {
303            let response = self
304                .client
305                .post(&url)
306                .json(&request)
307                .send()
308                .await
309                .map_err(|e| RainyError::Network {
310                    message: format!("Failed to send request: {}", e),
311                    retryable: true,
312                    source_error: Some(e.to_string()),
313                })?;
314
315            self.handle_stream_response(response).await
316        };
317
318        if self.auth_config.enable_retry {
319            retry_with_backoff(&self.retry_config, operation).await
320        } else {
321            operation().await
322        }
323    }
324
325    /// Creates a simple chat completion with a single user prompt.
326    ///
327    /// This is a convenience method for simple use cases where you only need to send a single
328    /// prompt to a model and get a text response.
329    ///
330    /// # Arguments
331    ///
332    /// * `model` - The name of the model to use for the completion.
333    /// * `prompt` - The user's prompt.
334    ///
335    /// # Returns
336    ///
337    /// A `Result` containing the `String` response from the model, or a `RainyError` on failure.
338    pub async fn simple_chat(
339        &self,
340        model: impl Into<String>,
341        prompt: impl Into<String>,
342    ) -> Result<String> {
343        let request = ChatCompletionRequest::new(model, vec![ChatMessage::user(prompt)]);
344
345        let (response, _) = self.chat_completion(request).await?;
346
347        Ok(response
348            .choices
349            .into_iter()
350            .next()
351            .map(|choice| choice.message.content)
352            .unwrap_or_default())
353    }
354
355    /// Handles the HTTP response, deserializing the body into a given type `T` on success,
356    /// or mapping the error to a `RainyError` on failure.
357    ///
358    /// This is an internal method used by the various endpoint functions.
359    pub(crate) async fn handle_response<T>(&self, response: Response) -> Result<T>
360    where
361        T: serde::de::DeserializeOwned,
362    {
363        let status = response.status();
364        let headers = response.headers().clone();
365        let request_id = headers
366            .get("x-request-id")
367            .and_then(|v| v.to_str().ok())
368            .map(String::from);
369
370        if status.is_success() {
371            let text = response.text().await?;
372            serde_json::from_str(&text).map_err(|e| RainyError::Serialization {
373                message: format!("Failed to parse response: {}", e),
374                source_error: Some(e.to_string()),
375            })
376        } else {
377            let text = response.text().await.unwrap_or_default();
378
379            // Try to parse structured error response
380            if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
381                let error = error_response.error;
382                self.map_api_error(error, status.as_u16(), request_id)
383            } else {
384                // Fallback to generic error
385                Err(RainyError::Api {
386                    code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
387                    message: if text.is_empty() {
388                        format!("HTTP {}", status.as_u16())
389                    } else {
390                        text
391                    },
392                    status_code: status.as_u16(),
393                    retryable: status.is_server_error(),
394                    request_id,
395                })
396            }
397        }
398    }
399
400    /// Handles the HTTP response for streaming requests.
401    pub(crate) async fn handle_stream_response(
402        &self,
403        response: Response,
404    ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionChunk>> + Send>>> {
405        let status = response.status();
406        let request_id = response
407            .headers()
408            .get("x-request-id")
409            .and_then(|v| v.to_str().ok())
410            .map(String::from);
411
412        if status.is_success() {
413            let stream = response
414                .bytes_stream()
415                .eventsource()
416                .map(move |event| match event {
417                    Ok(event) => {
418                        if event.data == "[DONE]" {
419                            return None;
420                        }
421
422                        match serde_json::from_str::<ChatCompletionChunk>(&event.data) {
423                            Ok(chunk) => Some(Ok(chunk)),
424                            Err(e) => Some(Err(RainyError::Serialization {
425                                message: format!("Failed to parse stream chunk: {}", e),
426                                source_error: Some(e.to_string()),
427                            })),
428                        }
429                    }
430                    Err(e) => Some(Err(RainyError::Network {
431                        message: format!("Stream error: {}", e),
432                        retryable: true,
433                        source_error: Some(e.to_string()),
434                    })),
435                })
436                .take_while(|x| futures::future::ready(x.is_some()))
437                .map(|x| x.unwrap());
438
439            Ok(Box::pin(stream))
440        } else {
441            let text = response.text().await.unwrap_or_default();
442
443            // Try to parse structured error response
444            if let Ok(error_response) = serde_json::from_str::<ApiErrorResponse>(&text) {
445                let error = error_response.error;
446                self.map_api_error(error, status.as_u16(), request_id)
447            } else {
448                Err(RainyError::Api {
449                    code: status.canonical_reason().unwrap_or("UNKNOWN").to_string(),
450                    message: if text.is_empty() {
451                        format!("HTTP {}", status.as_u16())
452                    } else {
453                        text
454                    },
455                    status_code: status.as_u16(),
456                    retryable: status.is_server_error(),
457                    request_id,
458                })
459            }
460        }
461    }
462
463    /// Extracts request metadata from the HTTP response headers.
464    ///
465    /// This is an internal method.
466    fn extract_metadata(&self, response: &Response, start_time: Instant) -> RequestMetadata {
467        let headers = response.headers();
468
469        RequestMetadata {
470            response_time: Some(start_time.elapsed().as_millis() as u64),
471            provider: headers
472                .get("x-provider")
473                .and_then(|v| v.to_str().ok())
474                .map(String::from),
475            tokens_used: headers
476                .get("x-tokens-used")
477                .and_then(|v| v.to_str().ok())
478                .and_then(|s| s.parse().ok()),
479            credits_used: headers
480                .get("x-credits-used")
481                .and_then(|v| v.to_str().ok())
482                .and_then(|s| s.parse().ok()),
483            credits_remaining: headers
484                .get("x-credits-remaining")
485                .and_then(|v| v.to_str().ok())
486                .and_then(|s| s.parse().ok()),
487            request_id: headers
488                .get("x-request-id")
489                .and_then(|v| v.to_str().ok())
490                .map(String::from),
491        }
492    }
493
494    /// Maps a structured API error response to a `RainyError`.
495    ///
496    /// This is an internal method.
497    fn map_api_error<T>(
498        &self,
499        error: crate::error::ApiErrorDetails,
500        status_code: u16,
501        request_id: Option<String>,
502    ) -> Result<T> {
503        let retryable = error.retryable.unwrap_or(status_code >= 500);
504
505        let rainy_error = match error.code.as_str() {
506            "INVALID_API_KEY" | "EXPIRED_API_KEY" => RainyError::Authentication {
507                code: error.code,
508                message: error.message,
509                retryable: false,
510            },
511            "INSUFFICIENT_CREDITS" => {
512                // Extract credit info from details if available
513                let (current_credits, estimated_cost, reset_date) =
514                    if let Some(details) = error.details {
515                        let current = details
516                            .get("current_credits")
517                            .and_then(|v| v.as_f64())
518                            .unwrap_or(0.0);
519                        let cost = details
520                            .get("estimated_cost")
521                            .and_then(|v| v.as_f64())
522                            .unwrap_or(0.0);
523                        let reset = details
524                            .get("reset_date")
525                            .and_then(|v| v.as_str())
526                            .map(String::from);
527                        (current, cost, reset)
528                    } else {
529                        (0.0, 0.0, None)
530                    };
531
532                RainyError::InsufficientCredits {
533                    code: error.code,
534                    message: error.message,
535                    current_credits,
536                    estimated_cost,
537                    reset_date,
538                }
539            }
540            "RATE_LIMIT_EXCEEDED" => {
541                let retry_after = error
542                    .details
543                    .as_ref()
544                    .and_then(|d| d.get("retry_after"))
545                    .and_then(|v| v.as_u64());
546
547                RainyError::RateLimit {
548                    code: error.code,
549                    message: error.message,
550                    retry_after,
551                    current_usage: None,
552                }
553            }
554            "INVALID_REQUEST" | "MISSING_REQUIRED_FIELD" | "INVALID_MODEL" => {
555                RainyError::InvalidRequest {
556                    code: error.code,
557                    message: error.message,
558                    details: error.details,
559                }
560            }
561            "PROVIDER_ERROR" | "PROVIDER_UNAVAILABLE" => {
562                let provider = error
563                    .details
564                    .as_ref()
565                    .and_then(|d| d.get("provider"))
566                    .and_then(|v| v.as_str())
567                    .unwrap_or("unknown")
568                    .to_string();
569
570                RainyError::Provider {
571                    code: error.code,
572                    message: error.message,
573                    provider,
574                    retryable,
575                }
576            }
577            _ => RainyError::Api {
578                code: error.code,
579                message: error.message,
580                status_code,
581                retryable,
582                request_id: request_id.clone(),
583            },
584        };
585
586        Err(rainy_error)
587    }
588
589    /// Returns a reference to the current authentication configuration.
590    pub fn auth_config(&self) -> &AuthConfig {
591        &self.auth_config
592    }
593
594    /// Returns the base URL being used by the client.
595    pub fn base_url(&self) -> &str {
596        &self.auth_config.base_url
597    }
598
599    /// Returns a reference to the underlying `reqwest::Client`.
600    ///
601    /// This is intended for internal use by the endpoint modules.
602    pub(crate) fn http_client(&self) -> &Client {
603        &self.client
604    }
605
606    /// Retrieves the list of available models from the API.
607    ///
608    /// This method returns information about all models that are currently available
609    /// through the Rainy API, including their compatibility status and supported parameters.
610    ///
611    /// # Returns
612    ///
613    /// A `Result` containing a `AvailableModels` struct with model information.
614    ///
615    /// # Example
616    ///
617    /// ```rust,no_run
618    /// # use rainy_sdk::RainyClient;
619    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
620    /// let client = RainyClient::with_api_key("your-api-key")?;
621    /// let models = client.list_available_models().await?;
622    ///
623    /// println!("Total models: {}", models.total_models);
624    /// for (provider, model_list) in &models.providers {
625    ///     println!("Provider {}: {:?}", provider, model_list);
626    /// }
627    /// # Ok(())
628    /// # }
629    /// ```
630    pub async fn list_available_models(&self) -> Result<AvailableModels> {
631        self.get_available_models().await
632    }
633
634    /// Retrieves the Cowork profile for the current user.
635    ///
636    /// This includes subscription plan details, usage statistics, and feature flags.
637    ///
638    /// # Returns
639    ///
640    /// A `Result` containing a `CoworkProfile` struct on success, or a `RainyError` on failure.
641    #[cfg(feature = "cowork")]
642    #[deprecated(
643        note = "Cowork endpoints are legacy and not supported by Rainy API v3. Migrate to v3 session/org endpoints."
644    )]
645    pub async fn get_cowork_profile(&self) -> Result<crate::cowork::CoworkProfile> {
646        let url = self.api_v1_url("/cowork/profile");
647
648        let operation = || async {
649            let response = self.client.get(&url).send().await?;
650            self.handle_response(response).await
651        };
652
653        if self.auth_config.enable_retry {
654            retry_with_backoff(&self.retry_config, operation).await
655        } else {
656            operation().await
657        }
658    }
659
660    // Legacy methods for backward compatibility
661
662    /// Makes a generic HTTP request to the API.
663    ///
664    /// This is an internal method kept for compatibility with endpoint implementations.
665    pub(crate) async fn make_request<T: serde::de::DeserializeOwned>(
666        &self,
667        method: reqwest::Method,
668        endpoint: &str,
669        body: Option<serde_json::Value>,
670    ) -> Result<T> {
671        #[cfg(feature = "rate-limiting")]
672        if let Some(ref limiter) = self.rate_limiter {
673            limiter.until_ready().await;
674        }
675
676        let url = self.api_v1_url(endpoint);
677        let headers = self.auth_config.build_headers()?;
678
679        let mut request = self.client.request(method, &url).headers(headers);
680
681        if let Some(body) = body {
682            request = request.json(&body);
683        }
684
685        let response = request.send().await?;
686        self.handle_response(response).await
687    }
688}
689
690impl std::fmt::Debug for RainyClient {
691    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
692        f.debug_struct("RainyClient")
693            .field("base_url", &self.auth_config.base_url)
694            .field("timeout", &self.auth_config.timeout_seconds)
695            .field("max_retries", &self.retry_config.max_retries)
696            .finish()
697    }
698}