rig/providers/openrouter/
client.rs

1use crate::{
2    client::{CompletionClient, ProviderClient},
3    impl_conversion_traits,
4};
5use serde::Deserialize;
6
7use super::completion::CompletionModel;
8
9// ================================================================
10// Main openrouter Client
11// ================================================================
12const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14#[derive(Clone)]
15pub struct Client {
16    base_url: String,
17    api_key: String,
18    http_client: reqwest::Client,
19}
20
21impl std::fmt::Debug for Client {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("Client")
24            .field("base_url", &self.base_url)
25            .field("http_client", &self.http_client)
26            .field("api_key", &"<REDACTED>")
27            .finish()
28    }
29}
30
31impl Client {
32    /// Create a new OpenRouter client with the given API key.
33    pub fn new(api_key: &str) -> Self {
34        Self::from_url(api_key, OPENROUTER_API_BASE_URL)
35    }
36
37    /// Create a new OpenRouter client with the given API key and base API URL.
38    pub fn from_url(api_key: &str, base_url: &str) -> Self {
39        Self {
40            base_url: base_url.to_string(),
41            api_key: api_key.to_string(),
42            http_client: reqwest::Client::builder()
43                .build()
44                .expect("OpenRouter reqwest client should build"),
45        }
46    }
47
48    /// Use your own `reqwest::Client`.
49    /// The required headers will be automatically attached upon trying to make a request.
50    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
51        self.http_client = client;
52
53        self
54    }
55
56    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
57        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
58        self.http_client.post(url).bearer_auth(&self.api_key)
59    }
60}
61
62impl ProviderClient for Client {
63    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
64    /// Panics if the environment variable is not set.
65    fn from_env() -> Self {
66        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
67        Self::new(&api_key)
68    }
69
70    fn from_val(input: crate::client::ProviderValue) -> Self {
71        let crate::client::ProviderValue::Simple(api_key) = input else {
72            panic!("Incorrect provider value type")
73        };
74        Self::new(&api_key)
75    }
76}
77
78impl CompletionClient for Client {
79    type CompletionModel = CompletionModel;
80
81    /// Create a completion model with the given name.
82    ///
83    /// # Example
84    /// ```
85    /// use rig::providers::openrouter::{Client, self};
86    ///
87    /// // Initialize the openrouter client
88    /// let openrouter = Client::new("your-openrouter-api-key");
89    ///
90    /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
91    /// ```
92    fn completion_model(&self, model: &str) -> CompletionModel {
93        CompletionModel::new(self.clone(), model)
94    }
95}
96
97impl_conversion_traits!(
98    AsEmbeddings,
99    AsTranscription,
100    AsImageGeneration,
101    AsAudioGeneration for Client
102);
103
104#[derive(Debug, Deserialize)]
105pub(crate) struct ApiErrorResponse {
106    pub message: String,
107}
108
109#[derive(Debug, Deserialize)]
110#[serde(untagged)]
111pub(crate) enum ApiResponse<T> {
112    Ok(T),
113    Err(ApiErrorResponse),
114}
115
116#[derive(Clone, Debug, Deserialize)]
117pub struct Usage {
118    pub prompt_tokens: usize,
119    pub completion_tokens: usize,
120    pub total_tokens: usize,
121}
122
123impl std::fmt::Display for Usage {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(
126            f,
127            "Prompt tokens: {} Total tokens: {}",
128            self.prompt_tokens, self.total_tokens
129        )
130    }
131}