rig/providers/openrouter/
client.rs1use crate::{
2 client::{CompletionClient, ProviderClient},
3 impl_conversion_traits,
4};
5use serde::Deserialize;
6
7use super::completion::CompletionModel;
8
9const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14#[derive(Clone)]
15pub struct Client {
16 base_url: String,
17 api_key: String,
18 http_client: reqwest::Client,
19}
20
21impl std::fmt::Debug for Client {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("Client")
24 .field("base_url", &self.base_url)
25 .field("http_client", &self.http_client)
26 .field("api_key", &"<REDACTED>")
27 .finish()
28 }
29}
30
31impl Client {
32 pub fn new(api_key: &str) -> Self {
34 Self::from_url(api_key, OPENROUTER_API_BASE_URL)
35 }
36
37 pub fn from_url(api_key: &str, base_url: &str) -> Self {
39 Self {
40 base_url: base_url.to_string(),
41 api_key: api_key.to_string(),
42 http_client: reqwest::Client::builder()
43 .build()
44 .expect("OpenRouter reqwest client should build"),
45 }
46 }
47
48 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
51 self.http_client = client;
52
53 self
54 }
55
56 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
57 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
58 self.http_client.post(url).bearer_auth(&self.api_key)
59 }
60}
61
62impl ProviderClient for Client {
63 fn from_env() -> Self {
66 let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
67 Self::new(&api_key)
68 }
69}
70
71impl CompletionClient for Client {
72 type CompletionModel = CompletionModel;
73
74 fn completion_model(&self, model: &str) -> CompletionModel {
86 CompletionModel::new(self.clone(), model)
87 }
88}
89
90impl_conversion_traits!(
91 AsEmbeddings,
92 AsTranscription,
93 AsImageGeneration,
94 AsAudioGeneration for Client
95);
96
97#[derive(Debug, Deserialize)]
98pub(crate) struct ApiErrorResponse {
99 pub message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104pub(crate) enum ApiResponse<T> {
105 Ok(T),
106 Err(ApiErrorResponse),
107}
108
109#[derive(Clone, Debug, Deserialize)]
110pub struct Usage {
111 pub prompt_tokens: usize,
112 pub completion_tokens: usize,
113 pub total_tokens: usize,
114}
115
116impl std::fmt::Display for Usage {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(
119 f,
120 "Prompt tokens: {} Total tokens: {}",
121 self.prompt_tokens, self.total_tokens
122 )
123 }
124}