Skip to main content

openai_protocol/
classify.rs

1//! Classify API protocol definitions.
2//!
3//! This module defines the request and response types for the `/v1/classify` API,
4//! which is compatible with vLLM's classification endpoint.
5//!
6//! Classification reuses the embedding backend - the scheduler returns logits as
7//! "embeddings", and the classify layer applies softmax + label mapping.
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use super::common::{GenerationRequest, UsageInfo};
13
14// ============================================================================
15// Classify API
16// ============================================================================
17
18/// Classification request - compatible with vLLM's /v1/classify API
19#[serde_with::skip_serializing_none]
20#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
21pub struct ClassifyRequest {
22    /// ID of the model to use
23    pub model: String,
24
25    /// Input can be a string, array of strings, or token IDs
26    /// - Single string: "text to classify"
27    /// - Array of strings: ["text1", "text2"]
28    /// - Token IDs: [1, 2, 3] (advanced usage)
29    pub input: Value,
30
31    /// Optional user identifier
32    pub user: Option<String>,
33
34    /// SGLang extension: request id for tracking
35    pub rid: Option<String>,
36
37    /// SGLang extension: request priority
38    pub priority: Option<i32>,
39
40    /// SGLang extension: enable/disable logging of metrics
41    pub log_metrics: Option<bool>,
42}
43
44impl GenerationRequest for ClassifyRequest {
45    fn is_stream(&self) -> bool {
46        false // Classification is always non-streaming
47    }
48
49    fn get_model(&self) -> Option<&str> {
50        Some(&self.model)
51    }
52
53    fn extract_text_for_routing(&self) -> String {
54        match &self.input {
55            Value::String(s) => s.clone(),
56            Value::Array(arr) => arr
57                .iter()
58                .filter_map(|v| v.as_str())
59                .collect::<Vec<_>>()
60                .join(" "),
61            _ => String::new(),
62        }
63    }
64}
65
66// ============================================================================
67// Classify Response
68// ============================================================================
69
70/// Single classification result
71#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
72pub struct ClassifyData {
73    /// Index of this result (for batch requests)
74    pub index: u32,
75    /// Predicted class label (from id2label mapping)
76    pub label: String,
77    /// Probability distribution over all classes (softmax of logits)
78    pub probs: Vec<f32>,
79    /// Number of classes
80    pub num_classes: u32,
81}
82
83/// Classification response
84#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
85pub struct ClassifyResponse {
86    /// Unique request ID (format: "classify-{uuid}")
87    pub id: String,
88    /// Always "list"
89    pub object: String,
90    /// Unix timestamp (seconds since epoch)
91    pub created: u64,
92    /// Model name
93    pub model: String,
94    /// Classification results (one per input in batch)
95    pub data: Vec<ClassifyData>,
96    /// Token usage info
97    pub usage: UsageInfo,
98}
99
100impl ClassifyResponse {
101    /// Create a new ClassifyResponse with the given data
102    pub fn new(
103        id: String,
104        model: String,
105        created: u64,
106        data: Vec<ClassifyData>,
107        usage: UsageInfo,
108    ) -> Self {
109        Self {
110            id,
111            object: "list".to_string(),
112            created,
113            model,
114            data,
115            usage,
116        }
117    }
118}