rig/providers/openai/
embedding.rs

1use super::{
2    Client,
3    client::{ApiErrorResponse, ApiResponse},
4    completion::Usage,
5};
6use crate::embeddings::EmbeddingError;
7use crate::http_client::HttpClientExt;
8use crate::{embeddings, http_client};
9use serde::Deserialize;
10use serde_json::json;
11
12// ================================================================
13// OpenAI Embedding API
14// ================================================================
15/// `text-embedding-3-large` embedding model
16pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
17/// `text-embedding-3-small` embedding model
18pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
19/// `text-embedding-ada-002` embedding model
20pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24    pub object: String,
25    pub data: Vec<EmbeddingData>,
26    pub model: String,
27    pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31    fn from(err: ApiErrorResponse) -> Self {
32        EmbeddingError::ProviderError(err.message)
33    }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38        match value {
39            ApiResponse::Ok(response) => Ok(response),
40            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
41        }
42    }
43}
44
45#[derive(Debug, Deserialize)]
46pub struct EmbeddingData {
47    pub object: String,
48    pub embedding: Vec<f64>,
49    pub index: usize,
50}
51
52#[derive(Clone)]
53pub struct EmbeddingModel<T = reqwest::Client> {
54    client: Client<T>,
55    pub model: String,
56    ndims: usize,
57}
58
59fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
60    match identifier {
61        TEXT_EMBEDDING_3_LARGE => Some(3_072),
62        TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
63        _ => None,
64    }
65}
66
67impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
68where
69    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
70{
71    const MAX_DOCUMENTS: usize = 1024;
72
73    type Client = Client<T>;
74
75    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
76        let model = model.into();
77        let dims = ndims
78            .or(model_dimensions_from_identifier(&model))
79            .unwrap_or_default();
80
81        Self::new(client.clone(), model, dims)
82    }
83
84    fn ndims(&self) -> usize {
85        self.ndims
86    }
87
88    async fn embed_texts(
89        &self,
90        documents: impl IntoIterator<Item = String>,
91    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
92        let documents = documents.into_iter().collect::<Vec<_>>();
93
94        let mut body = json!({
95            "model": self.model,
96            "input": documents,
97        });
98
99        if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
100            body["dimensions"] = json!(self.ndims);
101        }
102
103        let body = serde_json::to_vec(&body)?;
104
105        let req = self
106            .client
107            .post("/embeddings")?
108            .body(body)
109            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
110
111        let response = self.client.send(req).await?;
112
113        if response.status().is_success() {
114            let body: Vec<u8> = response.into_body().await?;
115            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
116
117            match body {
118                ApiResponse::Ok(response) => {
119                    tracing::info!(target: "rig",
120                        "OpenAI embedding token usage: {:?}",
121                        response.usage
122                    );
123
124                    if response.data.len() != documents.len() {
125                        return Err(EmbeddingError::ResponseError(
126                            "Response data length does not match input length".into(),
127                        ));
128                    }
129
130                    Ok(response
131                        .data
132                        .into_iter()
133                        .zip(documents.into_iter())
134                        .map(|(embedding, document)| embeddings::Embedding {
135                            document,
136                            vec: embedding.embedding,
137                        })
138                        .collect())
139                }
140                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
141            }
142        } else {
143            let text = http_client::text(response).await?;
144            Err(EmbeddingError::ProviderError(text))
145        }
146    }
147}
148
149impl<T> EmbeddingModel<T> {
150    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
151        Self {
152            client,
153            model: model.into(),
154            ndims,
155        }
156    }
157
158    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
159        Self {
160            client,
161            model: model.into(),
162            ndims,
163        }
164    }
165}