use serde::{Deserialize, Serialize};
use std::fmt;
use std::str::FromStr;
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WorkerType {
Prefill,
Decode,
Encode,
Aggregated,
}
impl WorkerType {
pub fn as_str(&self) -> &'static str {
match self {
WorkerType::Prefill => "prefill",
WorkerType::Decode => "decode",
WorkerType::Encode => "encode",
WorkerType::Aggregated => "aggregated",
}
}
}
impl fmt::Display for WorkerType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseWorkerTypeError {
pub token: String,
}
impl fmt::Display for ParseWorkerTypeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "unrecognized worker_type: {:?}", self.token)
}
}
impl std::error::Error for ParseWorkerTypeError {}
impl FromStr for WorkerType {
type Err = ParseWorkerTypeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"prefill" => Ok(WorkerType::Prefill),
"decode" => Ok(WorkerType::Decode),
"encode" => Ok(WorkerType::Encode),
"aggregated" => Ok(WorkerType::Aggregated),
_ => Err(ParseWorkerTypeError {
token: s.to_string(),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_canonical_lowercase() {
assert_eq!(WorkerType::Prefill.to_string(), "prefill");
assert_eq!(WorkerType::Decode.to_string(), "decode");
assert_eq!(WorkerType::Encode.to_string(), "encode");
assert_eq!(WorkerType::Aggregated.to_string(), "aggregated");
}
#[test]
fn from_str_accepts_canonical_names() {
assert_eq!(
"prefill".parse::<WorkerType>().unwrap(),
WorkerType::Prefill
);
assert_eq!("decode".parse::<WorkerType>().unwrap(), WorkerType::Decode);
assert_eq!("encode".parse::<WorkerType>().unwrap(), WorkerType::Encode);
assert_eq!(
"aggregated".parse::<WorkerType>().unwrap(),
WorkerType::Aggregated
);
}
#[test]
fn from_str_case_insensitive_and_whitespace_tolerant() {
assert_eq!(
"PREFILL".parse::<WorkerType>().unwrap(),
WorkerType::Prefill
);
assert_eq!(
" Decode ".parse::<WorkerType>().unwrap(),
WorkerType::Decode
);
}
#[test]
fn from_str_rejects_unknown_and_empty() {
assert!("wibble".parse::<WorkerType>().is_err());
assert!("".parse::<WorkerType>().is_err());
assert!("prefill|decode".parse::<WorkerType>().is_err());
}
#[test]
fn display_from_str_round_trip() {
for wt in [
WorkerType::Prefill,
WorkerType::Decode,
WorkerType::Encode,
WorkerType::Aggregated,
] {
assert_eq!(wt.to_string().parse::<WorkerType>().unwrap(), wt);
}
}
#[test]
fn serde_json_wire_format_is_canonical_lowercase() {
assert_eq!(
serde_json::to_string(&WorkerType::Prefill).unwrap(),
"\"prefill\""
);
assert_eq!(
serde_json::to_string(&WorkerType::Decode).unwrap(),
"\"decode\""
);
assert_eq!(
serde_json::to_string(&WorkerType::Encode).unwrap(),
"\"encode\""
);
assert_eq!(
serde_json::to_string(&WorkerType::Aggregated).unwrap(),
"\"aggregated\""
);
}
#[test]
fn serde_json_round_trip() {
for wt in [
WorkerType::Prefill,
WorkerType::Decode,
WorkerType::Encode,
WorkerType::Aggregated,
] {
let j = serde_json::to_string(&wt).unwrap();
let back: WorkerType = serde_json::from_str(&j).unwrap();
assert_eq!(back, wt);
}
}
#[test]
fn serde_json_rejects_unknown_value() {
assert!(serde_json::from_str::<WorkerType>("\"wibble\"").is_err());
}
}