1use super::{
2 client::{ApiErrorResponse, ApiResponse},
3 completion::Usage,
4};
5use crate::embeddings::EmbeddingError;
6use crate::http_client::HttpClientExt;
7use crate::{embeddings, http_client};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
16pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
18pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
20
21#[derive(Debug, Deserialize)]
22pub struct EmbeddingResponse {
23 pub object: String,
24 pub data: Vec<EmbeddingData>,
25 pub model: String,
26 pub usage: Usage,
27}
28
29impl From<ApiErrorResponse> for EmbeddingError {
30 fn from(err: ApiErrorResponse) -> Self {
31 EmbeddingError::ProviderError(err.message)
32 }
33}
34
35impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
36 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
37 match value {
38 ApiResponse::Ok(response) => Ok(response),
39 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
40 }
41 }
42}
43
44#[derive(Debug, Deserialize, Clone, Serialize)]
45#[serde(rename_all = "snake_case")]
46pub enum EncodingFormat {
47 Float,
48 Base64,
49}
50
51#[derive(Debug, Deserialize)]
52pub struct EmbeddingData {
53 pub object: String,
54 pub embedding: Vec<serde_json::Number>,
55 pub index: usize,
56}
57
58#[doc(hidden)]
59#[derive(Clone)]
60pub struct GenericEmbeddingModel<Ext = super::OpenAIResponsesExt, H = reqwest::Client> {
61 client: crate::client::Client<Ext, H>,
62 pub model: String,
63 pub encoding_format: Option<EncodingFormat>,
64 pub user: Option<String>,
65 ndims: usize,
66}
67
68pub type EmbeddingModel<H = reqwest::Client> = GenericEmbeddingModel<super::OpenAIResponsesExt, H>;
73
74fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
75 match identifier {
76 TEXT_EMBEDDING_3_LARGE => Some(3_072),
77 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
78 _ => None,
79 }
80}
81
82impl<Ext, H> embeddings::EmbeddingModel for GenericEmbeddingModel<Ext, H>
83where
84 crate::client::Client<Ext, H>: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
85 Ext: crate::client::Provider + Clone + 'static,
86 H: Clone + Default + std::fmt::Debug + 'static,
87{
88 const MAX_DOCUMENTS: usize = 1024;
89
90 type Client = crate::client::Client<Ext, H>;
91
92 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
93 let model = model.into();
94 let dims = ndims
95 .or(model_dimensions_from_identifier(&model))
96 .unwrap_or_default();
97
98 Self::new(client.clone(), model, dims)
99 }
100
101 fn ndims(&self) -> usize {
102 self.ndims
103 }
104
105 async fn embed_texts(
106 &self,
107 documents: impl IntoIterator<Item = String>,
108 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
109 let documents = documents.into_iter().collect::<Vec<_>>();
110
111 let mut body = json!({
112 "model": self.model,
113 "input": documents,
114 });
115
116 let body_object = body.as_object_mut().ok_or_else(|| {
117 EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
118 })?;
119
120 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
121 body_object.insert("dimensions".to_owned(), json!(self.ndims));
122 }
123
124 if let Some(encoding_format) = &self.encoding_format {
125 body_object.insert("encoding_format".to_owned(), json!(encoding_format));
126 }
127
128 if let Some(user) = &self.user {
129 body_object.insert("user".to_owned(), json!(user));
130 }
131
132 let body = serde_json::to_vec(&body)?;
133
134 let req = self
135 .client
136 .post("/embeddings")?
137 .body(body)
138 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
139
140 let response = self.client.send(req).await?;
141
142 if response.status().is_success() {
143 let body: Vec<u8> = response.into_body().await?;
144 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
145
146 match body {
147 ApiResponse::Ok(response) => {
148 tracing::info!(target: "rig",
149 "OpenAI embedding token usage: {:?}",
150 response.usage
151 );
152
153 if response.data.len() != documents.len() {
154 return Err(EmbeddingError::ResponseError(
155 "Response data length does not match input length".into(),
156 ));
157 }
158
159 Ok(response
160 .data
161 .into_iter()
162 .zip(documents.into_iter())
163 .map(|(embedding, document)| embeddings::Embedding {
164 document,
165 vec: embedding
166 .embedding
167 .into_iter()
168 .filter_map(|n| n.as_f64())
169 .collect(),
170 })
171 .collect())
172 }
173 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
174 }
175 } else {
176 let text = http_client::text(response).await?;
177 Err(EmbeddingError::ProviderError(text))
178 }
179 }
180}
181
182impl<Ext, H> GenericEmbeddingModel<Ext, H>
183where
184 Ext: crate::client::Provider,
185{
186 pub fn new(
187 client: crate::client::Client<Ext, H>,
188 model: impl Into<String>,
189 ndims: usize,
190 ) -> Self {
191 Self {
192 client,
193 model: model.into(),
194 encoding_format: None,
195 ndims,
196 user: None,
197 }
198 }
199
200 pub fn with_model(client: crate::client::Client<Ext, H>, model: &str, ndims: usize) -> Self {
201 Self {
202 client,
203 model: model.into(),
204 encoding_format: None,
205 ndims,
206 user: None,
207 }
208 }
209
210 pub fn with_encoding_format(
211 client: crate::client::Client<Ext, H>,
212 model: &str,
213 ndims: usize,
214 encoding_format: EncodingFormat,
215 ) -> Self {
216 Self {
217 client,
218 model: model.into(),
219 encoding_format: Some(encoding_format),
220 ndims,
221 user: None,
222 }
223 }
224
225 pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
226 self.encoding_format = Some(encoding_format);
227 self
228 }
229
230 pub fn user(mut self, user: impl Into<String>) -> Self {
231 self.user = Some(user.into());
232 self
233 }
234}