1use serde::de::DeserializeOwned;
14#[cfg(feature = "simd-json")]
15use std::cell::RefCell;
16
17#[cfg(feature = "simd-json")]
18thread_local! {
19 static SIMD_BUFFERS: RefCell<simd_json::Buffers> =
20 RefCell::new(simd_json::Buffers::new(8 * 1024));
21 static FRAME_BYTES: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
22}
23
24pub enum WsFrame<T> {
30 Single(T),
31 Array(Vec<T>),
32}
33
34impl<T> WsFrame<T> {
35 pub fn for_each<F: FnMut(T)>(self, mut f: F) {
37 match self {
38 Self::Single(item) => f(item),
39 Self::Array(items) => items.into_iter().for_each(f),
40 }
41 }
42}
43
44#[cfg(feature = "simd-json")]
54pub const SIMD_CROSSOVER_BYTES: usize = 512;
55
56pub fn decode_frame<T: DeserializeOwned>(text: &str) -> Option<WsFrame<T>> {
69 #[cfg(feature = "simd-json")]
70 if text.len() >= SIMD_CROSSOVER_BYTES {
71 return FRAME_BYTES.with(|cell_bytes| {
72 SIMD_BUFFERS.with(|cell_buf| {
73 let mut bytes = cell_bytes.borrow_mut();
74 let mut buffers = cell_buf.borrow_mut();
75 bytes.clear();
76 bytes.extend_from_slice(text.as_bytes());
77 let head = bytes.iter().find(|&&b| !b.is_ascii_whitespace()).copied()?;
78 if head == b'[' {
79 simd_json::serde::from_slice_with_buffers::<Vec<T>>(&mut bytes, &mut buffers)
80 .ok()
81 .map(WsFrame::Array)
82 } else {
83 simd_json::serde::from_slice_with_buffers::<T>(&mut bytes, &mut buffers)
84 .ok()
85 .map(WsFrame::Single)
86 }
87 })
88 });
89 }
90
91 let trimmed = text.trim_start();
92 if trimmed.starts_with('[') {
93 serde_json::from_str::<Vec<T>>(text)
94 .ok()
95 .map(WsFrame::Array)
96 } else {
97 serde_json::from_str::<T>(text).ok().map(WsFrame::Single)
98 }
99}
100
101pub fn decode_value(text: &str) -> Option<serde_json::Value> {
106 #[cfg(feature = "simd-json")]
107 if text.len() >= SIMD_CROSSOVER_BYTES {
108 return FRAME_BYTES.with(|cell_bytes| {
109 SIMD_BUFFERS.with(|cell_buf| {
110 let mut bytes = cell_bytes.borrow_mut();
111 let mut buffers = cell_buf.borrow_mut();
112 bytes.clear();
113 bytes.extend_from_slice(text.as_bytes());
114 simd_json::serde::from_slice_with_buffers::<serde_json::Value>(
115 &mut bytes,
116 &mut buffers,
117 )
118 .ok()
119 })
120 });
121 }
122 serde_json::from_str::<serde_json::Value>(text).ok()
123}
124
125#[cfg(feature = "simd-json")]
141pub struct TapeScratch {
142 buffers: simd_json::Buffers,
143}
144
145#[cfg(feature = "simd-json")]
146impl TapeScratch {
147 pub fn new() -> Self {
149 Self::with_capacity(16 * 1024)
150 }
151
152 pub fn with_capacity(cap: usize) -> Self {
153 Self {
154 buffers: simd_json::Buffers::new(cap),
155 }
156 }
157
158 pub fn parse_value<'a>(
162 &mut self,
163 bytes: &'a mut [u8],
164 ) -> Result<simd_json::BorrowedValue<'a>, simd_json::Error> {
165 simd_json::to_borrowed_value_with_buffers(bytes, &mut self.buffers)
166 }
167}
168
169#[cfg(feature = "simd-json")]
170impl Default for TapeScratch {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use serde::Deserialize;
180
181 #[derive(Debug, Deserialize, PartialEq)]
182 struct Msg {
183 event: String,
184 seq: u64,
185 }
186
187 #[test]
188 fn single_object() {
189 let text = r#"{"event":"book","seq":42}"#;
190 match decode_frame::<Msg>(text).unwrap() {
191 WsFrame::Single(m) => assert_eq!(
192 m,
193 Msg {
194 event: "book".into(),
195 seq: 42
196 }
197 ),
198 WsFrame::Array(_) => panic!("expected single"),
199 }
200 }
201
202 #[test]
203 fn array_of_objects() {
204 let text = r#"[{"event":"book","seq":1},{"event":"trade","seq":2}]"#;
205 match decode_frame::<Msg>(text).unwrap() {
206 WsFrame::Array(items) => assert_eq!(items.len(), 2),
207 WsFrame::Single(_) => panic!("expected array"),
208 }
209 }
210
211 #[test]
212 fn whitespace_prefix() {
213 let text = " \n [{\"event\":\"book\",\"seq\":1}]";
214 assert!(matches!(decode_frame::<Msg>(text), Some(WsFrame::Array(_))));
215 }
216
217 #[test]
218 fn malformed_returns_none() {
219 assert!(decode_frame::<Msg>("{not json").is_none());
220 assert!(decode_frame::<Msg>("").is_none());
221 }
222
223 #[test]
224 fn large_frame_uses_simd() {
225 let mut inner = String::new();
227 for i in 0..100 {
228 if i > 0 {
229 inner.push(',');
230 }
231 inner.push_str(&format!(r#"{{"event":"tick","seq":{i}}}"#));
232 }
233 let text = format!("[{inner}]");
234 match decode_frame::<Msg>(&text).unwrap() {
235 WsFrame::Array(items) => assert_eq!(items.len(), 100),
236 WsFrame::Single(_) => panic!("expected array"),
237 }
238 }
239
240 #[test]
241 fn decode_value_handles_both_sizes() {
242 let small = r#"{"msgType":"ping","seq":1}"#;
244 let v = decode_value(small).unwrap();
245 assert_eq!(v.get("msgType").and_then(|v| v.as_str()), Some("ping"));
246
247 let mut fields = String::new();
249 for i in 0..200 {
250 if i > 0 {
251 fields.push(',');
252 }
253 fields.push_str(&format!(r#""k{i}":"value_{i}""#));
254 }
255 let large = format!("{{{fields}}}");
256 let v = decode_value(&large).unwrap();
257 assert_eq!(v.get("k0").and_then(|v| v.as_str()), Some("value_0"));
258 }
259}