ai_chain_openai/
embeddings.rs1use std::sync::Arc;
2
3use async_openai::{
4 config::OpenAIConfig,
5 error::OpenAIError,
6 types::{CreateEmbeddingRequestArgs, EmbeddingInput},
7};
8use async_trait::async_trait;
9use ai_chain::traits::{self, EmbeddingsError};
10use thiserror::Error;
11
12pub struct Embeddings {
13 client: Arc<async_openai::Client<OpenAIConfig>>,
14 model: String,
15}
16
17#[derive(Debug, Error)]
18#[error(transparent)]
19pub enum OpenAIEmbeddingsError {
20 #[error(transparent)]
21 Client(#[from] OpenAIError),
22 #[error("Request to OpenAI embeddings API was successful but response is empty")]
23 EmptyResponse,
24}
25
26impl EmbeddingsError for OpenAIEmbeddingsError {}
27
28#[async_trait]
29impl traits::Embeddings for Embeddings {
30 type Error = OpenAIEmbeddingsError;
31
32 async fn embed_texts(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, Self::Error> {
33 let req = CreateEmbeddingRequestArgs::default()
34 .model(self.model.clone())
35 .input(EmbeddingInput::from(texts))
36 .build()?;
37 self.client
38 .embeddings()
39 .create(req)
40 .await
41 .map(|r| r.data.into_iter().map(|e| e.embedding).collect())
42 .map_err(|e| e.into())
43 }
44
45 async fn embed_query(&self, query: String) -> Result<Vec<f32>, Self::Error> {
46 let req = CreateEmbeddingRequestArgs::default()
47 .model(self.model.clone())
48 .input(EmbeddingInput::from(query))
49 .build()?;
50 self.client
51 .embeddings()
52 .create(req)
53 .await
54 .map(|r| r.data.into_iter())?
55 .map(|e| e.embedding)
56 .last()
57 .ok_or(OpenAIEmbeddingsError::EmptyResponse)
58 }
59}
60
61impl Default for Embeddings {
62 fn default() -> Self {
63 let client = Arc::new(async_openai::Client::<OpenAIConfig>::new());
64 Self {
65 client,
66 model: "text-embedding-ada-002".to_string(),
67 }
68 }
69}
70
71impl Embeddings {
72 pub fn for_client(client: async_openai::Client<OpenAIConfig>, model: &str) -> Self {
73 Self {
74 client: client.into(),
75 model: model.to_string(),
76 }
77 }
78}