rig/providers/openrouter/
embedding.rs1use super::{
2 Client, Usage,
3 client::{ApiErrorResponse, ApiResponse},
4};
5use crate::embeddings::EmbeddingError;
6use crate::http_client::HttpClientExt;
7use crate::wasm_compat::WasmCompatSend;
8use crate::{embeddings, http_client};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11
12#[derive(Debug, Deserialize)]
13pub struct EmbeddingResponse {
14 pub object: String,
15 pub data: Vec<EmbeddingData>,
16 pub model: String,
17 pub usage: Option<Usage>,
18 pub id: Option<String>,
19}
20
21impl From<ApiErrorResponse> for EmbeddingError {
22 fn from(err: ApiErrorResponse) -> Self {
23 EmbeddingError::ProviderError(err.message)
24 }
25}
26
27impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
28 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
29 match value {
30 ApiResponse::Ok(response) => Ok(response),
31 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
32 }
33 }
34}
35
36#[derive(Debug, Deserialize, Clone, Serialize)]
37#[serde(rename_all = "snake_case")]
38pub enum EncodingFormat {
39 Float,
40 Base64,
41}
42
43#[derive(Debug, Deserialize)]
44pub struct EmbeddingData {
45 pub object: String,
46 pub embedding: Vec<serde_json::Number>,
47 pub index: usize,
48}
49
50#[derive(Clone)]
51pub struct EmbeddingModel<T = reqwest::Client> {
52 client: Client<T>,
53 pub model: String,
54 pub encoding_format: Option<EncodingFormat>,
55 pub user: Option<String>,
56 ndims: usize,
57}
58
59impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
60where
61 T: HttpClientExt + Clone + std::fmt::Debug + Default + WasmCompatSend + 'static,
62{
63 const MAX_DOCUMENTS: usize = 1024;
64
65 type Client = Client<T>;
66
67 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
68 let model = model.into();
69 let dims = ndims.unwrap_or_default();
70
71 Self::new(client.clone(), model, dims)
72 }
73
74 fn ndims(&self) -> usize {
75 self.ndims
76 }
77
78 async fn embed_texts(
79 &self,
80 documents: impl IntoIterator<Item = String>,
81 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
82 let documents = documents.into_iter().collect::<Vec<_>>();
83
84 let mut body = json!({
85 "model": self.model,
86 "input": documents,
87 });
88
89 let body_object = body.as_object_mut().ok_or_else(|| {
90 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
91 })?;
92
93 if self.ndims > 0 {
94 body_object.insert("dimensions".to_owned(), json!(self.ndims));
95 }
96
97 if let Some(encoding_format) = &self.encoding_format {
98 body_object.insert("encoding_format".to_owned(), json!(encoding_format));
99 }
100
101 if let Some(user) = &self.user {
102 body_object.insert("user".to_owned(), json!(user));
103 }
104
105 let body = serde_json::to_vec(&body)?;
106
107 let req = self
108 .client
109 .post("/embeddings")?
110 .body(body)
111 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
112
113 let response = self.client.send(req).await?;
114
115 if response.status().is_success() {
116 let body: Vec<u8> = response.into_body().await?;
117 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
118
119 match body {
120 ApiResponse::Ok(response) => {
121 tracing::info!(target: "rig",
122 "OpenRouter embedding token usage: {:?}",
123 response.usage
124 );
125
126 if response.data.len() != documents.len() {
127 return Err(EmbeddingError::ResponseError(
128 "Response data length does not match input length".into(),
129 ));
130 }
131
132 Ok(response
133 .data
134 .into_iter()
135 .zip(documents.into_iter())
136 .map(|(embedding, document)| embeddings::Embedding {
137 document,
138 vec: embedding
139 .embedding
140 .into_iter()
141 .filter_map(|n| n.as_f64())
142 .collect(),
143 })
144 .collect())
145 }
146 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
147 }
148 } else {
149 let text = http_client::text(response).await?;
150 Err(EmbeddingError::ProviderError(text))
151 }
152 }
153}
154
155impl<T> EmbeddingModel<T> {
156 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
157 Self {
158 client,
159 model: model.into(),
160 encoding_format: None,
161 ndims,
162 user: None,
163 }
164 }
165
166 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
167 Self {
168 client,
169 model: model.into(),
170 encoding_format: None,
171 ndims,
172 user: None,
173 }
174 }
175
176 pub fn with_encoding_format(
177 client: Client<T>,
178 model: &str,
179 ndims: usize,
180 encoding_format: EncodingFormat,
181 ) -> Self {
182 Self {
183 client,
184 model: model.into(),
185 encoding_format: Some(encoding_format),
186 ndims,
187 user: None,
188 }
189 }
190
191 pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
192 self.encoding_format = Some(encoding_format);
193 self
194 }
195
196 pub fn user(mut self, user: impl Into<String>) -> Self {
197 self.user = Some(user.into());
198 self
199 }
200}