vectus/model/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::{error::Error, sync::Arc};

#[derive(Serialize)]
struct EmbeddingRequest {
    input: String,
    model: String,
    encoding_format: String,
}

#[derive(Deserialize)]
struct EmbeddingResponse {
    data: Vec<EmbeddingData>,
}

#[derive(Deserialize)]
struct EmbeddingData {
    embedding: Vec<f64>,
}

pub enum ModelType {
    OpenAI(String),
}

pub struct Model {
    pub name: Arc<String>,
    pub url: Arc<String>,
    pub api_key: Arc<String>,
}

impl Model {
    pub fn new(model_name: ModelType) -> Self {
        match model_name {
            ModelType::OpenAI(api_key) => Model {
                name: "text-embedding-3-large".to_string().into(),
                url: "https://api.openai.com/v1/embeddings".to_string().into(),
                api_key: api_key.into(),
            },
        }
    }

    pub async fn get_embedding(&self, input: &String) -> Result<Vec<f64>, Box<dyn Error>> {
        let request = EmbeddingRequest {
            input: input.clone(),
            model: self.name.to_string(),
            encoding_format: "float".to_string(),
        };

        let client = Client::new();

        let response = client
            .post(&*self.url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .json(&request)
            .send()
            .await?
            .json::<EmbeddingResponse>()
            .await?;

        Ok(response.data[0].embedding.clone())
    }
}