1use std::collections::HashMap;
2
3use derive_builder::Builder;
4use futures_util::{AsyncBufReadExt, StreamExt, stream::BoxStream};
5use serde::{Deserialize, Serialize};
6use surf::http::headers::AUTHORIZATION;
7
8use crate::{
9 error::OpenRouterError,
10 strip_option_map_setter, strip_option_vec_setter,
11 types::{
12 ProviderPreferences, ReasoningConfig, ResponseFormat, Role, completion::CompletionsResponse,
13 },
14 utils::handle_error,
15};
16
17#[derive(Serialize, Deserialize, Debug, Clone)]
18pub struct Message {
19 pub role: Role,
20 pub content: String,
21}
22
23impl Message {
24 pub fn new(role: Role, content: &str) -> Self {
25 Self {
26 role,
27 content: content.to_string(),
28 }
29 }
30}
31
32#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
33#[builder(build_fn(error = "OpenRouterError"))]
34pub struct ChatCompletionRequest {
35 #[builder(setter(into))]
36 model: String,
37
38 messages: Vec<Message>,
39
40 #[builder(setter(skip), default)]
41 #[serde(skip_serializing_if = "Option::is_none")]
42 stream: Option<bool>,
43
44 #[builder(setter(strip_option), default)]
45 #[serde(skip_serializing_if = "Option::is_none")]
46 max_tokens: Option<u32>,
47
48 #[builder(setter(strip_option), default)]
49 #[serde(skip_serializing_if = "Option::is_none")]
50 temperature: Option<f64>,
51
52 #[builder(setter(strip_option), default)]
53 #[serde(skip_serializing_if = "Option::is_none")]
54 seed: Option<u32>,
55
56 #[builder(setter(strip_option), default)]
57 #[serde(skip_serializing_if = "Option::is_none")]
58 top_p: Option<f64>,
59
60 #[builder(setter(strip_option), default)]
61 #[serde(skip_serializing_if = "Option::is_none")]
62 top_k: Option<u32>,
63
64 #[builder(setter(strip_option), default)]
65 #[serde(skip_serializing_if = "Option::is_none")]
66 frequency_penalty: Option<f64>,
67
68 #[builder(setter(strip_option), default)]
69 #[serde(skip_serializing_if = "Option::is_none")]
70 presence_penalty: Option<f64>,
71
72 #[builder(setter(strip_option), default)]
73 #[serde(skip_serializing_if = "Option::is_none")]
74 repetition_penalty: Option<f64>,
75
76 #[builder(setter(custom), default)]
77 #[serde(skip_serializing_if = "Option::is_none")]
78 logit_bias: Option<HashMap<String, f64>>,
79
80 #[builder(setter(strip_option), default)]
81 #[serde(skip_serializing_if = "Option::is_none")]
82 top_logprobs: Option<u32>,
83
84 #[builder(setter(strip_option), default)]
85 #[serde(skip_serializing_if = "Option::is_none")]
86 min_p: Option<f64>,
87
88 #[builder(setter(strip_option), default)]
89 #[serde(skip_serializing_if = "Option::is_none")]
90 top_a: Option<f64>,
91
92 #[builder(setter(custom), default)]
93 #[serde(skip_serializing_if = "Option::is_none")]
94 transforms: Option<Vec<String>>,
95
96 #[builder(setter(custom), default)]
97 #[serde(skip_serializing_if = "Option::is_none")]
98 models: Option<Vec<String>>,
99
100 #[builder(setter(into, strip_option), default)]
101 #[serde(skip_serializing_if = "Option::is_none")]
102 route: Option<String>,
103
104 #[builder(setter(strip_option), default)]
105 #[serde(skip_serializing_if = "Option::is_none")]
106 provider: Option<ProviderPreferences>,
107
108 #[builder(setter(strip_option), default)]
109 #[serde(skip_serializing_if = "Option::is_none")]
110 response_format: Option<ResponseFormat>,
111
112 #[builder(setter(strip_option), default)]
113 #[serde(skip_serializing_if = "Option::is_none")]
114 reasoning: Option<ReasoningConfig>,
115
116 #[builder(setter(strip_option), default)]
117 #[serde(skip_serializing_if = "Option::is_none")]
118 include_reasoning: Option<bool>,
119}
120
121impl ChatCompletionRequestBuilder {
122 strip_option_vec_setter!(models, String);
123 strip_option_map_setter!(logit_bias, String, f64);
124 strip_option_vec_setter!(transforms, String);
125
126 pub fn enable_reasoning(&mut self) -> &mut Self {
128 use crate::types::ReasoningConfig;
129 self.reasoning = Some(Some(ReasoningConfig::enabled()));
130 self
131 }
132
133 pub fn reasoning_effort(&mut self, effort: crate::types::Effort) -> &mut Self {
135 use crate::types::ReasoningConfig;
136 self.reasoning = Some(Some(ReasoningConfig::with_effort(effort)));
137 self
138 }
139
140 pub fn reasoning_max_tokens(&mut self, max_tokens: u32) -> &mut Self {
142 use crate::types::ReasoningConfig;
143 self.reasoning = Some(Some(ReasoningConfig::with_max_tokens(max_tokens)));
144 self
145 }
146
147 pub fn exclude_reasoning(&mut self) -> &mut Self {
149 use crate::types::ReasoningConfig;
150 self.reasoning = Some(Some(ReasoningConfig::excluded()));
151 self
152 }
153}
154
155impl ChatCompletionRequest {
156 pub fn builder() -> ChatCompletionRequestBuilder {
157 ChatCompletionRequestBuilder::default()
158 }
159
160 pub fn new(model: &str, messages: Vec<Message>) -> Self {
161 Self::builder()
162 .model(model)
163 .messages(messages)
164 .build()
165 .expect("Failed to build ChatCompletionRequest")
166 }
167
168 fn stream(&self, stream: bool) -> Self {
169 let mut req = self.clone();
170 req.stream = Some(stream);
171 req
172 }
173}
174
175pub async fn send_chat_completion(
189 base_url: &str,
190 api_key: &str,
191 x_title: &Option<String>,
192 http_referer: &Option<String>,
193 request: &ChatCompletionRequest,
194) -> Result<CompletionsResponse, OpenRouterError> {
195 let url = format!("{base_url}/chat/completions");
196
197 let request = request.stream(false);
199
200 let mut surf_req = surf::post(url)
201 .header(AUTHORIZATION, format!("Bearer {api_key}"))
202 .body_json(&request)?;
203
204 if let Some(x_title) = x_title {
205 surf_req = surf_req.header("X-Title", x_title);
206 }
207 if let Some(http_referer) = http_referer {
208 surf_req = surf_req.header("HTTP-Referer", http_referer);
209 }
210
211 let mut response = surf_req.await?;
212
213 if response.status().is_success() {
214 let chat_response = response.body_json().await?;
215 Ok(chat_response)
216 } else {
217 handle_error(response).await?;
218 unreachable!()
219 }
220}
221
222pub async fn stream_chat_completion(
234 base_url: &str,
235 api_key: &str,
236 request: &ChatCompletionRequest,
237) -> Result<BoxStream<'static, Result<CompletionsResponse, OpenRouterError>>, OpenRouterError> {
238 let url = format!("{base_url}/chat/completions");
239
240 let request = request.stream(true);
242
243 let response = surf::post(url)
244 .header(AUTHORIZATION, format!("Bearer {api_key}"))
245 .body_json(&request)?
246 .await?;
247
248 if response.status().is_success() {
249 let lines = response
250 .lines()
251 .filter_map(async |line| match line {
252 Ok(line) => line
253 .strip_prefix("data: ")
254 .filter(|line| *line != "[DONE]")
255 .map(serde_json::from_str::<CompletionsResponse>)
256 .map(|event| event.map_err(OpenRouterError::Serialization)),
257 Err(error) => Some(Err(OpenRouterError::Io(error))),
258 })
259 .boxed();
260
261 Ok(lines)
262 } else {
263 handle_error(response).await?;
264 unreachable!()
265 }
266}