use super::types::*;
use crate::error::{Error, Result};
use bytes::Bytes;
use std::collections::HashSet;
pub fn parse_payload(body: Bytes, content_type: &MediaType) -> Result<Option<Payload>> {
if body.is_empty() {
return Ok(None);
}
match content_type {
MediaType::ApplicationJson => parse_json_payload(body),
MediaType::UrlEncoded => parse_urlencoded_payload(body),
MediaType::TextCsv => {
Ok(Some(Payload::RawJson(body)))
}
MediaType::OctetStream | MediaType::TextPlain | MediaType::TextXml => {
Ok(Some(Payload::RawPayload(body)))
}
_ => parse_json_payload(body),
}
}
fn parse_json_payload(body: Bytes) -> Result<Option<Payload>> {
let value: serde_json::Value =
serde_json::from_slice(&body).map_err(|e| Error::InvalidBody(e.to_string()))?;
let keys = extract_json_keys(&value);
Ok(Some(Payload::ProcessedJson { raw: body, keys }))
}
fn extract_json_keys(value: &serde_json::Value) -> HashSet<String> {
match value {
serde_json::Value::Object(map) => map.keys().cloned().collect(),
serde_json::Value::Array(arr) => {
arr.iter()
.filter_map(|v| v.as_object())
.flat_map(|map| map.keys().cloned())
.collect()
}
_ => HashSet::new(),
}
}
fn parse_urlencoded_payload(body: Bytes) -> Result<Option<Payload>> {
let body_str =
std::str::from_utf8(&body).map_err(|_| Error::InvalidBody("Invalid UTF-8".into()))?;
let data: Vec<(String, String)> = url::form_urlencoded::parse(body_str.as_bytes())
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
let keys: HashSet<String> = data.iter().map(|(k, _)| k.clone()).collect();
Ok(Some(Payload::ProcessedUrlEncoded { data, keys }))
}
pub fn validate_payload_columns(
payload: &Payload,
expected: &HashSet<String>,
) -> Result<()> {
let keys = match payload {
Payload::ProcessedJson { keys, .. } => keys,
Payload::ProcessedUrlEncoded { keys, .. } => keys,
_ => return Ok(()),
};
for key in keys {
if !expected.contains(key) {
return Err(Error::UnknownColumn(key.clone()));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_json_object() {
let body = Bytes::from(r#"{"name": "John", "age": 30}"#);
let payload = parse_payload(body, &MediaType::ApplicationJson)
.unwrap()
.unwrap();
match payload {
Payload::ProcessedJson { keys, .. } => {
assert!(keys.contains("name"));
assert!(keys.contains("age"));
}
_ => panic!("Expected ProcessedJson"),
}
}
#[test]
fn test_parse_json_array() {
let body = Bytes::from(r#"[{"id": 1}, {"id": 2, "name": "test"}]"#);
let payload = parse_payload(body, &MediaType::ApplicationJson)
.unwrap()
.unwrap();
match payload {
Payload::ProcessedJson { keys, .. } => {
assert!(keys.contains("id"));
assert!(keys.contains("name"));
}
_ => panic!("Expected ProcessedJson"),
}
}
#[test]
fn test_parse_urlencoded() {
let body = Bytes::from("name=John&age=30");
let payload = parse_payload(body, &MediaType::UrlEncoded)
.unwrap()
.unwrap();
match payload {
Payload::ProcessedUrlEncoded { data, keys } => {
assert_eq!(data.len(), 2);
assert!(keys.contains("name"));
assert!(keys.contains("age"));
}
_ => panic!("Expected ProcessedUrlEncoded"),
}
}
#[test]
fn test_parse_empty_body() {
let body = Bytes::new();
let payload = parse_payload(body, &MediaType::ApplicationJson).unwrap();
assert!(payload.is_none());
}
#[test]
fn test_parse_octet_stream() {
let body = Bytes::from(vec![0u8, 1, 2, 3]);
let payload = parse_payload(body.clone(), &MediaType::OctetStream)
.unwrap()
.unwrap();
match payload {
Payload::RawPayload(data) => {
assert_eq!(data, body);
}
_ => panic!("Expected RawPayload"),
}
}
}