rig_tei/
rerank.rs

1use rig::http_client::{self, HttpClientExt};
2use serde::{Deserialize, Serialize};
3use serde_json::json;
4
5use super::client::Client;
6
7#[derive(Debug, Deserialize, Serialize, Clone)]
8pub struct RerankResult {
9    pub index: usize,
10    #[serde(default)]
11    pub text: Option<String>,
12    #[serde(alias = "score", alias = "relevance_score")]
13    pub relevance_score: f32,
14}
15
16#[derive(thiserror::Error, Debug)]
17pub enum RerankError {
18    #[error("http error: {0}")]
19    Http(#[from] http_client::Error),
20    #[error("provider error: {0}")]
21    Provider(String),
22    #[error("response error: {0}")]
23    Response(String),
24}
25
26impl Client<reqwest::Client> {
27    /// Rerank endpoint (customizable via ClientBuilder): POST {endpoints.rerank}
28    pub async fn rerank(
29        &self,
30        query: &str,
31        texts: impl IntoIterator<Item = String>,
32        top_n: Option<usize>,
33    ) -> Result<Vec<RerankResult>, RerankError> {
34        let texts: Vec<String> = texts.into_iter().collect();
35
36        let mut payload = json!({
37            "query": query,
38            "texts": texts,
39        });
40        if let Some(k) = top_n {
41            payload["top_n"] = json!(k);
42        }
43
44        let body =
45            serde_json::to_vec(&payload).map_err(|e| RerankError::Response(e.to_string()))?;
46
47        let req = self
48            .post_full(&self.endpoints.rerank)
49            .header("Content-Type", "application/json")
50            .body(body)
51            .map_err(|e| RerankError::Http(e.into()))?;
52
53        let response = HttpClientExt::send(&self.http_client, req).await?;
54        if !response.status().is_success() {
55            let text = http_client::text(response).await?;
56            return Err(RerankError::Provider(text));
57        }
58
59        let bytes: Vec<u8> = response.into_body().await?;
60        let parsed: Vec<RerankResult> = serde_json::from_slice(&bytes).map_err(|e| {
61            RerankError::Response(format!("Failed to parse TEI rerank response: {e}"))
62        })?;
63        Ok(parsed)
64    }
65}