ai_chain_openai/
embeddings.rs

1use std::sync::Arc;
2
3use async_openai::{
4    config::OpenAIConfig,
5    error::OpenAIError,
6    types::{CreateEmbeddingRequestArgs, EmbeddingInput},
7};
8use async_trait::async_trait;
9use ai_chain::traits::{self, EmbeddingsError};
10use thiserror::Error;
11
12pub struct Embeddings {
13    client: Arc<async_openai::Client<OpenAIConfig>>,
14    model: String,
15}
16
17#[derive(Debug, Error)]
18#[error(transparent)]
19pub enum OpenAIEmbeddingsError {
20    #[error(transparent)]
21    Client(#[from] OpenAIError),
22    #[error("Request to OpenAI embeddings API was successful but response is empty")]
23    EmptyResponse,
24}
25
26impl EmbeddingsError for OpenAIEmbeddingsError {}
27
28#[async_trait]
29impl traits::Embeddings for Embeddings {
30    type Error = OpenAIEmbeddingsError;
31
32    async fn embed_texts(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, Self::Error> {
33        let req = CreateEmbeddingRequestArgs::default()
34            .model(self.model.clone())
35            .input(EmbeddingInput::from(texts))
36            .build()?;
37        self.client
38            .embeddings()
39            .create(req)
40            .await
41            .map(|r| r.data.into_iter().map(|e| e.embedding).collect())
42            .map_err(|e| e.into())
43    }
44
45    async fn embed_query(&self, query: String) -> Result<Vec<f32>, Self::Error> {
46        let req = CreateEmbeddingRequestArgs::default()
47            .model(self.model.clone())
48            .input(EmbeddingInput::from(query))
49            .build()?;
50        self.client
51            .embeddings()
52            .create(req)
53            .await
54            .map(|r| r.data.into_iter())?
55            .map(|e| e.embedding)
56            .last()
57            .ok_or(OpenAIEmbeddingsError::EmptyResponse)
58    }
59}
60
61impl Default for Embeddings {
62    fn default() -> Self {
63        let client = Arc::new(async_openai::Client::<OpenAIConfig>::new());
64        Self {
65            client,
66            model: "text-embedding-ada-002".to_string(),
67        }
68    }
69}
70
71impl Embeddings {
72    pub fn for_client(client: async_openai::Client<OpenAIConfig>, model: &str) -> Self {
73        Self {
74            client: client.into(),
75            model: model.to_string(),
76        }
77    }
78}