rig/providers/openai/
embedding.rs

1use super::{ApiErrorResponse, ApiResponse, Client, completion::Usage};
2use crate::embeddings::EmbeddingError;
3use crate::http_client::HttpClientExt;
4use crate::{embeddings, http_client};
5use serde::Deserialize;
6use serde_json::json;
7
8// ================================================================
9// OpenAI Embedding API
10// ================================================================
11/// `text-embedding-3-large` embedding model
12pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
13/// `text-embedding-3-small` embedding model
14pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
15/// `text-embedding-ada-002` embedding model
16pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
17
18#[derive(Debug, Deserialize)]
19pub struct EmbeddingResponse {
20    pub object: String,
21    pub data: Vec<EmbeddingData>,
22    pub model: String,
23    pub usage: Usage,
24}
25
26impl From<ApiErrorResponse> for EmbeddingError {
27    fn from(err: ApiErrorResponse) -> Self {
28        EmbeddingError::ProviderError(err.message)
29    }
30}
31
32impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
33    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
34        match value {
35            ApiResponse::Ok(response) => Ok(response),
36            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
37        }
38    }
39}
40
41#[derive(Debug, Deserialize)]
42pub struct EmbeddingData {
43    pub object: String,
44    pub embedding: Vec<f64>,
45    pub index: usize,
46}
47
48#[derive(Clone)]
49pub struct EmbeddingModel<T = reqwest::Client> {
50    client: Client<T>,
51    pub model: String,
52    ndims: usize,
53}
54
55impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
56where
57    T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
58{
59    const MAX_DOCUMENTS: usize = 1024;
60
61    fn ndims(&self) -> usize {
62        self.ndims
63    }
64
65    #[cfg_attr(feature = "worker", worker::send)]
66    async fn embed_texts(
67        &self,
68        documents: impl IntoIterator<Item = String>,
69    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
70        let documents = documents.into_iter().collect::<Vec<_>>();
71
72        let mut body = json!({
73            "model": self.model,
74            "input": documents,
75        });
76
77        if self.ndims > 0 {
78            body["dimensions"] = json!(self.ndims);
79        }
80
81        let body = serde_json::to_vec(&body)?;
82
83        let req = self
84            .client
85            .post("/embeddings")?
86            .header("Content-Type", "application/json")
87            .body(body)
88            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
89
90        let response = self.client.send(req).await?;
91
92        if response.status().is_success() {
93            let body: Vec<u8> = response.into_body().await?;
94            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
95
96            match body {
97                ApiResponse::Ok(response) => {
98                    tracing::info!(target: "rig",
99                        "OpenAI embedding token usage: {:?}",
100                        response.usage
101                    );
102
103                    if response.data.len() != documents.len() {
104                        return Err(EmbeddingError::ResponseError(
105                            "Response data length does not match input length".into(),
106                        ));
107                    }
108
109                    Ok(response
110                        .data
111                        .into_iter()
112                        .zip(documents.into_iter())
113                        .map(|(embedding, document)| embeddings::Embedding {
114                            document,
115                            vec: embedding.embedding,
116                        })
117                        .collect())
118                }
119                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
120            }
121        } else {
122            let text = http_client::text(response).await?;
123            Err(EmbeddingError::ProviderError(text))
124        }
125    }
126}
127
128impl<T> EmbeddingModel<T> {
129    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
130        Self {
131            client,
132            model: model.to_string(),
133            ndims,
134        }
135    }
136}