use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct RankingRequest {
pub model: String,
pub query: String,
pub documents: Vec<String>,
pub top_k: Option<usize>,
pub return_documents: Option<bool>,
pub max_chunks_per_doc: Option<usize>,
}
impl RankingRequest {
pub fn new(
model: impl Into<String>,
query: impl Into<String>,
documents: Vec<impl Into<String>>,
) -> Self {
Self {
model: model.into(),
query: query.into(),
documents: documents.into_iter().map(|d| d.into()).collect(),
top_k: None,
return_documents: None,
max_chunks_per_doc: None,
}
}
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
pub fn with_documents(mut self) -> Self {
self.return_documents = Some(true);
self
}
pub fn with_max_chunks_per_doc(mut self, max_chunks: usize) -> Self {
self.max_chunks_per_doc = Some(max_chunks);
self
}
}
#[derive(Debug, Clone)]
pub struct RankingResponse {
pub results: Vec<RankedDocument>,
pub model: String,
pub meta: Option<RankingMeta>,
}
impl RankingResponse {
pub fn new(model: impl Into<String>, results: Vec<RankedDocument>) -> Self {
Self {
model: model.into(),
results,
meta: None,
}
}
pub fn top(&self) -> Option<&RankedDocument> {
self.results.first()
}
pub fn ranked_indices(&self) -> Vec<usize> {
self.results.iter().map(|r| r.index).collect()
}
}
#[derive(Debug, Clone)]
pub struct RankedDocument {
pub index: usize,
pub score: f32,
pub document: Option<String>,
}
impl RankedDocument {
pub fn new(index: usize, score: f32) -> Self {
Self {
index,
score,
document: None,
}
}
pub fn with_document(mut self, document: impl Into<String>) -> Self {
self.document = Some(document.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct RankingMeta {
pub billed_units: Option<u64>,
pub api_version: Option<String>,
}
#[async_trait]
pub trait RankingProvider: Send + Sync {
fn name(&self) -> &str;
async fn rank(&self, request: RankingRequest) -> Result<RankingResponse>;
fn default_ranking_model(&self) -> Option<&str> {
None
}
fn max_documents(&self) -> usize {
1000
}
fn max_query_length(&self) -> usize {
2048
}
}
#[derive(Debug, Clone)]
pub struct ModerationRequest {
pub model: String,
pub input: String,
pub inputs: Option<Vec<ModerationInput>>,
}
impl ModerationRequest {
pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
Self {
model: model.into(),
input: input.into(),
inputs: None,
}
}
pub fn with_inputs(mut self, inputs: Vec<ModerationInput>) -> Self {
self.inputs = Some(inputs);
self
}
}
#[derive(Debug, Clone)]
pub enum ModerationInput {
Text(String),
ImageUrl(String),
ImageBase64 { data: String, media_type: String },
}
#[derive(Debug, Clone)]
pub struct ModerationResponse {
pub flagged: bool,
pub categories: ModerationCategories,
pub category_scores: ModerationScores,
pub model: String,
}
impl ModerationResponse {
pub fn new(flagged: bool) -> Self {
Self {
flagged,
categories: ModerationCategories::default(),
category_scores: ModerationScores::default(),
model: String::new(),
}
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_categories(mut self, categories: ModerationCategories) -> Self {
self.categories = categories;
self
}
pub fn with_scores(mut self, scores: ModerationScores) -> Self {
self.category_scores = scores;
self
}
pub fn flagged_categories(&self) -> Vec<&'static str> {
let mut result = Vec::new();
if self.categories.hate {
result.push("hate");
}
if self.categories.hate_threatening {
result.push("hate/threatening");
}
if self.categories.harassment {
result.push("harassment");
}
if self.categories.harassment_threatening {
result.push("harassment/threatening");
}
if self.categories.self_harm {
result.push("self-harm");
}
if self.categories.self_harm_intent {
result.push("self-harm/intent");
}
if self.categories.self_harm_instructions {
result.push("self-harm/instructions");
}
if self.categories.sexual {
result.push("sexual");
}
if self.categories.sexual_minors {
result.push("sexual/minors");
}
if self.categories.violence {
result.push("violence");
}
if self.categories.violence_graphic {
result.push("violence/graphic");
}
if self.categories.illicit {
result.push("illicit");
}
if self.categories.illicit_violent {
result.push("illicit/violent");
}
result
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModerationCategories {
pub hate: bool,
#[serde(rename = "hate/threatening")]
pub hate_threatening: bool,
pub harassment: bool,
#[serde(rename = "harassment/threatening")]
pub harassment_threatening: bool,
#[serde(rename = "self-harm")]
pub self_harm: bool,
#[serde(rename = "self-harm/intent")]
pub self_harm_intent: bool,
#[serde(rename = "self-harm/instructions")]
pub self_harm_instructions: bool,
pub sexual: bool,
#[serde(rename = "sexual/minors")]
pub sexual_minors: bool,
pub violence: bool,
#[serde(rename = "violence/graphic")]
pub violence_graphic: bool,
#[serde(default)]
pub illicit: bool,
#[serde(default, rename = "illicit/violent")]
pub illicit_violent: bool,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModerationScores {
pub hate: f32,
#[serde(rename = "hate/threatening")]
pub hate_threatening: f32,
pub harassment: f32,
#[serde(rename = "harassment/threatening")]
pub harassment_threatening: f32,
#[serde(rename = "self-harm")]
pub self_harm: f32,
#[serde(rename = "self-harm/intent")]
pub self_harm_intent: f32,
#[serde(rename = "self-harm/instructions")]
pub self_harm_instructions: f32,
pub sexual: f32,
#[serde(rename = "sexual/minors")]
pub sexual_minors: f32,
pub violence: f32,
#[serde(rename = "violence/graphic")]
pub violence_graphic: f32,
#[serde(default)]
pub illicit: f32,
#[serde(default, rename = "illicit/violent")]
pub illicit_violent: f32,
}
#[async_trait]
pub trait ModerationProvider: Send + Sync {
fn name(&self) -> &str;
async fn moderate(&self, request: ModerationRequest) -> Result<ModerationResponse>;
fn default_moderation_model(&self) -> Option<&str> {
None
}
fn supports_multimodal(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct ClassificationRequest {
pub model: String,
pub input: String,
pub labels: Vec<String>,
pub multi_label: Option<bool>,
pub examples: Option<Vec<ClassificationExample>>,
}
impl ClassificationRequest {
pub fn new(
model: impl Into<String>,
input: impl Into<String>,
labels: Vec<impl Into<String>>,
) -> Self {
Self {
model: model.into(),
input: input.into(),
labels: labels.into_iter().map(|l| l.into()).collect(),
multi_label: None,
examples: None,
}
}
pub fn with_multi_label(mut self) -> Self {
self.multi_label = Some(true);
self
}
pub fn with_examples(mut self, examples: Vec<ClassificationExample>) -> Self {
self.examples = Some(examples);
self
}
}
#[derive(Debug, Clone)]
pub struct ClassificationExample {
pub text: String,
pub label: String,
}
impl ClassificationExample {
pub fn new(text: impl Into<String>, label: impl Into<String>) -> Self {
Self {
text: text.into(),
label: label.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct ClassificationResponse {
pub predictions: Vec<ClassificationPrediction>,
pub model: String,
}
impl ClassificationResponse {
pub fn new(model: impl Into<String>, predictions: Vec<ClassificationPrediction>) -> Self {
Self {
model: model.into(),
predictions,
}
}
pub fn top(&self) -> Option<&ClassificationPrediction> {
self.predictions.first()
}
pub fn label(&self) -> Option<&str> {
self.predictions.first().map(|p| p.label.as_str())
}
pub fn score_for(&self, label: &str) -> Option<f32> {
self.predictions
.iter()
.find(|p| p.label == label)
.map(|p| p.score)
}
}
#[derive(Debug, Clone)]
pub struct ClassificationPrediction {
pub label: String,
pub score: f32,
}
impl ClassificationPrediction {
pub fn new(label: impl Into<String>, score: f32) -> Self {
Self {
label: label.into(),
score,
}
}
}
#[async_trait]
pub trait ClassificationProvider: Send + Sync {
fn name(&self) -> &str;
async fn classify(&self, request: ClassificationRequest) -> Result<ClassificationResponse>;
fn default_classification_model(&self) -> Option<&str> {
None
}
fn max_labels(&self) -> usize {
100
}
fn supports_few_shot(&self) -> bool {
false
}
}
#[derive(Debug, Clone)]
pub struct RankingModelInfo {
pub id: &'static str,
pub provider: &'static str,
pub max_documents: usize,
pub max_query_tokens: usize,
pub price_per_1k_searches: f64,
}
pub static RANKING_MODELS: &[RankingModelInfo] = &[
RankingModelInfo {
id: "rerank-english-v3.0",
provider: "cohere",
max_documents: 1000,
max_query_tokens: 2048,
price_per_1k_searches: 2.00,
},
RankingModelInfo {
id: "rerank-multilingual-v3.0",
provider: "cohere",
max_documents: 1000,
max_query_tokens: 2048,
price_per_1k_searches: 2.00,
},
RankingModelInfo {
id: "rerank-2",
provider: "voyage",
max_documents: 1000,
max_query_tokens: 4000,
price_per_1k_searches: 0.05,
},
RankingModelInfo {
id: "rerank-lite-2",
provider: "voyage",
max_documents: 1000,
max_query_tokens: 4000,
price_per_1k_searches: 0.02,
},
RankingModelInfo {
id: "jina-reranker-v2-base-multilingual",
provider: "jina",
max_documents: 500,
max_query_tokens: 8192,
price_per_1k_searches: 0.02,
},
];
#[derive(Debug, Clone)]
pub struct ModerationModelInfo {
pub id: &'static str,
pub provider: &'static str,
pub supports_images: bool,
pub price_per_1k_requests: f64,
}
pub static MODERATION_MODELS: &[ModerationModelInfo] = &[
ModerationModelInfo {
id: "omni-moderation-latest",
provider: "openai",
supports_images: true,
price_per_1k_requests: 0.0, },
ModerationModelInfo {
id: "text-moderation-latest",
provider: "openai",
supports_images: false,
price_per_1k_requests: 0.0, },
ModerationModelInfo {
id: "text-moderation-stable",
provider: "openai",
supports_images: false,
price_per_1k_requests: 0.0, },
];
pub fn get_ranking_model_info(model_id: &str) -> Option<&'static RankingModelInfo> {
RANKING_MODELS.iter().find(|m| m.id == model_id)
}
pub fn get_moderation_model_info(model_id: &str) -> Option<&'static ModerationModelInfo> {
MODERATION_MODELS.iter().find(|m| m.id == model_id)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ranking_request_builder() {
let request = RankingRequest::new(
"rerank-english-v3.0",
"What is the capital?",
vec!["Paris is the capital", "Berlin is a city"],
)
.with_top_k(5)
.with_documents();
assert_eq!(request.model, "rerank-english-v3.0");
assert_eq!(request.query, "What is the capital?");
assert_eq!(request.documents.len(), 2);
assert_eq!(request.top_k, Some(5));
assert_eq!(request.return_documents, Some(true));
}
#[test]
fn test_ranking_response() {
let results = vec![
RankedDocument::new(1, 0.95).with_document("Top doc"),
RankedDocument::new(0, 0.8),
];
let response = RankingResponse::new("rerank-english-v3.0", results);
assert_eq!(response.top().unwrap().score, 0.95);
assert_eq!(response.ranked_indices(), vec![1, 0]);
}
#[test]
fn test_moderation_request() {
let request = ModerationRequest::new("omni-moderation-latest", "Some text to check");
assert_eq!(request.model, "omni-moderation-latest");
assert_eq!(request.input, "Some text to check");
}
#[test]
fn test_moderation_response() {
let categories = ModerationCategories {
hate: true,
violence: true,
..Default::default()
};
let response = ModerationResponse::new(true)
.with_model("omni-moderation-latest")
.with_categories(categories);
assert!(response.flagged);
let flagged = response.flagged_categories();
assert!(flagged.contains(&"hate"));
assert!(flagged.contains(&"violence"));
assert!(!flagged.contains(&"sexual"));
}
#[test]
fn test_classification_request_builder() {
let request = ClassificationRequest::new(
"embed-english-v3.0",
"I love this product!",
vec!["positive", "negative", "neutral"],
)
.with_multi_label()
.with_examples(vec![
ClassificationExample::new("Great!", "positive"),
ClassificationExample::new("Terrible", "negative"),
]);
assert_eq!(request.model, "embed-english-v3.0");
assert_eq!(request.input, "I love this product!");
assert_eq!(request.labels.len(), 3);
assert_eq!(request.multi_label, Some(true));
assert_eq!(request.examples.as_ref().unwrap().len(), 2);
}
#[test]
fn test_classification_response() {
let predictions = vec![
ClassificationPrediction::new("positive", 0.92),
ClassificationPrediction::new("neutral", 0.06),
ClassificationPrediction::new("negative", 0.02),
];
let response = ClassificationResponse::new("model", predictions);
assert_eq!(response.label(), Some("positive"));
assert_eq!(response.top().unwrap().score, 0.92);
assert_eq!(response.score_for("neutral"), Some(0.06));
assert_eq!(response.score_for("unknown"), None);
}
#[test]
fn test_ranking_model_registry() {
let model = get_ranking_model_info("rerank-english-v3.0");
assert!(model.is_some());
let model = model.unwrap();
assert_eq!(model.provider, "cohere");
assert_eq!(model.max_documents, 1000);
}
#[test]
fn test_moderation_model_registry() {
let model = get_moderation_model_info("omni-moderation-latest");
assert!(model.is_some());
let model = model.unwrap();
assert_eq!(model.provider, "openai");
assert!(model.supports_images);
}
}