objectiveai_api/chat/completions/upstream/openrouter/
client.rs1use eventsource_stream::Event as MessageEvent;
4use futures::{Stream, StreamExt};
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
6use std::time::Duration;
7
8pub fn response_id(created: u64) -> String {
12 let uuid = uuid::Uuid::new_v4();
13 format!("chtcpl-{}-{}", uuid.simple(), created)
14}
15
16#[derive(Debug, Clone)]
18pub struct Client {
19 pub http_client: reqwest::Client,
21 pub api_base: String,
23 pub api_key: String,
25 pub user_agent: Option<String>,
27 pub x_title: Option<String>,
29 pub referer: Option<String>,
31}
32
33impl Client {
34 pub fn new(
36 http_client: reqwest::Client,
37 api_base: String,
38 api_key: String,
39 user_agent: Option<String>,
40 x_title: Option<String>,
41 referer: Option<String>,
42 ) -> Self {
43 Self {
44 http_client,
45 api_base,
46 api_key,
47 user_agent,
48 x_title,
49 referer,
50 }
51 }
52
53 pub fn create_streaming_for_chat(
58 &self,
59 id: String,
60 byok: Option<&str>,
61 cost_multiplier: rust_decimal::Decimal,
62 first_chunk_timeout: Duration,
63 other_chunk_timeout: Duration,
64 ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
65 request: &objectiveai::chat::completions::request::ChatCompletionCreateParams,
66 ) -> impl Stream<
67 Item = Result<
68 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
69 super::Error,
70 >,
71 > + Send
72 + 'static {
73 self.create_streaming(
74 id,
75 ensemble_llm.id.clone(),
76 byok,
77 cost_multiplier,
78 first_chunk_timeout,
79 other_chunk_timeout,
80 super::request::ChatCompletionCreateParams::new_for_chat(ensemble_llm, request),
81 )
82 }
83
84 pub fn create_streaming_for_vector(
90 &self,
91 id: String,
92 byok: Option<&str>,
93 cost_multiplier: rust_decimal::Decimal,
94 first_chunk_timeout: Duration,
95 other_chunk_timeout: Duration,
96 ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
97 request: &objectiveai::vector::completions::request::VectorCompletionCreateParams,
98 vector_pfx_indices: &[(String, usize)],
99 ) -> impl Stream<
100 Item = Result<
101 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
102 super::Error,
103 >,
104 > + Send
105 + 'static {
106 self.create_streaming(
107 id,
108 ensemble_llm.id.clone(),
109 byok,
110 cost_multiplier,
111 first_chunk_timeout,
112 other_chunk_timeout,
113 super::request::ChatCompletionCreateParams::new_for_vector(
114 vector_pfx_indices,
115 ensemble_llm,
116 request,
117 ),
118 )
119 }
120
121 fn create_streaming(
123 &self,
124 id: String,
125 model: String,
126 byok: Option<&str>,
127 cost_multiplier: rust_decimal::Decimal,
128 first_chunk_timeout: Duration,
129 other_chunk_timeout: Duration,
130 request: super::request::ChatCompletionCreateParams,
131 ) -> impl Stream<
132 Item = Result<
133 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
134 super::Error,
135 >,
136 > + Send
137 + 'static {
138 let is_byok = byok.is_some();
139 let event_source =
140 self.create_streaming_event_source(byok.unwrap_or(&self.api_key), &request);
141 Self::create_streaming_stream(
142 event_source,
143 id,
144 model,
145 is_byok,
146 cost_multiplier,
147 first_chunk_timeout,
148 other_chunk_timeout,
149 )
150 }
151
152 fn create_streaming_event_source(
154 &self,
155 api_key: &str,
156 request: &super::request::ChatCompletionCreateParams,
157 ) -> EventSource {
158 let mut http_request = self
159 .http_client
160 .post(format!("{}/chat/completions", self.api_base))
161 .header("authorization", format!("Bearer {}", api_key));
162 if let Some(ref user_agent) = self.user_agent {
163 http_request = http_request.header("user-agent", user_agent);
164 }
165 if let Some(ref x_title) = self.x_title {
166 http_request = http_request.header("x-title", x_title);
167 }
168 if let Some(ref referer) = self.referer {
169 http_request = http_request
170 .header("referer", referer)
171 .header("http-referer", referer);
172 }
173 http_request.json(request).eventsource().unwrap()
174 }
175
176 fn create_streaming_stream(
180 mut event_source: EventSource,
181 id: String,
182 model: String,
183 is_byok: bool,
184 cost_multiplier: rust_decimal::Decimal,
185 first_chunk_timeout: Duration,
186 other_chunk_timeout: Duration,
187 ) -> impl Stream<
188 Item = Result<
189 objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
190 super::Error,
191 >,
192 > + Send
193 + 'static {
194 async_stream::stream! {
195 let mut first = true;
196 while let Some(event) = tokio::time::timeout(
197 if first {
198 first_chunk_timeout
199 } else {
200 other_chunk_timeout
201 },
202 event_source.next(),
203 ).await.transpose() {
204 first = false;
205 match event {
206 Ok(Ok(Event::Open)) => continue,
207 Ok(Ok(Event::Message(MessageEvent { data, .. }))) => {
208 if data == "[DONE]" {
209 break;
210 } else if data.starts_with(":") {
211 continue; } else if data.is_empty() {
213 continue; }
215 let mut de = serde_json::Deserializer::from_str(&data);
216 match serde_path_to_error::deserialize::<
217 _,
218 super::response::ChatCompletionChunk,
219 >(&mut de)
220 {
221 Ok(chunk) => yield Ok(chunk.into_downstream(
222 id.clone(),
223 model.clone(),
224 is_byok,
225 cost_multiplier,
226 )),
227 Err(e) => {
228 de = serde_json::Deserializer::from_str(&data);
229 match serde_path_to_error::deserialize::<
230 _,
231 super::OpenRouterProviderError,
232 >(&mut de)
233 {
234 Ok(provider_error) => yield Err(
235 super::Error::OpenRouterProviderError(
236 provider_error,
237 ),
238 ),
239 Err(_) => yield Err(
240 super::Error::DeserializationError(e),
241 ),
242 }
243 }
244 }
245 }
246 Ok(Err(reqwest_eventsource::Error::InvalidStatusCode(
247 code,
248 response,
249 ))) => {
250 match response.text().await {
251 Ok(body) => {
252 yield Err(super::Error::BadStatus {
253 code,
254 body: match serde_json::from_str::<
255 serde_json::Value,
256 >(
257 &body,
258 ) {
259 Ok(value) => value,
260 Err(_) => serde_json::Value::String(
261 body,
262 ),
263 },
264 });
265 }
266 Err(_) => {
267 yield Err(super::Error::BadStatus {
268 code,
269 body: serde_json::Value::Null,
270 });
271 }
272 }
273 }
274 Ok(Err(e)) => {
275 yield Err(super::Error::from(e));
276 }
277 Err(_) => {
278 yield Err(super::Error::StreamTimeout);
279 }
280 }
281 }
282 }
283 }
284}