use std::sync::atomic::{AtomicU8, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum DetectedFormat {
Unknown = 0,
Json = 1,
MessagePack = 2,
}
impl From<u8> for DetectedFormat {
fn from(v: u8) -> Self {
match v {
1 => DetectedFormat::Json,
2 => DetectedFormat::MessagePack,
_ => DetectedFormat::Unknown,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FormatMode {
#[default]
Auto,
ForceJson,
ForceMessagePack,
}
impl FormatMode {
#[must_use]
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"auto" => Some(FormatMode::Auto),
"json" => Some(FormatMode::ForceJson),
"messagepack" | "msgpack" => Some(FormatMode::ForceMessagePack),
_ => None,
}
}
}
pub struct FormatDetector {
detected_format: AtomicU8,
mismatch_count: AtomicU8,
mode: FormatMode,
}
impl FormatDetector {
const MISMATCH_THRESHOLD: u8 = 10;
#[must_use]
pub const fn new() -> Self {
Self {
detected_format: AtomicU8::new(DetectedFormat::Unknown as u8),
mismatch_count: AtomicU8::new(0),
mode: FormatMode::Auto,
}
}
#[must_use]
pub fn with_mode(mode: FormatMode) -> Self {
let initial_format = match mode {
FormatMode::Auto => DetectedFormat::Unknown,
FormatMode::ForceJson => DetectedFormat::Json,
FormatMode::ForceMessagePack => DetectedFormat::MessagePack,
};
Self {
detected_format: AtomicU8::new(initial_format as u8),
mismatch_count: AtomicU8::new(0),
mode,
}
}
#[must_use]
pub fn mode(&self) -> FormatMode {
self.mode
}
#[must_use]
pub fn format(&self) -> DetectedFormat {
DetectedFormat::from(self.detected_format.load(Ordering::Relaxed))
}
#[inline]
pub fn check_and_detect(&self, payload: &[u8]) -> Result<DetectedFormat, DetectedFormat> {
let detected = detect_format_bytes(payload);
match self.mode {
FormatMode::ForceJson => {
return match detected {
Some(DetectedFormat::Json) => Ok(DetectedFormat::Json),
_ => Err(DetectedFormat::Json), };
}
FormatMode::ForceMessagePack => {
return match detected {
Some(DetectedFormat::MessagePack) => Ok(DetectedFormat::MessagePack),
_ => Err(DetectedFormat::MessagePack), };
}
FormatMode::Auto => {} }
let current = self.format();
match (current, detected) {
(DetectedFormat::Unknown, Some(fmt)) => {
self.detected_format.store(fmt as u8, Ordering::Relaxed);
self.mismatch_count.store(0, Ordering::Relaxed);
Ok(fmt)
}
(_, None) => Err(DetectedFormat::Unknown),
(expected, Some(actual)) if expected == actual => {
self.mismatch_count.store(0, Ordering::Relaxed);
Ok(actual)
}
(expected, Some(actual)) => {
let count = self.mismatch_count.fetch_add(1, Ordering::Relaxed);
if count >= Self::MISMATCH_THRESHOLD {
self.detected_format.store(actual as u8, Ordering::Relaxed);
self.mismatch_count.store(0, Ordering::Relaxed);
#[cfg(feature = "logger")]
tracing::warn!(
old = ?expected,
new = ?actual,
"Format changed after {} mismatches, resetting",
count
);
Ok(actual)
} else {
Err(expected)
}
}
}
}
pub fn reset(&self) {
if self.mode == FormatMode::Auto {
self.detected_format
.store(DetectedFormat::Unknown as u8, Ordering::Relaxed);
self.mismatch_count.store(0, Ordering::Relaxed);
}
}
}
impl Default for FormatDetector {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn detect_format_bytes(payload: &[u8]) -> Option<DetectedFormat> {
let first_byte = *payload.first()?;
if first_byte == b'{' || first_byte == b'[' {
return Some(DetectedFormat::Json);
}
if matches!(first_byte, 0x80..=0x8F | 0xDE | 0xDF | 0x90..=0x9F | 0xDC | 0xDD) {
return Some(DetectedFormat::MessagePack);
}
if first_byte.is_ascii_whitespace() {
for &b in payload.iter().skip(1) {
if !b.is_ascii_whitespace() {
return match b {
b'{' | b'[' => Some(DetectedFormat::Json),
_ => None,
};
}
}
return None; }
None
}
#[inline]
#[must_use]
pub fn detect_format(payload: &[u8]) -> Option<DetectedFormat> {
detect_format_bytes(payload)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_json_object() {
assert_eq!(
detect_format(b"{\"key\": \"value\"}"),
Some(DetectedFormat::Json)
);
}
#[test]
fn test_detect_json_array() {
assert_eq!(detect_format(b"[1, 2, 3]"), Some(DetectedFormat::Json));
}
#[test]
fn test_detect_json_with_whitespace() {
assert_eq!(
detect_format(b" \n\t{\"key\": 1}"),
Some(DetectedFormat::Json)
);
}
#[test]
fn test_detect_msgpack_fixmap() {
assert_eq!(
detect_format(&[0x81, 0xA3, b'k', b'e', b'y']),
Some(DetectedFormat::MessagePack)
);
}
#[test]
fn test_detect_msgpack_map16() {
assert_eq!(
detect_format(&[0xDE, 0x00, 0x01]),
Some(DetectedFormat::MessagePack)
);
}
#[test]
fn test_detect_empty() {
assert_eq!(detect_format(b""), None);
}
#[test]
fn test_detect_whitespace_only() {
assert_eq!(detect_format(b" \n\t "), None);
}
#[test]
fn test_detect_unknown() {
assert_eq!(detect_format(b"hello"), None);
}
#[test]
fn test_format_detector_auto_detect() {
let detector = FormatDetector::new();
assert_eq!(detector.format(), DetectedFormat::Unknown);
let result = detector.check_and_detect(b"{\"key\": 1}");
assert_eq!(result, Ok(DetectedFormat::Json));
assert_eq!(detector.format(), DetectedFormat::Json);
assert_eq!(
detector.check_and_detect(b"{\"key\": 2}"),
Ok(DetectedFormat::Json)
);
assert_eq!(
detector.check_and_detect(&[0x81, 0xA1, b'k']),
Err(DetectedFormat::Json)
);
}
#[test]
fn test_format_detector_mismatch_reset() {
let detector = FormatDetector::new();
detector.check_and_detect(b"{\"key\": 1}").unwrap();
for _ in 0..11 {
let _ = detector.check_and_detect(&[0x81, 0xA1, b'k']);
}
assert_eq!(detector.format(), DetectedFormat::MessagePack);
}
#[test]
fn test_force_json_mode() {
let detector = FormatDetector::with_mode(FormatMode::ForceJson);
assert_eq!(detector.mode(), FormatMode::ForceJson);
assert_eq!(detector.format(), DetectedFormat::Json);
assert_eq!(
detector.check_and_detect(b"{\"key\": 1}"),
Ok(DetectedFormat::Json)
);
assert_eq!(
detector.check_and_detect(&[0x81, 0xA1, b'k']),
Err(DetectedFormat::Json)
);
assert_eq!(
detector.check_and_detect(b"hello"),
Err(DetectedFormat::Json)
);
assert_eq!(detector.format(), DetectedFormat::Json);
}
#[test]
fn test_force_msgpack_mode() {
let detector = FormatDetector::with_mode(FormatMode::ForceMessagePack);
assert_eq!(detector.mode(), FormatMode::ForceMessagePack);
assert_eq!(detector.format(), DetectedFormat::MessagePack);
assert_eq!(
detector.check_and_detect(&[0x81, 0xA1, b'k']),
Ok(DetectedFormat::MessagePack)
);
assert_eq!(
detector.check_and_detect(b"{\"key\": 1}"),
Err(DetectedFormat::MessagePack)
);
assert_eq!(detector.format(), DetectedFormat::MessagePack);
}
#[test]
fn test_force_mode_no_reset() {
let detector = FormatDetector::with_mode(FormatMode::ForceJson);
for _ in 0..20 {
let _ = detector.check_and_detect(&[0x81, 0xA1, b'k']);
}
assert_eq!(detector.format(), DetectedFormat::Json);
}
#[test]
fn test_format_mode_from_str() {
assert_eq!(FormatMode::parse("auto"), Some(FormatMode::Auto));
assert_eq!(FormatMode::parse("AUTO"), Some(FormatMode::Auto));
assert_eq!(FormatMode::parse("json"), Some(FormatMode::ForceJson));
assert_eq!(FormatMode::parse("JSON"), Some(FormatMode::ForceJson));
assert_eq!(
FormatMode::parse("messagepack"),
Some(FormatMode::ForceMessagePack)
);
assert_eq!(
FormatMode::parse("msgpack"),
Some(FormatMode::ForceMessagePack)
);
assert_eq!(FormatMode::parse("invalid"), None);
}
}