hyperi_rustlib/transport/
detect.rs1use std::sync::atomic::{AtomicU8, Ordering};
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[repr(u8)]
50pub enum DetectedFormat {
51 Unknown = 0,
53 Json = 1,
55 MessagePack = 2,
57}
58
59impl From<u8> for DetectedFormat {
60 fn from(v: u8) -> Self {
61 match v {
62 1 => DetectedFormat::Json,
63 2 => DetectedFormat::MessagePack,
64 _ => DetectedFormat::Unknown,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
71pub enum FormatMode {
72 #[default]
74 Auto,
75 ForceJson,
77 ForceMessagePack,
79}
80
81impl FormatMode {
82 #[must_use]
84 pub fn parse(s: &str) -> Option<Self> {
85 match s.to_lowercase().as_str() {
86 "auto" => Some(FormatMode::Auto),
87 "json" => Some(FormatMode::ForceJson),
88 "messagepack" | "msgpack" => Some(FormatMode::ForceMessagePack),
89 _ => None,
90 }
91 }
92}
93
94pub struct FormatDetector {
100 detected_format: AtomicU8,
101 mismatch_count: AtomicU8,
102 mode: FormatMode,
103}
104
105impl FormatDetector {
106 const MISMATCH_THRESHOLD: u8 = 10;
108
109 #[must_use]
111 pub const fn new() -> Self {
112 Self {
113 detected_format: AtomicU8::new(DetectedFormat::Unknown as u8),
114 mismatch_count: AtomicU8::new(0),
115 mode: FormatMode::Auto,
116 }
117 }
118
119 #[must_use]
121 pub fn with_mode(mode: FormatMode) -> Self {
122 let initial_format = match mode {
123 FormatMode::Auto => DetectedFormat::Unknown,
124 FormatMode::ForceJson => DetectedFormat::Json,
125 FormatMode::ForceMessagePack => DetectedFormat::MessagePack,
126 };
127 Self {
128 detected_format: AtomicU8::new(initial_format as u8),
129 mismatch_count: AtomicU8::new(0),
130 mode,
131 }
132 }
133
134 #[must_use]
136 pub fn mode(&self) -> FormatMode {
137 self.mode
138 }
139
140 #[must_use]
142 pub fn format(&self) -> DetectedFormat {
143 DetectedFormat::from(self.detected_format.load(Ordering::Relaxed))
144 }
145
146 #[inline]
151 pub fn check_and_detect(&self, payload: &[u8]) -> Result<DetectedFormat, DetectedFormat> {
152 let detected = detect_format_bytes(payload);
153
154 match self.mode {
156 FormatMode::ForceJson => {
157 return match detected {
158 Some(DetectedFormat::Json) => Ok(DetectedFormat::Json),
159 _ => Err(DetectedFormat::Json), };
161 }
162 FormatMode::ForceMessagePack => {
163 return match detected {
164 Some(DetectedFormat::MessagePack) => Ok(DetectedFormat::MessagePack),
165 _ => Err(DetectedFormat::MessagePack), };
167 }
168 FormatMode::Auto => {} }
170
171 let current = self.format();
173
174 match (current, detected) {
175 (DetectedFormat::Unknown, Some(fmt)) => {
177 self.detected_format.store(fmt as u8, Ordering::Relaxed);
178 self.mismatch_count.store(0, Ordering::Relaxed);
179 Ok(fmt)
180 }
181
182 (_, None) => Err(DetectedFormat::Unknown),
184
185 (expected, Some(actual)) if expected == actual => {
187 self.mismatch_count.store(0, Ordering::Relaxed);
188 Ok(actual)
189 }
190
191 (expected, Some(actual)) => {
193 let count = self.mismatch_count.fetch_add(1, Ordering::Relaxed);
194 if count >= Self::MISMATCH_THRESHOLD {
195 self.detected_format.store(actual as u8, Ordering::Relaxed);
197 self.mismatch_count.store(0, Ordering::Relaxed);
198 #[cfg(feature = "logger")]
199 tracing::warn!(
200 old = ?expected,
201 new = ?actual,
202 "Format changed after {} mismatches, resetting",
203 count
204 );
205 Ok(actual)
206 } else {
207 Err(expected)
209 }
210 }
211 }
212 }
213
214 pub fn reset(&self) {
216 if self.mode == FormatMode::Auto {
217 self.detected_format
218 .store(DetectedFormat::Unknown as u8, Ordering::Relaxed);
219 self.mismatch_count.store(0, Ordering::Relaxed);
220 }
221 }
222}
223
224impl Default for FormatDetector {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230#[inline]
234fn detect_format_bytes(payload: &[u8]) -> Option<DetectedFormat> {
235 let first_byte = *payload.first()?;
237
238 if first_byte == b'{' || first_byte == b'[' {
240 return Some(DetectedFormat::Json);
241 }
242
243 if matches!(first_byte, 0x80..=0x8F | 0xDE | 0xDF | 0x90..=0x9F | 0xDC | 0xDD) {
247 return Some(DetectedFormat::MessagePack);
248 }
249
250 if first_byte.is_ascii_whitespace() {
252 for &b in payload.iter().skip(1) {
253 if !b.is_ascii_whitespace() {
254 return match b {
255 b'{' | b'[' => Some(DetectedFormat::Json),
256 _ => None,
257 };
258 }
259 }
260 return None; }
262
263 None
264}
265
266#[inline]
270#[must_use]
271pub fn detect_format(payload: &[u8]) -> Option<DetectedFormat> {
272 detect_format_bytes(payload)
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_detect_json_object() {
281 assert_eq!(
282 detect_format(b"{\"key\": \"value\"}"),
283 Some(DetectedFormat::Json)
284 );
285 }
286
287 #[test]
288 fn test_detect_json_array() {
289 assert_eq!(detect_format(b"[1, 2, 3]"), Some(DetectedFormat::Json));
290 }
291
292 #[test]
293 fn test_detect_json_with_whitespace() {
294 assert_eq!(
295 detect_format(b" \n\t{\"key\": 1}"),
296 Some(DetectedFormat::Json)
297 );
298 }
299
300 #[test]
301 fn test_detect_msgpack_fixmap() {
302 assert_eq!(
303 detect_format(&[0x81, 0xA3, b'k', b'e', b'y']),
304 Some(DetectedFormat::MessagePack)
305 );
306 }
307
308 #[test]
309 fn test_detect_msgpack_map16() {
310 assert_eq!(
311 detect_format(&[0xDE, 0x00, 0x01]),
312 Some(DetectedFormat::MessagePack)
313 );
314 }
315
316 #[test]
317 fn test_detect_empty() {
318 assert_eq!(detect_format(b""), None);
319 }
320
321 #[test]
322 fn test_detect_whitespace_only() {
323 assert_eq!(detect_format(b" \n\t "), None);
324 }
325
326 #[test]
327 fn test_detect_unknown() {
328 assert_eq!(detect_format(b"hello"), None);
329 }
330
331 #[test]
332 fn test_format_detector_auto_detect() {
333 let detector = FormatDetector::new();
334 assert_eq!(detector.format(), DetectedFormat::Unknown);
335
336 let result = detector.check_and_detect(b"{\"key\": 1}");
338 assert_eq!(result, Ok(DetectedFormat::Json));
339 assert_eq!(detector.format(), DetectedFormat::Json);
340
341 assert_eq!(
343 detector.check_and_detect(b"{\"key\": 2}"),
344 Ok(DetectedFormat::Json)
345 );
346
347 assert_eq!(
349 detector.check_and_detect(&[0x81, 0xA1, b'k']),
350 Err(DetectedFormat::Json)
351 );
352 }
353
354 #[test]
355 fn test_format_detector_mismatch_reset() {
356 let detector = FormatDetector::new();
357
358 detector.check_and_detect(b"{\"key\": 1}").unwrap();
360
361 for _ in 0..11 {
363 let _ = detector.check_and_detect(&[0x81, 0xA1, b'k']);
364 }
365
366 assert_eq!(detector.format(), DetectedFormat::MessagePack);
368 }
369
370 #[test]
371 fn test_force_json_mode() {
372 let detector = FormatDetector::with_mode(FormatMode::ForceJson);
373 assert_eq!(detector.mode(), FormatMode::ForceJson);
374 assert_eq!(detector.format(), DetectedFormat::Json);
375
376 assert_eq!(
378 detector.check_and_detect(b"{\"key\": 1}"),
379 Ok(DetectedFormat::Json)
380 );
381
382 assert_eq!(
384 detector.check_and_detect(&[0x81, 0xA1, b'k']),
385 Err(DetectedFormat::Json)
386 );
387
388 assert_eq!(
390 detector.check_and_detect(b"hello"),
391 Err(DetectedFormat::Json)
392 );
393
394 assert_eq!(detector.format(), DetectedFormat::Json);
396 }
397
398 #[test]
399 fn test_force_msgpack_mode() {
400 let detector = FormatDetector::with_mode(FormatMode::ForceMessagePack);
401 assert_eq!(detector.mode(), FormatMode::ForceMessagePack);
402 assert_eq!(detector.format(), DetectedFormat::MessagePack);
403
404 assert_eq!(
406 detector.check_and_detect(&[0x81, 0xA1, b'k']),
407 Ok(DetectedFormat::MessagePack)
408 );
409
410 assert_eq!(
412 detector.check_and_detect(b"{\"key\": 1}"),
413 Err(DetectedFormat::MessagePack)
414 );
415
416 assert_eq!(detector.format(), DetectedFormat::MessagePack);
418 }
419
420 #[test]
421 fn test_force_mode_no_reset() {
422 let detector = FormatDetector::with_mode(FormatMode::ForceJson);
423
424 for _ in 0..20 {
426 let _ = detector.check_and_detect(&[0x81, 0xA1, b'k']);
427 }
428
429 assert_eq!(detector.format(), DetectedFormat::Json);
431 }
432
433 #[test]
434 fn test_format_mode_from_str() {
435 assert_eq!(FormatMode::parse("auto"), Some(FormatMode::Auto));
436 assert_eq!(FormatMode::parse("AUTO"), Some(FormatMode::Auto));
437 assert_eq!(FormatMode::parse("json"), Some(FormatMode::ForceJson));
438 assert_eq!(FormatMode::parse("JSON"), Some(FormatMode::ForceJson));
439 assert_eq!(
440 FormatMode::parse("messagepack"),
441 Some(FormatMode::ForceMessagePack)
442 );
443 assert_eq!(
444 FormatMode::parse("msgpack"),
445 Some(FormatMode::ForceMessagePack)
446 );
447 assert_eq!(FormatMode::parse("invalid"), None);
448 }
449}