use reqwest::Client;
use serde_json::json;
use crate::response::ResponseError;
#[derive(Debug, Clone)]
pub enum EmbeddingModel {
TextEmbedding3Small,
TextEmbedding3Large,
}
impl EmbeddingModel {
pub fn as_str(&self) -> &str {
match self {
EmbeddingModel::TextEmbedding3Small => "text-embedding-3-small",
EmbeddingModel::TextEmbedding3Large => "text-embedding-3-large",
}
}
pub fn dimensions(&self) -> usize {
match self {
EmbeddingModel::TextEmbedding3Small => 1536,
EmbeddingModel::TextEmbedding3Large => 3072,
}
}
}
#[derive(Debug, Clone)]
pub struct Embedding {
pub embedding: Vec<f32>,
pub model: String,
pub input: String,
}
impl Embedding {
pub fn new(input: String, model: EmbeddingModel) -> Self {
let embedding = vec![0.0; model.dimensions()];
Self {
embedding,
model: model.as_str().to_string(),
input
}
}
}
pub async fn get_embedding(input: &str, model: EmbeddingModel, api_key: &str) -> Result<Embedding, ResponseError> {
let client = Client::new();
let url = "https://api.openai.com/v1/embeddings";
let response = client
.post(url)
.header("Authorization", format!("Bearer {}", api_key))
.json(&json!({
"input": input,
"model": model.as_str(),
}))
.send()
.await
.map_err(|e| ResponseError::NetworkError(e.to_string()))?;
if !response.status().is_success() {
let error_text = match response.text().await {
Ok(text) => text,
Err(e) => format!("Failed to get error response: {}", e),
};
return Err(ResponseError::RequestError(format!("API request failed: {}", error_text)));
}
let data = response.json::<serde_json::Value>().await
.map_err(|e| ResponseError::ParseError(e.to_string()))?;
let embedding_data = data.get("data")
.and_then(|d| d.get(0))
.and_then(|d| d.get("embedding"))
.and_then(|e| e.as_array())
.ok_or_else(|| ResponseError::ParseError("***Failed to extract embedding from response".to_string()))?;
let embedding: Vec<f32> = embedding_data
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if embedding.len() != model.dimensions() {
return Err(ResponseError::ParseError(format!(
"***Expected {} dimensions but got {}",
model.dimensions(),
embedding.len()
)));
}
Ok(Embedding {
embedding,
model: model.as_str().to_string(),
input: input.to_string(),
})
}