use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::json::JsonObject;
use crate::shared::{Headers, ProviderMetadata, ProviderOptions, ResponseInfo, Warning};
#[async_trait]
pub trait RerankingModel: Send + Sync + std::fmt::Debug {
fn provider(&self) -> &str;
fn model_id(&self) -> &str;
fn specification_version(&self) -> &'static str {
"v4"
}
async fn do_rerank(&self, options: RerankingOptions) -> Result<RerankingResult>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankingOptions {
pub documents: RerankingDocuments,
pub query: String,
#[serde(default, rename = "topN", skip_serializing_if = "Option::is_none")]
pub top_n: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub headers: Option<Headers>,
#[serde(
default,
rename = "providerOptions",
skip_serializing_if = "Option::is_none"
)]
pub provider_options: Option<ProviderOptions>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum RerankingDocuments {
Text {
values: Vec<String>,
},
Object {
values: Vec<JsonObject>,
},
}
#[derive(Debug, Clone)]
pub struct RerankingResult {
pub ranking: Vec<RankingEntry>,
pub warnings: Vec<Warning>,
pub provider_metadata: Option<ProviderMetadata>,
pub response: Option<ResponseInfo>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub struct RankingEntry {
pub index: u32,
#[serde(rename = "relevanceScore")]
pub relevance_score: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn options_roundtrip_text_documents() {
let opts = RerankingOptions {
documents: RerankingDocuments::Text {
values: vec!["a".into(), "b".into()],
},
query: "q".into(),
top_n: Some(3),
headers: None,
provider_options: None,
};
let j = serde_json::to_value(&opts).unwrap();
assert_eq!(j["documents"]["type"], "text");
assert_eq!(j["documents"]["values"][0], "a");
assert_eq!(j["topN"], 3);
let back: RerankingOptions = serde_json::from_value(j).unwrap();
assert_eq!(back.top_n, Some(3));
}
#[test]
fn documents_object_variant_kebab_tagged() {
let docs = RerankingDocuments::Object {
values: vec![json!({ "title": "x" }).as_object().cloned().unwrap()],
};
let j = serde_json::to_value(&docs).unwrap();
assert_eq!(j["type"], "object");
}
#[test]
fn ranking_entry_uses_camel_case_score() {
let e = RankingEntry {
index: 2,
relevance_score: 0.87,
};
let j = serde_json::to_value(e).unwrap();
assert_eq!(j["index"], 2);
assert_eq!(j["relevanceScore"], 0.87);
}
}