Skip to main content

latchlm_openrouter/
lib.rs

1// This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
2// If a copy of the MPL was not distributed with this file, You can obtain one at
3// https://mozilla.org/MPL/2.0/.
4
5//! Provider implementation for OpenRouter.
6//!
7//! This crate implements a client for interacting with the OpenRouter API.
8
9use eventsource_stream::Eventsource;
10use futures::{FutureExt, StreamExt, stream::BoxStream};
11use latchlm_core::{AiModel, AiProvider, AiRequest, AiResponse, BoxFuture, Error, ModelId, Result};
12use reqwest::{Client, Url};
13use secrecy::{ExposeSecret, SecretString};
14use std::{borrow::Cow, env::VarError, future::ready, sync::Arc};
15
16mod response;
17pub use response::*;
18
19/// OpenRouter model identifier.
20#[derive(Debug, Clone)]
21pub struct OpenrouterModel(String);
22
23impl AsRef<str> for OpenrouterModel {
24    fn as_ref(&self) -> &str {
25        &self.0
26    }
27}
28
29impl AiModel for OpenrouterModel {
30    fn as_any(&self) -> &dyn std::any::Any {
31        self
32    }
33    fn model_id(&self) -> ModelId<'_> {
34        ModelId {
35            id: Cow::Borrowed(&self.0),
36            name: Cow::Borrowed(&self.0),
37        }
38    }
39}
40
41impl OpenrouterModel {
42    pub fn new<S: Into<String>>(model_name: S) -> Self {
43        Self(model_name.into())
44    }
45}
46
47/// Errors that can occur while using the [`Openrouter`] client.
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum OpenrouterError {
50    MissingClientError,
51    MissingApiKeyError,
52    HeaderParseError(String),
53}
54
55impl std::fmt::Display for OpenrouterError {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        match self {
58            Self::MissingClientError => write!(f, "HTTP client is required"),
59            Self::MissingApiKeyError => write!(f, "API key is required"),
60            Self::HeaderParseError(err) => write!(f, "Failed to parse header: {err}"),
61        }
62    }
63}
64
65impl From<OpenrouterError> for Error {
66    fn from(value: OpenrouterError) -> Self {
67        match value {
68            OpenrouterError::MissingClientError => Self::ProviderError {
69                provider: "OpenRouter".to_owned(),
70                error: "Missing reqwest::Client".to_owned(),
71            },
72            OpenrouterError::MissingApiKeyError => Self::ProviderError {
73                provider: "OpenRouter".to_owned(),
74                error: "Missing API key".to_owned(),
75            },
76            OpenrouterError::HeaderParseError(err) => Self::ProviderError {
77                provider: "OpenRouter".to_owned(),
78                error: format!("Failed to parse header: {err}"),
79            },
80        }
81    }
82}
83
84impl std::error::Error for OpenrouterError {}
85
86/// A builder for creating an [`Openrouter`] client.
87#[derive(Debug, Clone, Default)]
88pub struct OpenrouterBuilder {
89    client: Option<Client>,
90    api_key: Option<SecretString>,
91    http_referer: Option<String>,
92    x_title: Option<String>,
93}
94
95impl OpenrouterBuilder {
96    /// Creates a new OpenRouter client builder.
97    ///
98    /// # Returns
99    /// An new [`OpenrouterBuilder`] instance.
100    #[must_use]
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    /// Sets the HTTP client to use for making requests.
106    ///
107    /// # Arguments
108    ///
109    /// * `client` - The HTTP client to use.
110    ///
111    /// # Returns
112    ///
113    /// The updated [`OpenrouterBuilder`] instance.
114    #[must_use]
115    pub fn client(mut self, client: Client) -> Self {
116        self.client = Some(client);
117        self
118    }
119
120    /// Sets the API key to use for authentication.
121    ///
122    /// # Arguments
123    ///
124    /// * `api_key` - The API key to use.
125    ///
126    /// # Returns
127    ///
128    /// The updated [`OpenrouterBuilder`] instance.
129    #[must_use]
130    pub fn api_key(mut self, api_key: SecretString) -> Self {
131        self.api_key = Some(api_key);
132        self
133    }
134
135    /// Sets the API key to be used by the [`Openrouter`] client from the environment variable `OPENROUTER_API_KEY`.
136    pub fn api_key_from_env(mut self) -> std::result::Result<Self, VarError> {
137        let api_key = std::env::var("OPENROUTER_API_KEY")?;
138
139        self.api_key = Some(SecretString::from(api_key));
140        Ok(self)
141    }
142
143    /// Sets the `HTTP-Referer` header to be used by the [`Openrouter`] client.
144    ///
145    /// # Arguments
146    ///
147    /// * `http_referer` - The `HTTP-Referer` header to use.
148    ///
149    /// # Returns
150    ///
151    /// The updated [`OpenrouterBuilder`] instance.
152    #[must_use]
153    pub fn http_referer(mut self, http_referer: String) -> Self {
154        self.http_referer = Some(http_referer);
155        self
156    }
157
158    /// Sets the `X-Title` header to be used by the [`Openrouter`] client.
159    ///
160    /// # Arguments
161    ///
162    /// * `x_title` - The `X-Title` header to use.
163    ///
164    /// # Returns
165    ///
166    /// The updated [`OpenrouterBuilder`] instance.
167    #[must_use]
168    pub fn x_title(mut self, x_title: String) -> Self {
169        self.x_title = Some(x_title);
170        self
171    }
172
173    /// Builds the [`Openrouter`] client.
174    ///
175    /// # Returns
176    ///
177    /// The [`Openrouter`] client.
178    pub fn build(self) -> Result<Openrouter> {
179        let client = self.client.ok_or(OpenrouterError::MissingClientError)?;
180        let api_key = self.api_key.ok_or(OpenrouterError::MissingApiKeyError)?;
181
182        Ok(Openrouter::new(
183            client,
184            api_key,
185            self.http_referer,
186            self.x_title,
187        ))
188    }
189}
190
191/// A client for the OpenRouter API.
192#[derive(Debug, Clone)]
193pub struct Openrouter {
194    base_url: Url,
195    client: Client,
196    api_key: Arc<SecretString>,
197    http_referer: Option<String>,
198    x_title: Option<String>,
199}
200
201impl Openrouter {
202    const BASE_URL: &str = "https://openrouter.ai/api/v1/";
203
204    /// Creates a new [`Openrouter`] client.
205    ///
206    /// # Arguments
207    ///
208    /// * `client` - The HTTP client to use.
209    /// * `api_key` - The API key to use.
210    /// * `http_referer` - The HTTP referer to use.
211    /// * `x_title` - The X-Title header to use.
212    ///
213    /// # Returns
214    ///
215    /// The [`Openrouter`] client.
216    #[allow(clippy::expect_used)]
217    #[must_use]
218    pub fn new(
219        client: Client,
220        api_key: SecretString,
221        http_referer: Option<String>,
222        x_title: Option<String>,
223    ) -> Self {
224        Self {
225            base_url: Url::parse(Self::BASE_URL).expect("Invalid base URL"),
226            client,
227            api_key: Arc::new(api_key),
228            http_referer,
229            x_title,
230        }
231    }
232
233    /// Creates a new [`Openrouter`] client with a custom base URL for testing.
234    ///
235    /// This constructor is intended for testing and mocking scenarios and should **never**
236    /// be used in production code.
237    ///
238    /// # Arguments
239    ///
240    /// * `client` - The HTTP client to use.
241    /// * `base_url` - The base URL to use.
242    /// * `api_key` - The API key to use.
243    ///
244    /// # Returns
245    ///
246    /// The [`Openrouter`] client.
247    #[cfg(feature = "test-utils")]
248    #[must_use]
249    pub fn new_with_base_url(client: Client, base_url: Url, api_key: SecretString) -> Self {
250        Self {
251            base_url,
252            client,
253            api_key: Arc::new(api_key),
254            http_referer: None,
255            x_title: None,
256        }
257    }
258
259    /// Creates a new [`Openrouter`] client builder.
260    #[must_use]
261    pub fn builder() -> OpenrouterBuilder {
262        OpenrouterBuilder::new()
263    }
264
265    /// Sends a request to the OpenRouter API to generate content.
266    ///
267    /// This method constructs a request to OpenRouter's API, handles authentication
268    /// and returns the parsed response containing the generated content.
269    ///
270    /// # Arguments
271    ///
272    /// * `model` - The [`OpenrouterModel`] to use for the request.
273    /// * `request` - The [`AiRequest`] containing the request prompt and settings to send to the API.
274    ///
275    /// # Returns
276    ///
277    /// The [`OpenrouterResponse`] containing the generated content.
278    ///
279    /// # Errors
280    ///
281    /// Returns an [`Error`] if:
282    /// - The HTTP request fails (network issues, timeouts, etc.)
283    /// - The API returns a non-success status code
284    /// - The response body cannot be parsed as valid JSON
285    /// - The API key is invalid or missing required permissions
286    ///
287    /// # Example
288    ///
289    /// ```toml
290    /// [dependencies]
291    /// latchlm_core = "*"
292    /// latchlm_openrouter = "*"
293    /// secrecy = "*"
294    /// tokio = { version = "*", features = ["full"] }
295    /// ```
296    ///
297    /// ```no_run
298    /// use latchlm_core::AiRequest;
299    /// use latchlm_openrouter::{Openrouter, OpenrouterModel};
300    /// use secrecy::SecretString;
301    ///
302    /// #[tokio::main]
303    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
304    ///     let openrouter = Openrouter::builder()
305    ///         .client(reqwest::Client::new())
306    ///         .api_key(SecretString::new("your-api-key".into()))
307    ///         .build()?;
308    ///
309    ///     let response = openrouter.request(
310    ///         OpenrouterModel::new("openai/gpt-oss-20b"),
311    ///         AiRequest {
312    ///             text: "Hello".into(),
313    ///         }
314    ///     ).await?;
315    ///
316    ///     println!("Generated: {}", response.extract_text());
317    ///     Ok(())
318    /// }
319    /// ```
320    ///
321    /// [`AiRequest`]: latchlm_core::AiRequest
322    /// [`Error`]: latchlm_core::Error
323    #[allow(clippy::expect_used)]
324    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
325    pub async fn request(
326        &self,
327        model: OpenrouterModel,
328        request: AiRequest,
329    ) -> Result<OpenrouterResponse> {
330        let mut headers = reqwest::header::HeaderMap::new();
331        headers.insert(
332            "Content-Type",
333            reqwest::header::HeaderValue::from_static("application/json"),
334        );
335
336        if let Some(http_referer) = &self.http_referer {
337            headers.insert(
338                "HTTP-Referer",
339                http_referer.parse().expect("Failed to parse http-referer"),
340            );
341        }
342
343        if let Some(x_title) = &self.x_title {
344            headers.insert("X-Title", x_title.parse().expect("Failed to parse x-title"));
345        }
346
347        let request = serde_json::json!({
348            "model": model.as_ref(),
349            "messages": [
350                {
351                    "role": "user",
352                    "content": request.text
353                }
354            ],
355        });
356
357        let url = self
358            .base_url
359            .join("chat/completions")
360            .expect("Failed to join URL");
361
362        let response = self
363            .client
364            .post(url)
365            .headers(headers)
366            .bearer_auth(self.api_key.expose_secret())
367            .json(&request)
368            .send()
369            .await?;
370
371        if !response.status().is_success() {
372            let status = response.status().as_u16();
373            let message = response.text().await?;
374
375            #[cfg(feature = "tracing")]
376            tracing::error!("API error: {}", message);
377
378            return Err(Error::ApiError { status, message });
379        }
380
381        let bytes = response.bytes().await?;
382
383        #[cfg(feature = "tracing")]
384        tracing::debug!("Received response: {bytes:?}");
385
386        let response = serde_json::from_slice(&bytes)?;
387
388        Ok(response)
389    }
390
391    /// Sends a streaming request to the OpenRouter and returns a stream of responses.
392    ///
393    /// # Arguments
394    ///
395    /// * `model` - The model to use for the request.
396    /// * `request` - The request to send.
397    #[allow(clippy::expect_used)]
398    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
399    pub async fn streaming_request(
400        &self,
401        model: OpenrouterModel,
402        request: AiRequest,
403    ) -> Result<BoxStream<'_, Result<OpenrouterStreamResponse>>> {
404        let mut headers = reqwest::header::HeaderMap::new();
405        headers.insert(
406            "Content-Type",
407            reqwest::header::HeaderValue::from_static("application/json"),
408        );
409
410        if let Some(http_referer) = &self.http_referer {
411            headers.insert(
412                "HTTP-Referer",
413                http_referer.parse().expect("Failed to parse http-referer"),
414            );
415        }
416
417        if let Some(x_title) = &self.x_title {
418            headers.insert("X-Title", x_title.parse().expect("Failed to parse x-title"));
419        }
420
421        let request = serde_json::json!({
422            "model": model.as_ref(),
423            "messages": [
424                {
425                    "role": "user",
426                    "content": request.text
427                }
428            ],
429            "stream": true
430        });
431
432        let url = self
433            .base_url
434            .join("chat/completions")
435            .expect("Failed to join URL");
436
437        let response = self
438            .client
439            .post(url)
440            .headers(headers)
441            .bearer_auth(self.api_key.expose_secret())
442            .json(&request)
443            .send()
444            .await?;
445
446        if !response.status().is_success() {
447            #[cfg(feature = "tracing")]
448            tracing::error!("OpenRouter API error: {}", response.status());
449
450            return Err(Error::ApiError {
451                status: response.status().as_u16(),
452                message: response.text().await?,
453            });
454        }
455
456        let stream = response
457            .bytes_stream()
458            .eventsource()
459            .filter_map(|result| async {
460                let event = match result {
461                    Ok(event) => {
462                        #[cfg(feature = "tracing")]
463                        tracing::debug!("OpenRouter API event: {:?}", event);
464
465                        event
466                    }
467                    Err(err) => {
468                        #[cfg(feature = "tracing")]
469                        tracing::error!("OpenRouter error: {}", err);
470
471                        return Some(Err(Error::ProviderError {
472                            provider: "OpenRouter".to_string(),
473                            error: err.to_string(),
474                        }));
475                    }
476                };
477                let data = event.data;
478
479                if data.contains("[DONE]") {
480                    return None;
481                }
482
483                Some(serde_json::from_str::<OpenrouterStreamResponse>(&data).map_err(Into::into))
484            });
485
486        Ok(Box::pin(stream))
487    }
488
489    /// Returns a list of available models.
490    ///
491    /// This function fetches the list of available models from the OpenRouter API.
492    ///
493    /// # Errors
494    ///
495    /// Returns an [`Error`] if:
496    /// - The API request fails.
497    /// - The response is not successful.
498    /// - The response cannot be parsed.
499    ///
500    /// [`Error`]: latchlm_core::Error
501    #[allow(clippy::expect_used)]
502    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self)))]
503    pub async fn models(&self) -> Result<Vec<ModelId<'_>>> {
504        let url = self.base_url.join("models").expect("Failed to join URL");
505        let response = self.client.get(url).send().await?;
506
507        if !response.status().is_success() {
508            let status = response.status().as_u16();
509            let message = response.text().await?;
510
511            #[cfg(feature = "tracing")]
512            tracing::error!("API request failed: {}", &message);
513
514            return Err(Error::ApiError { status, message });
515        }
516
517        let response: ModelsList = response.json().await?;
518
519        Ok(response.into())
520    }
521}
522
523impl AiProvider for Openrouter {
524    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, model)))]
525    fn send_request(
526        &self,
527        model: &dyn AiModel,
528        request: AiRequest,
529    ) -> BoxFuture<'_, Result<AiResponse>> {
530        let Some(model) = model.downcast::<OpenrouterModel>() else {
531            let model_name = model.as_ref();
532
533            #[cfg(feature = "tracing")]
534            tracing::error!("Invalid model type: {}", model_name);
535
536            return Box::pin(ready(Err(Error::InvalidModelError(model_name.into()))));
537        };
538
539        let model = model.clone();
540        Box::pin(async move { self.request(model, request).await.map(Into::into) })
541    }
542
543    #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, model)))]
544    fn send_streaming(
545        &self,
546        model: &dyn AiModel,
547        request: AiRequest,
548    ) -> BoxStream<'_, Result<AiResponse>> {
549        let Some(model) = model.downcast::<OpenrouterModel>() else {
550            let model_name = model.as_ref().to_owned();
551
552            #[cfg(feature = "tracing")]
553            tracing::error!("Invalid model type: {}", model_name);
554
555            return Box::pin(futures::stream::once(async {
556                Err(Error::InvalidModelError(model_name))
557            }));
558        };
559
560        Box::pin(
561            async move {
562                match self.streaming_request(model, request).await {
563                    Ok(stream) => stream.map(|res| res.map(Into::into)).boxed(),
564                    Err(err) => futures::stream::once(async move { Err(err) }).boxed(),
565                }
566            }
567            .flatten_stream(),
568        )
569    }
570}