rig/providers/openrouter/
client.rs1use crate::{
2 client::{ClientBuilderError, CompletionClient, ProviderClient},
3 impl_conversion_traits,
4};
5use serde::{Deserialize, Serialize};
6
7use super::completion::CompletionModel;
8
9const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14pub struct ClientBuilder<'a> {
15 api_key: &'a str,
16 base_url: &'a str,
17 http_client: Option<reqwest::Client>,
18}
19
20impl<'a> ClientBuilder<'a> {
21 pub fn new(api_key: &'a str) -> Self {
22 Self {
23 api_key,
24 base_url: OPENROUTER_API_BASE_URL,
25 http_client: None,
26 }
27 }
28
29 pub fn base_url(mut self, base_url: &'a str) -> Self {
30 self.base_url = base_url;
31 self
32 }
33
34 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
35 self.http_client = Some(client);
36 self
37 }
38
39 pub fn build(self) -> Result<Client, ClientBuilderError> {
40 let http_client = if let Some(http_client) = self.http_client {
41 http_client
42 } else {
43 reqwest::Client::builder().build()?
44 };
45
46 Ok(Client {
47 base_url: self.base_url.to_string(),
48 api_key: self.api_key.to_string(),
49 http_client,
50 })
51 }
52}
53
54#[derive(Clone)]
55pub struct Client {
56 base_url: String,
57 api_key: String,
58 http_client: reqwest::Client,
59}
60
61impl std::fmt::Debug for Client {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("Client")
64 .field("base_url", &self.base_url)
65 .field("http_client", &self.http_client)
66 .field("api_key", &"<REDACTED>")
67 .finish()
68 }
69}
70
71impl Client {
72 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
83 ClientBuilder::new(api_key)
84 }
85
86 pub fn new(api_key: &str) -> Self {
91 Self::builder(api_key)
92 .build()
93 .expect("OpenRouter client should build")
94 }
95
96 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
97 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
98 self.http_client.post(url).bearer_auth(&self.api_key)
99 }
100}
101
102impl ProviderClient for Client {
103 fn from_env() -> Self {
106 let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
107 Self::new(&api_key)
108 }
109
110 fn from_val(input: crate::client::ProviderValue) -> Self {
111 let crate::client::ProviderValue::Simple(api_key) = input else {
112 panic!("Incorrect provider value type")
113 };
114 Self::new(&api_key)
115 }
116}
117
118impl CompletionClient for Client {
119 type CompletionModel = CompletionModel;
120
121 fn completion_model(&self, model: &str) -> CompletionModel {
133 CompletionModel::new(self.clone(), model)
134 }
135}
136
137impl_conversion_traits!(
138 AsEmbeddings,
139 AsTranscription,
140 AsImageGeneration,
141 AsAudioGeneration for Client
142);
143
144#[derive(Debug, Deserialize)]
145pub(crate) struct ApiErrorResponse {
146 pub message: String,
147}
148
149#[derive(Debug, Deserialize)]
150#[serde(untagged)]
151pub(crate) enum ApiResponse<T> {
152 Ok(T),
153 Err(ApiErrorResponse),
154}
155
156#[derive(Clone, Debug, Deserialize, Serialize)]
157pub struct Usage {
158 pub prompt_tokens: usize,
159 pub completion_tokens: usize,
160 pub total_tokens: usize,
161}
162
163impl std::fmt::Display for Usage {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 write!(
166 f,
167 "Prompt tokens: {} Total tokens: {}",
168 self.prompt_tokens, self.total_tokens
169 )
170 }
171}