use serde::{Deserialize, Serialize};
use std::error::Error;
#[derive(Serialize)]
struct EmbeddingRequest {
input: String,
model: String,
encoding_format: String,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Deserialize)]
struct EmbeddingData {
embedding: Vec<f64>,
}
pub enum ModelType {
OpenAI,
}
pub struct Model {
pub name: String,
pub url: String,
pub api_key: String,
}
impl Model {
pub fn new(model_name: ModelType, api_key: String) -> Self {
match model_name {
ModelType::OpenAI => Model {
name: "text-embedding-3-large".to_string(),
url: "https://api.openai.com/v1/embeddings".to_string(),
api_key,
},
}
}
pub async fn get_embedding(&self, input: &String) -> Result<Vec<f64>, Box<dyn Error>> {
let request = EmbeddingRequest {
input: input.clone(),
model: self.name.clone(),
encoding_format: "float".to_string(),
};
let client = reqwest::Client::new();
let response = client
.post(&self.url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&request)
.send()
.await?
.json::<EmbeddingResponse>()
.await?;
Ok(response.data[0].embedding.clone())
}
}