Skip to main content

llmsdk_provider/
reranking_model.rs

1//! Reranking model trait and supporting types.
2//!
3//! Mirrors `@ai-sdk/provider/src/reranking-model/v4/*`.
4// Rust guideline compliant 2026-02-21
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use crate::error::Result;
10use crate::json::JsonObject;
11use crate::shared::{Headers, ProviderMetadata, ProviderOptions, ResponseInfo, Warning};
12
13/// Contract every reranking model implements.
14///
15/// Mirrors `RerankingModelV4`.
16#[async_trait]
17pub trait RerankingModel: Send + Sync + std::fmt::Debug {
18    /// Provider id, e.g. `"cohere"`.
19    fn provider(&self) -> &str;
20
21    /// Provider-specific model id, e.g. `"rerank-english-v3.0"`.
22    fn model_id(&self) -> &str;
23
24    /// Specification version (currently `"v4"`).
25    fn specification_version(&self) -> &'static str {
26        "v4"
27    }
28
29    /// Rerank a list of documents against the given query.
30    ///
31    /// # Errors
32    ///
33    /// Returns a [`crate::ProviderError`] when the upstream call fails or
34    /// the response is malformed.
35    async fn do_rerank(&self, options: RerankingOptions) -> Result<RerankingResult>;
36}
37
38/// Options for one [`RerankingModel::do_rerank`] call.
39///
40/// Mirrors `RerankingModelV4CallOptions`.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RerankingOptions {
43    /// Documents to rerank.
44    pub documents: RerankingDocuments,
45    /// Query to rerank documents against.
46    pub query: String,
47    /// Limit returned documents to the top N.
48    #[serde(default, rename = "topN", skip_serializing_if = "Option::is_none")]
49    pub top_n: Option<u32>,
50    /// Extra HTTP headers (HTTP providers only).
51    #[serde(default, skip_serializing_if = "Option::is_none")]
52    pub headers: Option<Headers>,
53    /// Provider-specific options.
54    #[serde(
55        default,
56        rename = "providerOptions",
57        skip_serializing_if = "Option::is_none"
58    )]
59    pub provider_options: Option<ProviderOptions>,
60}
61
62/// Documents to rerank. Two-state tagged union over plain text or JSON objects.
63///
64/// Mirrors `RerankingModelV4CallOptions['documents']`.
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66#[serde(tag = "type", rename_all = "kebab-case")]
67pub enum RerankingDocuments {
68    /// Plain-text documents.
69    Text {
70        /// Text values.
71        values: Vec<String>,
72    },
73    /// JSON-object documents (Cohere passes them through with field filters).
74    Object {
75        /// Object values.
76        values: Vec<JsonObject>,
77    },
78}
79
80/// Result of [`RerankingModel::do_rerank`].
81///
82/// Mirrors `RerankingModelV4Result`.
83#[derive(Debug, Clone)]
84pub struct RerankingResult {
85    /// Reranked documents (sorted by relevance descending).
86    ///
87    /// Each entry refers back to the document's index in the input list.
88    pub ranking: Vec<RankingEntry>,
89    /// Warnings for the call.
90    pub warnings: Vec<Warning>,
91    /// Provider-specific metadata.
92    pub provider_metadata: Option<ProviderMetadata>,
93    /// Optional response info (telemetry).
94    pub response: Option<ResponseInfo>,
95}
96
97/// One entry in [`RerankingResult::ranking`].
98#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
99pub struct RankingEntry {
100    /// Index of the document in the input list before reranking.
101    pub index: u32,
102    /// Relevance score assigned by the model. Higher = more relevant.
103    #[serde(rename = "relevanceScore")]
104    pub relevance_score: f64,
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use serde_json::json;
111
112    #[test]
113    fn options_roundtrip_text_documents() {
114        let opts = RerankingOptions {
115            documents: RerankingDocuments::Text {
116                values: vec!["a".into(), "b".into()],
117            },
118            query: "q".into(),
119            top_n: Some(3),
120            headers: None,
121            provider_options: None,
122        };
123        let j = serde_json::to_value(&opts).unwrap();
124        assert_eq!(j["documents"]["type"], "text");
125        assert_eq!(j["documents"]["values"][0], "a");
126        assert_eq!(j["topN"], 3);
127        let back: RerankingOptions = serde_json::from_value(j).unwrap();
128        assert_eq!(back.top_n, Some(3));
129    }
130
131    #[test]
132    fn documents_object_variant_kebab_tagged() {
133        let docs = RerankingDocuments::Object {
134            values: vec![json!({ "title": "x" }).as_object().cloned().unwrap()],
135        };
136        let j = serde_json::to_value(&docs).unwrap();
137        assert_eq!(j["type"], "object");
138    }
139
140    #[test]
141    fn ranking_entry_uses_camel_case_score() {
142        let e = RankingEntry {
143            index: 2,
144            relevance_score: 0.87,
145        };
146        let j = serde_json::to_value(e).unwrap();
147        assert_eq!(j["index"], 2);
148        assert_eq!(j["relevanceScore"], 0.87);
149    }
150}