Skip to main content

openai_protocol/
tokenize.rs

1//! Tokenize and Detokenize API protocol types
2//!
3//! These types mirror the SGLang Python implementation for compatibility.
4//! See: python/sglang/srt/entrypoints/openai/protocol.py
5
6use serde::{Deserialize, Serialize};
7
8// ============================================================================
9// Tokenize API
10// ============================================================================
11
12/// Request schema for the /v1/tokenize endpoint
13///
14/// Supports both single string and batch tokenization.
15#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
16pub struct TokenizeRequest {
17    /// Model name for tokenizer selection
18    pub model: String,
19
20    /// Text(s) to tokenize - can be a single string or array of strings
21    pub prompt: StringOrArray,
22}
23
24/// Response schema for the /v1/tokenize endpoint
25#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
26pub struct TokenizeResponse {
27    /// Token IDs - single list for single input, nested list for batch
28    pub tokens: TokensResult,
29
30    /// Token count(s) - single int for single input, list for batch
31    pub count: CountResult,
32
33    /// Character count(s) of input - single int for single input, list for batch
34    pub char_count: CountResult,
35}
36
37/// Token IDs result - either single or batch
38#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
39#[serde(untagged)]
40pub enum TokensResult {
41    Single(Vec<u32>),
42    Batch(Vec<Vec<u32>>),
43}
44
45/// Count result - either single or batch
46#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
47#[serde(untagged)]
48pub enum CountResult {
49    Single(i32),
50    Batch(Vec<i32>),
51}
52
53// ============================================================================
54// Detokenize API
55// ============================================================================
56
57/// Request schema for the /v1/detokenize endpoint
58///
59/// Supports both single sequence and batch detokenization.
60#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
61pub struct DetokenizeRequest {
62    /// Model name for tokenizer selection
63    pub model: String,
64
65    /// Token IDs to detokenize - single list or batch (list of lists)
66    pub tokens: TokensInput,
67
68    /// Whether to skip special tokens (e.g., padding or EOS) during decoding
69    #[serde(default = "default_true")]
70    pub skip_special_tokens: bool,
71}
72
73/// Token input - either single sequence or batch
74#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
75#[serde(untagged)]
76pub enum TokensInput {
77    /// Single sequence of token IDs
78    Single(Vec<u32>),
79    /// Batch of token sequences
80    Batch(Vec<Vec<u32>>),
81}
82
83impl TokensInput {
84    /// Check if this is a batch input
85    pub fn is_batch(&self) -> bool {
86        matches!(self, TokensInput::Batch(_))
87    }
88
89    /// Get the sequences (always returns a vec of vecs for uniform processing)
90    pub fn sequences(&self) -> Vec<&[u32]> {
91        match self {
92            TokensInput::Single(seq) => vec![seq.as_slice()],
93            TokensInput::Batch(seqs) => seqs.iter().map(|s| s.as_slice()).collect(),
94        }
95    }
96}
97
98/// Response schema for the /v1/detokenize endpoint
99#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
100pub struct DetokenizeResponse {
101    /// Decoded text - single string for single input, list for batch
102    pub text: TextResult,
103}
104
105/// Text result - either single or batch
106#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
107#[serde(untagged)]
108pub enum TextResult {
109    Single(String),
110    Batch(Vec<String>),
111}
112
113// ============================================================================
114// Tokenizer Management API
115// ============================================================================
116
117/// Request schema for adding a tokenizer
118#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
119pub struct AddTokenizerRequest {
120    /// Name to register the tokenizer under
121    pub name: String,
122
123    /// Source: either a local path or HuggingFace model ID
124    pub source: String,
125
126    /// Optional path to chat template file
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub chat_template_path: Option<String>,
129}
130
131/// Response schema for adding a tokenizer (async)
132#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
133pub struct AddTokenizerResponse {
134    /// Unique identifier for the tokenizer (UUID)
135    pub id: String,
136    /// Status of the request: "pending", "processing", "completed", "failed"
137    pub status: String,
138    pub message: String,
139    /// Vocabulary size of the loaded tokenizer (only set on completion)
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub vocab_size: Option<usize>,
142}
143
144/// Response schema for listing tokenizers
145#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
146pub struct ListTokenizersResponse {
147    pub tokenizers: Vec<TokenizerInfo>,
148}
149
150/// Information about a registered tokenizer
151#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
152pub struct TokenizerInfo {
153    /// Unique identifier (UUID)
154    pub id: String,
155    /// User-provided name
156    pub name: String,
157    /// Source path or HuggingFace model ID
158    pub source: String,
159    pub vocab_size: usize,
160}
161
162/// Request schema for removing a tokenizer
163#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
164pub struct RemoveTokenizerRequest {
165    /// Name of the tokenizer to remove
166    pub name: String,
167}
168
169/// Response schema for removing a tokenizer
170#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
171pub struct RemoveTokenizerResponse {
172    pub success: bool,
173    pub message: String,
174}
175
176// ============================================================================
177// Helper Types
178// ============================================================================
179
180/// String or array of strings (for flexible input)
181#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)]
182#[schemars(rename = "TokenizeStringOrArray")]
183#[serde(untagged)]
184pub enum StringOrArray {
185    Single(String),
186    Array(Vec<String>),
187}
188
189impl StringOrArray {
190    /// Check if this is a batch (array) input
191    pub fn is_batch(&self) -> bool {
192        matches!(self, StringOrArray::Array(_))
193    }
194
195    /// Get all strings as a slice (converts single to vec)
196    pub fn as_strings(&self) -> Vec<&str> {
197        match self {
198            StringOrArray::Single(s) => vec![s.as_str()],
199            StringOrArray::Array(arr) => arr.iter().map(|s| s.as_str()).collect(),
200        }
201    }
202}
203
204// ============================================================================
205// Default Functions
206// ============================================================================
207
208fn default_true() -> bool {
209    true
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_tokenize_request_requires_model() {
218        let json = r#"{"prompt": "Hello world"}"#;
219        let result = serde_json::from_str::<TokenizeRequest>(json);
220        assert!(result.is_err(), "Should fail without model field");
221    }
222
223    #[test]
224    fn test_tokenize_request_batch() {
225        let json = r#"{"model": "llama", "prompt": ["Hello", "World"]}"#;
226        let req: TokenizeRequest = serde_json::from_str(json).unwrap();
227        assert_eq!(req.model, "llama");
228        assert!(matches!(req.prompt, StringOrArray::Array(_)));
229    }
230
231    #[test]
232    fn test_detokenize_request_requires_model() {
233        let json = r#"{"tokens": [1, 2, 3]}"#;
234        let result = serde_json::from_str::<DetokenizeRequest>(json);
235        assert!(result.is_err(), "Should fail without model field");
236    }
237
238    #[test]
239    fn test_detokenize_request_single() {
240        let json = r#"{"model": "test-model", "tokens": [1, 2, 3]}"#;
241        let req: DetokenizeRequest = serde_json::from_str(json).unwrap();
242        assert!(matches!(req.tokens, TokensInput::Single(_)));
243        assert!(req.skip_special_tokens);
244    }
245
246    #[test]
247    fn test_detokenize_request_batch() {
248        let json = r#"{"model": "test-model", "tokens": [[1, 2], [3, 4, 5]], "skip_special_tokens": false}"#;
249        let req: DetokenizeRequest = serde_json::from_str(json).unwrap();
250        assert!(matches!(req.tokens, TokensInput::Batch(_)));
251        assert!(!req.skip_special_tokens);
252    }
253
254    #[test]
255    fn test_tokenize_response_single() {
256        let resp = TokenizeResponse {
257            tokens: TokensResult::Single(vec![1, 2, 3]),
258            count: CountResult::Single(3),
259            char_count: CountResult::Single(11),
260        };
261        let json = serde_json::to_string(&resp).unwrap();
262        assert!(json.contains("[1,2,3]"));
263        assert!(json.contains("\"count\":3"));
264        assert!(json.contains("\"char_count\":11"));
265    }
266
267    #[test]
268    fn test_tokenize_response_batch() {
269        let resp = TokenizeResponse {
270            tokens: TokensResult::Batch(vec![vec![1, 2], vec![3, 4, 5]]),
271            count: CountResult::Batch(vec![2, 3]),
272            char_count: CountResult::Batch(vec![5, 5]),
273        };
274        let json = serde_json::to_string(&resp).unwrap();
275        assert!(json.contains("[[1,2],[3,4,5]]"));
276        assert!(json.contains("[2,3]"));
277    }
278}