openai_protocol/
rerank.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use validator::Validate;
6
7use super::common::{default_model, default_true, GenerationRequest, StringOrArray, UsageInfo};
8
9fn default_rerank_object() -> String {
10 "rerank".to_string()
11}
12
13fn current_timestamp() -> i64 {
15 std::time::SystemTime::now()
16 .duration_since(std::time::UNIX_EPOCH)
17 .unwrap_or_else(|_| std::time::Duration::from_secs(0))
18 .as_secs() as i64
19}
20
21#[derive(Debug, Clone, Deserialize, Serialize, Validate)]
26#[validate(schema(function = "validate_rerank_request"))]
27pub struct RerankRequest {
28 #[validate(custom(function = "validate_query"))]
30 pub query: String,
31
32 #[validate(custom(function = "validate_documents"))]
34 pub documents: Vec<String>,
35
36 #[serde(default = "default_model")]
38 pub model: String,
39
40 #[serde(skip_serializing_if = "Option::is_none")]
42 #[validate(range(min = 1))]
43 pub top_k: Option<usize>,
44
45 #[serde(default = "default_true")]
47 pub return_documents: bool,
48
49 pub rid: Option<StringOrArray>,
52
53 pub user: Option<String>,
55}
56
57impl GenerationRequest for RerankRequest {
58 fn get_model(&self) -> Option<&str> {
59 Some(&self.model)
60 }
61
62 fn is_stream(&self) -> bool {
63 false }
65
66 fn extract_text_for_routing(&self) -> String {
67 self.query.clone()
68 }
69}
70
71impl super::validated::Normalizable for RerankRequest {
72 }
74
75fn validate_query(query: &str) -> Result<(), validator::ValidationError> {
81 if query.trim().is_empty() {
82 return Err(validator::ValidationError::new("query cannot be empty"));
83 }
84 Ok(())
85}
86
87fn validate_documents(documents: &[String]) -> Result<(), validator::ValidationError> {
89 if documents.is_empty() {
90 return Err(validator::ValidationError::new(
91 "documents list cannot be empty",
92 ));
93 }
94 Ok(())
95}
96
97#[expect(
99 clippy::unnecessary_wraps,
100 reason = "validator crate requires Result return type"
101)]
102fn validate_rerank_request(req: &RerankRequest) -> Result<(), validator::ValidationError> {
103 if let Some(k) = req.top_k {
105 if k > req.documents.len() {
106 tracing::warn!(
108 "top_k ({}) is greater than number of documents ({})",
109 k,
110 req.documents.len()
111 );
112 }
113 }
114 Ok(())
115}
116
117impl RerankRequest {
118 pub fn effective_top_k(&self) -> usize {
120 self.top_k.unwrap_or(self.documents.len())
121 }
122}
123
124#[serde_with::skip_serializing_none]
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct RerankResult {
128 pub score: f32,
130
131 pub document: Option<String>,
133
134 pub index: usize,
136
137 pub meta_info: Option<HashMap<String, Value>>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct RerankResponse {
144 pub results: Vec<RerankResult>,
146
147 pub model: String,
149
150 pub usage: Option<UsageInfo>,
152
153 #[serde(default = "default_rerank_object")]
155 pub object: String,
156
157 pub id: Option<StringOrArray>,
159
160 pub created: i64,
162}
163
164impl RerankResponse {
165 pub fn new(
167 results: Vec<RerankResult>,
168 model: String,
169 request_id: Option<StringOrArray>,
170 ) -> Self {
171 RerankResponse {
172 results,
173 model,
174 usage: None,
175 object: default_rerank_object(),
176 id: request_id,
177 created: current_timestamp(),
178 }
179 }
180
181 pub fn apply_top_k(&mut self, k: usize) {
183 self.results.truncate(k);
184 }
185
186 pub fn drop_documents(&mut self) {
188 for result in &mut self.results {
189 result.document = None;
190 }
191 }
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct V1RerankReqInput {
198 pub query: String,
199 pub documents: Vec<String>,
200}
201
202impl From<V1RerankReqInput> for RerankRequest {
204 fn from(v1: V1RerankReqInput) -> Self {
205 RerankRequest {
206 query: v1.query,
207 documents: v1.documents,
208 model: default_model(),
209 top_k: None,
210 return_documents: true,
211 rid: None,
212 user: None,
213 }
214 }
215}