rig/providers/openrouter/
client.rs

1use crate::{
2    client::{CompletionClient, ProviderClient},
3    impl_conversion_traits,
4};
5use serde::Deserialize;
6
7use super::completion::CompletionModel;
8
9// ================================================================
10// Main openrouter Client
11// ================================================================
12const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
13
14#[derive(Clone, Debug)]
15pub struct Client {
16    base_url: String,
17    http_client: reqwest::Client,
18}
19
20impl Client {
21    /// Create a new OpenRouter client with the given API key.
22    pub fn new(api_key: &str) -> Self {
23        Self::from_url(api_key, OPENROUTER_API_BASE_URL)
24    }
25
26    /// Create a new OpenRouter client with the given API key and base API URL.
27    pub fn from_url(api_key: &str, base_url: &str) -> Self {
28        Self {
29            base_url: base_url.to_string(),
30            http_client: reqwest::Client::builder()
31                .default_headers({
32                    let mut headers = reqwest::header::HeaderMap::new();
33                    headers.insert(
34                        "Authorization",
35                        format!("Bearer {api_key}")
36                            .parse()
37                            .expect("Bearer token should parse"),
38                    );
39                    headers
40                })
41                .build()
42                .expect("OpenRouter reqwest client should build"),
43        }
44    }
45
46    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
47        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
48        self.http_client.post(url)
49    }
50}
51
52impl ProviderClient for Client {
53    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
54    /// Panics if the environment variable is not set.
55    fn from_env() -> Self {
56        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
57        Self::new(&api_key)
58    }
59}
60
61impl CompletionClient for Client {
62    type CompletionModel = CompletionModel;
63
64    /// Create a completion model with the given name.
65    ///
66    /// # Example
67    /// ```
68    /// use rig::providers::openrouter::{Client, self};
69    ///
70    /// // Initialize the openrouter client
71    /// let openrouter = Client::new("your-openrouter-api-key");
72    ///
73    /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B);
74    /// ```
75    fn completion_model(&self, model: &str) -> CompletionModel {
76        CompletionModel::new(self.clone(), model)
77    }
78}
79
80impl_conversion_traits!(
81    AsEmbeddings,
82    AsTranscription,
83    AsImageGeneration,
84    AsAudioGeneration for Client
85);
86
87#[derive(Debug, Deserialize)]
88pub(crate) struct ApiErrorResponse {
89    pub message: String,
90}
91
92#[derive(Debug, Deserialize)]
93#[serde(untagged)]
94pub(crate) enum ApiResponse<T> {
95    Ok(T),
96    Err(ApiErrorResponse),
97}
98
99#[derive(Clone, Debug, Deserialize)]
100pub struct Usage {
101    pub prompt_tokens: usize,
102    pub completion_tokens: usize,
103    pub total_tokens: usize,
104}
105
106impl std::fmt::Display for Usage {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        write!(
109            f,
110            "Prompt tokens: {} Total tokens: {}",
111            self.prompt_tokens, self.total_tokens
112        )
113    }
114}