openai_protocol/
tokenize.rs

1//! Tokenize and Detokenize API protocol types
2//!
3//! These types mirror the SGLang Python implementation for compatibility.
4//! See: python/sglang/srt/entrypoints/openai/protocol.py
5
6use serde::{Deserialize, Serialize};
7
8use super::UNKNOWN_MODEL_ID;
9
10// ============================================================================
11// Tokenize API
12// ============================================================================
13
14/// Request schema for the /v1/tokenize endpoint
15///
16/// Supports both single string and batch tokenization.
17#[derive(Debug, Clone, Deserialize, Serialize)]
18pub struct TokenizeRequest {
19    /// Model name for tokenizer selection
20    #[serde(default = "default_model_name")]
21    pub model: String,
22
23    /// Text(s) to tokenize - can be a single string or array of strings
24    pub prompt: StringOrArray,
25}
26
27/// Response schema for the /v1/tokenize endpoint
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TokenizeResponse {
30    /// Token IDs - single list for single input, nested list for batch
31    pub tokens: TokensResult,
32
33    /// Token count(s) - single int for single input, list for batch
34    pub count: CountResult,
35
36    /// Character count(s) of input - single int for single input, list for batch
37    pub char_count: CountResult,
38}
39
40/// Token IDs result - either single or batch
41#[derive(Debug, Clone, Serialize, Deserialize)]
42#[serde(untagged)]
43pub enum TokensResult {
44    Single(Vec<u32>),
45    Batch(Vec<Vec<u32>>),
46}
47
48/// Count result - either single or batch
49#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(untagged)]
51pub enum CountResult {
52    Single(i32),
53    Batch(Vec<i32>),
54}
55
56// ============================================================================
57// Detokenize API
58// ============================================================================
59
60/// Request schema for the /v1/detokenize endpoint
61///
62/// Supports both single sequence and batch detokenization.
63#[derive(Debug, Clone, Deserialize, Serialize)]
64pub struct DetokenizeRequest {
65    /// Model name for tokenizer selection
66    #[serde(default = "default_model_name")]
67    pub model: String,
68
69    /// Token IDs to detokenize - single list or batch (list of lists)
70    pub tokens: TokensInput,
71
72    /// Whether to skip special tokens (e.g., padding or EOS) during decoding
73    #[serde(default = "default_true")]
74    pub skip_special_tokens: bool,
75}
76
77/// Token input - either single sequence or batch
78#[derive(Debug, Clone, Deserialize, Serialize)]
79#[serde(untagged)]
80pub enum TokensInput {
81    /// Single sequence of token IDs
82    Single(Vec<u32>),
83    /// Batch of token sequences
84    Batch(Vec<Vec<u32>>),
85}
86
87impl TokensInput {
88    /// Check if this is a batch input
89    pub fn is_batch(&self) -> bool {
90        matches!(self, TokensInput::Batch(_))
91    }
92
93    /// Get the sequences (always returns a vec of vecs for uniform processing)
94    pub fn sequences(&self) -> Vec<&[u32]> {
95        match self {
96            TokensInput::Single(seq) => vec![seq.as_slice()],
97            TokensInput::Batch(seqs) => seqs.iter().map(|s| s.as_slice()).collect(),
98        }
99    }
100}
101
102/// Response schema for the /v1/detokenize endpoint
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct DetokenizeResponse {
105    /// Decoded text - single string for single input, list for batch
106    pub text: TextResult,
107}
108
109/// Text result - either single or batch
110#[derive(Debug, Clone, Serialize, Deserialize)]
111#[serde(untagged)]
112pub enum TextResult {
113    Single(String),
114    Batch(Vec<String>),
115}
116
117// ============================================================================
118// Tokenizer Management API
119// ============================================================================
120
121/// Request schema for adding a tokenizer
122#[derive(Debug, Clone, Deserialize, Serialize)]
123pub struct AddTokenizerRequest {
124    /// Name to register the tokenizer under
125    pub name: String,
126
127    /// Source: either a local path or HuggingFace model ID
128    pub source: String,
129
130    /// Optional path to chat template file
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub chat_template_path: Option<String>,
133}
134
135/// Response schema for adding a tokenizer (async)
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct AddTokenizerResponse {
138    /// Unique identifier for the tokenizer (UUID)
139    pub id: String,
140    /// Status of the request: "pending", "processing", "completed", "failed"
141    pub status: String,
142    pub message: String,
143    /// Vocabulary size of the loaded tokenizer (only set on completion)
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub vocab_size: Option<usize>,
146}
147
148/// Response schema for listing tokenizers
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct ListTokenizersResponse {
151    pub tokenizers: Vec<TokenizerInfo>,
152}
153
154/// Information about a registered tokenizer
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct TokenizerInfo {
157    /// Unique identifier (UUID)
158    pub id: String,
159    /// User-provided name
160    pub name: String,
161    /// Source path or HuggingFace model ID
162    pub source: String,
163    pub vocab_size: usize,
164}
165
166/// Request schema for removing a tokenizer
167#[derive(Debug, Clone, Deserialize, Serialize)]
168pub struct RemoveTokenizerRequest {
169    /// Name of the tokenizer to remove
170    pub name: String,
171}
172
173/// Response schema for removing a tokenizer
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct RemoveTokenizerResponse {
176    pub success: bool,
177    pub message: String,
178}
179
180// ============================================================================
181// Helper Types
182// ============================================================================
183
184/// String or array of strings (for flexible input)
185#[derive(Debug, Clone, Deserialize, Serialize)]
186#[serde(untagged)]
187pub enum StringOrArray {
188    Single(String),
189    Array(Vec<String>),
190}
191
192impl StringOrArray {
193    /// Check if this is a batch (array) input
194    pub fn is_batch(&self) -> bool {
195        matches!(self, StringOrArray::Array(_))
196    }
197
198    /// Get all strings as a slice (converts single to vec)
199    pub fn as_strings(&self) -> Vec<&str> {
200        match self {
201            StringOrArray::Single(s) => vec![s.as_str()],
202            StringOrArray::Array(arr) => arr.iter().map(|s| s.as_str()).collect(),
203        }
204    }
205}
206
207// ============================================================================
208// Default Functions
209// ============================================================================
210
211fn default_model_name() -> String {
212    UNKNOWN_MODEL_ID.to_string()
213}
214
215fn default_true() -> bool {
216    true
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_tokenize_request_single() {
225        let json = r#"{"prompt": "Hello world"}"#;
226        let req: TokenizeRequest = serde_json::from_str(json).unwrap();
227        assert_eq!(req.model, "unknown");
228        assert!(matches!(req.prompt, StringOrArray::Single(_)));
229    }
230
231    #[test]
232    fn test_tokenize_request_batch() {
233        let json = r#"{"model": "llama", "prompt": ["Hello", "World"]}"#;
234        let req: TokenizeRequest = serde_json::from_str(json).unwrap();
235        assert_eq!(req.model, "llama");
236        assert!(matches!(req.prompt, StringOrArray::Array(_)));
237    }
238
239    #[test]
240    fn test_detokenize_request_single() {
241        let json = r#"{"tokens": [1, 2, 3]}"#;
242        let req: DetokenizeRequest = serde_json::from_str(json).unwrap();
243        assert!(matches!(req.tokens, TokensInput::Single(_)));
244        assert!(req.skip_special_tokens);
245    }
246
247    #[test]
248    fn test_detokenize_request_batch() {
249        let json = r#"{"tokens": [[1, 2], [3, 4, 5]], "skip_special_tokens": false}"#;
250        let req: DetokenizeRequest = serde_json::from_str(json).unwrap();
251        assert!(matches!(req.tokens, TokensInput::Batch(_)));
252        assert!(!req.skip_special_tokens);
253    }
254
255    #[test]
256    fn test_tokenize_response_single() {
257        let resp = TokenizeResponse {
258            tokens: TokensResult::Single(vec![1, 2, 3]),
259            count: CountResult::Single(3),
260            char_count: CountResult::Single(11),
261        };
262        let json = serde_json::to_string(&resp).unwrap();
263        assert!(json.contains("[1,2,3]"));
264        assert!(json.contains("\"count\":3"));
265        assert!(json.contains("\"char_count\":11"));
266    }
267
268    #[test]
269    fn test_tokenize_response_batch() {
270        let resp = TokenizeResponse {
271            tokens: TokensResult::Batch(vec![vec![1, 2], vec![3, 4, 5]]),
272            count: CountResult::Batch(vec![2, 3]),
273            char_count: CountResult::Batch(vec![5, 5]),
274        };
275        let json = serde_json::to_string(&resp).unwrap();
276        assert!(json.contains("[[1,2],[3,4,5]]"));
277        assert!(json.contains("[2,3]"));
278    }
279}