Skip to main content

rig_core/providers/openai/
embedding.rs

1use super::{
2    client::{ApiErrorResponse, ApiResponse},
3    completion::Usage,
4};
5use crate::embeddings::EmbeddingError;
6use crate::http_client::HttpClientExt;
7use crate::{embeddings, http_client};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11// ================================================================
12// OpenAI Embedding API
13// ================================================================
14/// `text-embedding-3-large` embedding model
15pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
16/// `text-embedding-3-small` embedding model
17pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
18/// `text-embedding-ada-002` embedding model
19pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
20
21#[derive(Debug, Deserialize)]
22pub struct EmbeddingResponse {
23    pub object: String,
24    pub data: Vec<EmbeddingData>,
25    pub model: String,
26    pub usage: Usage,
27}
28
29impl From<ApiErrorResponse> for EmbeddingError {
30    fn from(err: ApiErrorResponse) -> Self {
31        EmbeddingError::ProviderError(err.message)
32    }
33}
34
35impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
36    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
37        match value {
38            ApiResponse::Ok(response) => Ok(response),
39            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
40        }
41    }
42}
43
44#[derive(Debug, Deserialize, Clone, Serialize)]
45#[serde(rename_all = "snake_case")]
46pub enum EncodingFormat {
47    Float,
48    Base64,
49}
50
51#[derive(Debug, Deserialize)]
52pub struct EmbeddingData {
53    pub object: String,
54    pub embedding: Vec<serde_json::Number>,
55    pub index: usize,
56}
57
58#[doc(hidden)]
59#[derive(Clone)]
60pub struct GenericEmbeddingModel<Ext = super::OpenAIResponsesExt, H = reqwest::Client> {
61    client: crate::client::Client<Ext, H>,
62    pub model: String,
63    pub encoding_format: Option<EncodingFormat>,
64    pub user: Option<String>,
65    ndims: usize,
66}
67
68/// The embedding model struct for OpenAI's Embeddings API.
69///
70/// This preserves the historical public generic shape where the first generic
71/// parameter is the HTTP client type.
72pub type EmbeddingModel<H = reqwest::Client> = GenericEmbeddingModel<super::OpenAIResponsesExt, H>;
73
74fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
75    match identifier {
76        TEXT_EMBEDDING_3_LARGE => Some(3_072),
77        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
78        _ => None,
79    }
80}
81
82impl<Ext, H> embeddings::EmbeddingModel for GenericEmbeddingModel<Ext, H>
83where
84    crate::client::Client<Ext, H>: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
85    Ext: crate::client::Provider + Clone + 'static,
86    H: Clone + Default + std::fmt::Debug + 'static,
87{
88    const MAX_DOCUMENTS: usize = 1024;
89
90    type Client = crate::client::Client<Ext, H>;
91
92    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
93        let model = model.into();
94        let dims = ndims
95            .or(model_dimensions_from_identifier(&model))
96            .unwrap_or_default();
97
98        Self::new(client.clone(), model, dims)
99    }
100
101    fn ndims(&self) -> usize {
102        self.ndims
103    }
104
105    async fn embed_texts(
106        &self,
107        documents: impl IntoIterator<Item = String>,
108    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
109        let documents: Vec<String> = documents.into_iter().collect();
110        let response = self.embed_texts_with_usage(documents).await?;
111        Ok(response.embeddings)
112    }
113
114    async fn embed_texts_with_usage(
115        &self,
116        documents: impl IntoIterator<Item = String>,
117    ) -> Result<embeddings::EmbeddingResponse, EmbeddingError> {
118        let documents: Vec<String> = documents.into_iter().collect();
119
120        let mut body = json!({
121            "model": self.model,
122            "input": documents,
123        });
124
125        let body_object = body.as_object_mut().ok_or_else(|| {
126            EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
127        })?;
128
129        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
130            body_object.insert("dimensions".to_owned(), json!(self.ndims));
131        }
132
133        if let Some(encoding_format) = &self.encoding_format {
134            body_object.insert("encoding_format".to_owned(), json!(encoding_format));
135        }
136
137        if let Some(user) = &self.user {
138            body_object.insert("user".to_owned(), json!(user));
139        }
140
141        let body = serde_json::to_vec(&body)?;
142
143        let req = self
144            .client
145            .post("/embeddings")?
146            .body(body)
147            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
148
149        let response = self.client.send(req).await?;
150
151        if response.status().is_success() {
152            let body: Vec<u8> = response.into_body().await?;
153            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
154
155            match body {
156                ApiResponse::Ok(response) => {
157                    tracing::info!(target: "rig",
158                        "OpenAI embedding token usage: {:?}",
159                        response.usage
160                    );
161
162                    if response.data.len() != documents.len() {
163                        return Err(EmbeddingError::ResponseError(
164                            "Response data length does not match input length".into(),
165                        ));
166                    }
167
168                    let usage = crate::completion::Usage {
169                        input_tokens: response.usage.prompt_tokens as u64,
170                        output_tokens: 0,
171                        total_tokens: response.usage.total_tokens as u64,
172                        cached_input_tokens: response
173                            .usage
174                            .prompt_tokens_details
175                            .as_ref()
176                            .map_or(0, |d| d.cached_tokens as u64),
177                        cache_creation_input_tokens: 0,
178                        tool_use_prompt_tokens: 0,
179                        reasoning_tokens: 0,
180                    };
181
182                    let embeddings = response
183                        .data
184                        .into_iter()
185                        .zip(documents.into_iter())
186                        .map(|(embedding, document)| embeddings::Embedding {
187                            document,
188                            vec: embedding
189                                .embedding
190                                .into_iter()
191                                .filter_map(|n| n.as_f64())
192                                .collect(),
193                        })
194                        .collect();
195
196                    Ok(embeddings::EmbeddingResponse { embeddings, usage })
197                }
198                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
199            }
200        } else {
201            let text = http_client::text(response).await?;
202            Err(EmbeddingError::ProviderError(text))
203        }
204    }
205}
206
207impl<Ext, H> GenericEmbeddingModel<Ext, H>
208where
209    Ext: crate::client::Provider,
210{
211    pub fn new(
212        client: crate::client::Client<Ext, H>,
213        model: impl Into<String>,
214        ndims: usize,
215    ) -> Self {
216        Self {
217            client,
218            model: model.into(),
219            encoding_format: None,
220            ndims,
221            user: None,
222        }
223    }
224
225    pub fn with_model(client: crate::client::Client<Ext, H>, model: &str, ndims: usize) -> Self {
226        Self {
227            client,
228            model: model.into(),
229            encoding_format: None,
230            ndims,
231            user: None,
232        }
233    }
234
235    pub fn with_encoding_format(
236        client: crate::client::Client<Ext, H>,
237        model: &str,
238        ndims: usize,
239        encoding_format: EncodingFormat,
240    ) -> Self {
241        Self {
242            client,
243            model: model.into(),
244            encoding_format: Some(encoding_format),
245            ndims,
246            user: None,
247        }
248    }
249
250    pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
251        self.encoding_format = Some(encoding_format);
252        self
253    }
254
255    pub fn user(mut self, user: impl Into<String>) -> Self {
256        self.user = Some(user.into());
257        self
258    }
259}