use serde::de::DeserializeOwned;
#[cfg(feature = "simd-json")]
use std::cell::RefCell;
#[cfg(feature = "simd-json")]
thread_local! {
static SIMD_BUFFERS: RefCell<simd_json::Buffers> =
RefCell::new(simd_json::Buffers::new(8 * 1024));
static FRAME_BYTES: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
}
pub enum WsFrame<T> {
Single(T),
Array(Vec<T>),
}
impl<T> WsFrame<T> {
pub fn for_each<F: FnMut(T)>(self, mut f: F) {
match self {
Self::Single(item) => f(item),
Self::Array(items) => items.into_iter().for_each(f),
}
}
}
#[cfg(feature = "simd-json")]
pub const SIMD_CROSSOVER_BYTES: usize = 512;
pub fn decode_frame<T: DeserializeOwned>(text: &str) -> Option<WsFrame<T>> {
#[cfg(feature = "simd-json")]
if text.len() >= SIMD_CROSSOVER_BYTES {
return FRAME_BYTES.with(|cell_bytes| {
SIMD_BUFFERS.with(|cell_buf| {
let mut bytes = cell_bytes.borrow_mut();
let mut buffers = cell_buf.borrow_mut();
bytes.clear();
bytes.extend_from_slice(text.as_bytes());
let head = bytes.iter().find(|&&b| !b.is_ascii_whitespace()).copied()?;
if head == b'[' {
simd_json::serde::from_slice_with_buffers::<Vec<T>>(&mut bytes, &mut buffers)
.ok()
.map(WsFrame::Array)
} else {
simd_json::serde::from_slice_with_buffers::<T>(&mut bytes, &mut buffers)
.ok()
.map(WsFrame::Single)
}
})
});
}
let trimmed = text.trim_start();
if trimmed.starts_with('[') {
serde_json::from_str::<Vec<T>>(text)
.ok()
.map(WsFrame::Array)
} else {
serde_json::from_str::<T>(text).ok().map(WsFrame::Single)
}
}
pub fn decode_value(text: &str) -> Option<serde_json::Value> {
#[cfg(feature = "simd-json")]
if text.len() >= SIMD_CROSSOVER_BYTES {
return FRAME_BYTES.with(|cell_bytes| {
SIMD_BUFFERS.with(|cell_buf| {
let mut bytes = cell_bytes.borrow_mut();
let mut buffers = cell_buf.borrow_mut();
bytes.clear();
bytes.extend_from_slice(text.as_bytes());
simd_json::serde::from_slice_with_buffers::<serde_json::Value>(
&mut bytes,
&mut buffers,
)
.ok()
})
});
}
serde_json::from_str::<serde_json::Value>(text).ok()
}
#[cfg(feature = "simd-json")]
pub struct TapeScratch {
buffers: simd_json::Buffers,
}
#[cfg(feature = "simd-json")]
impl TapeScratch {
pub fn new() -> Self {
Self::with_capacity(16 * 1024)
}
pub fn with_capacity(cap: usize) -> Self {
Self {
buffers: simd_json::Buffers::new(cap),
}
}
pub fn parse_value<'a>(
&mut self,
bytes: &'a mut [u8],
) -> Result<simd_json::BorrowedValue<'a>, simd_json::Error> {
simd_json::to_borrowed_value_with_buffers(bytes, &mut self.buffers)
}
}
#[cfg(feature = "simd-json")]
impl Default for TapeScratch {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Debug, Deserialize, PartialEq)]
struct Msg {
event: String,
seq: u64,
}
#[test]
fn single_object() {
let text = r#"{"event":"book","seq":42}"#;
match decode_frame::<Msg>(text).unwrap() {
WsFrame::Single(m) => assert_eq!(
m,
Msg {
event: "book".into(),
seq: 42
}
),
WsFrame::Array(_) => panic!("expected single"),
}
}
#[test]
fn array_of_objects() {
let text = r#"[{"event":"book","seq":1},{"event":"trade","seq":2}]"#;
match decode_frame::<Msg>(text).unwrap() {
WsFrame::Array(items) => assert_eq!(items.len(), 2),
WsFrame::Single(_) => panic!("expected array"),
}
}
#[test]
fn whitespace_prefix() {
let text = " \n [{\"event\":\"book\",\"seq\":1}]";
assert!(matches!(decode_frame::<Msg>(text), Some(WsFrame::Array(_))));
}
#[test]
fn malformed_returns_none() {
assert!(decode_frame::<Msg>("{not json").is_none());
assert!(decode_frame::<Msg>("").is_none());
}
#[test]
fn large_frame_uses_simd() {
let mut inner = String::new();
for i in 0..100 {
if i > 0 {
inner.push(',');
}
inner.push_str(&format!(r#"{{"event":"tick","seq":{i}}}"#));
}
let text = format!("[{inner}]");
match decode_frame::<Msg>(&text).unwrap() {
WsFrame::Array(items) => assert_eq!(items.len(), 100),
WsFrame::Single(_) => panic!("expected array"),
}
}
#[test]
fn decode_value_handles_both_sizes() {
let small = r#"{"msgType":"ping","seq":1}"#;
let v = decode_value(small).unwrap();
assert_eq!(v.get("msgType").and_then(|v| v.as_str()), Some("ping"));
let mut fields = String::new();
for i in 0..200 {
if i > 0 {
fields.push(',');
}
fields.push_str(&format!(r#""k{i}":"value_{i}""#));
}
let large = format!("{{{fields}}}");
let v = decode_value(&large).unwrap();
assert_eq!(v.get("k0").and_then(|v| v.as_str()), Some("value_0"));
}
}