use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use validator::Validate;
use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo};
fn default_rerank_object() -> String {
"rerank".to_string()
}
fn current_timestamp() -> i64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
.as_secs() as i64
}
#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
#[validate(schema(function = "validate_rerank_request"))]
pub struct RerankRequest {
#[validate(custom(function = "validate_query"))]
pub query: String,
#[validate(custom(function = "validate_documents"))]
pub documents: Vec<String>,
#[serde(default = "default_model")]
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(range(min = 1))]
pub top_k: Option<usize>,
#[serde(default = "default_true")]
pub return_documents: bool,
pub rid: Option<StringOrArray>,
pub user: Option<String>,
}
impl GenerationRequest for RerankRequest {
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn is_stream(&self) -> bool {
false }
fn extract_text_for_routing(&self) -> String {
self.query.clone()
}
}
impl super::validated::Normalizable for RerankRequest {
}
fn validate_query(query: &str) -> Result<(), validator::ValidationError> {
if query.trim().is_empty() {
return Err(validator::ValidationError::new("query cannot be empty"));
}
Ok(())
}
fn validate_documents(documents: &[String]) -> Result<(), validator::ValidationError> {
if documents.is_empty() {
return Err(validator::ValidationError::new(
"documents list cannot be empty",
));
}
Ok(())
}
fn validate_rerank_request(req: &RerankRequest) -> Result<(), validator::ValidationError> {
if let Some(k) = req.top_k {
if k > req.documents.len() {
tracing::warn!(
"top_k ({}) is greater than number of documents ({})",
k,
req.documents.len()
);
}
}
Ok(())
}
impl RerankRequest {
pub fn effective_top_k(&self) -> usize {
self.top_k.unwrap_or(self.documents.len())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResult {
pub score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub document: Option<String>,
pub index: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub meta_info: Option<HashMap<String, Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResponse {
pub results: Vec<RerankResult>,
pub model: String,
pub usage: Option<UsageInfo>,
#[serde(default = "default_rerank_object")]
pub object: String,
pub id: Option<StringOrArray>,
pub created: i64,
}
impl RerankResponse {
pub fn new(
results: Vec<RerankResult>,
model: String,
request_id: Option<StringOrArray>,
) -> Self {
RerankResponse {
results,
model,
usage: None,
object: default_rerank_object(),
id: request_id,
created: current_timestamp(),
}
}
pub fn apply_top_k(&mut self, k: usize) {
self.results.truncate(k);
}
pub fn drop_documents(&mut self) {
for result in &mut self.results {
result.document = None;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct V1RerankReqInput {
pub query: String,
pub documents: Vec<String>,
}
impl From<V1RerankReqInput> for RerankRequest {
fn from(v1: V1RerankReqInput) -> Self {
RerankRequest {
query: v1.query,
documents: v1.documents,
model: default_model(),
top_k: None,
return_documents: true,
rid: None,
user: None,
}
}
}