rig_bailian/
rerank.rs

1//! Category: rerank.rs (text reranking, DashScope gte-rerank-v2)
2
3use serde::{Deserialize, Serialize};
4
5use super::client::Client;
6
7/// Default DashScope rerank endpoint for gte-rerank-v2
8pub const GTE_RERANK_V2: &str = "gte-rerank-v2";
9pub const GTE_RERANK_V2_URL: &str =
10    "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank/";
11
12#[derive(Debug, Serialize)]
13pub struct RerankRequest {
14    pub model: String,
15    pub input: RerankInput,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub parameters: Option<RerankParameters>,
18}
19
20#[derive(Debug, Serialize)]
21pub struct RerankInput {
22    pub query: String,
23    pub documents: Vec<String>,
24}
25
26#[derive(Debug, Serialize)]
27pub struct RerankParameters {
28    pub return_documents: bool,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub top_n: Option<usize>,
31}
32
33#[derive(Debug, Clone, Deserialize)]
34pub struct RerankResponse {
35    pub output: Option<Output>,
36    pub message: Option<String>,
37    pub usage: Option<Usage>,
38    pub request_id: Option<String>,
39}
40
41#[derive(Debug, Clone, Deserialize)]
42pub struct Usage {
43    pub total_tokens: Option<u32>,
44}
45
46#[derive(Debug, Clone, Deserialize)]
47pub struct Output {
48    pub results: Vec<ResultItem>,
49}
50
51#[derive(Debug, Clone, Deserialize)]
52pub struct ResultItem {
53    pub index: usize,
54    pub relevance_score: f64,
55    pub document: Option<Document>,
56}
57
58#[derive(Debug, Clone, Deserialize)]
59pub struct Document {
60    pub text: String,
61}
62
63#[derive(thiserror::Error, Debug)]
64pub enum RerankError {
65    #[error("validation error: {0}")]
66    ValidationError(String),
67    #[error("http error: {0}")]
68    Http(#[from] reqwest::Error),
69    #[error("http status {0}: {1}")]
70    HttpStatus(u16, String),
71    #[error("response error: {0}")]
72    ResponseError(String),
73}
74
75/// Rerank model bound to Bailian client
76#[derive(Debug, Clone)]
77pub struct RerankModel {
78    pub(crate) client: Client<reqwest::Client>,
79    pub model: String,
80    /// Full endpoint URL (base + model), e.g. ".../text-re-rank/gte-rerank-v2"
81    pub endpoint: String,
82}
83
84impl RerankModel {
85    /// Create a rerank model using the Bailian client.
86    /// - `endpoint_base`: optional base URL (defaults to DashScope base)
87    /// - final endpoint = endpoint_base + model
88    pub fn new(
89        client: Client<reqwest::Client>,
90        model: impl Into<String>,
91        endpoint_base: Option<String>, // base URL, not the full endpoint
92    ) -> Self {
93        let model = model.into();
94        let endpoint = endpoint_base.unwrap_or_else(|| GTE_RERANK_V2_URL.to_string());
95        Self {
96            client,
97            model,
98            endpoint,
99        }
100    }
101
102    /// Rerank the given documents based on the query.
103    ///
104    /// Returns a Vec<RerankResult>. If top_n is provided, the result will be truncated accordingly.
105    pub async fn rerank(
106        &self,
107        query: &str,
108        documents: &[String],
109        top_n: Option<usize>,
110        return_documents: bool,
111    ) -> Result<Vec<RerankResult>, RerankError> {
112        if query.trim().is_empty() {
113            return Err(RerankError::ValidationError(
114                "Query cannot be empty".to_string(),
115            ));
116        }
117        if documents.is_empty() {
118            return Err(RerankError::ValidationError(
119                "Documents cannot be empty".to_string(),
120            ));
121        }
122
123        let request = RerankRequest {
124            model: self.model.clone(),
125            input: RerankInput {
126                query: query.to_string(),
127                documents: documents.to_vec(),
128            },
129            parameters: Some(RerankParameters {
130                return_documents,
131                top_n,
132            }),
133        };
134
135        let resp = self
136            .client
137            .http_client
138            .post(&self.endpoint)
139            .bearer_auth(&self.client.api_key)
140            .header("Content-Type", "application/json")
141            .json(&request)
142            .send()
143            .await?;
144
145        let status = resp.status();
146        let raw_text = resp.text().await?;
147        let resp_json: RerankResponse = serde_json::from_str(&raw_text)
148            .map_err(|e| RerankError::ResponseError(e.to_string()))?;
149
150        if status.is_success() {
151            if let Some(output) = resp_json.output {
152                let mut results: Vec<RerankResult> = output
153                    .results
154                    .into_iter()
155                    .map(|item| RerankResult {
156                        index: item.index,
157                        relevance_score: item.relevance_score,
158                        text: item.document.map(|d| d.text).unwrap_or_default(),
159                    })
160                    .collect();
161
162                if let Some(n) = top_n {
163                    results.truncate(n);
164                }
165
166                Ok(results)
167            } else {
168                Err(RerankError::ResponseError(
169                    "No output in response".to_string(),
170                ))
171            }
172        } else {
173            Err(RerankError::HttpStatus(
174                status.as_u16(),
175                resp_json
176                    .message
177                    .unwrap_or_else(|| "Unknown HTTP error".to_string()),
178            ))
179        }
180    }
181}
182
183/// Public result returned by rerank()
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct RerankResult {
186    pub index: usize,
187    pub relevance_score: f64,
188    /// Flattened document text (empty if server omitted it)
189    #[serde(default)]
190    pub text: String,
191}