use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankRequest {
pub model: String,
pub query: String,
pub documents: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub instruction: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_documents: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_chunks_per_doc: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub overlap_tokens: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResponse {
pub id: String,
pub results: Vec<RerankResult>,
pub tokens: RerankTokenUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResult {
#[serde(skip_serializing_if = "Option::is_none")]
pub document: Option<RerankDocument>,
pub index: u32,
pub relevance_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankDocument {
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankTokenUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
impl RerankRequest {
pub fn new(model: String, query: String, documents: Vec<String>) -> Self {
Self {
model,
query,
documents,
instruction: None,
top_n: None,
return_documents: None,
max_chunks_per_doc: None,
overlap_tokens: None,
}
}
pub fn with_instruction(mut self, instruction: String) -> Self {
self.instruction = Some(instruction);
self
}
pub fn with_top_n(mut self, top_n: u32) -> Self {
self.top_n = Some(top_n);
self
}
pub fn with_return_documents(mut self, return_documents: bool) -> Self {
self.return_documents = Some(return_documents);
self
}
pub fn with_max_chunks_per_doc(mut self, max_chunks: u32) -> Self {
self.max_chunks_per_doc = Some(max_chunks);
self
}
pub fn with_overlap_tokens(mut self, overlap: u32) -> Self {
self.overlap_tokens = Some(overlap);
self
}
}
impl RerankResponse {
pub fn top_result_index(&self) -> Option<u32> {
self.results.first().map(|r| r.index)
}
pub fn sorted_indices(&self) -> Vec<u32> {
self.results.iter().map(|r| r.index).collect()
}
pub fn relevance_scores(&self) -> Vec<f64> {
self.results.iter().map(|r| r.relevance_score).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rerank_request_creation() {
let request = RerankRequest::new(
"BAAI/bge-reranker-v2-m3".to_string(),
"Apple".to_string(),
vec!["apple".to_string(), "banana".to_string()],
);
assert_eq!(request.model, "BAAI/bge-reranker-v2-m3");
assert_eq!(request.query, "Apple");
assert_eq!(request.documents.len(), 2);
assert!(request.instruction.is_none());
}
#[test]
fn test_rerank_request_builder() {
let request = RerankRequest::new(
"test-model".to_string(),
"test query".to_string(),
vec!["doc1".to_string()],
)
.with_instruction("Please rerank".to_string())
.with_top_n(5)
.with_return_documents(true);
assert_eq!(request.instruction, Some("Please rerank".to_string()));
assert_eq!(request.top_n, Some(5));
assert_eq!(request.return_documents, Some(true));
}
#[test]
fn test_rerank_response_methods() {
let response = RerankResponse {
id: "test-id".to_string(),
results: vec![
RerankResult {
document: None,
index: 2,
relevance_score: 0.9,
},
RerankResult {
document: None,
index: 0,
relevance_score: 0.7,
},
],
tokens: RerankTokenUsage {
input_tokens: 100,
output_tokens: 10,
},
};
assert_eq!(response.top_result_index(), Some(2));
assert_eq!(response.sorted_indices(), vec![2, 0]);
assert_eq!(response.relevance_scores(), vec![0.9, 0.7]);
}
}