rig/providers/openai/
embedding.rs

1use super::{ApiErrorResponse, ApiResponse, Client, completion::Usage};
2use crate::embeddings;
3use crate::embeddings::EmbeddingError;
4use serde::Deserialize;
5use serde_json::json;
6
7// ================================================================
8// OpenAI Embedding API
9// ================================================================
10/// `text-embedding-3-large` embedding model
11pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
12/// `text-embedding-3-small` embedding model
13pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
14/// `text-embedding-ada-002` embedding model
15pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
16
17#[derive(Debug, Deserialize)]
18pub struct EmbeddingResponse {
19    pub object: String,
20    pub data: Vec<EmbeddingData>,
21    pub model: String,
22    pub usage: Usage,
23}
24
25impl From<ApiErrorResponse> for EmbeddingError {
26    fn from(err: ApiErrorResponse) -> Self {
27        EmbeddingError::ProviderError(err.message)
28    }
29}
30
31impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
32    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
33        match value {
34            ApiResponse::Ok(response) => Ok(response),
35            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
36        }
37    }
38}
39
40#[derive(Debug, Deserialize)]
41pub struct EmbeddingData {
42    pub object: String,
43    pub embedding: Vec<f64>,
44    pub index: usize,
45}
46
47#[derive(Clone)]
48pub struct EmbeddingModel {
49    client: Client,
50    pub model: String,
51    ndims: usize,
52}
53
54impl embeddings::EmbeddingModel for EmbeddingModel {
55    const MAX_DOCUMENTS: usize = 1024;
56
57    fn ndims(&self) -> usize {
58        self.ndims
59    }
60
61    #[cfg_attr(feature = "worker", worker::send)]
62    async fn embed_texts(
63        &self,
64        documents: impl IntoIterator<Item = String>,
65    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
66        let documents = documents.into_iter().collect::<Vec<_>>();
67
68        let response = self
69            .client
70            .post("/embeddings")
71            .json(&json!({
72                "model": self.model,
73                "input": documents,
74            }))
75            .send()
76            .await?;
77
78        if response.status().is_success() {
79            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
80                ApiResponse::Ok(response) => {
81                    tracing::info!(target: "rig",
82                        "OpenAI embedding token usage: {:?}",
83                        response.usage
84                    );
85
86                    if response.data.len() != documents.len() {
87                        return Err(EmbeddingError::ResponseError(
88                            "Response data length does not match input length".into(),
89                        ));
90                    }
91
92                    Ok(response
93                        .data
94                        .into_iter()
95                        .zip(documents.into_iter())
96                        .map(|(embedding, document)| embeddings::Embedding {
97                            document,
98                            vec: embedding.embedding,
99                        })
100                        .collect())
101                }
102                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
103            }
104        } else {
105            Err(EmbeddingError::ProviderError(response.text().await?))
106        }
107    }
108}
109
110impl EmbeddingModel {
111    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
112        Self {
113            client,
114            model: model.to_string(),
115            ndims,
116        }
117    }
118}