use async_trait::async_trait;
use bytes::Bytes;
use http::HeaderMap;
use serde_json::Value;
use super::parser::{ParseError, ParseResult, ParsedData, Parser};
#[derive(Debug, Clone)]
pub struct JSONParser {
pub allow_empty: bool,
pub strict: bool,
}
impl Default for JSONParser {
fn default() -> Self {
Self {
allow_empty: false,
strict: true, }
}
}
impl JSONParser {
pub fn new() -> Self {
Self::default()
}
pub fn allow_empty(mut self, allow: bool) -> Self {
self.allow_empty = allow;
self
}
pub fn strict(mut self, strict: bool) -> Self {
self.strict = strict;
self
}
}
#[async_trait]
impl Parser for JSONParser {
fn media_types(&self) -> Vec<String> {
vec![
"application/json".to_string(),
"application/*+json".to_string(),
]
}
async fn parse(
&self,
_content_type: Option<&str>,
body: Bytes,
_headers: &HeaderMap,
) -> ParseResult<ParsedData> {
if body.is_empty() {
if self.allow_empty {
return Ok(ParsedData::Json(Value::Null));
} else {
return Err(ParseError::ParseError("Empty request body".to_string()));
}
}
match serde_json::from_slice::<Value>(&body) {
Ok(value) => {
if self.strict {
Self::validate_strict_json(&value)?;
}
Ok(ParsedData::Json(value))
}
Err(e) => Err(ParseError::ParseError(format!("Invalid JSON: {}", e))),
}
}
}
impl JSONParser {
fn validate_strict_json(value: &Value) -> ParseResult<()> {
match value {
Value::Number(n) => {
if let Some(f) = n.as_f64()
&& !f.is_finite()
{
return Err(ParseError::ParseError(
"Non-finite float values (Infinity, -Infinity, NaN) are not allowed in strict mode".to_string()
));
}
}
Value::Array(arr) => {
for item in arr {
Self::validate_strict_json(item)?;
}
}
Value::Object(obj) => {
for value in obj.values() {
Self::validate_strict_json(value)?;
}
}
_ => {}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_json_parser_valid() {
let parser = JSONParser::new();
let body = Bytes::from(r#"{"name": "test", "value": 123}"#);
let headers = HeaderMap::new();
let result = parser
.parse(Some("application/json"), body, &headers)
.await
.unwrap();
match result {
ParsedData::Json(value) => {
assert_eq!(value["name"], "test");
assert_eq!(value["value"], 123);
}
_ => panic!("Expected JSON data"),
}
}
#[tokio::test]
async fn test_json_parser_invalid() {
let parser = JSONParser::new();
let body = Bytes::from(r#"{"invalid": json}"#);
let headers = HeaderMap::new();
let result = parser.parse(Some("application/json"), body, &headers).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_json_parser_empty_not_allowed() {
let parser = JSONParser::new();
let body = Bytes::new();
let headers = HeaderMap::new();
let result = parser.parse(Some("application/json"), body, &headers).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_json_parser_empty_allowed() {
let parser = JSONParser::new().allow_empty(true);
let body = Bytes::new();
let headers = HeaderMap::new();
let result = parser
.parse(Some("application/json"), body, &headers)
.await
.unwrap();
match result {
ParsedData::Json(Value::Null) => {}
_ => panic!("Expected null JSON value"),
}
}
#[test]
fn test_json_parser_media_types() {
let parser = JSONParser::new();
let media_types = parser.media_types();
assert!(media_types.contains(&"application/json".to_string()));
assert!(media_types.contains(&"application/*+json".to_string()));
}
#[tokio::test]
async fn test_json_float_strictness() {
let parser = JSONParser::new(); let headers = HeaderMap::new();
for value in ["Infinity", "-Infinity", "NaN"] {
let body = Bytes::from(value);
let result = parser.parse(Some("application/json"), body, &headers).await;
assert!(
result.is_err(),
"Expected error for {} (invalid JSON literal)",
value
);
}
let parser_non_strict = JSONParser::new().strict(false);
let valid_json = Bytes::from(r#"{"value": 1.0}"#);
let result = parser_non_strict
.parse(Some("application/json"), valid_json, &headers)
.await;
assert!(result.is_ok(), "Valid JSON should parse in non-strict mode");
}
#[tokio::test]
async fn test_json_edge_case_large_numbers() {
let parser = JSONParser::new();
let headers = HeaderMap::new();
let large_number = Bytes::from(r#"{"value": 1e308}"#);
let result = parser
.parse(Some("application/json"), large_number, &headers)
.await;
assert!(result.is_ok(), "Should parse very large finite numbers");
let large_negative = Bytes::from(r#"{"value": -1e308}"#);
let result = parser
.parse(Some("application/json"), large_negative, &headers)
.await;
assert!(result.is_ok(), "Should parse very large negative numbers");
}
#[tokio::test]
async fn test_json_edge_case_small_numbers() {
let parser = JSONParser::new();
let headers = HeaderMap::new();
let small_number = Bytes::from(r#"{"value": 2.2250738585072014e-308}"#);
let result = parser
.parse(Some("application/json"), small_number, &headers)
.await;
assert!(result.is_ok(), "Should parse very small finite numbers");
let small_negative = Bytes::from(r#"{"value": -2.2250738585072014e-308}"#);
let result = parser
.parse(Some("application/json"), small_negative, &headers)
.await;
assert!(result.is_ok(), "Should parse very small negative numbers");
}
#[tokio::test]
async fn test_json_scientific_notation() {
let parser = JSONParser::new();
let headers = HeaderMap::new();
let test_cases = vec![
r#"{"value": 1.5e10}"#, r#"{"value": 1.5E10}"#, r#"{"value": 1.5e+10}"#, r#"{"value": 1.5e-10}"#, r#"{"array": [1e5, 2e-5]}"#, ];
for test_case in test_cases {
let body = Bytes::from(test_case);
let result = parser.parse(Some("application/json"), body, &headers).await;
assert!(
result.is_ok(),
"Should parse scientific notation: {}",
test_case
);
}
}
#[tokio::test]
async fn test_json_nested_float_validation() {
let parser_strict = JSONParser::new(); let headers = HeaderMap::new();
let nested_infinity = Bytes::from(r#"{"outer": {"inner": Infinity}}"#);
let result = parser_strict
.parse(Some("application/json"), nested_infinity, &headers)
.await;
assert!(
result.is_err(),
"Nested Infinity literal should be rejected by serde_json"
);
let nested_valid = Bytes::from(r#"{"outer": {"inner": 123.456}}"#);
let result = parser_strict
.parse(Some("application/json"), nested_valid, &headers)
.await;
assert!(result.is_ok(), "Valid nested floats should be accepted");
}
#[tokio::test]
async fn test_json_array_float_validation() {
let parser_strict = JSONParser::new();
let parser_non_strict = JSONParser::new().strict(false);
let headers = HeaderMap::new();
let valid_array = Bytes::from(r#"[1.0, 2.5, 3.14159, 100.0]"#);
let result = parser_strict
.parse(Some("application/json"), valid_array.clone(), &headers)
.await;
assert!(result.is_ok(), "Valid float arrays should be accepted");
let result = parser_non_strict
.parse(Some("application/json"), valid_array, &headers)
.await;
assert!(
result.is_ok(),
"Valid float arrays should be accepted in non-strict mode"
);
let invalid_array = Bytes::from(r#"[1.0, Infinity, 3.0]"#);
let result = parser_strict
.parse(Some("application/json"), invalid_array.clone(), &headers)
.await;
assert!(
result.is_err(),
"Infinity literal in array should be rejected"
);
let result = parser_non_strict
.parse(Some("application/json"), invalid_array, &headers)
.await;
assert!(
result.is_err(),
"Infinity literal in array should be rejected even in non-strict mode"
);
}
}