rig/providers/openai/
embedding.rs1use super::{ApiErrorResponse, ApiResponse, Client, completion::Usage};
2use crate::embeddings;
3use crate::embeddings::EmbeddingError;
4use serde::Deserialize;
5use serde_json::json;
6
7pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large";
12pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small";
14pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";
16
17#[derive(Debug, Deserialize)]
18pub struct EmbeddingResponse {
19 pub object: String,
20 pub data: Vec<EmbeddingData>,
21 pub model: String,
22 pub usage: Usage,
23}
24
25impl From<ApiErrorResponse> for EmbeddingError {
26 fn from(err: ApiErrorResponse) -> Self {
27 EmbeddingError::ProviderError(err.message)
28 }
29}
30
31impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
32 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
33 match value {
34 ApiResponse::Ok(response) => Ok(response),
35 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
36 }
37 }
38}
39
40#[derive(Debug, Deserialize)]
41pub struct EmbeddingData {
42 pub object: String,
43 pub embedding: Vec<f64>,
44 pub index: usize,
45}
46
47#[derive(Clone)]
48pub struct EmbeddingModel {
49 client: Client,
50 pub model: String,
51 ndims: usize,
52}
53
54impl embeddings::EmbeddingModel for EmbeddingModel {
55 const MAX_DOCUMENTS: usize = 1024;
56
57 fn ndims(&self) -> usize {
58 self.ndims
59 }
60
61 #[cfg_attr(feature = "worker", worker::send)]
62 async fn embed_texts(
63 &self,
64 documents: impl IntoIterator<Item = String>,
65 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
66 let documents = documents.into_iter().collect::<Vec<_>>();
67
68 let response = self
69 .client
70 .post("/embeddings")
71 .json(&json!({
72 "model": self.model,
73 "input": documents,
74 }))
75 .send()
76 .await?;
77
78 if response.status().is_success() {
79 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
80 ApiResponse::Ok(response) => {
81 tracing::info!(target: "rig",
82 "OpenAI embedding token usage: {:?}",
83 response.usage
84 );
85
86 if response.data.len() != documents.len() {
87 return Err(EmbeddingError::ResponseError(
88 "Response data length does not match input length".into(),
89 ));
90 }
91
92 Ok(response
93 .data
94 .into_iter()
95 .zip(documents.into_iter())
96 .map(|(embedding, document)| embeddings::Embedding {
97 document,
98 vec: embedding.embedding,
99 })
100 .collect())
101 }
102 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
103 }
104 } else {
105 Err(EmbeddingError::ProviderError(response.text().await?))
106 }
107 }
108}
109
110impl EmbeddingModel {
111 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
112 Self {
113 client,
114 model: model.to_string(),
115 ndims,
116 }
117 }
118}