1use async_trait::async_trait;
67use serde::{Deserialize, Serialize};
68
69use crate::error::Result;
70
71#[derive(Debug, Clone)]
77pub struct RankingRequest {
78 pub model: String,
80 pub query: String,
82 pub documents: Vec<String>,
84 pub top_k: Option<usize>,
86 pub return_documents: Option<bool>,
88 pub max_chunks_per_doc: Option<usize>,
90}
91
92impl RankingRequest {
93 pub fn new(
95 model: impl Into<String>,
96 query: impl Into<String>,
97 documents: Vec<impl Into<String>>,
98 ) -> Self {
99 Self {
100 model: model.into(),
101 query: query.into(),
102 documents: documents.into_iter().map(|d| d.into()).collect(),
103 top_k: None,
104 return_documents: None,
105 max_chunks_per_doc: None,
106 }
107 }
108
109 pub fn with_top_k(mut self, top_k: usize) -> Self {
111 self.top_k = Some(top_k);
112 self
113 }
114
115 pub fn with_documents(mut self) -> Self {
117 self.return_documents = Some(true);
118 self
119 }
120
121 pub fn with_max_chunks_per_doc(mut self, max_chunks: usize) -> Self {
123 self.max_chunks_per_doc = Some(max_chunks);
124 self
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct RankingResponse {
131 pub results: Vec<RankedDocument>,
133 pub model: String,
135 pub meta: Option<RankingMeta>,
137}
138
139impl RankingResponse {
140 pub fn new(model: impl Into<String>, results: Vec<RankedDocument>) -> Self {
142 Self {
143 model: model.into(),
144 results,
145 meta: None,
146 }
147 }
148
149 pub fn top(&self) -> Option<&RankedDocument> {
151 self.results.first()
152 }
153
154 pub fn ranked_indices(&self) -> Vec<usize> {
156 self.results.iter().map(|r| r.index).collect()
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct RankedDocument {
163 pub index: usize,
165 pub score: f32,
167 pub document: Option<String>,
169}
170
171impl RankedDocument {
172 pub fn new(index: usize, score: f32) -> Self {
174 Self {
175 index,
176 score,
177 document: None,
178 }
179 }
180
181 pub fn with_document(mut self, document: impl Into<String>) -> Self {
183 self.document = Some(document.into());
184 self
185 }
186}
187
188#[derive(Debug, Clone, Default)]
190pub struct RankingMeta {
191 pub billed_units: Option<u64>,
193 pub api_version: Option<String>,
195}
196
197#[async_trait]
199pub trait RankingProvider: Send + Sync {
200 fn name(&self) -> &str;
202
203 async fn rank(&self, request: RankingRequest) -> Result<RankingResponse>;
205
206 fn default_ranking_model(&self) -> Option<&str> {
208 None
209 }
210
211 fn max_documents(&self) -> usize {
213 1000
214 }
215
216 fn max_query_length(&self) -> usize {
218 2048
219 }
220}
221
222#[derive(Debug, Clone)]
228pub struct ModerationRequest {
229 pub model: String,
231 pub input: String,
233 pub inputs: Option<Vec<ModerationInput>>,
235}
236
237impl ModerationRequest {
238 pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
240 Self {
241 model: model.into(),
242 input: input.into(),
243 inputs: None,
244 }
245 }
246
247 pub fn with_inputs(mut self, inputs: Vec<ModerationInput>) -> Self {
249 self.inputs = Some(inputs);
250 self
251 }
252}
253
254#[derive(Debug, Clone)]
256pub enum ModerationInput {
257 Text(String),
259 ImageUrl(String),
261 ImageBase64 { data: String, media_type: String },
263}
264
265#[derive(Debug, Clone)]
267pub struct ModerationResponse {
268 pub flagged: bool,
270 pub categories: ModerationCategories,
272 pub category_scores: ModerationScores,
274 pub model: String,
276}
277
278impl ModerationResponse {
279 pub fn new(flagged: bool) -> Self {
281 Self {
282 flagged,
283 categories: ModerationCategories::default(),
284 category_scores: ModerationScores::default(),
285 model: String::new(),
286 }
287 }
288
289 pub fn with_model(mut self, model: impl Into<String>) -> Self {
291 self.model = model.into();
292 self
293 }
294
295 pub fn with_categories(mut self, categories: ModerationCategories) -> Self {
297 self.categories = categories;
298 self
299 }
300
301 pub fn with_scores(mut self, scores: ModerationScores) -> Self {
303 self.category_scores = scores;
304 self
305 }
306
307 pub fn flagged_categories(&self) -> Vec<&'static str> {
309 let mut result = Vec::new();
310 if self.categories.hate {
311 result.push("hate");
312 }
313 if self.categories.hate_threatening {
314 result.push("hate/threatening");
315 }
316 if self.categories.harassment {
317 result.push("harassment");
318 }
319 if self.categories.harassment_threatening {
320 result.push("harassment/threatening");
321 }
322 if self.categories.self_harm {
323 result.push("self-harm");
324 }
325 if self.categories.self_harm_intent {
326 result.push("self-harm/intent");
327 }
328 if self.categories.self_harm_instructions {
329 result.push("self-harm/instructions");
330 }
331 if self.categories.sexual {
332 result.push("sexual");
333 }
334 if self.categories.sexual_minors {
335 result.push("sexual/minors");
336 }
337 if self.categories.violence {
338 result.push("violence");
339 }
340 if self.categories.violence_graphic {
341 result.push("violence/graphic");
342 }
343 if self.categories.illicit {
344 result.push("illicit");
345 }
346 if self.categories.illicit_violent {
347 result.push("illicit/violent");
348 }
349 result
350 }
351}
352
353#[derive(Debug, Clone, Default, Serialize, Deserialize)]
355pub struct ModerationCategories {
356 pub hate: bool,
358 #[serde(rename = "hate/threatening")]
360 pub hate_threatening: bool,
361 pub harassment: bool,
363 #[serde(rename = "harassment/threatening")]
365 pub harassment_threatening: bool,
366 #[serde(rename = "self-harm")]
368 pub self_harm: bool,
369 #[serde(rename = "self-harm/intent")]
371 pub self_harm_intent: bool,
372 #[serde(rename = "self-harm/instructions")]
374 pub self_harm_instructions: bool,
375 pub sexual: bool,
377 #[serde(rename = "sexual/minors")]
379 pub sexual_minors: bool,
380 pub violence: bool,
382 #[serde(rename = "violence/graphic")]
384 pub violence_graphic: bool,
385 #[serde(default)]
387 pub illicit: bool,
388 #[serde(default, rename = "illicit/violent")]
390 pub illicit_violent: bool,
391}
392
393#[derive(Debug, Clone, Default, Serialize, Deserialize)]
395pub struct ModerationScores {
396 pub hate: f32,
398 #[serde(rename = "hate/threatening")]
400 pub hate_threatening: f32,
401 pub harassment: f32,
403 #[serde(rename = "harassment/threatening")]
405 pub harassment_threatening: f32,
406 #[serde(rename = "self-harm")]
408 pub self_harm: f32,
409 #[serde(rename = "self-harm/intent")]
411 pub self_harm_intent: f32,
412 #[serde(rename = "self-harm/instructions")]
414 pub self_harm_instructions: f32,
415 pub sexual: f32,
417 #[serde(rename = "sexual/minors")]
419 pub sexual_minors: f32,
420 pub violence: f32,
422 #[serde(rename = "violence/graphic")]
424 pub violence_graphic: f32,
425 #[serde(default)]
427 pub illicit: f32,
428 #[serde(default, rename = "illicit/violent")]
430 pub illicit_violent: f32,
431}
432
433#[async_trait]
435pub trait ModerationProvider: Send + Sync {
436 fn name(&self) -> &str;
438
439 async fn moderate(&self, request: ModerationRequest) -> Result<ModerationResponse>;
441
442 fn default_moderation_model(&self) -> Option<&str> {
444 None
445 }
446
447 fn supports_multimodal(&self) -> bool {
449 false
450 }
451}
452
453#[derive(Debug, Clone)]
459pub struct ClassificationRequest {
460 pub model: String,
462 pub input: String,
464 pub labels: Vec<String>,
466 pub multi_label: Option<bool>,
468 pub examples: Option<Vec<ClassificationExample>>,
470}
471
472impl ClassificationRequest {
473 pub fn new(
475 model: impl Into<String>,
476 input: impl Into<String>,
477 labels: Vec<impl Into<String>>,
478 ) -> Self {
479 Self {
480 model: model.into(),
481 input: input.into(),
482 labels: labels.into_iter().map(|l| l.into()).collect(),
483 multi_label: None,
484 examples: None,
485 }
486 }
487
488 pub fn with_multi_label(mut self) -> Self {
490 self.multi_label = Some(true);
491 self
492 }
493
494 pub fn with_examples(mut self, examples: Vec<ClassificationExample>) -> Self {
496 self.examples = Some(examples);
497 self
498 }
499}
500
501#[derive(Debug, Clone)]
503pub struct ClassificationExample {
504 pub text: String,
506 pub label: String,
508}
509
510impl ClassificationExample {
511 pub fn new(text: impl Into<String>, label: impl Into<String>) -> Self {
513 Self {
514 text: text.into(),
515 label: label.into(),
516 }
517 }
518}
519
520#[derive(Debug, Clone)]
522pub struct ClassificationResponse {
523 pub predictions: Vec<ClassificationPrediction>,
525 pub model: String,
527}
528
529impl ClassificationResponse {
530 pub fn new(model: impl Into<String>, predictions: Vec<ClassificationPrediction>) -> Self {
532 Self {
533 model: model.into(),
534 predictions,
535 }
536 }
537
538 pub fn top(&self) -> Option<&ClassificationPrediction> {
540 self.predictions.first()
541 }
542
543 pub fn label(&self) -> Option<&str> {
545 self.predictions.first().map(|p| p.label.as_str())
546 }
547
548 pub fn score_for(&self, label: &str) -> Option<f32> {
550 self.predictions
551 .iter()
552 .find(|p| p.label == label)
553 .map(|p| p.score)
554 }
555}
556
557#[derive(Debug, Clone)]
559pub struct ClassificationPrediction {
560 pub label: String,
562 pub score: f32,
564}
565
566impl ClassificationPrediction {
567 pub fn new(label: impl Into<String>, score: f32) -> Self {
569 Self {
570 label: label.into(),
571 score,
572 }
573 }
574}
575
576#[async_trait]
578pub trait ClassificationProvider: Send + Sync {
579 fn name(&self) -> &str;
581
582 async fn classify(&self, request: ClassificationRequest) -> Result<ClassificationResponse>;
584
585 fn default_classification_model(&self) -> Option<&str> {
587 None
588 }
589
590 fn max_labels(&self) -> usize {
592 100
593 }
594
595 fn supports_few_shot(&self) -> bool {
597 false
598 }
599}
600
601#[derive(Debug, Clone)]
607pub struct RankingModelInfo {
608 pub id: &'static str,
610 pub provider: &'static str,
612 pub max_documents: usize,
614 pub max_query_tokens: usize,
616 pub price_per_1k_searches: f64,
618}
619
620pub static RANKING_MODELS: &[RankingModelInfo] = &[
622 RankingModelInfo {
624 id: "rerank-english-v3.0",
625 provider: "cohere",
626 max_documents: 1000,
627 max_query_tokens: 2048,
628 price_per_1k_searches: 2.00,
629 },
630 RankingModelInfo {
631 id: "rerank-multilingual-v3.0",
632 provider: "cohere",
633 max_documents: 1000,
634 max_query_tokens: 2048,
635 price_per_1k_searches: 2.00,
636 },
637 RankingModelInfo {
639 id: "rerank-2",
640 provider: "voyage",
641 max_documents: 1000,
642 max_query_tokens: 4000,
643 price_per_1k_searches: 0.05,
644 },
645 RankingModelInfo {
646 id: "rerank-lite-2",
647 provider: "voyage",
648 max_documents: 1000,
649 max_query_tokens: 4000,
650 price_per_1k_searches: 0.02,
651 },
652 RankingModelInfo {
654 id: "jina-reranker-v2-base-multilingual",
655 provider: "jina",
656 max_documents: 500,
657 max_query_tokens: 8192,
658 price_per_1k_searches: 0.02,
659 },
660];
661
662#[derive(Debug, Clone)]
664pub struct ModerationModelInfo {
665 pub id: &'static str,
667 pub provider: &'static str,
669 pub supports_images: bool,
671 pub price_per_1k_requests: f64,
673}
674
675pub static MODERATION_MODELS: &[ModerationModelInfo] = &[
677 ModerationModelInfo {
679 id: "omni-moderation-latest",
680 provider: "openai",
681 supports_images: true,
682 price_per_1k_requests: 0.0, },
684 ModerationModelInfo {
685 id: "text-moderation-latest",
686 provider: "openai",
687 supports_images: false,
688 price_per_1k_requests: 0.0, },
690 ModerationModelInfo {
691 id: "text-moderation-stable",
692 provider: "openai",
693 supports_images: false,
694 price_per_1k_requests: 0.0, },
696];
697
698pub fn get_ranking_model_info(model_id: &str) -> Option<&'static RankingModelInfo> {
700 RANKING_MODELS.iter().find(|m| m.id == model_id)
701}
702
703pub fn get_moderation_model_info(model_id: &str) -> Option<&'static ModerationModelInfo> {
705 MODERATION_MODELS.iter().find(|m| m.id == model_id)
706}
707
708#[cfg(test)]
709mod tests {
710 use super::*;
711
712 #[test]
714 fn test_ranking_request_builder() {
715 let request = RankingRequest::new(
716 "rerank-english-v3.0",
717 "What is the capital?",
718 vec!["Paris is the capital", "Berlin is a city"],
719 )
720 .with_top_k(5)
721 .with_documents();
722
723 assert_eq!(request.model, "rerank-english-v3.0");
724 assert_eq!(request.query, "What is the capital?");
725 assert_eq!(request.documents.len(), 2);
726 assert_eq!(request.top_k, Some(5));
727 assert_eq!(request.return_documents, Some(true));
728 }
729
730 #[test]
731 fn test_ranking_response() {
732 let results = vec![
733 RankedDocument::new(1, 0.95).with_document("Top doc"),
734 RankedDocument::new(0, 0.8),
735 ];
736 let response = RankingResponse::new("rerank-english-v3.0", results);
737
738 assert_eq!(response.top().unwrap().score, 0.95);
739 assert_eq!(response.ranked_indices(), vec![1, 0]);
740 }
741
742 #[test]
744 fn test_moderation_request() {
745 let request = ModerationRequest::new("omni-moderation-latest", "Some text to check");
746 assert_eq!(request.model, "omni-moderation-latest");
747 assert_eq!(request.input, "Some text to check");
748 }
749
750 #[test]
751 fn test_moderation_response() {
752 let categories = ModerationCategories {
753 hate: true,
754 violence: true,
755 ..Default::default()
756 };
757
758 let response = ModerationResponse::new(true)
759 .with_model("omni-moderation-latest")
760 .with_categories(categories);
761
762 assert!(response.flagged);
763 let flagged = response.flagged_categories();
764 assert!(flagged.contains(&"hate"));
765 assert!(flagged.contains(&"violence"));
766 assert!(!flagged.contains(&"sexual"));
767 }
768
769 #[test]
771 fn test_classification_request_builder() {
772 let request = ClassificationRequest::new(
773 "embed-english-v3.0",
774 "I love this product!",
775 vec!["positive", "negative", "neutral"],
776 )
777 .with_multi_label()
778 .with_examples(vec![
779 ClassificationExample::new("Great!", "positive"),
780 ClassificationExample::new("Terrible", "negative"),
781 ]);
782
783 assert_eq!(request.model, "embed-english-v3.0");
784 assert_eq!(request.input, "I love this product!");
785 assert_eq!(request.labels.len(), 3);
786 assert_eq!(request.multi_label, Some(true));
787 assert_eq!(request.examples.as_ref().unwrap().len(), 2);
788 }
789
790 #[test]
791 fn test_classification_response() {
792 let predictions = vec![
793 ClassificationPrediction::new("positive", 0.92),
794 ClassificationPrediction::new("neutral", 0.06),
795 ClassificationPrediction::new("negative", 0.02),
796 ];
797 let response = ClassificationResponse::new("model", predictions);
798
799 assert_eq!(response.label(), Some("positive"));
800 assert_eq!(response.top().unwrap().score, 0.92);
801 assert_eq!(response.score_for("neutral"), Some(0.06));
802 assert_eq!(response.score_for("unknown"), None);
803 }
804
805 #[test]
807 fn test_ranking_model_registry() {
808 let model = get_ranking_model_info("rerank-english-v3.0");
809 assert!(model.is_some());
810 let model = model.unwrap();
811 assert_eq!(model.provider, "cohere");
812 assert_eq!(model.max_documents, 1000);
813 }
814
815 #[test]
816 fn test_moderation_model_registry() {
817 let model = get_moderation_model_info("omni-moderation-latest");
818 assert!(model.is_some());
819 let model = model.unwrap();
820 assert_eq!(model.provider, "openai");
821 assert!(model.supports_images);
822 }
823}