rig/providers/openai/
embedding.rs

1use super::{
2    Client,
3    client::{ApiErrorResponse, ApiResponse},
4    completion::Usage,
5};
6use crate::embeddings::EmbeddingError;
7use crate::http_client::HttpClientExt;
8use crate::{embeddings, http_client};
9use serde::Deserialize;
10use serde_json::json;
11
12// ================================================================
13// OpenAI Embedding API
14// ================================================================
15/// `text-embedding-3-large` embedding model
16pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
17/// `text-embedding-3-small` embedding model
18pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
19/// `text-embedding-ada-002` embedding model
20pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24    pub object: String,
25    pub data: Vec<EmbeddingData>,
26    pub model: String,
27    pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31    fn from(err: ApiErrorResponse) -> Self {
32        EmbeddingError::ProviderError(err.message)
33    }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38        match value {
39            ApiResponse::Ok(response) => Ok(response),
40            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
41        }
42    }
43}
44
45#[derive(Debug, Deserialize)]
46pub struct EmbeddingData {
47    pub object: String,
48    pub embedding: Vec<f64>,
49    pub index: usize,
50}
51
52#[derive(Clone)]
53pub struct EmbeddingModel<T = reqwest::Client> {
54    client: Client<T>,
55    pub model: String,
56    ndims: usize,
57}
58
59fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
60    match identifier {
61        TEXT_EMBEDDING_3_LARGE => Some(3_072),
62        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
63        _ => None,
64    }
65}
66
67impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
68where
69    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
70{
71    const MAX_DOCUMENTS: usize = 1024;
72
73    type Client = Client<T>;
74
75    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
76        let model = model.into();
77        let dims = ndims
78            .or(model_dimensions_from_identifier(&model))
79            .unwrap_or_default();
80
81        Self::new(client.clone(), model, dims)
82    }
83
84    fn ndims(&self) -> usize {
85        self.ndims
86    }
87
88    #[cfg_attr(feature = "worker", worker::send)]
89    async fn embed_texts(
90        &self,
91        documents: impl IntoIterator<Item = String>,
92    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
93        let documents = documents.into_iter().collect::<Vec<_>>();
94
95        let mut body = json!({
96            "model": self.model,
97            "input": documents,
98        });
99
100        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
101            body["dimensions"] = json!(self.ndims);
102        }
103
104        let body = serde_json::to_vec(&body)?;
105
106        let req = self
107            .client
108            .post("/embeddings")?
109            .body(body)
110            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
111
112        let response = self.client.send(req).await?;
113
114        if response.status().is_success() {
115            let body: Vec<u8> = response.into_body().await?;
116            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
117
118            match body {
119                ApiResponse::Ok(response) => {
120                    tracing::info!(target: "rig",
121                        "OpenAI embedding token usage: {:?}",
122                        response.usage
123                    );
124
125                    if response.data.len() != documents.len() {
126                        return Err(EmbeddingError::ResponseError(
127                            "Response data length does not match input length".into(),
128                        ));
129                    }
130
131                    Ok(response
132                        .data
133                        .into_iter()
134                        .zip(documents.into_iter())
135                        .map(|(embedding, document)| embeddings::Embedding {
136                            document,
137                            vec: embedding.embedding,
138                        })
139                        .collect())
140                }
141                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
142            }
143        } else {
144            let text = http_client::text(response).await?;
145            Err(EmbeddingError::ProviderError(text))
146        }
147    }
148}
149
150impl<T> EmbeddingModel<T> {
151    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
152        Self {
153            client,
154            model: model.into(),
155            ndims,
156        }
157    }
158
159    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
160        Self {
161            client,
162            model: model.into(),
163            ndims,
164        }
165    }
166}