use std::collections::HashMap;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct TagSnapshot {
pub format: String,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub filename: Option<String>,
pub stream_info: StreamInfoSnapshot,
pub tags: HashMap<String, Vec<String>>,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub raw_tags: Option<serde_json::Value>,
}
#[cfg(feature = "serde")]
mod bounded_json_value {
use serde::{Deserialize, Deserializer};
const MAX_DEPTH: usize = 64;
const MAX_NODES: usize = 100_000;
const MAX_STRING_BYTES: usize = 10 * 1024 * 1024;
struct Budget {
node_count: usize,
string_bytes: usize,
}
fn validate(
value: &serde_json::Value,
depth: usize,
budget: &mut Budget,
) -> Result<(), String> {
budget.node_count += 1;
if budget.node_count > MAX_NODES {
return Err(format!("raw_tags exceeds {} node limit", MAX_NODES));
}
if depth >= MAX_DEPTH {
return Err(format!(
"raw_tags exceeds {} nesting depth limit",
MAX_DEPTH
));
}
match value {
serde_json::Value::String(s) => {
budget.string_bytes = budget.string_bytes.saturating_add(s.len());
if budget.string_bytes > MAX_STRING_BYTES {
return Err(format!(
"raw_tags exceeds {} byte size limit for string content",
MAX_STRING_BYTES
));
}
}
serde_json::Value::Array(arr) => {
for item in arr {
validate(item, depth + 1, budget)?;
}
}
serde_json::Value::Object(map) => {
for (k, v) in map {
budget.string_bytes = budget.string_bytes.saturating_add(k.len());
if budget.string_bytes > MAX_STRING_BYTES {
return Err(format!(
"raw_tags exceeds {} byte size limit for string content",
MAX_STRING_BYTES
));
}
validate(v, depth + 1, budget)?;
}
}
_ => {}
}
Ok(())
}
const MAX_RAW_INPUT_BYTES: usize = 16 * 1024 * 1024;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<serde_json::Value>, D::Error>
where
D: Deserializer<'de>,
{
let opt_raw: Option<Box<serde_json::value::RawValue>> = Option::deserialize(deserializer)?;
let raw = match opt_raw {
Some(r) => r,
None => return Ok(None),
};
let raw_str = raw.get();
if raw_str.len() > MAX_RAW_INPUT_BYTES {
return Err(serde::de::Error::custom(format!(
"raw_tags JSON input ({} bytes) exceeds {} byte size limit",
raw_str.len(),
MAX_RAW_INPUT_BYTES
)));
}
let value: serde_json::Value =
serde_json::from_str(raw_str).map_err(serde::de::Error::custom)?;
let mut budget = Budget {
node_count: 0,
string_bytes: 0,
};
validate(&value, 0, &mut budget).map_err(serde::de::Error::custom)?;
Ok(Some(value))
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for TagSnapshot {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
const MAX_TAG_ENTRIES: usize = 10_000;
const MAX_TAG_STRING_BYTES: usize = 10 * 1024 * 1024;
#[derive(serde::Deserialize)]
struct Inner {
format: String,
#[serde(default)]
filename: Option<String>,
stream_info: StreamInfoSnapshot,
tags: HashMap<String, Vec<String>>,
#[serde(deserialize_with = "bounded_json_value::deserialize", default)]
raw_tags: Option<serde_json::Value>,
}
let inner = Inner::deserialize(deserializer)?;
if inner.tags.len() > MAX_TAG_ENTRIES {
return Err(serde::de::Error::custom(format!(
"tags map contains {} entries, exceeding the {} entry limit",
inner.tags.len(),
MAX_TAG_ENTRIES,
)));
}
let mut string_bytes: usize = inner.format.len();
if let Some(ref f) = inner.filename {
string_bytes = string_bytes.saturating_add(f.len());
}
for (key, values) in &inner.tags {
string_bytes = string_bytes.saturating_add(key.len());
for v in values {
string_bytes = string_bytes.saturating_add(v.len());
}
if string_bytes > MAX_TAG_STRING_BYTES {
return Err(serde::de::Error::custom(format!(
"cumulative tag string content ({} bytes) exceeds {} byte limit",
string_bytes, MAX_TAG_STRING_BYTES,
)));
}
}
Ok(TagSnapshot {
format: inner.format,
filename: inner.filename,
stream_info: inner.stream_info,
tags: inner.tags,
raw_tags: inner.raw_tags,
})
}
}
#[cfg(feature = "serde")]
const MAX_RAW_INPUT_BYTES: usize = 16 * 1024 * 1024;
#[cfg(feature = "serde")]
impl TagSnapshot {
pub fn from_json_str(input: &str) -> Result<Self, serde_json::Error> {
if input.len() > MAX_RAW_INPUT_BYTES {
return Err(serde::de::Error::custom(format!(
"JSON input ({} bytes) exceeds {} byte size limit",
input.len(),
MAX_RAW_INPUT_BYTES,
)));
}
serde_json::from_str(input)
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StreamInfoSnapshot {
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub length_secs: Option<f64>,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub bitrate: Option<u32>,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub sample_rate: Option<u32>,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub channels: Option<u16>,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub bits_per_sample: Option<u16>,
}
impl StreamInfoSnapshot {
pub fn from_dynamic(info: &crate::file::DynamicStreamInfo) -> Self {
use crate::StreamInfo;
Self {
length_secs: info.length().map(|d| d.as_secs_f64()),
bitrate: info.bitrate(),
sample_rate: info.sample_rate(),
channels: info.channels(),
bits_per_sample: info.bits_per_sample(),
}
}
}