use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::{error::Error, sync::Arc};
#[derive(Serialize)]
struct EmbeddingRequest {
input: String,
model: String,
encoding_format: String,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
}
pub enum ModelType {
OpenAI(String),
}
pub struct Model {
pub name: Arc<String>,
pub url: Arc<String>,
pub api_key: Arc<String>,
}
impl Model {
pub fn new(model_name: ModelType) -> Self {
match model_name {
ModelType::OpenAI(api_key) => Model {
name: "text-embedding-3-large".to_string().into(),
url: "https://api.openai.com/v1/embeddings".to_string().into(),
api_key: api_key.into(),
},
}
}
pub async fn get_embedding(&self, input: &String) -> Result<Vec<f64>, Box<dyn Error>> {
let request = EmbeddingRequest {
input: input.clone(),
model: self.name.to_string(),
encoding_format: "float".to_string(),
};
let client = Client::new();
let response = client
.post(&*self.url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await?
.json::<EmbeddingResponse>()
.await?;
Ok(response.data[0].embedding.clone())
}
}