openai_protocol/classify.rs
1//! Classify API protocol definitions.
2//!
3//! This module defines the request and response types for the `/v1/classify` API,
4//! which is compatible with vLLM's classification endpoint.
5//!
6//! Classification reuses the embedding backend - the scheduler returns logits as
7//! "embeddings", and the classify layer applies softmax + label mapping.
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use super::common::{GenerationRequest, UsageInfo};
13
14// ============================================================================
15// Classify API
16// ============================================================================
17
18/// Classification request - compatible with vLLM's /v1/classify API
19#[serde_with::skip_serializing_none]
20#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
21pub struct ClassifyRequest {
22 /// ID of the model to use
23 pub model: String,
24
25 /// Input can be a string, array of strings, or token IDs
26 /// - Single string: "text to classify"
27 /// - Array of strings: ["text1", "text2"]
28 /// - Token IDs: [1, 2, 3] (advanced usage)
29 pub input: Value,
30
31 /// Optional user identifier
32 pub user: Option<String>,
33
34 /// SGLang extension: request id for tracking
35 pub rid: Option<String>,
36
37 /// SGLang extension: request priority
38 pub priority: Option<i32>,
39}
40
41impl GenerationRequest for ClassifyRequest {
42 fn is_stream(&self) -> bool {
43 false // Classification is always non-streaming
44 }
45
46 fn get_model(&self) -> Option<&str> {
47 Some(&self.model)
48 }
49
50 fn extract_text_for_routing(&self) -> String {
51 match &self.input {
52 Value::String(s) => s.clone(),
53 Value::Array(arr) => arr
54 .iter()
55 .filter_map(|v| v.as_str())
56 .collect::<Vec<_>>()
57 .join(" "),
58 _ => String::new(),
59 }
60 }
61}
62
63// ============================================================================
64// Classify Response
65// ============================================================================
66
67/// Single classification result
68#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
69pub struct ClassifyData {
70 /// Index of this result (for batch requests)
71 pub index: u32,
72 /// Predicted class label (from id2label mapping)
73 pub label: String,
74 /// Probability distribution over all classes (softmax of logits)
75 pub probs: Vec<f32>,
76 /// Number of classes
77 pub num_classes: u32,
78}
79
80/// Classification response
81#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
82pub struct ClassifyResponse {
83 /// Unique request ID (format: "classify-{uuid}")
84 pub id: String,
85 /// Always "list"
86 pub object: String,
87 /// Unix timestamp (seconds since epoch)
88 pub created: u64,
89 /// Model name
90 pub model: String,
91 /// Classification results (one per input in batch)
92 pub data: Vec<ClassifyData>,
93 /// Token usage info
94 pub usage: UsageInfo,
95}
96
97impl ClassifyResponse {
98 /// Create a new ClassifyResponse with the given data
99 pub fn new(
100 id: String,
101 model: String,
102 created: u64,
103 data: Vec<ClassifyData>,
104 usage: UsageInfo,
105 ) -> Self {
106 Self {
107 id,
108 object: "list".to_string(),
109 created,
110 model,
111 data,
112 usage,
113 }
114 }
115}