rig/providers/openrouter/
client.rs1use crate::{
2 client::{CompletionClient, ProviderClient},
3 impl_conversion_traits,
4};
5use serde::Deserialize;
6
7use super::completion::CompletionModel;
8
9const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14#[derive(Clone)]
15pub struct Client {
16 base_url: String,
17 api_key: String,
18 http_client: reqwest::Client,
19}
20
21impl std::fmt::Debug for Client {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("Client")
24 .field("base_url", &self.base_url)
25 .field("http_client", &self.http_client)
26 .field("api_key", &"<REDACTED>")
27 .finish()
28 }
29}
30
31impl Client {
32 pub fn new(api_key: &str) -> Self {
34 Self::from_url(api_key, OPENROUTER_API_BASE_URL)
35 }
36
37 pub fn from_url(api_key: &str, base_url: &str) -> Self {
39 Self {
40 base_url: base_url.to_string(),
41 api_key: api_key.to_string(),
42 http_client: reqwest::Client::builder()
43 .build()
44 .expect("OpenRouter reqwest client should build"),
45 }
46 }
47
48 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
51 self.http_client = client;
52
53 self
54 }
55
56 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
57 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
58 self.http_client.post(url).bearer_auth(&self.api_key)
59 }
60}
61
62impl ProviderClient for Client {
63 fn from_env() -> Self {
66 let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
67 Self::new(&api_key)
68 }
69
70 fn from_val(input: crate::client::ProviderValue) -> Self {
71 let crate::client::ProviderValue::Simple(api_key) = input else {
72 panic!("Incorrect provider value type")
73 };
74 Self::new(&api_key)
75 }
76}
77
78impl CompletionClient for Client {
79 type CompletionModel = CompletionModel;
80
81 fn completion_model(&self, model: &str) -> CompletionModel {
93 CompletionModel::new(self.clone(), model)
94 }
95}
96
97impl_conversion_traits!(
98 AsEmbeddings,
99 AsTranscription,
100 AsImageGeneration,
101 AsAudioGeneration for Client
102);
103
104#[derive(Debug, Deserialize)]
105pub(crate) struct ApiErrorResponse {
106 pub message: String,
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(untagged)]
111pub(crate) enum ApiResponse<T> {
112 Ok(T),
113 Err(ApiErrorResponse),
114}
115
116#[derive(Clone, Debug, Deserialize)]
117pub struct Usage {
118 pub prompt_tokens: usize,
119 pub completion_tokens: usize,
120 pub total_tokens: usize,
121}
122
123impl std::fmt::Display for Usage {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 write!(
126 f,
127 "Prompt tokens: {} Total tokens: {}",
128 self.prompt_tokens, self.total_tokens
129 )
130 }
131}