openai_protocol/
classify.rs

1//! Classify API protocol definitions.
2//!
3//! This module defines the request and response types for the `/v1/classify` API,
4//! which is compatible with vLLM's classification endpoint.
5//!
6//! Classification reuses the embedding backend - the scheduler returns logits as
7//! "embeddings", and the classify layer applies softmax + label mapping.
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use super::common::{GenerationRequest, UsageInfo};
13
14// ============================================================================
15// Classify API
16// ============================================================================
17
18/// Classification request - compatible with vLLM's /v1/classify API
19#[derive(Debug, Clone, Deserialize, Serialize)]
20pub struct ClassifyRequest {
21    /// ID of the model to use
22    pub model: String,
23
24    /// Input can be a string, array of strings, or token IDs
25    /// - Single string: "text to classify"
26    /// - Array of strings: ["text1", "text2"]
27    /// - Token IDs: [1, 2, 3] (advanced usage)
28    pub input: Value,
29
30    /// Optional user identifier
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub user: Option<String>,
33
34    /// SGLang extension: request id for tracking
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub rid: Option<String>,
37
38    /// SGLang extension: request priority
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub priority: Option<i32>,
41
42    /// SGLang extension: enable/disable logging of metrics
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub log_metrics: Option<bool>,
45}
46
47impl GenerationRequest for ClassifyRequest {
48    fn is_stream(&self) -> bool {
49        false // Classification is always non-streaming
50    }
51
52    fn get_model(&self) -> Option<&str> {
53        Some(&self.model)
54    }
55
56    fn extract_text_for_routing(&self) -> String {
57        match &self.input {
58            Value::String(s) => s.clone(),
59            Value::Array(arr) => arr
60                .iter()
61                .filter_map(|v| v.as_str())
62                .collect::<Vec<_>>()
63                .join(" "),
64            _ => String::new(),
65        }
66    }
67}
68
69// ============================================================================
70// Classify Response
71// ============================================================================
72
73/// Single classification result
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ClassifyData {
76    /// Index of this result (for batch requests)
77    pub index: u32,
78    /// Predicted class label (from id2label mapping)
79    pub label: String,
80    /// Probability distribution over all classes (softmax of logits)
81    pub probs: Vec<f32>,
82    /// Number of classes
83    pub num_classes: u32,
84}
85
86/// Classification response
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct ClassifyResponse {
89    /// Unique request ID (format: "classify-{uuid}")
90    pub id: String,
91    /// Always "list"
92    pub object: String,
93    /// Unix timestamp (seconds since epoch)
94    pub created: u64,
95    /// Model name
96    pub model: String,
97    /// Classification results (one per input in batch)
98    pub data: Vec<ClassifyData>,
99    /// Token usage info
100    pub usage: UsageInfo,
101}
102
103impl ClassifyResponse {
104    /// Create a new ClassifyResponse with the given data
105    pub fn new(
106        id: String,
107        model: String,
108        created: u64,
109        data: Vec<ClassifyData>,
110        usage: UsageInfo,
111    ) -> Self {
112        Self {
113            id,
114            object: "list".to_string(),
115            created,
116            model,
117            data,
118            usage,
119        }
120    }
121}