rig/providers/openai/
embedding.rs1use super::{
2 Client,
3 client::{ApiErrorResponse, ApiResponse},
4 completion::Usage,
5};
6use crate::embeddings::EmbeddingError;
7use crate::http_client::HttpClientExt;
8use crate::{embeddings, http_client};
9use serde::Deserialize;
10use serde_json::json;
11
12pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
17pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
19pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
21
22#[derive(Debug, Deserialize)]
23pub struct EmbeddingResponse {
24 pub object: String,
25 pub data: Vec<EmbeddingData>,
26 pub model: String,
27 pub usage: Usage,
28}
29
30impl From<ApiErrorResponse> for EmbeddingError {
31 fn from(err: ApiErrorResponse) -> Self {
32 EmbeddingError::ProviderError(err.message)
33 }
34}
35
36impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
37 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
38 match value {
39 ApiResponse::Ok(response) => Ok(response),
40 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
41 }
42 }
43}
44
45#[derive(Debug, Deserialize)]
46pub struct EmbeddingData {
47 pub object: String,
48 pub embedding: Vec<f64>,
49 pub index: usize,
50}
51
52#[derive(Clone)]
53pub struct EmbeddingModel<T = reqwest::Client> {
54 client: Client<T>,
55 pub model: String,
56 ndims: usize,
57}
58
59fn model_dimensions_from_identifier(identifier: &str) -> Option<usize> {
60 match identifier {
61 TEXT_EMBEDDING_3_LARGE => Some(3_072),
62 TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => Some(1_536),
63 _ => None,
64 }
65}
66
67impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
68where
69 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
70{
71 const MAX_DOCUMENTS: usize = 1024;
72
73 type Client = Client<T>;
74
75 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
76 let model = model.into();
77 let dims = ndims
78 .or(model_dimensions_from_identifier(&model))
79 .unwrap_or_default();
80
81 Self::new(client.clone(), model, dims)
82 }
83
84 fn ndims(&self) -> usize {
85 self.ndims
86 }
87
88 #[cfg_attr(feature = "worker", worker::send)]
89 async fn embed_texts(
90 &self,
91 documents: impl IntoIterator<Item = String>,
92 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
93 let documents = documents.into_iter().collect::<Vec<_>>();
94
95 let mut body = json!({
96 "model": self.model,
97 "input": documents,
98 });
99
100 if self.ndims > 0 && self.model.as_str() != TEXT_EMBEDDING_ADA_002 {
101 body["dimensions"] = json!(self.ndims);
102 }
103
104 let body = serde_json::to_vec(&body)?;
105
106 let req = self
107 .client
108 .post("/embeddings")?
109 .body(body)
110 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
111
112 let response = self.client.send(req).await?;
113
114 if response.status().is_success() {
115 let body: Vec<u8> = response.into_body().await?;
116 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
117
118 match body {
119 ApiResponse::Ok(response) => {
120 tracing::info!(target: "rig",
121 "OpenAI embedding token usage: {:?}",
122 response.usage
123 );
124
125 if response.data.len() != documents.len() {
126 return Err(EmbeddingError::ResponseError(
127 "Response data length does not match input length".into(),
128 ));
129 }
130
131 Ok(response
132 .data
133 .into_iter()
134 .zip(documents.into_iter())
135 .map(|(embedding, document)| embeddings::Embedding {
136 document,
137 vec: embedding.embedding,
138 })
139 .collect())
140 }
141 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
142 }
143 } else {
144 let text = http_client::text(response).await?;
145 Err(EmbeddingError::ProviderError(text))
146 }
147 }
148}
149
150impl<T> EmbeddingModel<T> {
151 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
152 Self {
153 client,
154 model: model.into(),
155 ndims,
156 }
157 }
158
159 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
160 Self {
161 client,
162 model: model.into(),
163 ndims,
164 }
165 }
166}