rig/providers/openai/
embedding.rs

1use super::{
2    Client,
3    client::{ApiErrorResponse, ApiResponse},
4    completion::Usage,
5};
6use crate::embeddings::EmbeddingError;
7use crate::http_client::HttpClientExt;
8use crate::{embeddings, http_client};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11
12// ================================================================
13// OpenAI Embedding API
14// ================================================================
15/// `text-embedding-3-large` embedding model
16pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
17/// `text-embedding-3-small` embedding model
18pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
19/// `text-embedding-ada-002` embedding model
20pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24    pub object: String,
25    pub data: Vec<EmbeddingData>,
26    pub model: String,
27    pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31    fn from(err: ApiErrorResponse) -> Self {
32        EmbeddingError::ProviderError(err.message)
33    }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38        match value {
39            ApiResponse::Ok(response) => Ok(response),
40            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
41        }
42    }
43}
44
45#[derive(Debug, Deserialize, Clone, Serialize)]
46#[serde(rename_all = "snake_case")]
47pub enum EncodingFormat {
48    Float,
49    Base64,
50}
51
52#[derive(Debug, Deserialize)]
53pub struct EmbeddingData {
54    pub object: String,
55    pub embedding: Vec<f64>,
56    pub index: usize,
57}
58
59#[derive(Clone)]
60pub struct EmbeddingModel<T = reqwest::Client> {
61    client: Client<T>,
62    pub model: String,
63    pub encoding_format: Option<EncodingFormat>,
64    pub user: Option<String>,
65    ndims: usize,
66}
67
68fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
69    match identifier {
70        TEXT_EMBEDDING_3_LARGE => Some(3_072),
71        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
72        _ => None,
73    }
74}
75
76impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
77where
78    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
79{
80    const MAX_DOCUMENTS: usize = 1024;
81
82    type Client = Client<T>;
83
84    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
85        let model = model.into();
86        let dims = ndims
87            .or(model_dimensions_from_identifier(&model))
88            .unwrap_or_default();
89
90        Self::new(client.clone(), model, dims)
91    }
92
93    fn ndims(&self) -> usize {
94        self.ndims
95    }
96
97    async fn embed_texts(
98        &self,
99        documents: impl IntoIterator<Item = String>,
100    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
101        let documents = documents.into_iter().collect::<Vec<_>>();
102
103        let mut body = json!({
104            "model": self.model,
105            "input": documents,
106        });
107
108        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
109            body["dimensions"] = json!(self.ndims);
110        }
111
112        if let Some(encoding_format) = &self.encoding_format {
113            body["encoding_format"] = json!(encoding_format);
114        }
115
116        if let Some(user) = &self.user {
117            body["user"] = json!(user);
118        }
119
120        let body = serde_json::to_vec(&body)?;
121
122        let req = self
123            .client
124            .post("/embeddings")?
125            .body(body)
126            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
127
128        let response = self.client.send(req).await?;
129
130        if response.status().is_success() {
131            let body: Vec<u8> = response.into_body().await?;
132            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
133
134            match body {
135                ApiResponse::Ok(response) => {
136                    tracing::info!(target: "rig",
137                        "OpenAI embedding token usage: {:?}",
138                        response.usage
139                    );
140
141                    if response.data.len() != documents.len() {
142                        return Err(EmbeddingError::ResponseError(
143                            "Response data length does not match input length".into(),
144                        ));
145                    }
146
147                    Ok(response
148                        .data
149                        .into_iter()
150                        .zip(documents.into_iter())
151                        .map(|(embedding, document)| embeddings::Embedding {
152                            document,
153                            vec: embedding.embedding,
154                        })
155                        .collect())
156                }
157                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
158            }
159        } else {
160            let text = http_client::text(response).await?;
161            Err(EmbeddingError::ProviderError(text))
162        }
163    }
164}
165
166impl<T> EmbeddingModel<T> {
167    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
168        Self {
169            client,
170            model: model.into(),
171            encoding_format: None,
172            ndims,
173            user: None,
174        }
175    }
176
177    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
178        Self {
179            client,
180            model: model.into(),
181            encoding_format: None,
182            ndims,
183            user: None,
184        }
185    }
186
187    pub fn with_encoding_format(
188        client: Client<T>,
189        model: &str,
190        ndims: usize,
191        encoding_format: EncodingFormat,
192    ) -> Self {
193        Self {
194            client,
195            model: model.into(),
196            encoding_format: Some(encoding_format),
197            ndims,
198            user: None,
199        }
200    }
201
202    pub fn encoding_format(mut self, encoding_format: EncodingFormat) -> Self {
203        self.encoding_format = Some(encoding_format);
204        self
205    }
206
207    pub fn user(mut self, user: impl Into<String>) -> Self {
208        self.user = Some(user.into());
209        self
210    }
211}