rig/providers/openai/
embedding.rs1use super::{ApiErrorResponse, ApiResponse, Client, completion::Usage};
2use crate::embeddings::EmbeddingError;
3use crate::http_client::HttpClientExt;
4use crate::{embeddings, http_client};
5use serde::Deserialize;
6use serde_json::json;
7
8pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
13pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
15pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
17
18#[derive(Debug, Deserialize)]
19pub struct EmbeddingResponse {
20 pub object: String,
21 pub data: Vec<EmbeddingData>,
22 pub model: String,
23 pub usage: Usage,
24}
25
26impl From<ApiErrorResponse> for EmbeddingError {
27 fn from(err: ApiErrorResponse) -> Self {
28 EmbeddingError::ProviderError(err.message)
29 }
30}
31
32impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
33 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
34 match value {
35 ApiResponse::Ok(response) => Ok(response),
36 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
37 }
38 }
39}
40
41#[derive(Debug, Deserialize)]
42pub struct EmbeddingData {
43 pub object: String,
44 pub embedding: Vec<f64>,
45 pub index: usize,
46}
47
48#[derive(Clone)]
49pub struct EmbeddingModel<T = reqwest::Client> {
50 client: Client<T>,
51 pub model: String,
52 ndims: usize,
53}
54
55impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
56where
57 T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
58{
59 const MAX_DOCUMENTS: usize = 1024;
60
61 fn ndims(&self) -> usize {
62 self.ndims
63 }
64
65 #[cfg_attr(feature = "worker", worker::send)]
66 async fn embed_texts(
67 &self,
68 documents: impl IntoIterator<Item = String>,
69 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
70 let documents = documents.into_iter().collect::<Vec<_>>();
71
72 let mut body = json!({
73 "model": self.model,
74 "input": documents,
75 });
76
77 if self.ndims > 0 {
78 body["dimensions"] = json!(self.ndims);
79 }
80
81 let body = serde_json::to_vec(&body)?;
82
83 let req = self
84 .client
85 .post("/embeddings")?
86 .header("Content-Type", "application/json")
87 .body(body)
88 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
89
90 let response = self.client.send(req).await?;
91
92 if response.status().is_success() {
93 let body: Vec<u8> = response.into_body().await?;
94 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
95
96 match body {
97 ApiResponse::Ok(response) => {
98 tracing::info!(target: "rig",
99 "OpenAI embedding token usage: {:?}",
100 response.usage
101 );
102
103 if response.data.len() != documents.len() {
104 return Err(EmbeddingError::ResponseError(
105 "Response data length does not match input length".into(),
106 ));
107 }
108
109 Ok(response
110 .data
111 .into_iter()
112 .zip(documents.into_iter())
113 .map(|(embedding, document)| embeddings::Embedding {
114 document,
115 vec: embedding.embedding,
116 })
117 .collect())
118 }
119 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
120 }
121 } else {
122 let text = http_client::text(response).await?;
123 Err(EmbeddingError::ProviderError(text))
124 }
125 }
126}
127
128impl<T> EmbeddingModel<T> {
129 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
130 Self {
131 client,
132 model: model.to_string(),
133 ndims,
134 }
135 }
136}