use std::env;
use serde_derive::{Deserialize, Serialize};
use crate::openai::get_client;
// Input structures
// Function Calls
#[derive(Debug, Serialize, Clone)]
pub struct Functions {
pub model: String,
pub tools: Vec<FunctionCall>,
}
#[derive(Debug, Serialize, Clone)]
pub struct FunctionCall {
pub r#type: String,
pub function: Function,
}
#[derive(Debug, Serialize, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: Parameters,
pub required: Vec<String>
}
#[derive(Debug, Serialize, Clone)]
pub struct Parameters {
pub r#type: String,
pub properties: Properties,
pub required: Vec<String>
}
#[derive(Debug, Serialize, Clone)]
pub struct Properties {
#[serde(rename = "function")]
pub function: Vec<ParameterType>
}
#[derive(Debug, Serialize, Clone)]
pub struct ParameterType {
pub r#type: String,
pub r#enum: String,
pub description: String,
}
#[derive(Debug, Deserialize)]
pub struct FunctionResponse {
pub data: Vec<EmbeddingData>,
pub id: String,
pub model: String,
pub stop_reason: String,
pub role: String,
}
#[derive(Debug, Deserialize)]
pub struct FunctionResponse {
pub r#type: String,
pub id: String,
pub name: String,
pub input: Vec<ParameterValues>
}
#[derive(Debug, Deserialize)]
pub struct ParameterValues {
#[serde(rename = "parameter")]
pub parameter: String,
}
pub async fn call_embedding_model(model: &str, input: &[String]) -> Result<Vec<Vec<f32>>, Box<dyn std::error::Error + Send>> {
// Confirm endpoint
let url: String =
env::var("GPT_EMBEDDING_URL").expect("GPT_EMBEDDING_URL not found in enviroment variables");
let client = get_client().await?;
// Create chat completion
let embedding = Embedding {
input: input.to_vec(),
model: model.into(),
dimensions: 384, // 1536
};
// Extract API Response
let res = client
.post(url)
.json(&embedding)
.send()
.await;
let res: EmbeddingResponse = res
.map_err(|e| -> Box<dyn std::error::Error + Send> { Box::new(e) })?
.json()
.await
.map_err(|e| -> Box<dyn std::error::Error + Send> { Box::new(e) })?;
let embed: Vec<Vec<f32>> = res.data.iter().map(|e| e.embedding.clone()).collect();
// Send Response
Ok(embed)
}
#[cfg(test)]
mod tests {
use super::*;
pub fn cosine_dist(a: &[f32], b: &[f32]) -> f32 {
assert!(a.len() == b.len());
let dot: f32 = a.iter().zip(b.iter()).fold(0.0, |a, (x, y)| a + x * y);
let ma: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1.0 - dot / ma * mb
}
#[tokio::test]
async fn test_call_embed() {
let messages: Vec<String> = vec!["What is the meaining of life?".into(), "What is the purpose of death?".into()];
let model: String = std::env::var("GPT_EMBEDDING_VERSION")
.expect("GPT_EMBEDDING_VERSION not found in enviroment variables");
match call_embedding_model(&model, &messages).await {
Ok(answer) => { println!("{}", cosine_dist(&answer[0], &answer[1])); assert!(true) },
//Ok(answer) => { println!("{answer:?}"); assert!(true) },
Err(e) => { println!("{e}"); assert!(false) },
}
}
}