use crate::clients::open_ai::model::embeddings::{
BatchEmbeddingRequest, EmbeddingObject, EmbeddingRequest, EmbeddingResponse,
};
use crate::clients::open_ai::model::errors::OpenAIError;
use crate::clients::open_ai::open_ai_core::OpenAIHttpClient;
use crate::clients::traits::AsyncEmbeddingClient;
use crate::common::{Chunk, Chunks, Embedding, OpenAIEmbeddingModel};
use std::env::VarError;
const OPENAI_EMBEDDING_URL: &str = "https://api.openai.com/v1/embeddings";
pub struct OpenAIEmbeddingClient {
url: String,
client: OpenAIHttpClient,
embedding_model: OpenAIEmbeddingModel,
}
impl OpenAIEmbeddingClient {
pub fn try_new(
embedding_model: OpenAIEmbeddingModel,
) -> Result<OpenAIEmbeddingClient, VarError> {
let client: OpenAIHttpClient = OpenAIHttpClient::try_new()?;
Ok(OpenAIEmbeddingClient {
url: OPENAI_EMBEDDING_URL.into(),
client,
embedding_model,
})
}
fn handle_embedding_success_response(
input_text: Chunks,
response: EmbeddingResponse,
) -> Vec<Embedding> {
let embedding_objects: Vec<EmbeddingObject> = response.data;
embedding_objects
.into_iter()
.zip(input_text)
.map(|(embedding_object, chunk)| Embedding::new(chunk, embedding_object.embedding))
.collect()
}
}
impl AsyncEmbeddingClient for OpenAIEmbeddingClient {
type ErrorType = OpenAIError;
async fn generate_embeddings(&self, text: Chunks) -> Result<Vec<Embedding>, OpenAIError> {
let input_text: Vec<String> = text
.iter()
.map(|chunk| (*chunk).content().to_string())
.collect();
let request_body = BatchEmbeddingRequest::builder()
.input(input_text)
.model(self.embedding_model)
.build();
let response: EmbeddingResponse = self.client.send_request(request_body, &self.url).await?;
Ok(Self::handle_embedding_success_response(text, response))
}
async fn generate_embedding(&self, text: Chunk) -> Result<Embedding, Self::ErrorType> {
let request_body = EmbeddingRequest::builder()
.input(text.content().to_string())
.model(self.embedding_model)
.build();
let response: EmbeddingResponse = self.client.send_request(request_body, &self.url).await?;
Ok(Self::handle_embedding_success_response(vec![text], response)[0].clone())
}
}
#[cfg(test)]
mod embedding_client_tests {
use super::*;
use crate::clients::open_ai::model::errors::{OpenAIErrorBody, OpenAIErrorData};
use mockito::{Mock, Server, ServerGuard};
const EMBEDDING_RESPONSE: &'static str = r#"
{
"data": [
{
"embedding": [
-0.006929283495992422,
-0.005336422007530928,
-0.009327292,
-0.024047505110502243
],
"index": 0,
"object": "embedding"
},
{
"embedding": [
-0.006929283495992422,
-0.005336422007530928,
-0.009327292,
-0.024047505110502243
],
"index": 1,
"object": "embedding"
}
],
"model": "text-embedding-ada-002",
"object": "list",
"usage": {
"prompt_tokens": 5,
"total_tokens": 5
}
}
"#;
const ERROR_RESPONSE: &'static str = r#"
{
"error": {
"message": "Incorrect API key provided: fdas. You can find your API key at https://platform.openai.com/account/api-keys.",
"type": "invalid_request_error",
"param": null,
"code": "invalid_api_key"
}
}
"#;
#[tokio::test]
async fn test_correct_response_succeeds() {
let (client, mut server) = with_mocked_client().await;
let mock = with_mocked_request(&mut server, 200, EMBEDDING_RESPONSE);
let expected_embedding = vec![
-0.006929283495992422,
-0.005336422007530928,
-0.009327292,
-0.024047505110502243,
];
let chunks: Chunks = vec![Chunk::new("Test-0"), Chunk::new("Test-1")];
let response = client.generate_embeddings(chunks).await.unwrap();
mock.assert();
for (i, embedding) in response.into_iter().enumerate() {
assert_eq!(*embedding.chunk(), Chunk::new(format!("Test-{}", i)));
assert_eq!(embedding.vector(), expected_embedding);
}
let chunk = Chunk::new("Test-0");
let response = client.generate_embedding(chunk).await.unwrap();
assert_eq!(*response.chunk(), Chunk::new("Test-0"));
assert_eq!(response.vector(), expected_embedding);
}
#[tokio::test]
async fn test_400_gives_correct_error() {
let (client, mut server) = with_mocked_client().await;
let mock = with_mocked_request(&mut server, 400, ERROR_RESPONSE);
let expected_response = OpenAIError::CODE400(OpenAIErrorBody {
error: OpenAIErrorData {
message: "Incorrect API key provided: fdas. You can find your API key at https://platform.openai.com/account/api-keys.".to_string(),
error_type: "invalid_request_error".to_string(),
param: None,
code: "invalid_api_key".to_string()
}
});
let chunks: Chunks = vec![Chunk::new("Test-0"), Chunk::new("Test-1")];
let response = client.generate_embeddings(chunks).await.unwrap_err();
mock.assert();
assert_eq!(response, expected_response);
let chunk = Chunk::new("Test-0");
let response = client.generate_embedding(chunk).await.unwrap_err();
assert_eq!(response, expected_response);
}
fn with_mocked_request(
server: &mut ServerGuard,
status_code: usize,
response_body: &str,
) -> Mock {
server
.mock("POST", "/")
.with_status(status_code)
.with_header("content-type", "application/json")
.with_body(response_body)
.create()
}
async fn with_mocked_client() -> (OpenAIEmbeddingClient, ServerGuard) {
std::env::set_var("OPENAI_API_KEY", "fake key");
let server = Server::new_async().await;
let url = server.url();
let model = OpenAIEmbeddingModel::TextEmbeddingAda002;
let mut client = OpenAIEmbeddingClient::try_new(model).unwrap();
client.url = url;
(client, server)
}
}