1use serde::{Deserialize, Serialize};
4
5use super::client::Client;
6
7pub const GTE_RERANK_V2: &str = "gte-rerank-v2";
9pub const GTE_RERANK_V2_URL: &str =
10 "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank/";
11
12#[derive(Debug, Serialize)]
13pub struct RerankRequest {
14 pub model: String,
15 pub input: RerankInput,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub parameters: Option<RerankParameters>,
18}
19
20#[derive(Debug, Serialize)]
21pub struct RerankInput {
22 pub query: String,
23 pub documents: Vec<String>,
24}
25
26#[derive(Debug, Serialize)]
27pub struct RerankParameters {
28 pub return_documents: bool,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub top_n: Option<usize>,
31}
32
33#[derive(Debug, Clone, Deserialize)]
34pub struct RerankResponse {
35 pub output: Option<Output>,
36 pub message: Option<String>,
37 pub usage: Option<Usage>,
38 pub request_id: Option<String>,
39}
40
41#[derive(Debug, Clone, Deserialize)]
42pub struct Usage {
43 pub total_tokens: Option<u32>,
44}
45
46#[derive(Debug, Clone, Deserialize)]
47pub struct Output {
48 pub results: Vec<ResultItem>,
49}
50
51#[derive(Debug, Clone, Deserialize)]
52pub struct ResultItem {
53 pub index: usize,
54 pub relevance_score: f64,
55 pub document: Option<Document>,
56}
57
58#[derive(Debug, Clone, Deserialize)]
59pub struct Document {
60 pub text: String,
61}
62
63#[derive(thiserror::Error, Debug)]
64pub enum RerankError {
65 #[error("validation error: {0}")]
66 ValidationError(String),
67 #[error("http error: {0}")]
68 Http(#[from] reqwest::Error),
69 #[error("http status {0}: {1}")]
70 HttpStatus(u16, String),
71 #[error("response error: {0}")]
72 ResponseError(String),
73}
74
75#[derive(Debug, Clone)]
77pub struct RerankModel {
78 pub(crate) client: Client<reqwest::Client>,
79 pub model: String,
80 pub endpoint: String,
82}
83
84impl RerankModel {
85 pub fn new(
89 client: Client<reqwest::Client>,
90 model: impl Into<String>,
91 endpoint_base: Option<String>, ) -> Self {
93 let model = model.into();
94 let endpoint = endpoint_base.unwrap_or_else(|| GTE_RERANK_V2_URL.to_string());
95 Self {
96 client,
97 model,
98 endpoint,
99 }
100 }
101
102 pub async fn rerank(
106 &self,
107 query: &str,
108 documents: &[String],
109 top_n: Option<usize>,
110 return_documents: bool,
111 ) -> Result<Vec<RerankResult>, RerankError> {
112 if query.trim().is_empty() {
113 return Err(RerankError::ValidationError(
114 "Query cannot be empty".to_string(),
115 ));
116 }
117 if documents.is_empty() {
118 return Err(RerankError::ValidationError(
119 "Documents cannot be empty".to_string(),
120 ));
121 }
122
123 let request = RerankRequest {
124 model: self.model.clone(),
125 input: RerankInput {
126 query: query.to_string(),
127 documents: documents.to_vec(),
128 },
129 parameters: Some(RerankParameters {
130 return_documents,
131 top_n,
132 }),
133 };
134
135 let resp = self
136 .client
137 .http_client
138 .post(&self.endpoint)
139 .bearer_auth(&self.client.api_key)
140 .header("Content-Type", "application/json")
141 .json(&request)
142 .send()
143 .await?;
144
145 let status = resp.status();
146 let raw_text = resp.text().await?;
147 let resp_json: RerankResponse = serde_json::from_str(&raw_text)
148 .map_err(|e| RerankError::ResponseError(e.to_string()))?;
149
150 if status.is_success() {
151 if let Some(output) = resp_json.output {
152 let mut results: Vec<RerankResult> = output
153 .results
154 .into_iter()
155 .map(|item| RerankResult {
156 index: item.index,
157 relevance_score: item.relevance_score,
158 text: item.document.map(|d| d.text).unwrap_or_default(),
159 })
160 .collect();
161
162 if let Some(n) = top_n {
163 results.truncate(n);
164 }
165
166 Ok(results)
167 } else {
168 Err(RerankError::ResponseError(
169 "No output in response".to_string(),
170 ))
171 }
172 } else {
173 Err(RerankError::HttpStatus(
174 status.as_u16(),
175 resp_json
176 .message
177 .unwrap_or_else(|| "Unknown HTTP error".to_string()),
178 ))
179 }
180 }
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct RerankResult {
186 pub index: usize,
187 pub relevance_score: f64,
188 #[serde(default)]
190 pub text: String,
191}