use regex::Regex;
use serde_json::Value;
use std::sync::LazyLock;
static COLLECTION_NAME_REGEX: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"^[a-zA-Z0-9_-]{1,64}$").unwrap()
});
static KEY_NAME_REGEX: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"^[a-zA-Z0-9_.-]{1,256}$").unwrap()
});
static FIELD_NAME_REGEX: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"^[a-zA-Z0-9_.-]{1,128}$").unwrap()
});
#[derive(Debug)]
pub enum ValidationError {
InvalidCollectionName(String),
InvalidKeyName(String),
InvalidFieldName(String),
CollectionNameTooLong,
KeyNameTooLong,
PayloadTooLarge,
InvalidJsonDepth,
TooManyKeys,
UnknownProperty(String),
}
impl std::fmt::Display for ValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValidationError::InvalidCollectionName(name) => {
write!(f, "Invalid collection name: '{}'. Must be alphanumeric with _, - only (1-64 chars)", name)
}
ValidationError::InvalidKeyName(name) => {
write!(f, "Invalid key name: '{}'. Must be alphanumeric with _, -, . only (1-256 chars)", name)
}
ValidationError::InvalidFieldName(name) => {
write!(f, "Invalid field name: '{}'. Must be alphanumeric with _, -, . only (1-128 chars)", name)
}
ValidationError::CollectionNameTooLong => {
write!(f, "Collection name too long (max 64 characters)")
}
ValidationError::KeyNameTooLong => {
write!(f, "Key name too long (max 256 characters)")
}
ValidationError::PayloadTooLarge => {
write!(f, "Payload too large (max 10MB)")
}
ValidationError::InvalidJsonDepth => {
write!(f, "JSON nesting too deep (max 32 levels)")
}
ValidationError::TooManyKeys => {
write!(f, "Too many keys in single request (max 1000)")
}
ValidationError::UnknownProperty(name) => {
write!(f, "Unknown property: '{}'. Check the API docs for the list of supported properties for this endpoint", name)
}
}
}
}
impl std::error::Error for ValidationError {}
pub fn validate_collection_name(name: &str) -> Result<(), ValidationError> {
if name.is_empty() {
return Err(ValidationError::InvalidCollectionName(name.to_string()));
}
if name.len() > 64 {
return Err(ValidationError::CollectionNameTooLong);
}
if !COLLECTION_NAME_REGEX.is_match(name) {
return Err(ValidationError::InvalidCollectionName(name.to_string()));
}
if matches!(name, "admin" | "system" | "config" | "internal" | "__proto__") {
return Err(ValidationError::InvalidCollectionName(
format!("{} (reserved name)", name)
));
}
Ok(())
}
pub fn validate_key_name(key: &str) -> Result<(), ValidationError> {
if key.is_empty() {
return Err(ValidationError::InvalidKeyName(key.to_string()));
}
if key.len() > 256 {
return Err(ValidationError::KeyNameTooLong);
}
if !KEY_NAME_REGEX.is_match(key) {
return Err(ValidationError::InvalidKeyName(key.to_string()));
}
Ok(())
}
pub fn validate_field_name(field: &str) -> Result<(), ValidationError> {
if field.is_empty() {
return Err(ValidationError::InvalidFieldName(field.to_string()));
}
if field.len() > 128 {
return Err(ValidationError::InvalidFieldName(
format!("{} (too long)", field)
));
}
if !FIELD_NAME_REGEX.is_match(field) {
return Err(ValidationError::InvalidFieldName(field.to_string()));
}
for part in field.split('.') {
if part.is_empty() {
return Err(ValidationError::InvalidFieldName(field.to_string()));
}
}
Ok(())
}
pub fn validate_json_depth(value: &Value, max_depth: usize) -> Result<(), ValidationError> {
fn check_depth(value: &Value, current: usize, max: usize) -> Result<(), ValidationError> {
if current > max {
return Err(ValidationError::InvalidJsonDepth);
}
match value {
Value::Object(map) => {
for v in map.values() {
check_depth(v, current + 1, max)?;
}
}
Value::Array(arr) => {
for v in arr {
check_depth(v, current + 1, max)?;
}
}
_ => {}
}
Ok(())
}
check_depth(value, 0, max_depth)
}
pub fn validate_payload_size(payload: &Value, max_size_bytes: usize) -> Result<(), ValidationError> {
let serialized = serde_json::to_string(payload).unwrap_or_default();
if serialized.len() > max_size_bytes {
return Err(ValidationError::PayloadTooLarge);
}
Ok(())
}
pub fn validate_key_count(count: usize, max_keys: usize) -> Result<(), ValidationError> {
if count > max_keys {
return Err(ValidationError::TooManyKeys);
}
Ok(())
}
pub fn validate_allowed_properties(payload: &Value, allowed: &[&str]) -> Result<(), ValidationError> {
if let Some(obj) = payload.as_object() {
for key in obj.keys() {
if !allowed.contains(&key.as_str()) {
return Err(ValidationError::UnknownProperty(key.clone()));
}
}
}
Ok(())
}
pub fn validate_request(payload: &Value, max_body_size: usize) -> Result<(), ValidationError> {
validate_payload_size(payload, max_body_size)?;
validate_json_depth(payload, 32)?;
if let Some(collection) = payload.get("collection").and_then(|v| v.as_str()) {
validate_collection_name(collection)?;
}
if let Some(keys) = payload.get("keys") {
match keys {
Value::String(key) => validate_key_name(key)?,
Value::Array(arr) => {
validate_key_count(arr.len(), 1000)?;
for key in arr {
if let Some(key_str) = key.as_str() {
validate_key_name(key_str)?;
}
}
}
_ => {}
}
}
if let Some(data) = payload.get("data") {
if let Value::Object(map) = data {
validate_key_count(map.len(), 1000)?;
for key in map.keys() {
validate_key_name(key)?;
}
}
}
if let Some(fields) = payload.get("fields").and_then(|v| v.as_array()) {
for field in fields {
if let Some(field_str) = field.as_str() {
validate_field_name(field_str)?;
}
}
}
if let Some(joins) = payload.get("joins").and_then(|v| v.as_array()) {
for join in joins {
if let Some(join_collection) = join.get("collection").and_then(|v| v.as_str()) {
validate_collection_name(join_collection)?;
}
if let Some(alias) = join.get("alias").and_then(|v| v.as_str()) {
validate_key_name(alias)?;
}
if let Some(foreign_key) = join.get("foreign_key").and_then(|v| v.as_str()) {
validate_field_name(foreign_key)?;
}
if let Some(join_fields) = join.get("fields").and_then(|v| v.as_array()) {
for field in join_fields {
if let Some(field_str) = field.as_str() {
validate_field_name(field_str)?;
}
}
}
}
}
if let Some(where_clause) = payload.get("where").and_then(|v| v.as_object()) {
for key in where_clause.keys() {
if !key.starts_with('$') {
validate_field_name(key)?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_valid_collection_names() {
assert!(validate_collection_name("users").is_ok());
assert!(validate_collection_name("user_data").is_ok());
assert!(validate_collection_name("data-2024").is_ok());
assert!(validate_collection_name("test123").is_ok());
}
#[test]
fn test_invalid_collection_names() {
assert!(validate_collection_name("").is_err()); assert!(validate_collection_name("user$data").is_err()); assert!(validate_collection_name("../etc/passwd").is_err()); assert!(validate_collection_name("admin").is_err()); }
#[test]
fn test_json_depth() {
let shallow = json!({"a": {"b": "c"}});
assert!(validate_json_depth(&shallow, 10).is_ok());
let mut deep = json!({});
let mut current = &mut deep;
for _ in 0..50 {
*current = json!({"nested": {}});
current = current.get_mut("nested").unwrap();
}
assert!(validate_json_depth(&deep, 32).is_err());
}
}