1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#![allow(dead_code)]

use crate::embedding::{embedder_trait::Embedder, EmbedderError};
pub use async_openai::config::{AzureConfig, Config, OpenAIConfig};
use async_openai::{
    types::{CreateEmbeddingRequestArgs, EmbeddingInput},
    Client,
};
use async_trait::async_trait;

#[derive(Debug)]
pub struct OpenAiEmbedder<C: Config> {
    config: C,
    model: String,
}

impl<C: Config + Send + Sync + 'static> Into<Box<dyn Embedder>> for OpenAiEmbedder<C> {
    fn into(self) -> Box<dyn Embedder> {
        Box::new(self)
    }
}

impl<C: Config> OpenAiEmbedder<C> {
    pub fn new(config: C) -> Self {
        OpenAiEmbedder {
            config,
            model: String::from("text-embedding-ada-002"),
        }
    }

    pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
        self.model = model.into();
        self
    }

    pub fn with_config(mut self, config: C) -> Self {
        self.config = config;
        self
    }
}

impl Default for OpenAiEmbedder<OpenAIConfig> {
    fn default() -> Self {
        OpenAiEmbedder::new(OpenAIConfig::default())
    }
}

#[async_trait]
impl<C: Config + Send + Sync> Embedder for OpenAiEmbedder<C> {
    async fn embed_documents(&self, documents: &[String]) -> Result<Vec<Vec<f64>>, EmbedderError> {
        let client = Client::with_config(self.config.clone());

        let request = CreateEmbeddingRequestArgs::default()
            .model(&self.model)
            .input(EmbeddingInput::StringArray(documents.into()))
            .build()?;

        let response = client.embeddings().create(request).await?;

        let embeddings = response
            .data
            .into_iter()
            .map(|item| item.embedding)
            .map(|embedding| {
                embedding
                    .into_iter()
                    .map(|x| x as f64)
                    .collect::<Vec<f64>>()
            })
            .collect();

        Ok(embeddings)
    }

    async fn embed_query(&self, text: &str) -> Result<Vec<f64>, EmbedderError> {
        let client = Client::with_config(self.config.clone());

        let request = CreateEmbeddingRequestArgs::default()
            .model(&self.model)
            .input(text)
            .build()?;

        let mut response = client.embeddings().create(request).await?;

        let item = response.data.swap_remove(0);

        Ok(item
            .embedding
            .into_iter()
            .map(|x| x as f64)
            .collect::<Vec<f64>>())
    }
}