use serde::{Deserialize, Serialize};
use super::UNKNOWN_MODEL_ID;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TokenizeRequest {
#[serde(default = "default_model_name")]
pub model: String,
pub prompt: StringOrArray,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizeResponse {
pub tokens: TokensResult,
pub count: CountResult,
pub char_count: CountResult,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum TokensResult {
Single(Vec<u32>),
Batch(Vec<Vec<u32>>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CountResult {
Single(i32),
Batch(Vec<i32>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DetokenizeRequest {
#[serde(default = "default_model_name")]
pub model: String,
pub tokens: TokensInput,
#[serde(default = "default_true")]
pub skip_special_tokens: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum TokensInput {
Single(Vec<u32>),
Batch(Vec<Vec<u32>>),
}
impl TokensInput {
pub fn is_batch(&self) -> bool {
matches!(self, TokensInput::Batch(_))
}
pub fn sequences(&self) -> Vec<&[u32]> {
match self {
TokensInput::Single(seq) => vec![seq.as_slice()],
TokensInput::Batch(seqs) => seqs.iter().map(|s| s.as_slice()).collect(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetokenizeResponse {
pub text: TextResult,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum TextResult {
Single(String),
Batch(Vec<String>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AddTokenizerRequest {
pub name: String,
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template_path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AddTokenizerResponse {
pub id: String,
pub status: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub vocab_size: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListTokenizersResponse {
pub tokenizers: Vec<TokenizerInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerInfo {
pub id: String,
pub name: String,
pub source: String,
pub vocab_size: usize,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RemoveTokenizerRequest {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RemoveTokenizerResponse {
pub success: bool,
pub message: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrArray {
Single(String),
Array(Vec<String>),
}
impl StringOrArray {
pub fn is_batch(&self) -> bool {
matches!(self, StringOrArray::Array(_))
}
pub fn as_strings(&self) -> Vec<&str> {
match self {
StringOrArray::Single(s) => vec![s.as_str()],
StringOrArray::Array(arr) => arr.iter().map(|s| s.as_str()).collect(),
}
}
}
fn default_model_name() -> String {
UNKNOWN_MODEL_ID.to_string()
}
fn default_true() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenize_request_single() {
let json = r#"{"prompt": "Hello world"}"#;
let req: TokenizeRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.model, "unknown");
assert!(matches!(req.prompt, StringOrArray::Single(_)));
}
#[test]
fn test_tokenize_request_batch() {
let json = r#"{"model": "llama", "prompt": ["Hello", "World"]}"#;
let req: TokenizeRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.model, "llama");
assert!(matches!(req.prompt, StringOrArray::Array(_)));
}
#[test]
fn test_detokenize_request_single() {
let json = r#"{"tokens": [1, 2, 3]}"#;
let req: DetokenizeRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req.tokens, TokensInput::Single(_)));
assert!(req.skip_special_tokens);
}
#[test]
fn test_detokenize_request_batch() {
let json = r#"{"tokens": [[1, 2], [3, 4, 5]], "skip_special_tokens": false}"#;
let req: DetokenizeRequest = serde_json::from_str(json).unwrap();
assert!(matches!(req.tokens, TokensInput::Batch(_)));
assert!(!req.skip_special_tokens);
}
#[test]
fn test_tokenize_response_single() {
let resp = TokenizeResponse {
tokens: TokensResult::Single(vec![1, 2, 3]),
count: CountResult::Single(3),
char_count: CountResult::Single(11),
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("[1,2,3]"));
assert!(json.contains("\"count\":3"));
assert!(json.contains("\"char_count\":11"));
}
#[test]
fn test_tokenize_response_batch() {
let resp = TokenizeResponse {
tokens: TokensResult::Batch(vec![vec![1, 2], vec![3, 4, 5]]),
count: CountResult::Batch(vec![2, 3]),
char_count: CountResult::Batch(vec![5, 5]),
};
let json = serde_json::to_string(&resp).unwrap();
assert!(json.contains("[[1,2],[3,4,5]]"));
assert!(json.contains("[2,3]"));
}
}