lib_client_openrouter/
client.rs

1//! OpenRouter API client implementation.
2
3use crate::auth::AuthStrategy;
4use crate::error::{OpenRouterError, Result};
5use crate::types::{
6    CreateChatCompletionRequest, CreateChatCompletionResponse, CreditsResponse, ErrorResponse,
7    GenerationStats, Model, ModelList,
8};
9use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
10use std::sync::Arc;
11
12const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14/// OpenRouter API client.
15pub struct Client {
16    http: reqwest::Client,
17    auth: Arc<dyn AuthStrategy>,
18    base_url: String,
19}
20
21impl Client {
22    /// Create a new client builder.
23    pub fn builder() -> ClientBuilder<()> {
24        ClientBuilder::new()
25    }
26
27    /// Create a chat completion.
28    pub async fn create_chat_completion(
29        &self,
30        request: CreateChatCompletionRequest,
31    ) -> Result<CreateChatCompletionResponse> {
32        let url = format!("{}/chat/completions", self.base_url);
33        self.post(&url, &request).await
34    }
35
36    /// List available models.
37    pub async fn list_models(&self) -> Result<ModelList> {
38        let url = format!("{}/models", self.base_url);
39        self.get(&url).await
40    }
41
42    /// Get a specific model by ID.
43    pub async fn get_model(&self, model_id: &str) -> Result<Model> {
44        let models = self.list_models().await?;
45        models
46            .data
47            .into_iter()
48            .find(|m| m.id == model_id)
49            .ok_or_else(|| OpenRouterError::NotFound(format!("Model not found: {}", model_id)))
50    }
51
52    /// Get generation statistics by ID.
53    pub async fn get_generation(&self, generation_id: &str) -> Result<GenerationStats> {
54        let url = format!("{}/generation?id={}", self.base_url, generation_id);
55        self.get(&url).await
56    }
57
58    /// Get account credits/balance.
59    pub async fn get_credits(&self) -> Result<CreditsResponse> {
60        // Note: This endpoint is at /api/v1/auth/key
61        let url = format!("{}/auth/key", self.base_url);
62        self.get(&url).await
63    }
64
65    /// Send a GET request.
66    async fn get<T>(&self, url: &str) -> Result<T>
67    where
68        T: serde::de::DeserializeOwned,
69    {
70        let mut headers = HeaderMap::new();
71        self.auth.apply(&mut headers).await?;
72
73        tracing::debug!(url = %url, "GET request");
74
75        let response = self.http.get(url).headers(headers).send().await?;
76
77        self.handle_response(response).await
78    }
79
80    /// Send a POST request with JSON body.
81    async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
82    where
83        T: serde::de::DeserializeOwned,
84        B: serde::Serialize,
85    {
86        let mut headers = HeaderMap::new();
87        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
88        self.auth.apply(&mut headers).await?;
89
90        tracing::debug!(url = %url, "POST request");
91
92        let response = self
93            .http
94            .post(url)
95            .headers(headers)
96            .json(body)
97            .send()
98            .await?;
99
100        self.handle_response(response).await
101    }
102
103    /// Handle API response.
104    async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
105    where
106        T: serde::de::DeserializeOwned,
107    {
108        let status = response.status();
109        let status_code = status.as_u16();
110
111        // Extract rate limit headers before consuming response
112        let retry_after = response
113            .headers()
114            .get("retry-after")
115            .and_then(|v| v.to_str().ok())
116            .and_then(|s| s.parse().ok());
117
118        if status.is_success() {
119            let body = response.text().await?;
120            tracing::debug!(status = %status_code, "Response received");
121            serde_json::from_str(&body).map_err(OpenRouterError::from)
122        } else {
123            let body = response.text().await?;
124            tracing::warn!(status = %status_code, body = %body, "API error");
125
126            // Try to parse error response
127            if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
128                let message = error_response.error.message;
129                let code = error_response.error.code;
130
131                return Err(match status_code {
132                    401 => OpenRouterError::Unauthorized,
133                    402 => OpenRouterError::InsufficientCredits(message),
134                    403 => OpenRouterError::Forbidden(message),
135                    404 => OpenRouterError::NotFound(message),
136                    429 => OpenRouterError::RateLimited {
137                        retry_after: retry_after.unwrap_or(60),
138                    },
139                    500..=599 => OpenRouterError::ServerError(message),
140                    _ => match code {
141                        Some(400) => OpenRouterError::InvalidRequest(message),
142                        Some(404) => OpenRouterError::ModelNotAvailable(message),
143                        _ => OpenRouterError::Api {
144                            status: status_code,
145                            message,
146                        },
147                    },
148                });
149            }
150
151            Err(OpenRouterError::Api {
152                status: status_code,
153                message: body,
154            })
155        }
156    }
157}
158
159/// Client builder.
160pub struct ClientBuilder<A> {
161    auth: A,
162    base_url: String,
163}
164
165impl ClientBuilder<()> {
166    /// Create a new client builder.
167    pub fn new() -> Self {
168        Self {
169            auth: (),
170            base_url: DEFAULT_BASE_URL.to_string(),
171        }
172    }
173
174    /// Set the authentication strategy.
175    pub fn auth<S: AuthStrategy + 'static>(self, strategy: S) -> ClientBuilder<S> {
176        ClientBuilder {
177            auth: strategy,
178            base_url: self.base_url,
179        }
180    }
181}
182
183impl Default for ClientBuilder<()> {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189impl<A: AuthStrategy + 'static> ClientBuilder<A> {
190    /// Set a custom base URL.
191    pub fn base_url(mut self, url: impl Into<String>) -> Self {
192        self.base_url = url.into();
193        self
194    }
195
196    /// Build the client.
197    pub fn build(self) -> Client {
198        Client {
199            http: reqwest::Client::new(),
200            auth: Arc::new(self.auth),
201            base_url: self.base_url,
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::auth::ApiKeyAuth;
210    use crate::types::Message;
211
212    #[test]
213    fn test_builder() {
214        let client = Client::builder()
215            .auth(ApiKeyAuth::new("test-key"))
216            .base_url("https://custom.api.com")
217            .build();
218
219        assert_eq!(client.base_url, "https://custom.api.com");
220    }
221
222    #[test]
223    fn test_create_chat_completion_request() {
224        let request =
225            CreateChatCompletionRequest::new("openai/gpt-4o", vec![Message::user("Hello")])
226                .with_max_tokens(1024)
227                .with_temperature(0.7);
228
229        assert_eq!(request.model, "openai/gpt-4o");
230        assert_eq!(request.max_tokens, Some(1024));
231        assert_eq!(request.temperature, Some(0.7));
232    }
233
234    #[test]
235    fn test_auth_with_site_info() {
236        let auth = ApiKeyAuth::new("sk-or-test")
237            .with_site_url("https://myapp.com")
238            .with_site_name("My App");
239
240        let _client = Client::builder().auth(auth).build();
241    }
242}