openrouter_rs/api/
chat.rs

1use std::collections::HashMap;
2
3use derive_builder::Builder;
4use futures_util::{AsyncBufReadExt, StreamExt, stream::BoxStream};
5use serde::{Deserialize, Serialize};
6use surf::http::headers::AUTHORIZATION;
7
8use crate::{
9    error::OpenRouterError,
10    strip_option_map_setter, strip_option_vec_setter,
11    types::{
12        ProviderPreferences, ReasoningConfig, ResponseFormat, Role, completion::CompletionsResponse,
13    },
14    utils::handle_error,
15};
16
17#[derive(Serialize, Deserialize, Debug, Clone)]
18pub struct Message {
19    pub role: Role,
20    pub content: String,
21}
22
23impl Message {
24    pub fn new(role: Role, content: &str) -> Self {
25        Self {
26            role,
27            content: content.to_string(),
28        }
29    }
30}
31
32#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
33#[builder(build_fn(error = "OpenRouterError"))]
34pub struct ChatCompletionRequest {
35    #[builder(setter(into))]
36    model: String,
37
38    messages: Vec<Message>,
39
40    #[builder(setter(skip), default)]
41    #[serde(skip_serializing_if = "Option::is_none")]
42    stream: Option<bool>,
43
44    #[builder(setter(strip_option), default)]
45    #[serde(skip_serializing_if = "Option::is_none")]
46    max_tokens: Option<u32>,
47
48    #[builder(setter(strip_option), default)]
49    #[serde(skip_serializing_if = "Option::is_none")]
50    temperature: Option<f64>,
51
52    #[builder(setter(strip_option), default)]
53    #[serde(skip_serializing_if = "Option::is_none")]
54    seed: Option<u32>,
55
56    #[builder(setter(strip_option), default)]
57    #[serde(skip_serializing_if = "Option::is_none")]
58    top_p: Option<f64>,
59
60    #[builder(setter(strip_option), default)]
61    #[serde(skip_serializing_if = "Option::is_none")]
62    top_k: Option<u32>,
63
64    #[builder(setter(strip_option), default)]
65    #[serde(skip_serializing_if = "Option::is_none")]
66    frequency_penalty: Option<f64>,
67
68    #[builder(setter(strip_option), default)]
69    #[serde(skip_serializing_if = "Option::is_none")]
70    presence_penalty: Option<f64>,
71
72    #[builder(setter(strip_option), default)]
73    #[serde(skip_serializing_if = "Option::is_none")]
74    repetition_penalty: Option<f64>,
75
76    #[builder(setter(custom), default)]
77    #[serde(skip_serializing_if = "Option::is_none")]
78    logit_bias: Option<HashMap<String, f64>>,
79
80    #[builder(setter(strip_option), default)]
81    #[serde(skip_serializing_if = "Option::is_none")]
82    top_logprobs: Option<u32>,
83
84    #[builder(setter(strip_option), default)]
85    #[serde(skip_serializing_if = "Option::is_none")]
86    min_p: Option<f64>,
87
88    #[builder(setter(strip_option), default)]
89    #[serde(skip_serializing_if = "Option::is_none")]
90    top_a: Option<f64>,
91
92    #[builder(setter(custom), default)]
93    #[serde(skip_serializing_if = "Option::is_none")]
94    transforms: Option<Vec<String>>,
95
96    #[builder(setter(custom), default)]
97    #[serde(skip_serializing_if = "Option::is_none")]
98    models: Option<Vec<String>>,
99
100    #[builder(setter(into, strip_option), default)]
101    #[serde(skip_serializing_if = "Option::is_none")]
102    route: Option<String>,
103
104    #[builder(setter(strip_option), default)]
105    #[serde(skip_serializing_if = "Option::is_none")]
106    provider: Option<ProviderPreferences>,
107
108    #[builder(setter(strip_option), default)]
109    #[serde(skip_serializing_if = "Option::is_none")]
110    response_format: Option<ResponseFormat>,
111
112    #[builder(setter(strip_option), default)]
113    #[serde(skip_serializing_if = "Option::is_none")]
114    reasoning: Option<ReasoningConfig>,
115
116    #[builder(setter(strip_option), default)]
117    #[serde(skip_serializing_if = "Option::is_none")]
118    include_reasoning: Option<bool>,
119}
120
121impl ChatCompletionRequestBuilder {
122    strip_option_vec_setter!(models, String);
123    strip_option_map_setter!(logit_bias, String, f64);
124    strip_option_vec_setter!(transforms, String);
125
126    /// Enable reasoning with default settings (medium effort)
127    pub fn enable_reasoning(&mut self) -> &mut Self {
128        use crate::types::ReasoningConfig;
129        self.reasoning = Some(Some(ReasoningConfig::enabled()));
130        self
131    }
132
133    /// Set reasoning effort level
134    pub fn reasoning_effort(&mut self, effort: crate::types::Effort) -> &mut Self {
135        use crate::types::ReasoningConfig;
136        self.reasoning = Some(Some(ReasoningConfig::with_effort(effort)));
137        self
138    }
139
140    /// Set reasoning max tokens
141    pub fn reasoning_max_tokens(&mut self, max_tokens: u32) -> &mut Self {
142        use crate::types::ReasoningConfig;
143        self.reasoning = Some(Some(ReasoningConfig::with_max_tokens(max_tokens)));
144        self
145    }
146
147    /// Exclude reasoning from response (use reasoning internally but don't return it)
148    pub fn exclude_reasoning(&mut self) -> &mut Self {
149        use crate::types::ReasoningConfig;
150        self.reasoning = Some(Some(ReasoningConfig::excluded()));
151        self
152    }
153}
154
155impl ChatCompletionRequest {
156    pub fn builder() -> ChatCompletionRequestBuilder {
157        ChatCompletionRequestBuilder::default()
158    }
159
160    pub fn new(model: &str, messages: Vec<Message>) -> Self {
161        Self::builder()
162            .model(model)
163            .messages(messages)
164            .build()
165            .expect("Failed to build ChatCompletionRequest")
166    }
167
168    fn stream(&self, stream: bool) -> Self {
169        let mut req = self.clone();
170        req.stream = Some(stream);
171        req
172    }
173}
174
175/// Send a chat completion request to a selected model.
176///
177/// # Arguments
178///
179/// * `base_url` - The base URL for the OpenRouter API.
180/// * `api_key` - The API key for authentication.
181/// * `x_title` - The name of the site for the request.
182/// * `http_referer` - The URL of the site for the request.
183/// * `request` - The chat completion request containing the model and messages.
184///
185/// # Returns
186///
187/// * `Result<CompletionsResponse, OpenRouterError>` - The response from the chat completion request.
188pub async fn send_chat_completion(
189    base_url: &str,
190    api_key: &str,
191    x_title: &Option<String>,
192    http_referer: &Option<String>,
193    request: &ChatCompletionRequest,
194) -> Result<CompletionsResponse, OpenRouterError> {
195    let url = format!("{base_url}/chat/completions");
196
197    // Ensure that the request is not streaming to get a single response
198    let request = request.stream(false);
199
200    let mut surf_req = surf::post(url)
201        .header(AUTHORIZATION, format!("Bearer {api_key}"))
202        .body_json(&request)?;
203
204    if let Some(x_title) = x_title {
205        surf_req = surf_req.header("X-Title", x_title);
206    }
207    if let Some(http_referer) = http_referer {
208        surf_req = surf_req.header("HTTP-Referer", http_referer);
209    }
210
211    let mut response = surf_req.await?;
212
213    if response.status().is_success() {
214        let chat_response = response.body_json().await?;
215        Ok(chat_response)
216    } else {
217        handle_error(response).await?;
218        unreachable!()
219    }
220}
221
222/// Stream chat completion events from a selected model.
223///
224/// # Arguments
225///
226/// * `base_url` - The base URL for the OpenRouter API.
227/// * `api_key` - The API key for authentication.
228/// * `request` - The chat completion request containing the model and messages.
229///
230/// # Returns
231///
232/// * `Result<BoxStream<'static, Result<CompletionsResponse, OpenRouterError>>, OpenRouterError>` - A stream of chat completion events or an error.
233pub async fn stream_chat_completion(
234    base_url: &str,
235    api_key: &str,
236    request: &ChatCompletionRequest,
237) -> Result<BoxStream<'static, Result<CompletionsResponse, OpenRouterError>>, OpenRouterError> {
238    let url = format!("{base_url}/chat/completions");
239
240    // Ensure that the request is streaming to get a continuous response
241    let request = request.stream(true);
242
243    let response = surf::post(url)
244        .header(AUTHORIZATION, format!("Bearer {api_key}"))
245        .body_json(&request)?
246        .await?;
247
248    if response.status().is_success() {
249        let lines = response
250            .lines()
251            .filter_map(async |line| match line {
252                Ok(line) => line
253                    .strip_prefix("data: ")
254                    .filter(|line| *line != "[DONE]")
255                    .map(serde_json::from_str::<CompletionsResponse>)
256                    .map(|event| event.map_err(OpenRouterError::Serialization)),
257                Err(error) => Some(Err(OpenRouterError::Io(error))),
258            })
259            .boxed();
260
261        Ok(lines)
262    } else {
263        handle_error(response).await?;
264        unreachable!()
265    }
266}