use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::errors::{RagError, Result};
use crate::rerank::SimilarityReranker;
use crate::vector_store::Similarity;
pub struct CohereReranker {
client: Client,
api_key: String,
model: String,
}
impl CohereReranker {
pub fn new(api_key: String) -> Self {
Self { client: Client::new(), api_key, model: "rerank-english-v3.0".to_string() }
}
}
#[derive(Serialize)]
struct CohereRequest {
query: String,
documents: Vec<String>,
model: String,
}
#[derive(Deserialize)]
struct CohereResponse {
results: Vec<CohereResult>,
}
#[derive(Deserialize)]
struct CohereResult {
index: usize,
relevance_score: f32,
}
#[async_trait]
impl SimilarityReranker for CohereReranker {
async fn rerank(&self, query: &str, items: Vec<Similarity>) -> Result<Vec<Similarity>> {
if items.is_empty() {
return Ok(items);
}
let docs: Vec<String> = items.iter().map(|i| i.document.content.clone()).collect();
let req = CohereRequest { query: query.to_string(), documents: docs, model: self.model.clone() };
let resp = self.client
.post("https://api.cohere.com/v2/rerank")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req)
.send().await?;
if !resp.status().is_success() {
return Err(RagError::EmbeddingError(resp.text().await?));
}
let data: CohereResponse = resp.json().await?;
let mut out = Vec::with_capacity(data.results.len());
for r in data.results {
if let Some(mut item) = items.get(r.index).cloned() {
item.score = r.relevance_score;
out.push(item);
}
}
Ok(out)
}
}
pub struct VoyageReranker {
client: Client,
api_key: String,
model: String,
}
impl VoyageReranker {
pub fn new(api_key: String) -> Self {
Self { client: Client::new(), api_key, model: "rerank-lite-1".to_string() }
}
}
#[derive(Serialize)]
struct VoyageRequest {
query: String,
documents: Vec<String>,
model: String,
}
#[derive(Deserialize)]
struct VoyageResponse {
data: Vec<VoyageResult>,
}
#[derive(Deserialize)]
struct VoyageResult {
index: usize,
relevance_score: f32,
}
#[async_trait]
impl SimilarityReranker for VoyageReranker {
async fn rerank(&self, query: &str, items: Vec<Similarity>) -> Result<Vec<Similarity>> {
if items.is_empty() {
return Ok(items);
}
let docs: Vec<String> = items.iter().map(|i| i.document.content.clone()).collect();
let req = VoyageRequest { query: query.to_string(), documents: docs, model: self.model.clone() };
let resp = self.client
.post("https://api.voyageai.com/v1/rerank")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req)
.send().await?;
if !resp.status().is_success() {
return Err(RagError::EmbeddingError(resp.text().await?));
}
let data: VoyageResponse = resp.json().await?;
let mut out = Vec::with_capacity(data.data.len());
for r in data.data {
if let Some(mut item) = items.get(r.index).cloned() {
item.score = r.relevance_score;
out.push(item);
}
}
Ok(out)
}
}
pub struct MixedBreadReranker {
client: Client,
api_key: String,
model: String,
}
impl MixedBreadReranker {
pub fn new(api_key: String) -> Self {
Self { client: Client::new(), api_key, model: "mixedbread-ai/mxbai-rerank-large-v1".to_string() }
}
}
#[derive(Serialize)]
struct MixedBreadRequest {
query: String,
documents: Vec<String>,
model: String,
}
#[derive(Deserialize)]
struct MixedBreadResponse {
data: Vec<MixedBreadResult>,
}
#[derive(Deserialize)]
struct MixedBreadResult {
index: usize,
score: f32,
}
#[async_trait]
impl SimilarityReranker for MixedBreadReranker {
async fn rerank(&self, query: &str, items: Vec<Similarity>) -> Result<Vec<Similarity>> {
if items.is_empty() {
return Ok(items);
}
let docs: Vec<String> = items.iter().map(|i| i.document.content.clone()).collect();
let req = MixedBreadRequest { query: query.to_string(), documents: docs, model: self.model.clone() };
let resp = self.client
.post("https://api.mixedbread.ai/v1/reranking")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&req)
.send().await?;
if !resp.status().is_success() {
return Err(RagError::EmbeddingError(resp.text().await?));
}
let data: MixedBreadResponse = resp.json().await?;
let mut out = Vec::with_capacity(data.data.len());
for r in data.data {
if let Some(mut item) = items.get(r.index).cloned() {
item.score = r.score;
out.push(item);
}
}
Ok(out)
}
}