rig/providers/openai/
embedding.rs1use super::{
2 Client,
3 client::{ApiErrorResponse, ApiResponse},
4 completion::Usage,
5};
6use crate::embeddings::EmbeddingError;
7use crate::http_client::HttpClientExt;
8use crate::{embeddings, http_client};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11
12pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
17pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
19pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24 pub object: String,
25 pub data: Vec<EmbeddingData>,
26 pub model: String,
27 pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31 fn from(err: ApiErrorResponse) -> Self {
32 EmbeddingError::ProviderError(err.message)
33 }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38 match value {
39 ApiResponse::Ok(response) => Ok(response),
40 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
41 }
42 }
43}
44
45#[derive(Debug, Deserialize, Clone, Serialize)]
46#[serde(rename_all = "snake_case")]
47pub enum EncodingFormat {
48 Float,
49 Base64,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct EmbeddingData {
54 pub object: String,
55 pub embedding: Vec<f64>,
56 pub index: usize,
57}
58
59#[derive(Clone)]
60pub struct EmbeddingModel<T = reqwest::Client> {
61 client: Client<T>,
62 pub model: String,
63 pub encoding_format: Option<EncodingFormat>,
64 pub user: Option<String>,
65 ndims: usize,
66}
67
68fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
69 match identifier {
70 TEXT_EMBEDDING_3_LARGE => Some(3_072),
71 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
72 _ => None,
73 }
74}
75
76impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
77where
78 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
79{
80 const MAX_DOCUMENTS: usize = 1024;
81
82 type Client = Client<T>;
83
84 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
85 let model = model.into();
86 let dims = ndims
87 .or(model_dimensions_from_identifier(&model))
88 .unwrap_or_default();
89
90 Self::new(client.clone(), model, dims)
91 }
92
93 fn ndims(&self) -> usize {
94 self.ndims
95 }
96
97 async fn embed_texts(
98 &self,
99 documents: impl IntoIterator<Item = String>,
100 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
101 let documents = documents.into_iter().collect::<Vec<_>>();
102
103 let mut body = json!({
104 "model": self.model,
105 "input": documents,
106 });
107
108 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
109 body["dimensions"] = json!(self.ndims);
110 }
111
112 if let Some(encoding_format) = &self.encoding_format {
113 body["encoding_format"] = json!(encoding_format);
114 }
115
116 if let Some(user) = &self.user {
117 body["user"] = json!(user);
118 }
119
120 let body = serde_json::to_vec(&body)?;
121
122 let req = self
123 .client
124 .post("/embeddings")?
125 .body(body)
126 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
127
128 let response = self.client.send(req).await?;
129
130 if response.status().is_success() {
131 let body: Vec<u8> = response.into_body().await?;
132 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
133
134 match body {
135 ApiResponse::Ok(response) => {
136 tracing::info!(target: "rig",
137 "OpenAI embedding token usage: {:?}",
138 response.usage
139 );
140
141 if response.data.len() != documents.len() {
142 return Err(EmbeddingError::ResponseError(
143 "Response data length does not match input length".into(),
144 ));
145 }
146
147 Ok(response
148 .data
149 .into_iter()
150 .zip(documents.into_iter())
151 .map(|(embedding, document)| embeddings::Embedding {
152 document,
153 vec: embedding.embedding,
154 })
155 .collect())
156 }
157 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
158 }
159 } else {
160 let text = http_client::text(response).await?;
161 Err(EmbeddingError::ProviderError(text))
162 }
163 }
164}
165
166impl<T> EmbeddingModel<T> {
167 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
168 Self {
169 client,
170 model: model.into(),
171 encoding_format: None,
172 ndims,
173 user: None,
174 }
175 }
176
177 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
178 Self {
179 client,
180 model: model.into(),
181 encoding_format: None,
182 ndims,
183 user: None,
184 }
185 }
186
187 pub fn with_encoding_format(
188 client: Client<T>,
189 model: &str,
190 ndims: usize,
191 encoding_format: EncodingFormat,
192 ) -> Self {
193 Self {
194 client,
195 model: model.into(),
196 encoding_format: Some(encoding_format),
197 ndims,
198 user: None,
199 }
200 }
201
202 pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
203 self.encoding_format = Some(encoding_format);
204 self
205 }
206
207 pub fn user(mut self, user: impl Into<String>) -> Self {
208 self.user = Some(user.into());
209 self
210 }
211}