langchain_rust/embedding/openai/
openai_embedder.rs

1#![allow(dead_code)]
2
3use crate::embedding::{embedder_trait::Embedder, EmbedderError};
4pub use async_openai::config::{AzureConfig, Config, OpenAIConfig};
5use async_openai::{
6    types::{CreateEmbeddingRequestArgs, EmbeddingInput},
7    Client,
8};
9use async_trait::async_trait;
10
11#[derive(Debug)]
12pub struct OpenAiEmbedder<C: Config> {
13    config: C,
14    model: String,
15}
16
17impl<C: Config + Send + Sync + 'static> Into<Box<dyn Embedder>> for OpenAiEmbedder<C> {
18    fn into(self) -> Box<dyn Embedder> {
19        Box::new(self)
20    }
21}
22
23impl<C: Config> OpenAiEmbedder<C> {
24    pub fn new(config: C) -> Self {
25        OpenAiEmbedder {
26            config,
27            model: String::from("text-embedding-ada-002"),
28        }
29    }
30
31    pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
32        self.model = model.into();
33        self
34    }
35
36    pub fn with_config(mut self, config: C) -> Self {
37        self.config = config;
38        self
39    }
40}
41
42impl Default for OpenAiEmbedder<OpenAIConfig> {
43    fn default() -> Self {
44        OpenAiEmbedder::new(OpenAIConfig::default())
45    }
46}
47
48#[async_trait]
49impl<C: Config + Send + Sync> Embedder for OpenAiEmbedder<C> {
50    async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> {
51        let client = Client::with_config(self.config.clone());
52
53        let request = CreateEmbeddingRequestArgs::default()
54            .model(&self.model)
55            .input(EmbeddingInput::StringArray(documents.into()))
56            .build()?;
57
58        let response = client.embeddings().create(request).await?;
59
60        let embeddings = response
61            .data
62            .into_iter()
63            .map(|item| item.embedding)
64            .map(|embedding| {
65                embedding
66                    .into_iter()
67                    .map(|x| x as f64)
68                    .collect::<Vec<f64>>()
69            })
70            .collect();
71
72        Ok(embeddings)
73    }
74
75    async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> {
76        let client = Client::with_config(self.config.clone());
77
78        let request = CreateEmbeddingRequestArgs::default()
79            .model(&self.model)
80            .input(text)
81            .build()?;
82
83        let mut response = client.embeddings().create(request).await?;
84
85        let item = response.data.swap_remove(0);
86
87        Ok(item
88            .embedding
89            .into_iter()
90            .map(|x| x as f64)
91            .collect::<Vec<f64>>())
92    }
93}