openai_protocol/
rerank.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::common::{default_true, GenerationRequest, StringOrArray, UsageInfo};
8
9fn default_rerank_object() -> String {
10 "rerank".to_string()
11}
12
13fn current_timestamp() -> i64 {
15 std::time::SystemTime::now()
16 .duration_since(std::time::UNIX_EPOCH)
17 .unwrap_or_else(|_| std::time::Duration::from_secs(0))
18 .as_secs() as i64
19}
20
21#[derive(Debug, Clone, Deserialize, Serialize, Validate, schemars::JsonSchema)]
26#[validate(schema(function = "validate_rerank_request"))]
27pub struct RerankRequest {
28 #[validate(custom(function = "validate_query"))]
30 pub query: String,
31
32 #[validate(custom(function = "validate_documents"))]
34 pub documents: Vec<String>,
35
36 pub model: String,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
41 #[validate(range(min = 1))]
42 pub top_k: Option<usize>,
43
44 #[serde(default = "default_true")]
46 pub return_documents: bool,
47
48 pub rid: Option<StringOrArray>,
51
52 pub user: Option<String>,
54}
55
56impl GenerationRequest for RerankRequest {
57 fn get_model(&self) -> Option<&str> {
58 Some(&self.model)
59 }
60
61 fn is_stream(&self) -> bool {
62 false }
64
65 fn extract_text_for_routing(&self) -> String {
66 self.query.clone()
67 }
68}
69
70impl super::validated::Normalizable for RerankRequest {
71 }
73
74fn validate_query(query: &str) -> Result<(), validator::ValidationError> {
80 if query.trim().is_empty() {
81 return Err(validator::ValidationError::new("query cannot be empty"));
82 }
83 Ok(())
84}
85
86fn validate_documents(documents: &[String]) -> Result<(), validator::ValidationError> {
88 if documents.is_empty() {
89 return Err(validator::ValidationError::new(
90 "documents list cannot be empty",
91 ));
92 }
93 Ok(())
94}
95
96#[expect(
98 clippy::unnecessary_wraps,
99 reason = "validator crate requires Result return type"
100)]
101fn validate_rerank_request(req: &RerankRequest) -> Result<(), validator::ValidationError> {
102 if let Some(k) = req.top_k {
104 if k > req.documents.len() {
105 tracing::warn!(
107 "top_k ({}) is greater than number of documents ({})",
108 k,
109 req.documents.len()
110 );
111 }
112 }
113 Ok(())
114}
115
116impl RerankRequest {
117 pub fn effective_top_k(&self) -> usize {
119 self.top_k.unwrap_or(self.documents.len())
120 }
121}
122
123#[serde_with::skip_serializing_none]
125#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
126pub struct RerankResult {
127 pub score: f32,
129
130 pub document: Option<String>,
132
133 pub index: usize,
135
136 pub meta_info: Option<HashMap<String, Value>>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
142pub struct RerankResponse {
143 pub results: Vec<RerankResult>,
145
146 pub model: String,
148
149 pub usage: Option<UsageInfo>,
151
152 #[serde(default = "default_rerank_object")]
154 pub object: String,
155
156 pub id: Option<StringOrArray>,
158
159 pub created: i64,
161}
162
163impl RerankResponse {
164 pub fn new(
166 results: Vec<RerankResult>,
167 model: String,
168 request_id: Option<StringOrArray>,
169 ) -> Self {
170 RerankResponse {
171 results,
172 model,
173 usage: None,
174 object: default_rerank_object(),
175 id: request_id,
176 created: current_timestamp(),
177 }
178 }
179
180 pub fn apply_top_k(&mut self, k: usize) {
182 self.results.truncate(k);
183 }
184
185 pub fn drop_documents(&mut self) {
187 for result in &mut self.results {
188 result.document = None;
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
196pub struct V1RerankReqInput {
197 pub query: String,
198 pub documents: Vec<String>,
199}
200
201impl From<V1RerankReqInput> for RerankRequest {
203 fn from(v1: V1RerankReqInput) -> Self {
204 RerankRequest {
205 query: v1.query,
206 documents: v1.documents,
207 model: super::UNKNOWN_MODEL_ID.to_string(),
208 top_k: None,
209 return_documents: true,
210 rid: None,
211 user: None,
212 }
213 }
214}