Skip to main content

openai_protocol/
rerank.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo};
8
9fn default_rerank_object() -> String {
10    "rerank".to_string()
11}
12
13/// TODO: Create timestamp should not be in protocol layer
14fn current_timestamp() -> i64 {
15    std::time::SystemTime::now()
16        .duration_since(std::time::UNIX_EPOCH)
17        .unwrap_or_else(|_| std::time::Duration::from_secs(0))
18        .as_secs() as i64
19}
20
21// ============================================================================
22// Rerank API
23// ============================================================================
24
25#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
26#[validate(schema(function = "validate_rerank_request"))]
27pub struct RerankRequest {
28    /// The query text to rank documents against
29    #[validate(custom(function = "validate_query"))]
30    pub query: String,
31
32    /// List of documents to be ranked
33    #[validate(custom(function = "validate_documents"))]
34    pub documents: Vec<String>,
35
36    /// Model to use for reranking
37    #[serde(default = "default_model")]
38    pub model: String,
39
40    /// Maximum number of documents to return (optional)
41    #[serde(skip_serializing_if = "Option::is_none")]
42    #[validate(range(min = 1))]
43    pub top_k: Option<usize>,
44
45    /// Whether to return documents in addition to scores
46    #[serde(default = "default_true")]
47    pub return_documents: bool,
48
49    // SGLang specific extensions
50    /// Request ID for tracking
51    pub rid: Option<StringOrArray>,
52
53    /// User identifier
54    pub user: Option<String>,
55}
56
57impl GenerationRequest for RerankRequest {
58    fn get_model(&self) -> Option<&str> {
59        Some(&self.model)
60    }
61
62    fn is_stream(&self) -> bool {
63        false // Reranking doesn't support streaming
64    }
65
66    fn extract_text_for_routing(&self) -> String {
67        self.query.clone()
68    }
69}
70
71impl super::validated::Normalizable for RerankRequest {
72    // Use default no-op normalization
73}
74
75// ============================================================================
76// Validation Functions
77// ============================================================================
78
79/// Validates that the query is not empty
80fn validate_query(query: &str) -> Result<(), validator::ValidationError> {
81    if query.trim().is_empty() {
82        return Err(validator::ValidationError::new("query cannot be empty"));
83    }
84    Ok(())
85}
86
87/// Validates that the documents list is not empty
88fn validate_documents(documents: &[String]) -> Result<(), validator::ValidationError> {
89    if documents.is_empty() {
90        return Err(validator::ValidationError::new(
91            "documents list cannot be empty",
92        ));
93    }
94    Ok(())
95}
96
97/// Schema-level validation for cross-field dependencies
98#[expect(
99    clippy::unnecessary_wraps,
100    reason = "validator crate requires Result return type"
101)]
102fn validate_rerank_request(req: &RerankRequest) -> Result<(), validator::ValidationError> {
103    // Validate top_k if specified
104    if let Some(k) = req.top_k {
105        if k > req.documents.len() {
106            // This is allowed but we log a warning
107            tracing::warn!(
108                "top_k ({}) is greater than number of documents ({})",
109                k,
110                req.documents.len()
111            );
112        }
113    }
114    Ok(())
115}
116
117impl RerankRequest {
118    /// Get the effective top_k value
119    pub fn effective_top_k(&self) -> usize {
120        self.top_k.unwrap_or(self.documents.len())
121    }
122}
123
124/// Individual rerank result
125#[serde_with::skip_serializing_none]
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct RerankResult {
128    /// Relevance score for the document
129    pub score: f32,
130
131    /// The document text (if return_documents was true)
132    pub document: Option<String>,
133
134    /// Original index of the document in the request
135    pub index: usize,
136
137    /// Additional metadata about the ranking
138    pub meta_info: Option<HashMap<String, Value>>,
139}
140
141/// Rerank response containing sorted results
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct RerankResponse {
144    /// Ranked results sorted by score (highest first)
145    pub results: Vec<RerankResult>,
146
147    /// Model used for reranking
148    pub model: String,
149
150    /// Usage information
151    pub usage: Option<UsageInfo>,
152
153    /// Response object type
154    #[serde(default = "default_rerank_object")]
155    pub object: String,
156
157    /// Response ID
158    pub id: Option<StringOrArray>,
159
160    /// Creation timestamp
161    pub created: i64,
162}
163
164impl RerankResponse {
165    /// Create a new RerankResponse with the given results and model
166    pub fn new(
167        results: Vec<RerankResult>,
168        model: String,
169        request_id: Option<StringOrArray>,
170    ) -> Self {
171        RerankResponse {
172            results,
173            model,
174            usage: None,
175            object: default_rerank_object(),
176            id: request_id,
177            created: current_timestamp(),
178        }
179    }
180
181    /// Apply top_k limit to results
182    pub fn apply_top_k(&mut self, k: usize) {
183        self.results.truncate(k);
184    }
185
186    /// Drop documents from results (when return_documents is false)
187    pub fn drop_documents(&mut self) {
188        for result in &mut self.results {
189            result.document = None;
190        }
191    }
192}
193
194/// V1 API compatibility format for rerank requests
195/// Matches Python's V1RerankReqInput
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct V1RerankReqInput {
198    pub query: String,
199    pub documents: Vec<String>,
200}
201
202/// Convert V1RerankReqInput to RerankRequest
203impl From<V1RerankReqInput> for RerankRequest {
204    fn from(v1: V1RerankReqInput) -> Self {
205        RerankRequest {
206            query: v1.query,
207            documents: v1.documents,
208            model: default_model(),
209            top_k: None,
210            return_documents: true,
211            rid: None,
212            user: None,
213        }
214    }
215}