langchain_rust/embedding/openai/
openai_embedder.rs1#![allow(dead_code)]
2
3use crate::embedding::{embedder_trait::Embedder, EmbedderError};
4pub use async_openai::config::{AzureConfig, Config, OpenAIConfig};
5use async_openai::{
6 types::{CreateEmbeddingRequestArgs, EmbeddingInput},
7 Client,
8};
9use async_trait::async_trait;
10
11#[derive(Debug)]
12pub struct OpenAiEmbedder<C: Config> {
13 config: C,
14 model: String,
15}
16
17impl<C: Config + Send + Sync + 'static> Into<Box<dyn Embedder>> for OpenAiEmbedder<C> {
18 fn into(self) -> Box<dyn Embedder> {
19 Box::new(self)
20 }
21}
22
23impl<C: Config> OpenAiEmbedder<C> {
24 pub fn new(config: C) -> Self {
25 OpenAiEmbedder {
26 config,
27 model: String::from("text-embedding-ada-002"),
28 }
29 }
30
31 pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
32 self.model = model.into();
33 self
34 }
35
36 pub fn with_config(mut self, config: C) -> Self {
37 self.config = config;
38 self
39 }
40}
41
42impl Default for OpenAiEmbedder<OpenAIConfig> {
43 fn default() -> Self {
44 OpenAiEmbedder::new(OpenAIConfig::default())
45 }
46}
47
48#[async_trait]
49impl<C: Config + Send + Sync> Embedder for OpenAiEmbedder<C> {
50 async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> {
51 let client = Client::with_config(self.config.clone());
52
53 let request = CreateEmbeddingRequestArgs::default()
54 .model(&self.model)
55 .input(EmbeddingInput::StringArray(documents.into()))
56 .build()?;
57
58 let response = client.embeddings().create(request).await?;
59
60 let embeddings = response
61 .data
62 .into_iter()
63 .map(|item| item.embedding)
64 .map(|embedding| {
65 embedding
66 .into_iter()
67 .map(|x| x as f64)
68 .collect::<Vec<f64>>()
69 })
70 .collect();
71
72 Ok(embeddings)
73 }
74
75 async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> {
76 let client = Client::with_config(self.config.clone());
77
78 let request = CreateEmbeddingRequestArgs::default()
79 .model(&self.model)
80 .input(text)
81 .build()?;
82
83 let mut response = client.embeddings().create(request).await?;
84
85 let item = response.data.swap_remove(0);
86
87 Ok(item
88 .embedding
89 .into_iter()
90 .map(|x| x as f64)
91 .collect::<Vec<f64>>())
92 }
93}