openrouter_rs/api/
models.rs

1use serde::{Deserialize, Serialize};
2use surf::http::headers::AUTHORIZATION;
3use urlencoding::encode;
4
5use crate::{
6    error::OpenRouterError,
7    types::{ApiResponse, ModelCategory, SupportedParameters},
8    utils::handle_error,
9};
10
11#[derive(Serialize, Deserialize, Debug)]
12pub struct Model {
13    pub id: String,
14    pub name: String,
15    pub created: f64,
16    pub description: String,
17    pub context_length: f64,
18    pub architecture: Architecture,
19    pub top_provider: TopProvider,
20    pub pricing: Pricing,
21    pub per_request_limits: Option<std::collections::HashMap<String, String>>,
22}
23
24#[derive(Serialize, Deserialize, Debug)]
25pub struct Architecture {
26    pub modality: String,
27    pub tokenizer: String,
28    pub instruct_type: Option<String>,
29}
30
31#[derive(Serialize, Deserialize, Debug)]
32pub struct TopProvider {
33    pub context_length: Option<f64>,
34    pub max_completion_tokens: Option<f64>,
35    pub is_moderated: bool,
36}
37
38#[derive(Serialize, Deserialize, Debug)]
39pub struct Pricing {
40    pub prompt: String,
41    pub completion: String,
42    pub image: Option<String>,
43    pub request: Option<String>,
44    pub input_cache_read: Option<String>,
45    pub input_cache_write: Option<String>,
46    pub web_search: Option<String>,
47    pub internal_reasoning: Option<String>,
48}
49
50#[derive(Serialize, Deserialize, Debug)]
51pub struct Endpoint {
52    pub name: String,
53    pub context_length: f64,
54    pub pricing: EndpointPricing,
55    pub provider_name: String,
56    pub supported_parameters: Vec<String>,
57    pub quantization: Option<String>,
58    pub max_completion_tokens: Option<f64>,
59    pub max_prompt_tokens: Option<f64>,
60    pub status: Option<serde_json::Value>,
61}
62
63#[derive(Serialize, Deserialize, Debug)]
64pub struct EndpointPricing {
65    pub request: String,
66    pub image: String,
67    pub prompt: String,
68    pub completion: String,
69}
70
71#[derive(Serialize, Deserialize, Debug)]
72pub struct EndpointData {
73    pub id: String,
74    pub name: String,
75    pub created: f64,
76    pub description: String,
77    pub architecture: EndpointArchitecture,
78    pub endpoints: Vec<Endpoint>,
79}
80
81#[derive(Serialize, Deserialize, Debug)]
82pub struct EndpointArchitecture {
83    pub tokenizer: Option<String>,
84    pub instruct_type: Option<String>,
85    pub modality: Option<String>,
86}
87
88/// Returns a list of models available through the API
89///
90/// # Arguments
91///
92/// * `base_url` - The base URL of the OpenRouter API.
93/// * `api_key` - The API key for authentication.
94/// * `category` - The category of the models.
95///
96/// # Returns
97///
98/// * `Result<Vec<Model>, OpenRouterError>` - A list of models or an error.
99pub async fn list_models(
100    base_url: &str,
101    api_key: &str,
102    category: Option<ModelCategory>,
103    supported_parameters: Option<SupportedParameters>,
104) -> Result<Vec<Model>, OpenRouterError> {
105    let url = match (category, supported_parameters) {
106        (Some(category), None) => {
107            format!("{base_url}/models?category={category}")
108        }
109        (None, Some(supported_parameters)) => {
110            format!("{base_url}/models?supported_parameters={supported_parameters}")
111        }
112        _ => {
113            format!("{base_url}/models")
114        }
115    };
116
117    let mut response = surf::get(url)
118        .header(AUTHORIZATION, format!("Bearer {api_key}"))
119        .await?;
120
121    if response.status().is_success() {
122        let model_list_response: ApiResponse<_> = response.body_json().await?;
123        Ok(model_list_response.data)
124    } else {
125        handle_error(response).await?;
126        unreachable!()
127    }
128}
129
130/// Returns details about the endpoints for a specific model
131///
132/// # Arguments
133///
134/// * `base_url` - The base URL of the OpenRouter API.
135/// * `api_key` - The API key for authentication.
136/// * `author` - The author of the model.
137/// * `slug` - The slug identifier for the model.
138///
139/// # Returns
140///
141/// * `Result<EndpointData, OpenRouterError>` - The endpoint data or an error.
142pub async fn list_model_endpoints(
143    base_url: &str,
144    api_key: &str,
145    author: &str,
146    slug: &str,
147) -> Result<EndpointData, OpenRouterError> {
148    let encoded_author = encode(author);
149    let encoded_slug = encode(slug);
150    let url = format!("{base_url}/models/{encoded_author}/{encoded_slug}/endpoints");
151    println!("URL: {url}");
152
153    let mut response = surf::get(&url)
154        .header(AUTHORIZATION, format!("Bearer {api_key}"))
155        .await?;
156
157    if response.status().is_success() {
158        let endpoint_list_response: ApiResponse<_> = response.body_json().await?;
159        Ok(endpoint_list_response.data)
160    } else {
161        handle_error(response).await?;
162        unreachable!()
163    }
164}