llmkit/
specialized.rs

1//! Specialized AI APIs for ranking, moderation, and classification.
2//!
3//! This module provides unified interfaces for specialized AI tasks that go beyond
4//! text generation, including document ranking/reranking, content moderation,
5//! and text classification.
6//!
7//! # Ranking Example
8//!
9//! ```ignore
10//! use llmkit::{RankingProvider, RankingRequest};
11//!
12//! let provider = CohereProvider::from_env()?;
13//!
14//! let request = RankingRequest::new(
15//!     "rerank-english-v3.0",
16//!     "What is the capital of France?",
17//!     vec![
18//!         "Paris is the capital of France.",
19//!         "Berlin is the capital of Germany.",
20//!         "London is the capital of England.",
21//!     ],
22//! );
23//!
24//! let response = provider.rank(request).await?;
25//! for result in &response.results {
26//!     println!("Score: {:.4} - {}", result.score, result.document.as_ref().unwrap());
27//! }
28//! ```
29//!
30//! # Moderation Example
31//!
32//! ```ignore
33//! use llmkit::{ModerationProvider, ModerationRequest};
34//!
35//! let provider = OpenAIModerationProvider::from_env()?;
36//!
37//! let request = ModerationRequest::new("omni-moderation-latest", "Some user content to check");
38//!
39//! let response = provider.moderate(request).await?;
40//! if response.flagged {
41//!     println!("Content was flagged!");
42//!     for category in response.flagged_categories() {
43//!         println!("  - {}", category);
44//!     }
45//! }
46//! ```
47//!
48//! # Classification Example
49//!
50//! ```ignore
51//! use llmkit::{ClassificationProvider, ClassificationRequest};
52//!
53//! let provider = CohereClassifyProvider::from_env()?;
54//!
55//! let request = ClassificationRequest::new(
56//!     "embed-english-v3.0",
57//!     "I love this product!",
58//!     vec!["positive", "negative", "neutral"],
59//! );
60//!
61//! let response = provider.classify(request).await?;
62//! let top = &response.predictions[0];
63//! println!("Predicted: {} (score: {:.4})", top.label, top.score);
64//! ```
65
66use async_trait::async_trait;
67use serde::{Deserialize, Serialize};
68
69use crate::error::Result;
70
71// ============================================================================
72// Ranking / Reranking
73// ============================================================================
74
75/// Request for ranking/reranking documents.
76#[derive(Debug, Clone)]
77pub struct RankingRequest {
78    /// The model to use for ranking.
79    pub model: String,
80    /// The query to rank documents against.
81    pub query: String,
82    /// The documents to rank.
83    pub documents: Vec<String>,
84    /// Maximum number of results to return.
85    pub top_k: Option<usize>,
86    /// Whether to return document content in results.
87    pub return_documents: Option<bool>,
88    /// Maximum number of tokens per document (for truncation).
89    pub max_chunks_per_doc: Option<usize>,
90}
91
92impl RankingRequest {
93    /// Create a new ranking request.
94    pub fn new(
95        model: impl Into<String>,
96        query: impl Into<String>,
97        documents: Vec<impl Into<String>>,
98    ) -> Self {
99        Self {
100            model: model.into(),
101            query: query.into(),
102            documents: documents.into_iter().map(|d| d.into()).collect(),
103            top_k: None,
104            return_documents: None,
105            max_chunks_per_doc: None,
106        }
107    }
108
109    /// Set the maximum number of results to return.
110    pub fn with_top_k(mut self, top_k: usize) -> Self {
111        self.top_k = Some(top_k);
112        self
113    }
114
115    /// Include document content in results.
116    pub fn with_documents(mut self) -> Self {
117        self.return_documents = Some(true);
118        self
119    }
120
121    /// Set max chunks per document for long document handling.
122    pub fn with_max_chunks_per_doc(mut self, max_chunks: usize) -> Self {
123        self.max_chunks_per_doc = Some(max_chunks);
124        self
125    }
126}
127
128/// Response from a ranking request.
129#[derive(Debug, Clone)]
130pub struct RankingResponse {
131    /// Ranked results, sorted by relevance (highest first).
132    pub results: Vec<RankedDocument>,
133    /// The model used for ranking.
134    pub model: String,
135    /// API metadata (billing, etc.)
136    pub meta: Option<RankingMeta>,
137}
138
139impl RankingResponse {
140    /// Create a new ranking response.
141    pub fn new(model: impl Into<String>, results: Vec<RankedDocument>) -> Self {
142        Self {
143            model: model.into(),
144            results,
145            meta: None,
146        }
147    }
148
149    /// Get the top-ranked document.
150    pub fn top(&self) -> Option<&RankedDocument> {
151        self.results.first()
152    }
153
154    /// Get the indices of ranked documents in order.
155    pub fn ranked_indices(&self) -> Vec<usize> {
156        self.results.iter().map(|r| r.index).collect()
157    }
158}
159
160/// A ranked document with its relevance score.
161#[derive(Debug, Clone)]
162pub struct RankedDocument {
163    /// Original index in the input documents array.
164    pub index: usize,
165    /// Relevance score (higher is more relevant).
166    pub score: f32,
167    /// The document content (if return_documents was true).
168    pub document: Option<String>,
169}
170
171impl RankedDocument {
172    /// Create a new ranked document.
173    pub fn new(index: usize, score: f32) -> Self {
174        Self {
175            index,
176            score,
177            document: None,
178        }
179    }
180
181    /// Set the document content.
182    pub fn with_document(mut self, document: impl Into<String>) -> Self {
183        self.document = Some(document.into());
184        self
185    }
186}
187
188/// Metadata from ranking API response.
189#[derive(Debug, Clone, Default)]
190pub struct RankingMeta {
191    /// Billing tokens used.
192    pub billed_units: Option<u64>,
193    /// API version.
194    pub api_version: Option<String>,
195}
196
197/// Trait for providers that support document ranking/reranking.
198#[async_trait]
199pub trait RankingProvider: Send + Sync {
200    /// Get the provider name.
201    fn name(&self) -> &str;
202
203    /// Rank documents by relevance to a query.
204    async fn rank(&self, request: RankingRequest) -> Result<RankingResponse>;
205
206    /// Get the default model for ranking.
207    fn default_ranking_model(&self) -> Option<&str> {
208        None
209    }
210
211    /// Get the maximum number of documents that can be ranked in one request.
212    fn max_documents(&self) -> usize {
213        1000
214    }
215
216    /// Get the maximum query length in characters.
217    fn max_query_length(&self) -> usize {
218        2048
219    }
220}
221
222// ============================================================================
223// Moderation
224// ============================================================================
225
226/// Request for content moderation.
227#[derive(Debug, Clone)]
228pub struct ModerationRequest {
229    /// The model to use for moderation.
230    pub model: String,
231    /// The text content to moderate.
232    pub input: String,
233    /// Additional inputs (for multi-modal moderation).
234    pub inputs: Option<Vec<ModerationInput>>,
235}
236
237impl ModerationRequest {
238    /// Create a new moderation request.
239    pub fn new(model: impl Into<String>, input: impl Into<String>) -> Self {
240        Self {
241            model: model.into(),
242            input: input.into(),
243            inputs: None,
244        }
245    }
246
247    /// Add multiple inputs for moderation.
248    pub fn with_inputs(mut self, inputs: Vec<ModerationInput>) -> Self {
249        self.inputs = Some(inputs);
250        self
251    }
252}
253
254/// Input types for multi-modal moderation.
255#[derive(Debug, Clone)]
256pub enum ModerationInput {
257    /// Text content.
258    Text(String),
259    /// Image URL.
260    ImageUrl(String),
261    /// Base64-encoded image.
262    ImageBase64 { data: String, media_type: String },
263}
264
265/// Response from a moderation request.
266#[derive(Debug, Clone)]
267pub struct ModerationResponse {
268    /// Whether any content was flagged.
269    pub flagged: bool,
270    /// Category flags indicating which categories were triggered.
271    pub categories: ModerationCategories,
272    /// Category scores (0.0 to 1.0).
273    pub category_scores: ModerationScores,
274    /// Model used for moderation.
275    pub model: String,
276}
277
278impl ModerationResponse {
279    /// Create a new moderation response.
280    pub fn new(flagged: bool) -> Self {
281        Self {
282            flagged,
283            categories: ModerationCategories::default(),
284            category_scores: ModerationScores::default(),
285            model: String::new(),
286        }
287    }
288
289    /// Set the model name.
290    pub fn with_model(mut self, model: impl Into<String>) -> Self {
291        self.model = model.into();
292        self
293    }
294
295    /// Set category flags.
296    pub fn with_categories(mut self, categories: ModerationCategories) -> Self {
297        self.categories = categories;
298        self
299    }
300
301    /// Set category scores.
302    pub fn with_scores(mut self, scores: ModerationScores) -> Self {
303        self.category_scores = scores;
304        self
305    }
306
307    /// Get a list of flagged category names.
308    pub fn flagged_categories(&self) -> Vec<&'static str> {
309        let mut result = Vec::new();
310        if self.categories.hate {
311            result.push("hate");
312        }
313        if self.categories.hate_threatening {
314            result.push("hate/threatening");
315        }
316        if self.categories.harassment {
317            result.push("harassment");
318        }
319        if self.categories.harassment_threatening {
320            result.push("harassment/threatening");
321        }
322        if self.categories.self_harm {
323            result.push("self-harm");
324        }
325        if self.categories.self_harm_intent {
326            result.push("self-harm/intent");
327        }
328        if self.categories.self_harm_instructions {
329            result.push("self-harm/instructions");
330        }
331        if self.categories.sexual {
332            result.push("sexual");
333        }
334        if self.categories.sexual_minors {
335            result.push("sexual/minors");
336        }
337        if self.categories.violence {
338            result.push("violence");
339        }
340        if self.categories.violence_graphic {
341            result.push("violence/graphic");
342        }
343        if self.categories.illicit {
344            result.push("illicit");
345        }
346        if self.categories.illicit_violent {
347            result.push("illicit/violent");
348        }
349        result
350    }
351}
352
353/// Moderation category flags.
354#[derive(Debug, Clone, Default, Serialize, Deserialize)]
355pub struct ModerationCategories {
356    /// Hate speech.
357    pub hate: bool,
358    /// Hate speech with threatening language.
359    #[serde(rename = "hate/threatening")]
360    pub hate_threatening: bool,
361    /// Harassment.
362    pub harassment: bool,
363    /// Harassment with threatening language.
364    #[serde(rename = "harassment/threatening")]
365    pub harassment_threatening: bool,
366    /// Self-harm content.
367    #[serde(rename = "self-harm")]
368    pub self_harm: bool,
369    /// Self-harm with intent.
370    #[serde(rename = "self-harm/intent")]
371    pub self_harm_intent: bool,
372    /// Self-harm instructions.
373    #[serde(rename = "self-harm/instructions")]
374    pub self_harm_instructions: bool,
375    /// Sexual content.
376    pub sexual: bool,
377    /// Sexual content involving minors.
378    #[serde(rename = "sexual/minors")]
379    pub sexual_minors: bool,
380    /// Violent content.
381    pub violence: bool,
382    /// Graphic violence.
383    #[serde(rename = "violence/graphic")]
384    pub violence_graphic: bool,
385    /// Illicit activity.
386    #[serde(default)]
387    pub illicit: bool,
388    /// Violent illicit activity.
389    #[serde(default, rename = "illicit/violent")]
390    pub illicit_violent: bool,
391}
392
393/// Moderation category confidence scores (0.0 to 1.0).
394#[derive(Debug, Clone, Default, Serialize, Deserialize)]
395pub struct ModerationScores {
396    /// Hate speech score.
397    pub hate: f32,
398    /// Hate/threatening score.
399    #[serde(rename = "hate/threatening")]
400    pub hate_threatening: f32,
401    /// Harassment score.
402    pub harassment: f32,
403    /// Harassment/threatening score.
404    #[serde(rename = "harassment/threatening")]
405    pub harassment_threatening: f32,
406    /// Self-harm score.
407    #[serde(rename = "self-harm")]
408    pub self_harm: f32,
409    /// Self-harm/intent score.
410    #[serde(rename = "self-harm/intent")]
411    pub self_harm_intent: f32,
412    /// Self-harm/instructions score.
413    #[serde(rename = "self-harm/instructions")]
414    pub self_harm_instructions: f32,
415    /// Sexual score.
416    pub sexual: f32,
417    /// Sexual/minors score.
418    #[serde(rename = "sexual/minors")]
419    pub sexual_minors: f32,
420    /// Violence score.
421    pub violence: f32,
422    /// Violence/graphic score.
423    #[serde(rename = "violence/graphic")]
424    pub violence_graphic: f32,
425    /// Illicit score.
426    #[serde(default)]
427    pub illicit: f32,
428    /// Illicit/violent score.
429    #[serde(default, rename = "illicit/violent")]
430    pub illicit_violent: f32,
431}
432
433/// Trait for providers that support content moderation.
434#[async_trait]
435pub trait ModerationProvider: Send + Sync {
436    /// Get the provider name.
437    fn name(&self) -> &str;
438
439    /// Moderate content for policy violations.
440    async fn moderate(&self, request: ModerationRequest) -> Result<ModerationResponse>;
441
442    /// Get the default model for moderation.
443    fn default_moderation_model(&self) -> Option<&str> {
444        None
445    }
446
447    /// Check if the provider supports multi-modal moderation (images).
448    fn supports_multimodal(&self) -> bool {
449        false
450    }
451}
452
453// ============================================================================
454// Classification
455// ============================================================================
456
457/// Request for text classification.
458#[derive(Debug, Clone)]
459pub struct ClassificationRequest {
460    /// The model to use for classification.
461    pub model: String,
462    /// The text to classify.
463    pub input: String,
464    /// The possible labels/classes.
465    pub labels: Vec<String>,
466    /// Whether to allow multiple labels per input.
467    pub multi_label: Option<bool>,
468    /// Optional examples for few-shot classification.
469    pub examples: Option<Vec<ClassificationExample>>,
470}
471
472impl ClassificationRequest {
473    /// Create a new classification request.
474    pub fn new(
475        model: impl Into<String>,
476        input: impl Into<String>,
477        labels: Vec<impl Into<String>>,
478    ) -> Self {
479        Self {
480            model: model.into(),
481            input: input.into(),
482            labels: labels.into_iter().map(|l| l.into()).collect(),
483            multi_label: None,
484            examples: None,
485        }
486    }
487
488    /// Enable multi-label classification.
489    pub fn with_multi_label(mut self) -> Self {
490        self.multi_label = Some(true);
491        self
492    }
493
494    /// Add examples for few-shot classification.
495    pub fn with_examples(mut self, examples: Vec<ClassificationExample>) -> Self {
496        self.examples = Some(examples);
497        self
498    }
499}
500
501/// Example for few-shot classification.
502#[derive(Debug, Clone)]
503pub struct ClassificationExample {
504    /// The example text.
505    pub text: String,
506    /// The label for this example.
507    pub label: String,
508}
509
510impl ClassificationExample {
511    /// Create a new classification example.
512    pub fn new(text: impl Into<String>, label: impl Into<String>) -> Self {
513        Self {
514            text: text.into(),
515            label: label.into(),
516        }
517    }
518}
519
520/// Response from a classification request.
521#[derive(Debug, Clone)]
522pub struct ClassificationResponse {
523    /// Predictions sorted by confidence (highest first).
524    pub predictions: Vec<ClassificationPrediction>,
525    /// The model used for classification.
526    pub model: String,
527}
528
529impl ClassificationResponse {
530    /// Create a new classification response.
531    pub fn new(model: impl Into<String>, predictions: Vec<ClassificationPrediction>) -> Self {
532        Self {
533            model: model.into(),
534            predictions,
535        }
536    }
537
538    /// Get the top prediction.
539    pub fn top(&self) -> Option<&ClassificationPrediction> {
540        self.predictions.first()
541    }
542
543    /// Get the predicted label.
544    pub fn label(&self) -> Option<&str> {
545        self.predictions.first().map(|p| p.label.as_str())
546    }
547
548    /// Get the confidence score for a specific label.
549    pub fn score_for(&self, label: &str) -> Option<f32> {
550        self.predictions
551            .iter()
552            .find(|p| p.label == label)
553            .map(|p| p.score)
554    }
555}
556
557/// A classification prediction with confidence score.
558#[derive(Debug, Clone)]
559pub struct ClassificationPrediction {
560    /// The predicted label.
561    pub label: String,
562    /// Confidence score (0.0 to 1.0).
563    pub score: f32,
564}
565
566impl ClassificationPrediction {
567    /// Create a new classification prediction.
568    pub fn new(label: impl Into<String>, score: f32) -> Self {
569        Self {
570            label: label.into(),
571            score,
572        }
573    }
574}
575
576/// Trait for providers that support text classification.
577#[async_trait]
578pub trait ClassificationProvider: Send + Sync {
579    /// Get the provider name.
580    fn name(&self) -> &str;
581
582    /// Classify text into one or more labels.
583    async fn classify(&self, request: ClassificationRequest) -> Result<ClassificationResponse>;
584
585    /// Get the default model for classification.
586    fn default_classification_model(&self) -> Option<&str> {
587        None
588    }
589
590    /// Get the maximum number of labels supported.
591    fn max_labels(&self) -> usize {
592        100
593    }
594
595    /// Check if few-shot classification is supported.
596    fn supports_few_shot(&self) -> bool {
597        false
598    }
599}
600
601// ============================================================================
602// Model Registry
603// ============================================================================
604
605/// Information about a ranking model.
606#[derive(Debug, Clone)]
607pub struct RankingModelInfo {
608    /// Model ID.
609    pub id: &'static str,
610    /// Provider.
611    pub provider: &'static str,
612    /// Max documents per request.
613    pub max_documents: usize,
614    /// Max query tokens.
615    pub max_query_tokens: usize,
616    /// Price per 1000 searches (USD).
617    pub price_per_1k_searches: f64,
618}
619
620/// Registry of known ranking models.
621pub static RANKING_MODELS: &[RankingModelInfo] = &[
622    // Cohere
623    RankingModelInfo {
624        id: "rerank-english-v3.0",
625        provider: "cohere",
626        max_documents: 1000,
627        max_query_tokens: 2048,
628        price_per_1k_searches: 2.00,
629    },
630    RankingModelInfo {
631        id: "rerank-multilingual-v3.0",
632        provider: "cohere",
633        max_documents: 1000,
634        max_query_tokens: 2048,
635        price_per_1k_searches: 2.00,
636    },
637    // Voyage
638    RankingModelInfo {
639        id: "rerank-2",
640        provider: "voyage",
641        max_documents: 1000,
642        max_query_tokens: 4000,
643        price_per_1k_searches: 0.05,
644    },
645    RankingModelInfo {
646        id: "rerank-lite-2",
647        provider: "voyage",
648        max_documents: 1000,
649        max_query_tokens: 4000,
650        price_per_1k_searches: 0.02,
651    },
652    // Jina
653    RankingModelInfo {
654        id: "jina-reranker-v2-base-multilingual",
655        provider: "jina",
656        max_documents: 500,
657        max_query_tokens: 8192,
658        price_per_1k_searches: 0.02,
659    },
660];
661
662/// Information about a moderation model.
663#[derive(Debug, Clone)]
664pub struct ModerationModelInfo {
665    /// Model ID.
666    pub id: &'static str,
667    /// Provider.
668    pub provider: &'static str,
669    /// Supports multi-modal (images).
670    pub supports_images: bool,
671    /// Price per 1000 requests (USD).
672    pub price_per_1k_requests: f64,
673}
674
675/// Registry of known moderation models.
676pub static MODERATION_MODELS: &[ModerationModelInfo] = &[
677    // OpenAI
678    ModerationModelInfo {
679        id: "omni-moderation-latest",
680        provider: "openai",
681        supports_images: true,
682        price_per_1k_requests: 0.0, // Free
683    },
684    ModerationModelInfo {
685        id: "text-moderation-latest",
686        provider: "openai",
687        supports_images: false,
688        price_per_1k_requests: 0.0, // Free
689    },
690    ModerationModelInfo {
691        id: "text-moderation-stable",
692        provider: "openai",
693        supports_images: false,
694        price_per_1k_requests: 0.0, // Free
695    },
696];
697
698/// Get ranking model info by ID.
699pub fn get_ranking_model_info(model_id: &str) -> Option<&'static RankingModelInfo> {
700    RANKING_MODELS.iter().find(|m| m.id == model_id)
701}
702
703/// Get moderation model info by ID.
704pub fn get_moderation_model_info(model_id: &str) -> Option<&'static ModerationModelInfo> {
705    MODERATION_MODELS.iter().find(|m| m.id == model_id)
706}
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    // Ranking tests
713    #[test]
714    fn test_ranking_request_builder() {
715        let request = RankingRequest::new(
716            "rerank-english-v3.0",
717            "What is the capital?",
718            vec!["Paris is the capital", "Berlin is a city"],
719        )
720        .with_top_k(5)
721        .with_documents();
722
723        assert_eq!(request.model, "rerank-english-v3.0");
724        assert_eq!(request.query, "What is the capital?");
725        assert_eq!(request.documents.len(), 2);
726        assert_eq!(request.top_k, Some(5));
727        assert_eq!(request.return_documents, Some(true));
728    }
729
730    #[test]
731    fn test_ranking_response() {
732        let results = vec![
733            RankedDocument::new(1, 0.95).with_document("Top doc"),
734            RankedDocument::new(0, 0.8),
735        ];
736        let response = RankingResponse::new("rerank-english-v3.0", results);
737
738        assert_eq!(response.top().unwrap().score, 0.95);
739        assert_eq!(response.ranked_indices(), vec![1, 0]);
740    }
741
742    // Moderation tests
743    #[test]
744    fn test_moderation_request() {
745        let request = ModerationRequest::new("omni-moderation-latest", "Some text to check");
746        assert_eq!(request.model, "omni-moderation-latest");
747        assert_eq!(request.input, "Some text to check");
748    }
749
750    #[test]
751    fn test_moderation_response() {
752        let categories = ModerationCategories {
753            hate: true,
754            violence: true,
755            ..Default::default()
756        };
757
758        let response = ModerationResponse::new(true)
759            .with_model("omni-moderation-latest")
760            .with_categories(categories);
761
762        assert!(response.flagged);
763        let flagged = response.flagged_categories();
764        assert!(flagged.contains(&"hate"));
765        assert!(flagged.contains(&"violence"));
766        assert!(!flagged.contains(&"sexual"));
767    }
768
769    // Classification tests
770    #[test]
771    fn test_classification_request_builder() {
772        let request = ClassificationRequest::new(
773            "embed-english-v3.0",
774            "I love this product!",
775            vec!["positive", "negative", "neutral"],
776        )
777        .with_multi_label()
778        .with_examples(vec![
779            ClassificationExample::new("Great!", "positive"),
780            ClassificationExample::new("Terrible", "negative"),
781        ]);
782
783        assert_eq!(request.model, "embed-english-v3.0");
784        assert_eq!(request.input, "I love this product!");
785        assert_eq!(request.labels.len(), 3);
786        assert_eq!(request.multi_label, Some(true));
787        assert_eq!(request.examples.as_ref().unwrap().len(), 2);
788    }
789
790    #[test]
791    fn test_classification_response() {
792        let predictions = vec![
793            ClassificationPrediction::new("positive", 0.92),
794            ClassificationPrediction::new("neutral", 0.06),
795            ClassificationPrediction::new("negative", 0.02),
796        ];
797        let response = ClassificationResponse::new("model", predictions);
798
799        assert_eq!(response.label(), Some("positive"));
800        assert_eq!(response.top().unwrap().score, 0.92);
801        assert_eq!(response.score_for("neutral"), Some(0.06));
802        assert_eq!(response.score_for("unknown"), None);
803    }
804
805    // Registry tests
806    #[test]
807    fn test_ranking_model_registry() {
808        let model = get_ranking_model_info("rerank-english-v3.0");
809        assert!(model.is_some());
810        let model = model.unwrap();
811        assert_eq!(model.provider, "cohere");
812        assert_eq!(model.max_documents, 1000);
813    }
814
815    #[test]
816    fn test_moderation_model_registry() {
817        let model = get_moderation_model_info("omni-moderation-latest");
818        assert!(model.is_some());
819        let model = model.unwrap();
820        assert_eq!(model.provider, "openai");
821        assert!(model.supports_images);
822    }
823}