openrouter_rs/api/
chat.rs

1use crate::{
2    error::OpenRouterError,
3    setter,
4    types::{ProviderPreferences, ReasoningConfig, Role},
5    utils::handle_error,
6};
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Serialize, Deserialize, Debug)]
12pub struct Message {
13    role: Role,
14    content: String,
15}
16
17impl Message {
18    pub fn new(role: Role, content: &str) -> Self {
19        Self {
20            role,
21            content: content.to_string(),
22        }
23    }
24}
25
26#[derive(Serialize, Deserialize, Debug)]
27pub struct ChatCompletionRequest {
28    model: String,
29    messages: Vec<Message>,
30    stream: Option<bool>,
31    max_tokens: Option<u32>,
32    temperature: Option<f64>,
33    seed: Option<u32>,
34    top_p: Option<f64>,
35    top_k: Option<u32>,
36    frequency_penalty: Option<f64>,
37    presence_penalty: Option<f64>,
38    repetition_penalty: Option<f64>,
39    logit_bias: Option<HashMap<String, f64>>,
40    top_logprobs: Option<u32>,
41    min_p: Option<f64>,
42    top_a: Option<f64>,
43    transforms: Option<Vec<String>>,
44    models: Option<Vec<String>>,
45    route: Option<String>,
46    provider: Option<ProviderPreferences>,
47    reasoning: Option<ReasoningConfig>,
48}
49
50impl ChatCompletionRequest {
51    pub fn new(model: &str, messages: Vec<Message>) -> Self {
52        Self {
53            model: model.to_string(),
54            messages,
55            stream: None,
56            max_tokens: None,
57            temperature: None,
58            seed: None,
59            top_p: None,
60            top_k: None,
61            frequency_penalty: None,
62            presence_penalty: None,
63            repetition_penalty: None,
64            logit_bias: None,
65            top_logprobs: None,
66            min_p: None,
67            top_a: None,
68            transforms: None,
69            models: None,
70            route: None,
71            provider: None,
72            reasoning: None,
73        }
74    }
75
76    setter!(stream, bool);
77    setter!(max_tokens, u32);
78    setter!(temperature, f64);
79    setter!(seed, u32);
80    setter!(top_p, f64);
81    setter!(top_k, u32);
82    setter!(frequency_penalty, f64);
83    setter!(presence_penalty, f64);
84    setter!(repetition_penalty, f64);
85    setter!(logit_bias, HashMap<String, f64>);
86    setter!(top_logprobs, u32);
87    setter!(min_p, f64);
88    setter!(top_a, f64);
89    setter!(transforms, Vec<String>);
90    setter!(models, Vec<String>);
91    setter!(route, String);
92    setter!(provider, ProviderPreferences);
93    setter!(reasoning, ReasoningConfig);
94}
95
96#[derive(Serialize, Deserialize, Debug)]
97pub struct ChatCompletionResponse {
98    id: Option<String>,
99    choices: Option<Vec<Choice>>,
100}
101
102#[derive(Serialize, Deserialize, Debug)]
103pub struct Choice {
104    message: Option<Message>,
105}
106
107/// Send a chat completion request to a selected model.
108///
109/// # Arguments
110///
111/// * `client` - The HTTP client to use for the request.
112/// * `api_key` - The API key for authentication.
113/// * `request` - The chat completion request containing the model and messages.
114///
115/// # Returns
116///
117/// * `Result<ChatCompletionResponse, OpenRouterError>` - The response from the chat completion request.
118pub async fn send_chat_completion(
119    client: &Client,
120    api_key: &str,
121    request: &ChatCompletionRequest,
122) -> Result<ChatCompletionResponse, OpenRouterError> {
123    let url = "https://openrouter.ai/api/v1/chat/completions";
124
125    let response = client
126        .post(url)
127        .bearer_auth(api_key)
128        .json(request)
129        .send()
130        .await?;
131
132    if response.status().is_success() {
133        let chat_response = response.json::<ChatCompletionResponse>().await?;
134        Ok(chat_response)
135    } else {
136        handle_error(response).await?;
137        unreachable!()
138    }
139}