use crate::types::{Request, RequestError};
use bytes::Bytes;
use serde::Deserialize;
use serde_json::value::RawValue;
use std::ops::Range;
use tracing::{debug, enabled, instrument, span::Span, Level};
#[derive(Default)]
pub struct InboundData {
bytes: Bytes,
reqs: Vec<Range<usize>>,
single: bool,
}
impl InboundData {
pub(crate) fn len(&self) -> usize {
self.reqs.len()
}
pub(crate) fn iter(&self) -> impl Iterator<Item = Result<Request, RequestError>> + '_ {
self.reqs
.iter()
.map(move |r| Request::try_from(self.bytes.slice(r.clone())))
}
pub const fn single(&self) -> bool {
self.single
}
}
impl core::fmt::Debug for InboundData {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("BatchReq")
.field("bytes", &self.bytes.len())
.field("reqs", &self.reqs.len())
.finish()
}
}
impl TryFrom<Bytes> for InboundData {
type Error = RequestError;
#[instrument(level = "debug", skip(bytes), fields(buf_len = bytes.len(), bytes = tracing::field::Empty))]
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
if enabled!(Level::TRACE) {
Span::current().record("bytes", format!("0x{:x}", bytes));
}
debug!("Parsing inbound data");
let mut deserializer = serde_json::Deserializer::from_slice(&bytes);
if let Ok(reqs) = Vec::<&RawValue>::deserialize(&mut deserializer) {
deserializer.end()?;
let reqs = reqs
.into_iter()
.map(|raw| find_range!(bytes, raw.get()))
.collect();
return Ok(Self {
bytes,
reqs,
single: false,
});
}
let rv = <&RawValue>::deserialize(&mut deserializer)?;
deserializer.end()?;
if !rv.get().starts_with("{") {
return Err(RequestError::UnexpectedJsonType);
}
let range = find_range!(bytes, rv.get());
Ok(Self {
bytes,
reqs: vec![range],
single: true,
})
}
}
#[cfg(test)]
mod test {
use super::*;
fn assert_invalid_json(batch: &'static str) {
let bytes = Bytes::from(batch);
let err = InboundData::try_from(bytes).unwrap_err();
assert!(matches!(err, RequestError::InvalidJson(_)));
}
#[test]
fn test_deser_batch() {
let batch = r#"[
{"id": 1, "method": "foo", "params": [1, 2, 3]},
{"id": 2, "method": "bar", "params": [4, 5, 6]}
]"#;
let bytes = Bytes::from(batch);
let batch = InboundData::try_from(bytes).unwrap();
assert_eq!(batch.len(), 2);
assert!(!batch.single());
}
#[test]
fn test_deser_single() {
let single = r#"{"id": 1, "method": "foo", "params": [1, 2, 3]}"#;
let bytes = Bytes::from(single);
let batch = InboundData::try_from(bytes).unwrap();
assert_eq!(batch.len(), 1);
assert!(batch.single());
}
#[test]
fn test_deser_single_with_whitespace() {
let single = r#"
{"id": 1, "method": "foo", "params": [1, 2, 3]}
"#;
let bytes = Bytes::from(single);
let batch = InboundData::try_from(bytes).unwrap();
assert_eq!(batch.len(), 1);
assert!(batch.single());
}
#[test]
fn test_broken_batch() {
let batch = r#"[
{"id": 1, "method": "foo", "params": [1, 2, 3]},
{"id": 2, "method": "bar", "params": [4, 5, 6]
]"#;
assert_invalid_json(batch);
}
#[test]
fn test_junk_prefix() {
let batch = r#"JUNK[
{"id": 1, "method": "foo", "params": [1, 2, 3]},
{"id": 2, "method": "bar", "params": [4, 5, 6]}
]"#;
assert_invalid_json(batch);
}
#[test]
fn test_junk_suffix() {
let batch = r#"[
{"id": 1, "method": "foo", "params": [1, 2, 3]},
{"id": 2, "method": "bar", "params": [4, 5, 6]}
]JUNK"#;
assert_invalid_json(batch);
}
#[test]
fn test_invalid_utf8_prefix() {
let batch = r#"\xF1\x80[
{"id": 1, "method": "foo", "params": [1, 2, 3]},
{"id": 2, "method": "bar", "params": [4, 5, 6]}
]"#;
assert_invalid_json(batch);
}
#[test]
fn test_invalid_utf8_suffix() {
let batch = r#"[
{"id": 1, "method": "foo", "params": [1, 2, 3]},
{"id": 2, "method": "bar", "params": [4, 5, 6]}
]\xF1\x80"#;
assert_invalid_json(batch);
}
}