use async_openai::{config::OpenAIConfig, types::CreateEmbeddingRequestArgs, Client};
use serde::{Deserialize, Serialize};
use std::{error::Error, sync::Arc};
pub enum ModelType {
OpenAI(String),
}
pub struct Model {
client: Arc<Client<OpenAIConfig>>,
}
impl Model {
pub fn new(model_name: ModelType) -> Self {
match model_name {
ModelType::OpenAI(api_key) => Model {
client: Arc::new(Client::with_config(
OpenAIConfig::new().with_api_key(api_key),
)),
},
}
}
pub async fn get_embedding(&self, input: &String) -> Result<Vec<f32>, Box<dyn Error>> {
let request = CreateEmbeddingRequestArgs::default()
.model("text-embedding-3-large")
.input([input])
.build()
.unwrap();
let response = self.client.embeddings().create(request).await.unwrap();
Ok(response.data[0].embedding.clone())
}
}