use async_trait::async_trait;
use reqwest::Client;
use seedframe::embeddings::{model::EmbeddingModel, EmbedderError};
use serde::{Deserialize, Serialize};
use serde_json::json;
const DEFAULT_API_KEY_VAR_NAME: &str = "VOYAGEAI_API_KEY";
const DEFAULT_URL: &str = "https://api.voyageai.com/v1/embeddings";
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct ModelConfig {
api_key_var: Option<String>,
api_url: Option<String>,
model: String,
}
pub struct VoyageAIEmbedding {
api_key: String,
api_url: String,
model: String,
client: Client,
}
impl VoyageAIEmbedding {
#[must_use]
pub fn new(json_config: Option<&str>) -> Self {
let (api_key_var, api_url, model) = if let Some(json) = json_config {
let config: ModelConfig = serde_json::from_str(json).unwrap();
(
config
.api_key_var
.unwrap_or(DEFAULT_API_KEY_VAR_NAME.to_string()),
config.api_url.unwrap_or(DEFAULT_URL.to_string()),
config.model,
)
} else {
panic!(
"VoyageAIEmbedding expects a config json with atleast the required model field!"
);
};
let api_key = std::env::var(&api_key_var)
.unwrap_or_else(|_| panic!("Failed to fetch env var `{api_key_var}` !"));
Self {
api_key,
api_url,
client: reqwest::Client::new(),
model,
}
}
}
#[derive(Deserialize)]
struct VoyageAIEmbeddingResponse {
pub data: Vec<VoyageAIEmbeddingData>,
}
#[derive(Deserialize)]
struct VoyageAIEmbeddingData {
pub embedding: Vec<f64>,
}
#[async_trait]
impl EmbeddingModel for VoyageAIEmbedding {
async fn embed(&self, data: &str) -> Result<Vec<f64>, EmbedderError> {
let request_body = json!({
"input": data,
"model": self.model,
});
let response = self
.client
.post(&self.api_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| EmbedderError::RequestError(e.to_string()))?;
if response.status().is_success() {
let response = response
.json::<VoyageAIEmbeddingResponse>()
.await
.map_err(|e| EmbedderError::ParseError(e.to_string()))?;
Ok(response
.data
.into_iter()
.flat_map(|d| d.embedding)
.collect())
} else {
let error_message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(EmbedderError::ProviderError(error_message))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore]
async fn simple_openai_embed_request() {
let openai_embedding_model = VoyageAIEmbedding::new(None);
let response = openai_embedding_model.embed("test").await;
assert!(response.is_ok());
}
}