use std::fmt;
use std::str::FromStr;
use serde::{Deserialize, Serialize};
use crate::error::{BackendError, DynamoError, ErrorType};
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, clap::ValueEnum)]
#[serde(rename_all = "lowercase")]
pub enum DisaggregationMode {
#[default]
#[serde(alias = "agg", alias = "aggregated")]
#[value(name = "agg", alias = "aggregated")]
Aggregated,
Prefill,
Decode,
}
impl DisaggregationMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Aggregated => "agg",
Self::Prefill => "prefill",
Self::Decode => "decode",
}
}
pub fn is_prefill(&self) -> bool {
matches!(self, Self::Prefill)
}
pub fn is_decode(&self) -> bool {
matches!(self, Self::Decode)
}
}
impl fmt::Display for DisaggregationMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for DisaggregationMode {
type Err = DynamoError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"agg" | "aggregated" => Ok(Self::Aggregated),
"prefill" => Ok(Self::Prefill),
"decode" => Ok(Self::Decode),
other => Err(DynamoError::builder()
.error_type(ErrorType::Backend(BackendError::InvalidArgument))
.message(format!(
"unknown disaggregation mode '{other}' (expected one of: agg, prefill, decode)"
))
.build()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_canonical_slugs() {
assert_eq!(
"agg".parse::<DisaggregationMode>().unwrap(),
DisaggregationMode::Aggregated
);
assert_eq!(
"prefill".parse::<DisaggregationMode>().unwrap(),
DisaggregationMode::Prefill
);
assert_eq!(
"decode".parse::<DisaggregationMode>().unwrap(),
DisaggregationMode::Decode
);
}
#[test]
fn parses_aggregated_alias_and_is_case_insensitive() {
assert_eq!(
"AGGREGATED".parse::<DisaggregationMode>().unwrap(),
DisaggregationMode::Aggregated
);
assert_eq!(
" Prefill ".parse::<DisaggregationMode>().unwrap(),
DisaggregationMode::Prefill
);
}
#[test]
fn rejects_unknown() {
let e = "encode".parse::<DisaggregationMode>().unwrap_err();
assert_eq!(
e.error_type(),
ErrorType::Backend(BackendError::InvalidArgument)
);
}
#[test]
fn display_round_trips_through_from_str() {
for mode in [
DisaggregationMode::Aggregated,
DisaggregationMode::Prefill,
DisaggregationMode::Decode,
] {
let printed = mode.to_string();
assert_eq!(printed.parse::<DisaggregationMode>().unwrap(), mode);
}
}
#[test]
fn predicates_match_variants() {
assert!(DisaggregationMode::Prefill.is_prefill());
assert!(!DisaggregationMode::Prefill.is_decode());
assert!(DisaggregationMode::Decode.is_decode());
assert!(!DisaggregationMode::Aggregated.is_prefill());
assert!(!DisaggregationMode::Aggregated.is_decode());
}
#[test]
fn default_is_aggregated() {
assert_eq!(
DisaggregationMode::default(),
DisaggregationMode::Aggregated
);
}
#[test]
fn serde_round_trip_uses_lowercase() {
let json = serde_json::to_string(&DisaggregationMode::Prefill).unwrap();
assert_eq!(json, "\"prefill\"");
let back: DisaggregationMode = serde_json::from_str(&json).unwrap();
assert_eq!(back, DisaggregationMode::Prefill);
}
}