rig/providers/openrouter/
client.rs1use crate::{
2 client::{ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError},
3 impl_conversion_traits,
4};
5use serde::{Deserialize, Serialize};
6
7use super::completion::CompletionModel;
8
9const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14pub struct ClientBuilder<'a> {
15 api_key: &'a str,
16 base_url: &'a str,
17 http_client: Option<reqwest::Client>,
18}
19
20impl<'a> ClientBuilder<'a> {
21 pub fn new(api_key: &'a str) -> Self {
22 Self {
23 api_key,
24 base_url: OPENROUTER_API_BASE_URL,
25 http_client: None,
26 }
27 }
28
29 pub fn base_url(mut self, base_url: &'a str) -> Self {
30 self.base_url = base_url;
31 self
32 }
33
34 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
35 self.http_client = Some(client);
36 self
37 }
38
39 pub fn build(self) -> Result<Client, ClientBuilderError> {
40 let http_client = if let Some(http_client) = self.http_client {
41 http_client
42 } else {
43 reqwest::Client::builder().build()?
44 };
45
46 Ok(Client {
47 base_url: self.base_url.to_string(),
48 api_key: self.api_key.to_string(),
49 http_client,
50 })
51 }
52}
53
54#[derive(Clone)]
55pub struct Client {
56 base_url: String,
57 api_key: String,
58 http_client: reqwest::Client,
59}
60
61impl std::fmt::Debug for Client {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("Client")
64 .field("base_url", &self.base_url)
65 .field("http_client", &self.http_client)
66 .field("api_key", &"<REDACTED>")
67 .finish()
68 }
69}
70
71impl Client {
72 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
83 ClientBuilder::new(api_key)
84 }
85
86 pub fn new(api_key: &str) -> Self {
91 Self::builder(api_key)
92 .build()
93 .expect("OpenRouter client should build")
94 }
95
96 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
97 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
98 self.http_client.post(url).bearer_auth(&self.api_key)
99 }
100
101 pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
102 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
103 self.http_client.get(url).bearer_auth(&self.api_key)
104 }
105}
106
107impl ProviderClient for Client {
108 fn from_env() -> Self {
111 let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
112 Self::new(&api_key)
113 }
114
115 fn from_val(input: crate::client::ProviderValue) -> Self {
116 let crate::client::ProviderValue::Simple(api_key) = input else {
117 panic!("Incorrect provider value type")
118 };
119 Self::new(&api_key)
120 }
121}
122
123impl CompletionClient for Client {
124 type CompletionModel = CompletionModel;
125
126 fn completion_model(&self, model: &str) -> CompletionModel {
138 CompletionModel::new(self.clone(), model)
139 }
140}
141
142impl VerifyClient for Client {
143 #[cfg_attr(feature = "worker", worker::send)]
144 async fn verify(&self) -> Result<(), VerifyError> {
145 let response = self.get("/key").send().await?;
146 match response.status() {
147 reqwest::StatusCode::OK => Ok(()),
148 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
149 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
150 Err(VerifyError::ProviderError(response.text().await?))
151 }
152 _ => {
153 response.error_for_status()?;
154 Ok(())
155 }
156 }
157 }
158}
159
160impl_conversion_traits!(
161 AsEmbeddings,
162 AsTranscription,
163 AsImageGeneration,
164 AsAudioGeneration for Client
165);
166
167#[derive(Debug, Deserialize)]
168pub(crate) struct ApiErrorResponse {
169 pub message: String,
170}
171
172#[derive(Debug, Deserialize)]
173#[serde(untagged)]
174pub(crate) enum ApiResponse<T> {
175 Ok(T),
176 Err(ApiErrorResponse),
177}
178
179#[derive(Clone, Debug, Deserialize, Serialize)]
180pub struct Usage {
181 pub prompt_tokens: usize,
182 pub completion_tokens: usize,
183 pub total_tokens: usize,
184}
185
186impl std::fmt::Display for Usage {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 write!(
189 f,
190 "Prompt tokens: {} Total tokens: {}",
191 self.prompt_tokens, self.total_tokens
192 )
193 }
194}