use std::env;
use serde_derive::{Deserialize, Serialize};
use crate::openai::get_client;
#[derive(Debug, Serialize, Clone)]
pub struct Embedding {
pub input: Vec<String>,
pub model: String,
pub dimensions: usize,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub data: Vec<EmbeddingData>,
pub model: String,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub index: usize,
pub embedding: Vec<f32>,
}
pub async fn call_embedding_model(model: &str, input: &[String]) -> Result<Vec<Vec<f32>>, Box<dyn std::error::Error + Send>> {
let url: String =
env::var("GPT_EMBEDDING_URL").expect("GPT_EMBEDDING_URL not found in enviroment variables");
let client = get_client().await?;
let embedding = Embedding {
input: input.to_vec(),
model: model.into(),
dimensions: 384, };
let res = client
.post(url)
.json(&embedding)
.send()
.await;
let res: EmbeddingResponse = res
.map_err(|e| -> Box<dyn std::error::Error + Send> { Box::new(e) })?
.json()
.await
.map_err(|e| -> Box<dyn std::error::Error + Send> { Box::new(e) })?;
let embed: Vec<Vec<f32>> = res.data.iter().map(|e| e.embedding.clone()).collect();
Ok(embed)
}
#[cfg(test)]
mod tests {
use super::*;
pub fn cosine_dist(a: &[f32], b: &[f32]) -> f32 {
assert!(a.len() == b.len());
let dot: f32 = a.iter().zip(b.iter()).fold(0.0, |a, (x, y)| a + x * y);
let ma: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1.0 - dot / ma * mb
}
#[tokio::test]
async fn test_call_embed() {
let messages: Vec<String> = vec!["What is the meaining of life?".into(), "What is the purpose of death?".into()];
let model: String = std::env::var("GPT_EMBEDDING_VERSION")
.expect("GPT_EMBEDDING_VERSION not found in enviroment variables");
match call_embedding_model(&model, &messages).await {
Ok(answer) => { println!("{}", cosine_dist(&answer[0], &answer[1])); assert!(true) },
Err(e) => { println!("{e}"); assert!(false) },
}
}
}