llmsdk_provider/
reranking_model.rs1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use crate::error::Result;
10use crate::json::JsonObject;
11use crate::shared::{Headers, ProviderMetadata, ProviderOptions, ResponseInfo, Warning};
12
13#[async_trait]
17pub trait RerankingModel: Send + Sync + std::fmt::Debug {
18 fn provider(&self) -> &str;
20
21 fn model_id(&self) -> &str;
23
24 fn specification_version(&self) -> &'static str {
26 "v4"
27 }
28
29 async fn do_rerank(&self, options: RerankingOptions) -> Result<RerankingResult>;
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RerankingOptions {
43 pub documents: RerankingDocuments,
45 pub query: String,
47 #[serde(default, rename = "topN", skip_serializing_if = "Option::is_none")]
49 pub top_n: Option<u32>,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
52 pub headers: Option<Headers>,
53 #[serde(
55 default,
56 rename = "providerOptions",
57 skip_serializing_if = "Option::is_none"
58 )]
59 pub provider_options: Option<ProviderOptions>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66#[serde(tag = "type", rename_all = "kebab-case")]
67pub enum RerankingDocuments {
68 Text {
70 values: Vec<String>,
72 },
73 Object {
75 values: Vec<JsonObject>,
77 },
78}
79
80#[derive(Debug, Clone)]
84pub struct RerankingResult {
85 pub ranking: Vec<RankingEntry>,
89 pub warnings: Vec<Warning>,
91 pub provider_metadata: Option<ProviderMetadata>,
93 pub response: Option<ResponseInfo>,
95}
96
97#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
99pub struct RankingEntry {
100 pub index: u32,
102 #[serde(rename = "relevanceScore")]
104 pub relevance_score: f64,
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use serde_json::json;
111
112 #[test]
113 fn options_roundtrip_text_documents() {
114 let opts = RerankingOptions {
115 documents: RerankingDocuments::Text {
116 values: vec!["a".into(), "b".into()],
117 },
118 query: "q".into(),
119 top_n: Some(3),
120 headers: None,
121 provider_options: None,
122 };
123 let j = serde_json::to_value(&opts).unwrap();
124 assert_eq!(j["documents"]["type"], "text");
125 assert_eq!(j["documents"]["values"][0], "a");
126 assert_eq!(j["topN"], 3);
127 let back: RerankingOptions = serde_json::from_value(j).unwrap();
128 assert_eq!(back.top_n, Some(3));
129 }
130
131 #[test]
132 fn documents_object_variant_kebab_tagged() {
133 let docs = RerankingDocuments::Object {
134 values: vec![json!({ "title": "x" }).as_object().cloned().unwrap()],
135 };
136 let j = serde_json::to_value(&docs).unwrap();
137 assert_eq!(j["type"], "object");
138 }
139
140 #[test]
141 fn ranking_entry_uses_camel_case_score() {
142 let e = RankingEntry {
143 index: 2,
144 relevance_score: 0.87,
145 };
146 let j = serde_json::to_value(e).unwrap();
147 assert_eq!(j["index"], 2);
148 assert_eq!(j["relevanceScore"], 0.87);
149 }
150}