Skip to main content

rig/providers/openrouter/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    completion::GetTokenUsage,
7    http_client,
8};
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11
12// ================================================================
13// Main openrouter Client
14// ================================================================
15const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct OpenRouterExt;
19#[derive(Debug, Default, Clone, Copy)]
20pub struct OpenRouterExtBuilder;
21
22type OpenRouterApiKey = BearerAuth;
23
24pub type Client<H = reqwest::Client> = client::Client<OpenRouterExt, H>;
25pub type ClientBuilder<H = reqwest::Client> =
26    client::ClientBuilder<OpenRouterExtBuilder, OpenRouterApiKey, H>;
27
28impl Provider for OpenRouterExt {
29    type Builder = OpenRouterExtBuilder;
30
31    const VERIFY_PATH: &'static str = "/key";
32}
33
34impl<H> Capabilities<H> for OpenRouterExt {
35    type Completion = Capable<super::CompletionModel<H>>;
36    type Embeddings = Capable<super::EmbeddingModel<H>>;
37    type Transcription = Nothing;
38    type ModelListing = Nothing;
39    #[cfg(feature = "image")]
40    type ImageGeneration = Nothing;
41
42    #[cfg(feature = "audio")]
43    type AudioGeneration = Nothing;
44}
45
46impl DebugExt for OpenRouterExt {}
47
48impl ProviderBuilder for OpenRouterExtBuilder {
49    type Extension<H>
50        = OpenRouterExt
51    where
52        H: http_client::HttpClientExt;
53    type ApiKey = OpenRouterApiKey;
54
55    const BASE_URL: &'static str = OPENROUTER_API_BASE_URL;
56
57    fn build<H>(
58        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
59    ) -> http_client::Result<Self::Extension<H>>
60    where
61        H: http_client::HttpClientExt,
62    {
63        Ok(OpenRouterExt)
64    }
65}
66
67impl ProviderClient for Client {
68    type Input = OpenRouterApiKey;
69
70    /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable.
71    /// Panics if the environment variable is not set.
72    fn from_env() -> Self {
73        let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY not set");
74
75        Self::new(&api_key).unwrap()
76    }
77
78    fn from_val(input: Self::Input) -> Self {
79        Self::new(input).unwrap()
80    }
81}
82
83#[derive(Debug, Deserialize)]
84pub(crate) struct ApiErrorResponse {
85    pub message: String,
86}
87
88#[derive(Debug, Deserialize)]
89#[serde(untagged)]
90pub(crate) enum ApiResponse<T> {
91    Ok(T),
92    Err(ApiErrorResponse),
93}
94
95#[derive(Clone, Debug, Deserialize, Serialize)]
96pub struct Usage {
97    pub prompt_tokens: usize,
98    #[serde(default)]
99    pub completion_tokens: usize,
100    pub total_tokens: usize,
101    #[serde(default)]
102    pub cost: f64,
103}
104
105impl std::fmt::Display for Usage {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        write!(
108            f,
109            "Prompt tokens: {} Total tokens: {}",
110            self.prompt_tokens, self.total_tokens
111        )
112    }
113}
114
115impl GetTokenUsage for Usage {
116    fn token_usage(&self) -> Option<crate::completion::Usage> {
117        let mut usage = crate::completion::Usage::new();
118
119        usage.input_tokens = self.prompt_tokens as u64;
120        usage.output_tokens = self.completion_tokens as u64;
121        usage.total_tokens = self.total_tokens as u64;
122
123        Some(usage)
124    }
125}
126#[cfg(test)]
127mod tests {
128    #[test]
129    fn test_client_initialization() {
130        let _client =
131            crate::providers::openrouter::Client::new("dummy-key").expect("Client::new() failed");
132        let _client_from_builder = crate::providers::openrouter::Client::builder()
133            .api_key("dummy-key")
134            .build()
135            .expect("Client::builder() failed");
136    }
137}