use crate::types::{RequestError, ID_LEN_LIMIT, METHOD_LEN_LIMIT};
use bytes::Bytes;
use serde_json::value::RawValue;
use std::ops::Range;
macro_rules! find_range {
($bytes:expr, $rv:expr) => {{
let rv = $rv.as_bytes();
let start = rv.as_ptr() as usize - $bytes.as_ptr() as usize;
let end = start + rv.len();
debug_assert_eq!(rv, &$bytes[start..end]);
start..end
}};
}
#[derive(Clone)]
pub struct Request {
bytes: Bytes,
id: Range<usize>,
method: Range<usize>,
params: Range<usize>,
}
impl core::fmt::Debug for Request {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Request")
.field("bytes", &self.bytes.len())
.field("method", &self.method)
.finish_non_exhaustive()
}
}
#[derive(serde::Deserialize)]
struct DeserHelper<'a> {
#[serde(borrow)]
id: &'a RawValue,
#[serde(borrow)]
method: &'a RawValue,
#[serde(borrow)]
params: &'a RawValue,
}
impl TryFrom<Bytes> for Request {
type Error = RequestError;
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
let DeserHelper { id, method, params } = serde_json::from_slice(bytes.as_ref())?;
let id = find_range!(bytes, id.get());
let id_len = id.end - id.start;
if id_len > ID_LEN_LIMIT {
return Err(RequestError::IdTooLarge(id_len));
}
let method = method
.get()
.strip_prefix('"')
.and_then(|s| s.strip_suffix('"'))
.ok_or(RequestError::InvalidMethod)?;
let method = find_range!(bytes, method);
let method_len = method.end - method.start;
if method_len > METHOD_LEN_LIMIT {
return Err(RequestError::MethodTooLarge(method_len));
}
let params = find_range!(bytes, params.get());
Ok(Self {
bytes,
id,
method,
params,
})
}
}
#[cfg(feature = "ws")]
impl TryFrom<tokio_tungstenite::tungstenite::Utf8Bytes> for Request {
type Error = RequestError;
fn try_from(bytes: tokio_tungstenite::tungstenite::Utf8Bytes) -> Result<Self, Self::Error> {
Self::try_from(Bytes::from(bytes))
}
}
impl Request {
pub fn id(&self) -> &str {
unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(self.id.clone())) }
}
pub fn id_owned(&self) -> Box<RawValue> {
RawValue::from_string(self.id().to_string()).expect("valid json")
}
pub fn method(&self) -> &str {
unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(self.method.clone())) }
}
pub fn params(&self) -> &str {
unsafe { core::str::from_utf8_unchecked(self.bytes.get_unchecked(self.params.clone())) }
}
pub fn deser_params<'a: 'de, 'de, T: serde::Deserialize<'de>>(
&'a self,
) -> serde_json::Result<T> {
serde_json::from_str(self.params())
}
}
#[cfg(test)]
mod test {
use crate::types::METHOD_LEN_LIMIT;
use super::*;
#[test]
fn test_request() {
let bytes = Bytes::from_static(b"{\"id\":1,\"method\":\"foo\",\"params\":[]}");
let req = Request::try_from(bytes).unwrap();
assert_eq!(req.id(), "1");
assert_eq!(req.method(), r#"foo"#);
assert_eq!(req.params(), r#"[]"#);
}
#[test]
fn non_utf8() {
let bytes = Bytes::from_static(b"{\"id\xFF\xFF\":1,\"method\":\"foo\",\"params\":[]}");
let err = Request::try_from(bytes).unwrap_err();
assert!(matches!(err, RequestError::InvalidJson(_)));
assert!(err.to_string().contains("invalid unicode code point"));
}
#[test]
fn too_large_id() {
let id = "a".repeat(ID_LEN_LIMIT + 1);
let bytes = Bytes::from(format!(r#"{{"id":"{}","method":"foo","params":[]}}"#, id));
let RequestError::IdTooLarge(size) = Request::try_from(bytes).unwrap_err() else {
panic!("Expected RequestError::IdTooLarge")
};
assert_eq!(size, ID_LEN_LIMIT + 3);
}
#[test]
fn too_large_method() {
let method = "a".repeat(METHOD_LEN_LIMIT + 1);
let bytes = Bytes::from(format!(r#"{{"id":1,"method":"{}","params":[]}}"#, method));
let RequestError::MethodTooLarge(size) = Request::try_from(bytes).unwrap_err() else {
panic!("Expected RequestError::MethodTooLarge")
};
assert_eq!(size, METHOD_LEN_LIMIT + 1);
}
}