lib_client_openrouter/
client.rs1use crate::auth::AuthStrategy;
4use crate::error::{OpenRouterError, Result};
5use crate::types::{
6 CreateChatCompletionRequest, CreateChatCompletionResponse, CreditsResponse, ErrorResponse,
7 GenerationStats, Model, ModelList,
8};
9use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};
10use std::sync::Arc;
11
12const DEFAULT_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14pub struct Client {
16 http: reqwest::Client,
17 auth: Arc<dyn AuthStrategy>,
18 base_url: String,
19}
20
21impl Client {
22 pub fn builder() -> ClientBuilder<()> {
24 ClientBuilder::new()
25 }
26
27 pub async fn create_chat_completion(
29 &self,
30 request: CreateChatCompletionRequest,
31 ) -> Result<CreateChatCompletionResponse> {
32 let url = format!("{}/chat/completions", self.base_url);
33 self.post(&url, &request).await
34 }
35
36 pub async fn list_models(&self) -> Result<ModelList> {
38 let url = format!("{}/models", self.base_url);
39 self.get(&url).await
40 }
41
42 pub async fn get_model(&self, model_id: &str) -> Result<Model> {
44 let models = self.list_models().await?;
45 models
46 .data
47 .into_iter()
48 .find(|m| m.id == model_id)
49 .ok_or_else(|| OpenRouterError::NotFound(format!("Model not found: {}", model_id)))
50 }
51
52 pub async fn get_generation(&self, generation_id: &str) -> Result<GenerationStats> {
54 let url = format!("{}/generation?id={}", self.base_url, generation_id);
55 self.get(&url).await
56 }
57
58 pub async fn get_credits(&self) -> Result<CreditsResponse> {
60 let url = format!("{}/auth/key", self.base_url);
62 self.get(&url).await
63 }
64
65 async fn get<T>(&self, url: &str) -> Result<T>
67 where
68 T: serde::de::DeserializeOwned,
69 {
70 let mut headers = HeaderMap::new();
71 self.auth.apply(&mut headers).await?;
72
73 tracing::debug!(url = %url, "GET request");
74
75 let response = self.http.get(url).headers(headers).send().await?;
76
77 self.handle_response(response).await
78 }
79
80 async fn post<T, B>(&self, url: &str, body: &B) -> Result<T>
82 where
83 T: serde::de::DeserializeOwned,
84 B: serde::Serialize,
85 {
86 let mut headers = HeaderMap::new();
87 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
88 self.auth.apply(&mut headers).await?;
89
90 tracing::debug!(url = %url, "POST request");
91
92 let response = self
93 .http
94 .post(url)
95 .headers(headers)
96 .json(body)
97 .send()
98 .await?;
99
100 self.handle_response(response).await
101 }
102
103 async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
105 where
106 T: serde::de::DeserializeOwned,
107 {
108 let status = response.status();
109 let status_code = status.as_u16();
110
111 let retry_after = response
113 .headers()
114 .get("retry-after")
115 .and_then(|v| v.to_str().ok())
116 .and_then(|s| s.parse().ok());
117
118 if status.is_success() {
119 let body = response.text().await?;
120 tracing::debug!(status = %status_code, "Response received");
121 serde_json::from_str(&body).map_err(OpenRouterError::from)
122 } else {
123 let body = response.text().await?;
124 tracing::warn!(status = %status_code, body = %body, "API error");
125
126 if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(&body) {
128 let message = error_response.error.message;
129 let code = error_response.error.code;
130
131 return Err(match status_code {
132 401 => OpenRouterError::Unauthorized,
133 402 => OpenRouterError::InsufficientCredits(message),
134 403 => OpenRouterError::Forbidden(message),
135 404 => OpenRouterError::NotFound(message),
136 429 => OpenRouterError::RateLimited {
137 retry_after: retry_after.unwrap_or(60),
138 },
139 500..=599 => OpenRouterError::ServerError(message),
140 _ => match code {
141 Some(400) => OpenRouterError::InvalidRequest(message),
142 Some(404) => OpenRouterError::ModelNotAvailable(message),
143 _ => OpenRouterError::Api {
144 status: status_code,
145 message,
146 },
147 },
148 });
149 }
150
151 Err(OpenRouterError::Api {
152 status: status_code,
153 message: body,
154 })
155 }
156 }
157}
158
159pub struct ClientBuilder<A> {
161 auth: A,
162 base_url: String,
163}
164
165impl ClientBuilder<()> {
166 pub fn new() -> Self {
168 Self {
169 auth: (),
170 base_url: DEFAULT_BASE_URL.to_string(),
171 }
172 }
173
174 pub fn auth<S: AuthStrategy + 'static>(self, strategy: S) -> ClientBuilder<S> {
176 ClientBuilder {
177 auth: strategy,
178 base_url: self.base_url,
179 }
180 }
181}
182
183impl Default for ClientBuilder<()> {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189impl<A: AuthStrategy + 'static> ClientBuilder<A> {
190 pub fn base_url(mut self, url: impl Into<String>) -> Self {
192 self.base_url = url.into();
193 self
194 }
195
196 pub fn build(self) -> Client {
198 Client {
199 http: reqwest::Client::new(),
200 auth: Arc::new(self.auth),
201 base_url: self.base_url,
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::auth::ApiKeyAuth;
210 use crate::types::Message;
211
212 #[test]
213 fn test_builder() {
214 let client = Client::builder()
215 .auth(ApiKeyAuth::new("test-key"))
216 .base_url("https://custom.api.com")
217 .build();
218
219 assert_eq!(client.base_url, "https://custom.api.com");
220 }
221
222 #[test]
223 fn test_create_chat_completion_request() {
224 let request =
225 CreateChatCompletionRequest::new("openai/gpt-4o", vec![Message::user("Hello")])
226 .with_max_tokens(1024)
227 .with_temperature(0.7);
228
229 assert_eq!(request.model, "openai/gpt-4o");
230 assert_eq!(request.max_tokens, Some(1024));
231 assert_eq!(request.temperature, Some(0.7));
232 }
233
234 #[test]
235 fn test_auth_with_site_info() {
236 let auth = ApiKeyAuth::new("sk-or-test")
237 .with_site_url("https://myapp.com")
238 .with_site_name("My App");
239
240 let _client = Client::builder().auth(auth).build();
241 }
242}