rig/providers/openrouter/
client.rs

1use crate::{
2    client::{CompletionClient, ProviderClient},
3    impl_conversion_traits,
4};
5use serde::Deserialize;
6
7use super::completion::CompletionModel;
8
9// ================================================================
10// Main openrouter Client
11// ================================================================
12const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14#[derive(Clone)]
15pub struct Client {
16    base_url: String,
17    api_key: String,
18    http_client: reqwest::Client,
19}
20
21impl std::fmt::Debug for Client {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("Client")
24            .field("base_url", &self.base_url)
25            .field("http_client", &self.http_client)
26            .field("api_key", &"<REDACTED>")
27            .finish()
28    }
29}
30
31impl Client {
32    /// Create a new OpenRouter client with the given API key.
33    pub fn new(api_key: &str) -> Self {
34        Self::from_url(api_key, OPENROUTER_API_BASE_URL)
35    }
36
37    /// Create a new OpenRouter client with the given API key and base API URL.
38    pub fn from_url(api_key: &str, base_url: &str) -> Self {
39        Self {
40            base_url: base_url.to_string(),
41            api_key: api_key.to_string(),
42            http_client: reqwest::Client::builder()
43                .build()
44                .expect("OpenRouter reqwest client should build"),
45        }
46    }
47
48    /// Use your own `reqwest::Client`.
49    /// The required headers will be automatically attached upon trying to make a request.
50    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
51        self.http_client = client;
52
53        self
54    }
55
56    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
57        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
58        self.http_client.post(url).bearer_auth(&self.api_key)
59    }
60}
61
62impl ProviderClient for Client {
63    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
64    /// Panics if the environment variable is not set.
65    fn from_env() -> Self {
66        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
67        Self::new(&api_key)
68    }
69}
70
71impl CompletionClient for Client {
72    type CompletionModel = CompletionModel;
73
74    /// Create a completion model with the given name.
75    ///
76    /// # Example
77    /// ```
78    /// use rig::providers::openrouter::{Client, self};
79    ///
80    /// // Initialize the openrouter client
81    /// let openrouter = Client::new("your-openrouter-api-key");
82    ///
83    /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
84    /// ```
85    fn completion_model(&self, model: &str) -> CompletionModel {
86        CompletionModel::new(self.clone(), model)
87    }
88}
89
90impl_conversion_traits!(
91    AsEmbeddings,
92    AsTranscription,
93    AsImageGeneration,
94    AsAudioGeneration for Client
95);
96
97#[derive(Debug, Deserialize)]
98pub(crate) struct ApiErrorResponse {
99    pub message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104pub(crate) enum ApiResponse<T> {
105    Ok(T),
106    Err(ApiErrorResponse),
107}
108
109#[derive(Clone, Debug, Deserialize)]
110pub struct Usage {
111    pub prompt_tokens: usize,
112    pub completion_tokens: usize,
113    pub total_tokens: usize,
114}
115
116impl std::fmt::Display for Usage {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(
119            f,
120            "Prompt tokens: {} Total tokens: {}",
121            self.prompt_tokens, self.total_tokens
122        )
123    }
124}