Skip to main content

rig/providers/openai/
embedding.rs

1use super::{
2    client::{ApiErrorResponse, ApiResponse},
3    completion::Usage,
4};
5use crate::embeddings::EmbeddingError;
6use crate::http_client::HttpClientExt;
7use crate::{embeddings, http_client};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10
11// ================================================================
12// OpenAI Embedding API
13// ================================================================
14/// `text-embedding-3-large` embedding model
15pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
16/// `text-embedding-3-small` embedding model
17pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
18/// `text-embedding-ada-002` embedding model
19pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
20
21#[derive(Debug, Deserialize)]
22pub struct EmbeddingResponse {
23    pub object: String,
24    pub data: Vec<EmbeddingData>,
25    pub model: String,
26    pub usage: Usage,
27}
28
29impl From<ApiErrorResponse> for EmbeddingError {
30    fn from(err: ApiErrorResponse) -> Self {
31        EmbeddingError::ProviderError(err.message)
32    }
33}
34
35impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
36    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
37        match value {
38            ApiResponse::Ok(response) => Ok(response),
39            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
40        }
41    }
42}
43
44#[derive(Debug, Deserialize, Clone, Serialize)]
45#[serde(rename_all = "snake_case")]
46pub enum EncodingFormat {
47    Float,
48    Base64,
49}
50
51#[derive(Debug, Deserialize)]
52pub struct EmbeddingData {
53    pub object: String,
54    pub embedding: Vec<serde_json::Number>,
55    pub index: usize,
56}
57
58#[doc(hidden)]
59#[derive(Clone)]
60pub struct GenericEmbeddingModel<Ext = super::OpenAIResponsesExt, H = reqwest::Client> {
61    client: crate::client::Client<Ext, H>,
62    pub model: String,
63    pub encoding_format: Option<EncodingFormat>,
64    pub user: Option<String>,
65    ndims: usize,
66}
67
68/// The embedding model struct for OpenAI's Embeddings API.
69///
70/// This preserves the historical public generic shape where the first generic
71/// parameter is the HTTP client type.
72pub type EmbeddingModel<H = reqwest::Client> = GenericEmbeddingModel<super::OpenAIResponsesExt, H>;
73
74fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
75    match identifier {
76        TEXT_EMBEDDING_3_LARGE => Some(3_072),
77        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
78        _ => None,
79    }
80}
81
82impl<Ext, H> embeddings::EmbeddingModel for GenericEmbeddingModel<Ext, H>
83where
84    crate::client::Client<Ext, H>: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
85    Ext: crate::client::Provider + Clone + 'static,
86    H: Clone + Default + std::fmt::Debug + 'static,
87{
88    const MAX_DOCUMENTS: usize = 1024;
89
90    type Client = crate::client::Client<Ext, H>;
91
92    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
93        let model = model.into();
94        let dims = ndims
95            .or(model_dimensions_from_identifier(&model))
96            .unwrap_or_default();
97
98        Self::new(client.clone(), model, dims)
99    }
100
101    fn ndims(&self) -> usize {
102        self.ndims
103    }
104
105    async fn embed_texts(
106        &self,
107        documents: impl IntoIterator<Item = String>,
108    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
109        let documents = documents.into_iter().collect::<Vec<_>>();
110
111        let mut body = json!({
112            "model": self.model,
113            "input": documents,
114        });
115
116        let body_object = body.as_object_mut().ok_or_else(|| {
117            EmbeddingError::ResponseError("embedding request body must be a JSON object".into())
118        })?;
119
120        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
121            body_object.insert("dimensions".to_owned(), json!(self.ndims));
122        }
123
124        if let Some(encoding_format) = &self.encoding_format {
125            body_object.insert("encoding_format".to_owned(), json!(encoding_format));
126        }
127
128        if let Some(user) = &self.user {
129            body_object.insert("user".to_owned(), json!(user));
130        }
131
132        let body = serde_json::to_vec(&body)?;
133
134        let req = self
135            .client
136            .post("/embeddings")?
137            .body(body)
138            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
139
140        let response = self.client.send(req).await?;
141
142        if response.status().is_success() {
143            let body: Vec<u8> = response.into_body().await?;
144            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
145
146            match body {
147                ApiResponse::Ok(response) => {
148                    tracing::info!(target: "rig",
149                        "OpenAI embedding token usage: {:?}",
150                        response.usage
151                    );
152
153                    if response.data.len() != documents.len() {
154                        return Err(EmbeddingError::ResponseError(
155                            "Response data length does not match input length".into(),
156                        ));
157                    }
158
159                    Ok(response
160                        .data
161                        .into_iter()
162                        .zip(documents.into_iter())
163                        .map(|(embedding, document)| embeddings::Embedding {
164                            document,
165                            vec: embedding
166                                .embedding
167                                .into_iter()
168                                .filter_map(|n| n.as_f64())
169                                .collect(),
170                        })
171                        .collect())
172                }
173                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
174            }
175        } else {
176            let text = http_client::text(response).await?;
177            Err(EmbeddingError::ProviderError(text))
178        }
179    }
180}
181
182impl<Ext, H> GenericEmbeddingModel<Ext, H>
183where
184    Ext: crate::client::Provider,
185{
186    pub fn new(
187        client: crate::client::Client<Ext, H>,
188        model: impl Into<String>,
189        ndims: usize,
190    ) -> Self {
191        Self {
192            client,
193            model: model.into(),
194            encoding_format: None,
195            ndims,
196            user: None,
197        }
198    }
199
200    pub fn with_model(client: crate::client::Client<Ext, H>, model: &str, ndims: usize) -> Self {
201        Self {
202            client,
203            model: model.into(),
204            encoding_format: None,
205            ndims,
206            user: None,
207        }
208    }
209
210    pub fn with_encoding_format(
211        client: crate::client::Client<Ext, H>,
212        model: &str,
213        ndims: usize,
214        encoding_format: EncodingFormat,
215    ) -> Self {
216        Self {
217            client,
218            model: model.into(),
219            encoding_format: Some(encoding_format),
220            ndims,
221            user: None,
222        }
223    }
224
225    pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
226        self.encoding_format = Some(encoding_format);
227        self
228    }
229
230    pub fn user(mut self, user: impl Into<String>) -> Self {
231        self.user = Some(user.into());
232        self
233    }
234}