use async_trait::async_trait;
use reqwest::Client;
use secrecy::{ExposeSecret, SecretString};
use serde_json::{json, Value};
use cognis_core::embeddings::Embeddings;
use cognis_core::error::{CognisError, Result};
#[derive(Debug)]
pub struct VoyageEmbeddingsBuilder {
api_key: Option<SecretString>,
model: Option<String>,
input_type: Option<String>,
}
impl VoyageEmbeddingsBuilder {
pub fn new() -> Self {
Self {
api_key: None,
model: None,
input_type: None,
}
}
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(SecretString::from(key.into()));
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn input_type(mut self, input_type: impl Into<String>) -> Self {
self.input_type = Some(input_type.into());
self
}
pub fn build(self) -> Result<VoyageEmbeddings> {
let api_key = match self.api_key {
Some(key) => key,
None => {
let key = std::env::var("VOYAGE_API_KEY")
.or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
.map_err(|_| {
CognisError::Other(
"api_key not provided and neither VOYAGE_API_KEY nor ANTHROPIC_API_KEY env var is set".into(),
)
})?;
SecretString::from(key)
}
};
Ok(VoyageEmbeddings {
api_key,
model: self.model.unwrap_or_else(|| "voyage-3".into()),
input_type: self.input_type,
client: Client::new(),
})
}
}
impl Default for VoyageEmbeddingsBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct VoyageEmbeddings {
api_key: SecretString,
pub model: String,
pub input_type: Option<String>,
client: Client,
}
impl std::fmt::Debug for VoyageEmbeddings {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VoyageEmbeddings")
.field("model", &self.model)
.field("input_type", &self.input_type)
.finish()
}
}
impl VoyageEmbeddings {
pub fn builder() -> VoyageEmbeddingsBuilder {
VoyageEmbeddingsBuilder::new()
}
fn build_payload(&self, texts: &[String], input_type: Option<&str>) -> Value {
let mut payload = json!({
"model": self.model,
"input": texts,
});
let effective_input_type = input_type.or(self.input_type.as_deref());
if let Some(it) = effective_input_type {
payload["input_type"] = json!(it);
}
payload
}
async fn call_api(
&self,
texts: Vec<String>,
input_type: Option<&str>,
) -> Result<Vec<Vec<f32>>> {
let url = "https://api.voyageai.com/v1/embeddings";
let payload = self.build_payload(&texts, input_type);
let response = self
.client
.post(url)
.header(
"Authorization",
format!("Bearer {}", self.api_key.expose_secret()),
)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await
.map_err(|e| CognisError::Other(format!("HTTP request failed: {}", e)))?;
let status = response.status().as_u16();
if !(200..300).contains(&status) {
let body = response.text().await.unwrap_or_default();
return Err(CognisError::HttpError { status, body });
}
let body: Value = response
.json()
.await
.map_err(|e| CognisError::Other(format!("Failed to parse response JSON: {}", e)))?;
Self::parse_response(&body)
}
fn parse_response(body: &Value) -> Result<Vec<Vec<f32>>> {
let data = body.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
CognisError::Other("Missing 'data' array in Voyage AI embeddings response".into())
})?;
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(data.len());
for item in data {
let embedding = item
.get("embedding")
.and_then(|v| v.as_array())
.ok_or_else(|| {
CognisError::Other("Missing 'embedding' array in response data item".into())
})?;
let vec: Vec<f32> = embedding
.iter()
.map(|v| {
v.as_f64().map(|f| f as f32).ok_or_else(|| {
CognisError::Other("Non-numeric value in embedding array".into())
})
})
.collect::<Result<Vec<f32>>>()?;
embeddings.push(vec);
}
Ok(embeddings)
}
}
#[async_trait]
impl Embeddings for VoyageEmbeddings {
async fn embed_documents(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
self.call_api(texts, Some("document")).await
}
async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
let results = self.call_api(vec![text.to_string()], Some("query")).await?;
results
.into_iter()
.next()
.ok_or_else(|| CognisError::Other("Empty embedding response for query".into()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_defaults() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.build()
.unwrap();
assert_eq!(embeddings.model, "voyage-3");
assert!(embeddings.input_type.is_none());
}
#[test]
fn test_builder_custom_values() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.model("voyage-3-lite")
.input_type("query")
.build()
.unwrap();
assert_eq!(embeddings.model, "voyage-3-lite");
assert_eq!(embeddings.input_type, Some("query".to_string()));
}
#[test]
fn test_builder_requires_api_key() {
std::env::remove_var("VOYAGE_API_KEY");
std::env::remove_var("ANTHROPIC_API_KEY");
let result = VoyageEmbeddings::builder().build();
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("VOYAGE_API_KEY"));
assert!(err.contains("ANTHROPIC_API_KEY"));
}
#[test]
fn test_build_payload_for_query() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.build()
.unwrap();
let texts = vec!["what is machine learning?".to_string()];
let payload = embeddings.build_payload(&texts, Some("query"));
assert_eq!(payload["model"], "voyage-3");
assert_eq!(payload["input"], json!(["what is machine learning?"]));
assert_eq!(payload["input_type"], "query");
}
#[test]
fn test_build_payload_for_documents() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.build()
.unwrap();
let texts = vec!["hello".to_string(), "world".to_string()];
let payload = embeddings.build_payload(&texts, Some("document"));
assert_eq!(payload["model"], "voyage-3");
assert_eq!(payload["input"], json!(["hello", "world"]));
assert_eq!(payload["input_type"], "document");
}
#[test]
fn test_parse_response() {
let body = json!({
"data": [
{"embedding": [0.1, 0.2, 0.3], "index": 0},
{"embedding": [0.4, 0.5, 0.6], "index": 1}
],
"model": "voyage-3",
"usage": {"total_tokens": 10}
});
let result = VoyageEmbeddings::parse_response(&body).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].len(), 3);
assert!((result[0][0] - 0.1).abs() < 1e-6);
assert!((result[1][2] - 0.6).abs() < 1e-6);
}
#[tokio::test]
async fn test_embed_documents_empty() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.build()
.unwrap();
let result = embeddings.embed_documents(vec![]).await.unwrap();
assert!(result.is_empty());
}
#[test]
fn test_api_key_from_env_voyage() {
std::env::set_var("VOYAGE_API_KEY", "env-voyage-key");
std::env::remove_var("ANTHROPIC_API_KEY");
let embeddings = VoyageEmbeddings::builder().build().unwrap();
assert_eq!(embeddings.model, "voyage-3");
std::env::remove_var("VOYAGE_API_KEY");
}
#[test]
fn test_api_key_from_env_anthropic_fallback() {
std::env::remove_var("VOYAGE_API_KEY");
std::env::set_var("ANTHROPIC_API_KEY", "env-anthropic-key");
let embeddings = VoyageEmbeddings::builder().build().unwrap();
assert_eq!(embeddings.model, "voyage-3");
std::env::remove_var("ANTHROPIC_API_KEY");
}
#[test]
fn test_custom_model_name() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.model("voyage-code-3")
.build()
.unwrap();
assert_eq!(embeddings.model, "voyage-code-3");
let payload = embeddings.build_payload(&["code snippet".to_string()], Some("document"));
assert_eq!(payload["model"], "voyage-code-3");
}
#[test]
fn test_debug_does_not_leak_api_key() {
let embeddings = VoyageEmbeddings::builder()
.api_key("super-secret-key")
.build()
.unwrap();
let debug_str = format!("{:?}", embeddings);
assert!(!debug_str.contains("super-secret-key"));
assert!(debug_str.contains("VoyageEmbeddings"));
assert!(debug_str.contains("voyage-3"));
}
#[test]
fn test_build_payload_without_input_type() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.build()
.unwrap();
let texts = vec!["hello".to_string()];
let payload = embeddings.build_payload(&texts, None);
assert_eq!(payload["model"], "voyage-3");
assert_eq!(payload["input"], json!(["hello"]));
assert!(payload.get("input_type").is_none());
}
#[test]
fn test_build_payload_with_builder_input_type_default() {
let embeddings = VoyageEmbeddings::builder()
.api_key("test-key")
.input_type("document")
.build()
.unwrap();
let texts = vec!["hello".to_string()];
let payload = embeddings.build_payload(&texts, None);
assert_eq!(payload["input_type"], "document");
let payload = embeddings.build_payload(&texts, Some("query"));
assert_eq!(payload["input_type"], "query");
}
#[test]
fn test_parse_response_missing_data() {
let body = json!({"error": "something"});
let result = VoyageEmbeddings::parse_response(&body);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("data"));
}
}