Skip to main content

objectiveai_api/chat/completions/upstream/openrouter/
client.rs

1//! OpenRouter HTTP client implementation.
2
3use eventsource_stream::Event as MessageEvent;
4use futures::{Stream, StreamExt};
5use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
6use std::time::Duration;
7
8/// Generates a unique response ID for a chat completion.
9///
10/// Combines a UUID with the creation timestamp for uniqueness.
11pub fn response_id(created: u64) -> String {
12    let uuid = uuid::Uuid::new_v4();
13    format!("chtcpl-{}-{}", uuid.simple(), created)
14}
15
16/// HTTP client for communicating with the OpenRouter API.
17#[derive(Debug, Clone)]
18pub struct Client {
19    /// The underlying HTTP client.
20    pub http_client: reqwest::Client,
21    /// Base URL for the OpenRouter API.
22    pub api_base: String,
23    /// API key for authentication with OpenRouter.
24    pub api_key: String,
25    /// Optional User-Agent header value.
26    pub user_agent: Option<String>,
27    /// Optional X-Title header value.
28    pub x_title: Option<String>,
29    /// Optional Referer header value (sent as both referer and http-referer).
30    pub referer: Option<String>,
31}
32
33impl Client {
34    /// Creates a new OpenRouter client.
35    pub fn new(
36        http_client: reqwest::Client,
37        api_base: String,
38        api_key: String,
39        user_agent: Option<String>,
40        x_title: Option<String>,
41        referer: Option<String>,
42    ) -> Self {
43        Self {
44            http_client,
45            api_base,
46            api_key,
47            user_agent,
48            x_title,
49            referer,
50        }
51    }
52
53    /// Creates a streaming chat completion request.
54    ///
55    /// Transforms the request using the Ensemble LLM's configuration and
56    /// returns a stream of chat completion chunks.
57    pub fn create_streaming_for_chat(
58        &self,
59        id: String,
60        byok: Option<&str>,
61        cost_multiplier: rust_decimal::Decimal,
62        first_chunk_timeout: Duration,
63        other_chunk_timeout: Duration,
64        ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
65        request: &objectiveai::chat::completions::request::ChatCompletionCreateParams,
66    ) -> impl Stream<
67        Item = Result<
68            objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
69            super::Error,
70        >,
71    > + Send
72    + 'static {
73        self.create_streaming(
74            id,
75            ensemble_llm.id.clone(),
76            byok,
77            cost_multiplier,
78            first_chunk_timeout,
79            other_chunk_timeout,
80            super::request::ChatCompletionCreateParams::new_for_chat(ensemble_llm, request),
81        )
82    }
83
84    /// Creates a streaming chat completion for LLM voting in vector completions.
85    ///
86    /// The LLM sees responses labeled with prefix keys (e.g., `` `A` ``) and responds
87    /// with its choice. The `vector_pfx_indices` maps the prefix keys shown to the LLM
88    /// to the indices of the responses in the original request.
89    pub fn create_streaming_for_vector(
90        &self,
91        id: String,
92        byok: Option<&str>,
93        cost_multiplier: rust_decimal::Decimal,
94        first_chunk_timeout: Duration,
95        other_chunk_timeout: Duration,
96        ensemble_llm: &objectiveai::ensemble_llm::EnsembleLlm,
97        request: &objectiveai::vector::completions::request::VectorCompletionCreateParams,
98        vector_pfx_indices: &[(String, usize)],
99    ) -> impl Stream<
100        Item = Result<
101            objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
102            super::Error,
103        >,
104    > + Send
105    + 'static {
106        self.create_streaming(
107            id,
108            ensemble_llm.id.clone(),
109            byok,
110            cost_multiplier,
111            first_chunk_timeout,
112            other_chunk_timeout,
113            super::request::ChatCompletionCreateParams::new_for_vector(
114                vector_pfx_indices,
115                ensemble_llm,
116                request,
117            ),
118        )
119    }
120
121    /// Internal method that creates the streaming request to OpenRouter.
122    fn create_streaming(
123        &self,
124        id: String,
125        model: String,
126        byok: Option<&str>,
127        cost_multiplier: rust_decimal::Decimal,
128        first_chunk_timeout: Duration,
129        other_chunk_timeout: Duration,
130        request: super::request::ChatCompletionCreateParams,
131    ) -> impl Stream<
132        Item = Result<
133            objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
134            super::Error,
135        >,
136    > + Send
137    + 'static {
138        let is_byok = byok.is_some();
139        let event_source =
140            self.create_streaming_event_source(byok.unwrap_or(&self.api_key), &request);
141        Self::create_streaming_stream(
142            event_source,
143            id,
144            model,
145            is_byok,
146            cost_multiplier,
147            first_chunk_timeout,
148            other_chunk_timeout,
149        )
150    }
151
152    /// Creates an SSE EventSource for the streaming request.
153    fn create_streaming_event_source(
154        &self,
155        api_key: &str,
156        request: &super::request::ChatCompletionCreateParams,
157    ) -> EventSource {
158        let mut http_request = self
159            .http_client
160            .post(format!("{}/chat/completions", self.api_base))
161            .header("authorization", format!("Bearer {}", api_key));
162        if let Some(ref user_agent) = self.user_agent {
163            http_request = http_request.header("user-agent", user_agent);
164        }
165        if let Some(ref x_title) = self.x_title {
166            http_request = http_request.header("x-title", x_title);
167        }
168        if let Some(ref referer) = self.referer {
169            http_request = http_request
170                .header("referer", referer)
171                .header("http-referer", referer);
172        }
173        http_request.json(request).eventsource().unwrap()
174    }
175
176    /// Processes the SSE EventSource into a stream of chat completion chunks.
177    ///
178    /// Handles timeouts, error responses, and transforms upstream chunks to downstream format.
179    fn create_streaming_stream(
180        mut event_source: EventSource,
181        id: String,
182        model: String,
183        is_byok: bool,
184        cost_multiplier: rust_decimal::Decimal,
185        first_chunk_timeout: Duration,
186        other_chunk_timeout: Duration,
187    ) -> impl Stream<
188        Item = Result<
189            objectiveai::chat::completions::response::streaming::ChatCompletionChunk,
190            super::Error,
191        >,
192    > + Send
193    + 'static {
194        async_stream::stream! {
195            let mut first = true;
196            while let Some(event) = tokio::time::timeout(
197                if first {
198                    first_chunk_timeout
199                } else {
200                    other_chunk_timeout
201                },
202                event_source.next(),
203            ).await.transpose() {
204                first = false;
205                match event {
206                    Ok(Ok(Event::Open)) => continue,
207                    Ok(Ok(Event::Message(MessageEvent { data, .. }))) => {
208                        if data == "[DONE]" {
209                            break;
210                        } else if data.starts_with(":") {
211                            continue; // skip comments
212                        } else if data.is_empty() {
213                            continue; // skip empty messages
214                        }
215                        let mut de = serde_json::Deserializer::from_str(&data);
216                        match serde_path_to_error::deserialize::<
217                            _,
218                            super::response::ChatCompletionChunk,
219                        >(&mut de)
220                        {
221                            Ok(chunk) => yield Ok(chunk.into_downstream(
222                                id.clone(),
223                                model.clone(),
224                                is_byok,
225                                cost_multiplier,
226                            )),
227                            Err(e) => {
228                                de = serde_json::Deserializer::from_str(&data);
229                                match serde_path_to_error::deserialize::<
230                                    _,
231                                    super::OpenRouterProviderError,
232                                >(&mut de)
233                                {
234                                    Ok(provider_error) => yield Err(
235                                        super::Error::OpenRouterProviderError(
236                                            provider_error,
237                                        ),
238                                    ),
239                                    Err(_) => yield Err(
240                                        super::Error::DeserializationError(e),
241                                    ),
242                                }
243                            }
244                        }
245                    }
246                    Ok(Err(reqwest_eventsource::Error::InvalidStatusCode(
247                        code,
248                        response,
249                    ))) => {
250                        match response.text().await {
251                            Ok(body) => {
252                                yield Err(super::Error::BadStatus {
253                                    code,
254                                    body: match serde_json::from_str::<
255                                        serde_json::Value,
256                                    >(
257                                        &body,
258                                    ) {
259                                        Ok(value) => value,
260                                        Err(_) => serde_json::Value::String(
261                                            body,
262                                        ),
263                                    },
264                                });
265                            }
266                            Err(_) => {
267                                yield Err(super::Error::BadStatus {
268                                    code,
269                                    body: serde_json::Value::Null,
270                                });
271                            }
272                        }
273                    }
274                    Ok(Err(e)) => {
275                        yield Err(super::Error::from(e));
276                    }
277                    Err(_) => {
278                        yield Err(super::Error::StreamTimeout);
279                    }
280                }
281            }
282        }
283    }
284}