Skip to main content

openai_protocol/
rerank.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::common::{default_true, GenerationRequest, StringOrArray, UsageInfo};
8
9fn default_rerank_object() -> String {
10    "rerank".to_string()
11}
12
13/// TODO: Create timestamp should not be in protocol layer
14fn current_timestamp() -> i64 {
15    std::time::SystemTime::now()
16        .duration_since(std::time::UNIX_EPOCH)
17        .unwrap_or_else(|_| std::time::Duration::from_secs(0))
18        .as_secs() as i64
19}
20
21// ============================================================================
22// Rerank API
23// ============================================================================
24
25#[derive(Debug, Clone, Deserialize, Serialize, Validate, schemars::JsonSchema)]
26#[validate(schema(function = "validate_rerank_request"))]
27pub struct RerankRequest {
28    /// The query text to rank documents against
29    #[validate(custom(function = "validate_query"))]
30    pub query: String,
31
32    /// List of documents to be ranked
33    #[validate(custom(function = "validate_documents"))]
34    pub documents: Vec<String>,
35
36    /// Model to use for reranking
37    pub model: String,
38
39    /// Maximum number of documents to return (optional)
40    #[serde(skip_serializing_if = "Option::is_none")]
41    #[validate(range(min = 1))]
42    pub top_k: Option<usize>,
43
44    /// Whether to return documents in addition to scores
45    #[serde(default = "default_true")]
46    pub return_documents: bool,
47
48    // SGLang specific extensions
49    /// Request ID for tracking
50    pub rid: Option<StringOrArray>,
51
52    /// User identifier
53    pub user: Option<String>,
54}
55
56impl GenerationRequest for RerankRequest {
57    fn get_model(&self) -> Option<&str> {
58        Some(&self.model)
59    }
60
61    fn is_stream(&self) -> bool {
62        false // Reranking doesn't support streaming
63    }
64
65    fn extract_text_for_routing(&self) -> String {
66        self.query.clone()
67    }
68}
69
70impl super::validated::Normalizable for RerankRequest {
71    // Use default no-op normalization
72}
73
74// ============================================================================
75// Validation Functions
76// ============================================================================
77
78/// Validates that the query is not empty
79fn validate_query(query: &str) -> Result<(), validator::ValidationError> {
80    if query.trim().is_empty() {
81        return Err(validator::ValidationError::new("query cannot be empty"));
82    }
83    Ok(())
84}
85
86/// Validates that the documents list is not empty
87fn validate_documents(documents: &[String]) -> Result<(), validator::ValidationError> {
88    if documents.is_empty() {
89        return Err(validator::ValidationError::new(
90            "documents list cannot be empty",
91        ));
92    }
93    Ok(())
94}
95
96/// Schema-level validation for cross-field dependencies
97#[expect(
98    clippy::unnecessary_wraps,
99    reason = "validator crate requires Result return type"
100)]
101fn validate_rerank_request(req: &RerankRequest) -> Result<(), validator::ValidationError> {
102    // Validate top_k if specified
103    if let Some(k) = req.top_k {
104        if k > req.documents.len() {
105            // This is allowed but we log a warning
106            tracing::warn!(
107                "top_k ({}) is greater than number of documents ({})",
108                k,
109                req.documents.len()
110            );
111        }
112    }
113    Ok(())
114}
115
116impl RerankRequest {
117    /// Get the effective top_k value
118    pub fn effective_top_k(&self) -> usize {
119        self.top_k.unwrap_or(self.documents.len())
120    }
121}
122
123/// Individual rerank result
124#[serde_with::skip_serializing_none]
125#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
126pub struct RerankResult {
127    /// Relevance score for the document
128    pub score: f32,
129
130    /// The document text (if return_documents was true)
131    pub document: Option<String>,
132
133    /// Original index of the document in the request
134    pub index: usize,
135
136    /// Additional metadata about the ranking
137    pub meta_info: Option<HashMap<String, Value>>,
138}
139
140/// Rerank response containing sorted results
141#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
142pub struct RerankResponse {
143    /// Ranked results sorted by score (highest first)
144    pub results: Vec<RerankResult>,
145
146    /// Model used for reranking
147    pub model: String,
148
149    /// Usage information
150    pub usage: Option<UsageInfo>,
151
152    /// Response object type
153    #[serde(default = "default_rerank_object")]
154    pub object: String,
155
156    /// Response ID
157    pub id: Option<StringOrArray>,
158
159    /// Creation timestamp
160    pub created: i64,
161}
162
163impl RerankResponse {
164    /// Create a new RerankResponse with the given results and model
165    pub fn new(
166        results: Vec<RerankResult>,
167        model: String,
168        request_id: Option<StringOrArray>,
169    ) -> Self {
170        RerankResponse {
171            results,
172            model,
173            usage: None,
174            object: default_rerank_object(),
175            id: request_id,
176            created: current_timestamp(),
177        }
178    }
179
180    /// Apply top_k limit to results
181    pub fn apply_top_k(&mut self, k: usize) {
182        self.results.truncate(k);
183    }
184
185    /// Drop documents from results (when return_documents is false)
186    pub fn drop_documents(&mut self) {
187        for result in &mut self.results {
188            result.document = None;
189        }
190    }
191}
192
193/// V1 API compatibility format for rerank requests
194/// Matches Python's V1RerankReqInput
195#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
196pub struct V1RerankReqInput {
197    pub query: String,
198    pub documents: Vec<String>,
199}
200
201/// Convert V1RerankReqInput to RerankRequest
202impl From<V1RerankReqInput> for RerankRequest {
203    fn from(v1: V1RerankReqInput) -> Self {
204        RerankRequest {
205            query: v1.query,
206            documents: v1.documents,
207            model: super::UNKNOWN_MODEL_ID.to_string(),
208            top_k: None,
209            return_documents: true,
210            rid: None,
211            user: None,
212        }
213    }
214}