use super::provider::{DocumentToRerank, RerankResult, Reranker};
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub struct VoyageReranker {
api_key: String,
model: String,
base_url: String,
client: reqwest::Client,
}
impl VoyageReranker {
pub fn new(api_key: &str, model: &str, base_url: &str) -> Self {
Self {
api_key: api_key.to_string(),
model: model.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
client: reqwest::Client::new(),
}
}
}
#[async_trait]
impl Reranker for VoyageReranker {
fn model_name(&self) -> &str {
&self.model
}
async fn rerank(
&self,
query: &str,
documents: Vec<DocumentToRerank>,
) -> Result<Vec<RerankResult>> {
if documents.is_empty() {
return Ok(vec![]);
}
let doc_contents: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
let request = VoyageRerankRequest {
model: &self.model,
query,
documents: &doc_contents,
return_documents: false,
};
let url = format!("{}/rerank", self.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send rerank request to Voyage AI")?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Voyage rerank API error ({}): {}", status, error_text);
}
let response: VoyageRerankResponse = response
.json()
.await
.context("Failed to parse Voyage rerank response")?;
let results = response
.data
.into_iter()
.map(|d| {
let original_id = documents[d.index].id;
RerankResult {
id: original_id,
relevance_score: d.relevance_score,
}
})
.collect();
Ok(results)
}
}
#[derive(Debug, Serialize)]
struct VoyageRerankRequest<'a> {
model: &'a str,
query: &'a str,
documents: &'a [&'a str],
#[serde(skip_serializing_if = "std::ops::Not::not")]
return_documents: bool,
}
#[derive(Debug, Deserialize)]
struct VoyageRerankResponse {
data: Vec<VoyageRerankData>,
#[allow(dead_code)]
usage: Option<VoyageRerankUsage>,
}
#[derive(Debug, Deserialize)]
struct VoyageRerankData {
index: usize,
relevance_score: f64,
}
#[derive(Debug, Deserialize)]
struct VoyageRerankUsage {
#[allow(dead_code)]
total_tokens: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reranker_creation() {
let reranker = VoyageReranker::new("test-key", "rerank-2.5", "https://api.voyageai.com/v1");
assert_eq!(reranker.model_name(), "rerank-2.5");
}
}