Skip to main content

rig/providers/openrouter/
embedding.rs

1use super::{
2    Client, Usage,
3    client::{ApiErrorResponse, ApiResponse},
4};
5use crate::embeddings::EmbeddingError;
6use crate::http_client::HttpClientExt;
7use crate::wasm_compat::WasmCompatSend;
8use crate::{embeddings, http_client};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11
12#[derive(Debug, Deserialize)]
13pub struct EmbeddingResponse {
14    pub object: String,
15    pub data: Vec<EmbeddingData>,
16    pub model: String,
17    pub usage: Option<Usage>,
18    pub id: Option<String>,
19}
20
21impl From<ApiErrorResponse> for EmbeddingError {
22    fn from(err: ApiErrorResponse) -> Self {
23        EmbeddingError::ProviderError(err.message)
24    }
25}
26
27impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
28    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
29        match value {
30            ApiResponse::Ok(response) => Ok(response),
31            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
32        }
33    }
34}
35
36#[derive(Debug, Deserialize, Clone, Serialize)]
37#[serde(rename_all = "snake_case")]
38pub enum EncodingFormat {
39    Float,
40    Base64,
41}
42
43#[derive(Debug, Deserialize)]
44pub struct EmbeddingData {
45    pub object: String,
46    pub embedding: Vec<serde_json::Number>,
47    pub index: usize,
48}
49
50#[derive(Clone)]
51pub struct EmbeddingModel<T = reqwest::Client> {
52    client: Client<T>,
53    pub model: String,
54    pub encoding_format: Option<EncodingFormat>,
55    pub user: Option<String>,
56    ndims: usize,
57}
58
59impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
60where
61    T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
62{
63    const MAX_DOCUMENTS: usize = 1024;
64
65    type Client = Client<T>;
66
67    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
68        let model = model.into();
69        let dims = ndims.unwrap_or_default();
70
71        Self::new(client.clone(), model, dims)
72    }
73
74    fn ndims(&self) -> usize {
75        self.ndims
76    }
77
78    async fn embed_texts(
79        &self,
80        documents: impl IntoIterator<Item = String>,
81    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
82        let documents = documents.into_iter().collect::<Vec<_>>();
83
84        let mut body = json!({
85            "model": self.model,
86            "input": documents,
87        });
88
89        let body_object = body.as_object_mut().ok_or_else(|| {
90            EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
91        })?;
92
93        if self.ndims > 0 {
94            body_object.insert("dimensions".to_owned(), json!(self.ndims));
95        }
96
97        if let Some(encoding_format) = &self.encoding_format {
98            body_object.insert("encoding_format".to_owned(), json!(encoding_format));
99        }
100
101        if let Some(user) = &self.user {
102            body_object.insert("user".to_owned(), json!(user));
103        }
104
105        let body = serde_json::to_vec(&body)?;
106
107        let req = self
108            .client
109            .post("/embeddings")?
110            .body(body)
111            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
112
113        let response = self.client.send(req).await?;
114
115        if response.status().is_success() {
116            let body: Vec<u8> = response.into_body().await?;
117            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
118
119            match body {
120                ApiResponse::Ok(response) => {
121                    tracing::info!(target: "rig",
122                        "OpenRouter embedding token usage: {:?}",
123                        response.usage
124                    );
125
126                    if response.data.len() != documents.len() {
127                        return Err(EmbeddingError::ResponseError(
128                            "Response data length does not match input length".into(),
129                        ));
130                    }
131
132                    Ok(response
133                        .data
134                        .into_iter()
135                        .zip(documents.into_iter())
136                        .map(|(embedding, document)| embeddings::Embedding {
137                            document,
138                            vec: embedding
139                                .embedding
140                                .into_iter()
141                                .filter_map(|n| n.as_f64())
142                                .collect(),
143                        })
144                        .collect())
145                }
146                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
147            }
148        } else {
149            let text = http_client::text(response).await?;
150            Err(EmbeddingError::ProviderError(text))
151        }
152    }
153}
154
155impl<T> EmbeddingModel<T> {
156    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
157        Self {
158            client,
159            model: model.into(),
160            encoding_format: None,
161            ndims,
162            user: None,
163        }
164    }
165
166    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
167        Self {
168            client,
169            model: model.into(),
170            encoding_format: None,
171            ndims,
172            user: None,
173        }
174    }
175
176    pub fn with_encoding_format(
177        client: Client<T>,
178        model: &str,
179        ndims: usize,
180        encoding_format: EncodingFormat,
181    ) -> Self {
182        Self {
183            client,
184            model: model.into(),
185            encoding_format: Some(encoding_format),
186            ndims,
187            user: None,
188        }
189    }
190
191    pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
192        self.encoding_format = Some(encoding_format);
193        self
194    }
195
196    pub fn user(mut self, user: impl Into<String>) -> Self {
197        self.user = Some(user.into());
198        self
199    }
200}