1use eventsource_stream::Eventsource;
10use futures::{FutureExt, StreamExt, stream::BoxStream};
11use latchlm_core::{AiModel, AiProvider, AiRequest, AiResponse, BoxFuture, Error, ModelId, Result};
12use reqwest::{Client, Url};
13use secrecy::{ExposeSecret, SecretString};
14use std::{borrow::Cow, env::VarError, future::ready, sync::Arc};
15
16mod response;
17pub use response::*;
18
19#[derive(Debug, Clone)]
21pub struct OpenrouterModel(String);
22
23impl AsRef<str> for OpenrouterModel {
24 fn as_ref(&self) -> &str {
25 &self.0
26 }
27}
28
29impl AiModel for OpenrouterModel {
30 fn as_any(&self) -> &dyn std::any::Any {
31 self
32 }
33 fn model_id(&self) -> ModelId<'_> {
34 ModelId {
35 id: Cow::Borrowed(&self.0),
36 name: Cow::Borrowed(&self.0),
37 }
38 }
39}
40
41impl OpenrouterModel {
42 pub fn new<S: Into<String>>(model_name: S) -> Self {
43 Self(model_name.into())
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum OpenrouterError {
50 MissingClientError,
51 MissingApiKeyError,
52 HeaderParseError(String),
53}
54
55impl std::fmt::Display for OpenrouterError {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 match self {
58 Self::MissingClientError => write!(f, "HTTP client is required"),
59 Self::MissingApiKeyError => write!(f, "API key is required"),
60 Self::HeaderParseError(err) => write!(f, "Failed to parse header: {err}"),
61 }
62 }
63}
64
65impl From<OpenrouterError> for Error {
66 fn from(value: OpenrouterError) -> Self {
67 match value {
68 OpenrouterError::MissingClientError => Self::ProviderError {
69 provider: "OpenRouter".to_owned(),
70 error: "Missing reqwest::Client".to_owned(),
71 },
72 OpenrouterError::MissingApiKeyError => Self::ProviderError {
73 provider: "OpenRouter".to_owned(),
74 error: "Missing API key".to_owned(),
75 },
76 OpenrouterError::HeaderParseError(err) => Self::ProviderError {
77 provider: "OpenRouter".to_owned(),
78 error: format!("Failed to parse header: {err}"),
79 },
80 }
81 }
82}
83
84impl std::error::Error for OpenrouterError {}
85
86#[derive(Debug, Clone, Default)]
88pub struct OpenrouterBuilder {
89 client: Option<Client>,
90 api_key: Option<SecretString>,
91 http_referer: Option<String>,
92 x_title: Option<String>,
93}
94
95impl OpenrouterBuilder {
96 #[must_use]
101 pub fn new() -> Self {
102 Self::default()
103 }
104
105 #[must_use]
115 pub fn client(mut self, client: Client) -> Self {
116 self.client = Some(client);
117 self
118 }
119
120 #[must_use]
130 pub fn api_key(mut self, api_key: SecretString) -> Self {
131 self.api_key = Some(api_key);
132 self
133 }
134
135 pub fn api_key_from_env(mut self) -> std::result::Result<Self, VarError> {
137 let api_key = std::env::var("OPENROUTER_API_KEY")?;
138
139 self.api_key = Some(SecretString::from(api_key));
140 Ok(self)
141 }
142
143 #[must_use]
153 pub fn http_referer(mut self, http_referer: String) -> Self {
154 self.http_referer = Some(http_referer);
155 self
156 }
157
158 #[must_use]
168 pub fn x_title(mut self, x_title: String) -> Self {
169 self.x_title = Some(x_title);
170 self
171 }
172
173 pub fn build(self) -> Result<Openrouter> {
179 let client = self.client.ok_or(OpenrouterError::MissingClientError)?;
180 let api_key = self.api_key.ok_or(OpenrouterError::MissingApiKeyError)?;
181
182 Ok(Openrouter::new(
183 client,
184 api_key,
185 self.http_referer,
186 self.x_title,
187 ))
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct Openrouter {
194 base_url: Url,
195 client: Client,
196 api_key: Arc<SecretString>,
197 http_referer: Option<String>,
198 x_title: Option<String>,
199}
200
201impl Openrouter {
202 const BASE_URL: &str = "https://openrouter.ai/api/v1/";
203
204 #[allow(clippy::expect_used)]
217 #[must_use]
218 pub fn new(
219 client: Client,
220 api_key: SecretString,
221 http_referer: Option<String>,
222 x_title: Option<String>,
223 ) -> Self {
224 Self {
225 base_url: Url::parse(Self::BASE_URL).expect("Invalid base URL"),
226 client,
227 api_key: Arc::new(api_key),
228 http_referer,
229 x_title,
230 }
231 }
232
233 #[cfg(feature = "test-utils")]
248 #[must_use]
249 pub fn new_with_base_url(client: Client, base_url: Url, api_key: SecretString) -> Self {
250 Self {
251 base_url,
252 client,
253 api_key: Arc::new(api_key),
254 http_referer: None,
255 x_title: None,
256 }
257 }
258
259 #[must_use]
261 pub fn builder() -> OpenrouterBuilder {
262 OpenrouterBuilder::new()
263 }
264
265 #[allow(clippy::expect_used)]
324 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
325 pub async fn request(
326 &self,
327 model: OpenrouterModel,
328 request: AiRequest,
329 ) -> Result<OpenrouterResponse> {
330 let mut headers = reqwest::header::HeaderMap::new();
331 headers.insert(
332 "Content-Type",
333 reqwest::header::HeaderValue::from_static("application/json"),
334 );
335
336 if let Some(http_referer) = &self.http_referer {
337 headers.insert(
338 "HTTP-Referer",
339 http_referer.parse().expect("Failed to parse http-referer"),
340 );
341 }
342
343 if let Some(x_title) = &self.x_title {
344 headers.insert("X-Title", x_title.parse().expect("Failed to parse x-title"));
345 }
346
347 let request = serde_json::json!({
348 "model": model.as_ref(),
349 "messages": [
350 {
351 "role": "user",
352 "content": request.text
353 }
354 ],
355 });
356
357 let url = self
358 .base_url
359 .join("chat/completions")
360 .expect("Failed to join URL");
361
362 let response = self
363 .client
364 .post(url)
365 .headers(headers)
366 .bearer_auth(self.api_key.expose_secret())
367 .json(&request)
368 .send()
369 .await?;
370
371 if !response.status().is_success() {
372 let status = response.status().as_u16();
373 let message = response.text().await?;
374
375 #[cfg(feature = "tracing")]
376 tracing::error!("API error: {}", message);
377
378 return Err(Error::ApiError { status, message });
379 }
380
381 let bytes = response.bytes().await?;
382
383 #[cfg(feature = "tracing")]
384 tracing::debug!("Received response: {bytes:?}");
385
386 let response = serde_json::from_slice(&bytes)?;
387
388 Ok(response)
389 }
390
391 #[allow(clippy::expect_used)]
398 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
399 pub async fn streaming_request(
400 &self,
401 model: OpenrouterModel,
402 request: AiRequest,
403 ) -> Result<BoxStream<'_, Result<OpenrouterStreamResponse>>> {
404 let mut headers = reqwest::header::HeaderMap::new();
405 headers.insert(
406 "Content-Type",
407 reqwest::header::HeaderValue::from_static("application/json"),
408 );
409
410 if let Some(http_referer) = &self.http_referer {
411 headers.insert(
412 "HTTP-Referer",
413 http_referer.parse().expect("Failed to parse http-referer"),
414 );
415 }
416
417 if let Some(x_title) = &self.x_title {
418 headers.insert("X-Title", x_title.parse().expect("Failed to parse x-title"));
419 }
420
421 let request = serde_json::json!({
422 "model": model.as_ref(),
423 "messages": [
424 {
425 "role": "user",
426 "content": request.text
427 }
428 ],
429 "stream": true
430 });
431
432 let url = self
433 .base_url
434 .join("chat/completions")
435 .expect("Failed to join URL");
436
437 let response = self
438 .client
439 .post(url)
440 .headers(headers)
441 .bearer_auth(self.api_key.expose_secret())
442 .json(&request)
443 .send()
444 .await?;
445
446 if !response.status().is_success() {
447 #[cfg(feature = "tracing")]
448 tracing::error!("OpenRouter API error: {}", response.status());
449
450 return Err(Error::ApiError {
451 status: response.status().as_u16(),
452 message: response.text().await?,
453 });
454 }
455
456 let stream = response
457 .bytes_stream()
458 .eventsource()
459 .filter_map(|result| async {
460 let event = match result {
461 Ok(event) => {
462 #[cfg(feature = "tracing")]
463 tracing::debug!("OpenRouter API event: {:?}", event);
464
465 event
466 }
467 Err(err) => {
468 #[cfg(feature = "tracing")]
469 tracing::error!("OpenRouter error: {}", err);
470
471 return Some(Err(Error::ProviderError {
472 provider: "OpenRouter".to_string(),
473 error: err.to_string(),
474 }));
475 }
476 };
477 let data = event.data;
478
479 if data.contains("[DONE]") {
480 return None;
481 }
482
483 Some(serde_json::from_str::<OpenrouterStreamResponse>(&data).map_err(Into::into))
484 });
485
486 Ok(Box::pin(stream))
487 }
488
489 #[allow(clippy::expect_used)]
502 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
503 pub async fn models(&self) -> Result<Vec<ModelId<'_>>> {
504 let url = self.base_url.join("models").expect("Failed to join URL");
505 let response = self.client.get(url).send().await?;
506
507 if !response.status().is_success() {
508 let status = response.status().as_u16();
509 let message = response.text().await?;
510
511 #[cfg(feature = "tracing")]
512 tracing::error!("API request failed: {}", &message);
513
514 return Err(Error::ApiError { status, message });
515 }
516
517 let response: ModelsList = response.json().await?;
518
519 Ok(response.into())
520 }
521}
522
523impl AiProvider for Openrouter {
524 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, model)))]
525 fn send_request(
526 &self,
527 model: &dyn AiModel,
528 request: AiRequest,
529 ) -> BoxFuture<'_, Result<AiResponse>> {
530 let Some(model) = model.downcast::<OpenrouterModel>() else {
531 let model_name = model.as_ref();
532
533 #[cfg(feature = "tracing")]
534 tracing::error!("Invalid model type: {}", model_name);
535
536 return Box::pin(ready(Err(Error::InvalidModelError(model_name.into()))));
537 };
538
539 let model = model.clone();
540 Box::pin(async move { self.request(model, request).await.map(Into::into) })
541 }
542
543 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, model)))]
544 fn send_streaming(
545 &self,
546 model: &dyn AiModel,
547 request: AiRequest,
548 ) -> BoxStream<'_, Result<AiResponse>> {
549 let Some(model) = model.downcast::<OpenrouterModel>() else {
550 let model_name = model.as_ref().to_owned();
551
552 #[cfg(feature = "tracing")]
553 tracing::error!("Invalid model type: {}", model_name);
554
555 return Box::pin(futures::stream::once(async {
556 Err(Error::InvalidModelError(model_name))
557 }));
558 };
559
560 Box::pin(
561 async move {
562 match self.streaming_request(model, request).await {
563 Ok(stream) => stream.map(|res| res.map(Into::into)).boxed(),
564 Err(err) => futures::stream::once(async move { Err(err) }).boxed(),
565 }
566 }
567 .flatten_stream(),
568 )
569 }
570}