use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
pub const MAX_TABLE_NAME_LENGTH: usize = 255;
pub const MAX_COLUMN_NAME_LENGTH: usize = 255;
pub const MAX_IDENTIFIER_LENGTH: usize = 255;
pub const MAX_DESCRIPTION_LENGTH: usize = 10000;
pub const MAX_BPMN_DMN_FILE_SIZE: u64 = 10 * 1024 * 1024;
pub const MAX_OPENAPI_FILE_SIZE: u64 = 5 * 1024 * 1024;
pub const MAX_MODEL_NAME_LENGTH: usize = 255;
#[derive(Debug, Clone, Error, Serialize, Deserialize)]
pub enum ValidationError {
#[error("{0} cannot be empty")]
Empty(&'static str),
#[error("{field} exceeds maximum length (max: {max}, got: {actual})")]
TooLong {
field: &'static str,
max: usize,
actual: usize,
},
#[error("{field} contains invalid characters: {reason}")]
InvalidCharacters { field: &'static str, reason: String },
#[error("{0}: {1}")]
InvalidFormat(&'static str, String),
#[error("{field} cannot be a reserved word: {word}")]
ReservedWord { field: &'static str, word: String },
}
pub type ValidationResult<T> = Result<T, ValidationError>;
pub fn validate_table_name(name: &str) -> ValidationResult<()> {
if name.is_empty() {
return Err(ValidationError::Empty("table name"));
}
if name.len() > MAX_TABLE_NAME_LENGTH {
return Err(ValidationError::TooLong {
field: "table name",
max: MAX_TABLE_NAME_LENGTH,
actual: name.len(),
});
}
let first_char = match name.chars().next() {
Some(c) => c,
None => return Err(ValidationError::Empty("table name")),
};
if !first_char.is_alphabetic() && first_char != '_' {
return Err(ValidationError::InvalidFormat(
"table name",
"must start with a letter or underscore".to_string(),
));
}
for c in name.chars() {
if !c.is_alphanumeric() && c != '_' && c != '-' {
return Err(ValidationError::InvalidCharacters {
field: "table name",
reason: format!("invalid character: '{}'", c),
});
}
}
if is_sql_reserved_word(name) {
return Err(ValidationError::ReservedWord {
field: "table name",
word: name.to_string(),
});
}
Ok(())
}
pub fn validate_column_name(name: &str) -> ValidationResult<()> {
if name.is_empty() {
return Err(ValidationError::Empty("column name"));
}
if name.len() > MAX_COLUMN_NAME_LENGTH {
return Err(ValidationError::TooLong {
field: "column name",
max: MAX_COLUMN_NAME_LENGTH,
actual: name.len(),
});
}
let first_char = match name.chars().next() {
Some(c) => c,
None => return Err(ValidationError::Empty("column name")),
};
if !first_char.is_alphabetic() && first_char != '_' {
return Err(ValidationError::InvalidFormat(
"column name",
"must start with a letter or underscore".to_string(),
));
}
for c in name.chars() {
if !c.is_alphanumeric() && c != '_' && c != '-' && c != '.' {
return Err(ValidationError::InvalidCharacters {
field: "column name",
reason: format!("invalid character: '{}'", c),
});
}
}
if !name.contains('.') && is_sql_reserved_word(name) {
return Err(ValidationError::ReservedWord {
field: "column name",
word: name.to_string(),
});
}
Ok(())
}
pub fn validate_uuid(id: &str) -> ValidationResult<Uuid> {
Uuid::parse_str(id)
.map_err(|e| ValidationError::InvalidFormat("UUID", format!("invalid UUID format: {}", e)))
}
pub fn validate_data_type(data_type: &str) -> ValidationResult<()> {
if data_type.is_empty() {
return Err(ValidationError::Empty("data type"));
}
if data_type.len() > MAX_IDENTIFIER_LENGTH {
return Err(ValidationError::TooLong {
field: "data type",
max: MAX_IDENTIFIER_LENGTH,
actual: data_type.len(),
});
}
let lower = data_type.to_lowercase();
if lower.contains(';') || lower.contains("--") || lower.contains("/*") {
return Err(ValidationError::InvalidCharacters {
field: "data type",
reason: "contains SQL comment or statement separator".to_string(),
});
}
for c in data_type.chars() {
if !c.is_alphanumeric()
&& c != '('
&& c != ')'
&& c != ','
&& c != ' '
&& c != '_'
&& c != '<'
&& c != '>'
&& c != '['
&& c != ']'
{
return Err(ValidationError::InvalidCharacters {
field: "data type",
reason: format!("invalid character: '{}'", c),
});
}
}
Ok(())
}
pub fn validate_description(desc: &str) -> ValidationResult<()> {
if desc.len() > MAX_DESCRIPTION_LENGTH {
return Err(ValidationError::TooLong {
field: "description",
max: MAX_DESCRIPTION_LENGTH,
actual: desc.len(),
});
}
Ok(())
}
pub fn sanitize_sql_identifier(name: &str, dialect: &str) -> String {
let quote_char = match dialect.to_lowercase().as_str() {
"mysql" | "mariadb" => '`',
"sqlserver" | "mssql" => '[',
_ => '"', };
let end_char = if quote_char == '[' { ']' } else { quote_char };
let escaped = if quote_char == end_char {
name.replace(quote_char, &format!("{}{}", quote_char, quote_char))
} else {
name.replace(end_char, &format!("{}{}", end_char, end_char))
};
format!("{}{}{}", quote_char, escaped, end_char)
}
pub fn sanitize_description(desc: &str) -> String {
desc.chars()
.filter(|c| !c.is_control() || *c == '\n' || *c == '\t' || *c == '\r')
.collect()
}
fn is_sql_reserved_word(word: &str) -> bool {
const RESERVED_WORDS: &[&str] = &[
"select",
"from",
"where",
"insert",
"update",
"delete",
"create",
"drop",
"alter",
"table",
"index",
"view",
"database",
"schema",
"grant",
"revoke",
"commit",
"rollback",
"begin",
"end",
"transaction",
"primary",
"foreign",
"key",
"references",
"constraint",
"unique",
"check",
"default",
"not",
"null",
"and",
"or",
"in",
"between",
"like",
"is",
"case",
"when",
"then",
"else",
"as",
"on",
"join",
"inner",
"outer",
"left",
"right",
"full",
"cross",
"natural",
"using",
"group",
"by",
"having",
"order",
"asc",
"desc",
"limit",
"offset",
"union",
"intersect",
"except",
"all",
"distinct",
"top",
"values",
"set",
"into",
"exec",
"execute",
"procedure",
"function",
"trigger",
"true",
"false",
"int",
"integer",
"varchar",
"char",
"text",
"boolean",
"date",
"time",
"timestamp",
"float",
"double",
"decimal",
"numeric",
];
let lower = word.to_lowercase();
RESERVED_WORDS.contains(&lower.as_str())
}
pub fn sanitize_model_name(name: &str) -> String {
let mut sanitized = String::with_capacity(name.len());
let mut last_was_dot = false;
for ch in name.chars() {
match ch {
ch if ch.is_alphanumeric() || ch == '-' || ch == '_' => {
sanitized.push(ch);
last_was_dot = false;
}
'.' if !last_was_dot => {
sanitized.push('.');
last_was_dot = true;
}
_ => {
if !last_was_dot {
sanitized.push('_');
}
last_was_dot = false;
}
}
if sanitized.len() >= MAX_MODEL_NAME_LENGTH {
break;
}
}
sanitized = sanitized.trim_end_matches(['.', '_']).to_string();
if sanitized.is_empty() {
sanitized = "model".to_string();
}
sanitized
}
pub fn validate_bpmn_dmn_file_size(file_size: u64) -> ValidationResult<()> {
if file_size > MAX_BPMN_DMN_FILE_SIZE {
return Err(ValidationError::TooLong {
field: "BPMN/DMN file size",
max: MAX_BPMN_DMN_FILE_SIZE as usize,
actual: file_size as usize,
});
}
Ok(())
}
pub const MAX_PATH_LENGTH: usize = 4096;
pub fn validate_path(path: &str, allow_absolute: bool) -> ValidationResult<()> {
if path.is_empty() {
return Err(ValidationError::Empty("path"));
}
if path.contains('\0') {
return Err(ValidationError::InvalidCharacters {
field: "path",
reason: "null bytes not allowed".to_string(),
});
}
if path.len() > MAX_PATH_LENGTH {
return Err(ValidationError::TooLong {
field: "path",
max: MAX_PATH_LENGTH,
actual: path.len(),
});
}
if path.contains("..") {
return Err(ValidationError::InvalidCharacters {
field: "path",
reason: "path traversal (..) not allowed".to_string(),
});
}
if !allow_absolute && (path.starts_with('/') || path.starts_with('\\')) {
return Err(ValidationError::InvalidFormat(
"path",
"absolute paths not allowed".to_string(),
));
}
if !allow_absolute && path.len() >= 2 {
let bytes = path.as_bytes();
if bytes[0].is_ascii_alphabetic() && bytes[1] == b':' {
return Err(ValidationError::InvalidFormat(
"path",
"absolute paths not allowed".to_string(),
));
}
}
Ok(())
}
pub fn validate_glob_pattern(pattern: &str) -> ValidationResult<()> {
if pattern.is_empty() {
return Err(ValidationError::Empty("glob pattern"));
}
if pattern.contains('\0') {
return Err(ValidationError::InvalidCharacters {
field: "glob pattern",
reason: "null bytes not allowed".to_string(),
});
}
if pattern.len() > MAX_PATH_LENGTH {
return Err(ValidationError::TooLong {
field: "glob pattern",
max: MAX_PATH_LENGTH,
actual: pattern.len(),
});
}
if pattern.contains("..") {
return Err(ValidationError::InvalidCharacters {
field: "glob pattern",
reason: "path traversal (..) not allowed".to_string(),
});
}
Ok(())
}
pub fn sanitize_path(path: &str) -> String {
let mut sanitized = path
.replace('\0', "")
.replace("..", "")
.replace('\\', "/");
while sanitized.starts_with('/') {
sanitized = sanitized[1..].to_string();
}
while sanitized.contains("//") {
sanitized = sanitized.replace("//", "/");
}
while sanitized.ends_with('/') && sanitized.len() > 1 {
sanitized.pop();
}
sanitized
}
pub fn validate_url(url: &str) -> ValidationResult<()> {
if url.is_empty() {
return Err(ValidationError::Empty("URL"));
}
let lower = url.to_lowercase();
if !lower.starts_with("http://") && !lower.starts_with("https://") {
return Err(ValidationError::InvalidFormat(
"URL",
"only http:// and https:// URLs are allowed".to_string(),
));
}
if let Some(after_scheme) = url.split("://").nth(1) {
let host_part = after_scheme.split('/').next().unwrap_or("");
if host_part.contains('@') {
return Err(ValidationError::InvalidFormat(
"URL",
"URLs with embedded credentials not allowed".to_string(),
));
}
}
Ok(())
}
#[cfg(test)]
mod path_validation_tests {
use super::*;
#[test]
fn test_validate_path() {
assert!(validate_path("data/file.json", false).is_ok());
assert!(validate_path("nested/path/to/file.csv", false).is_ok());
assert!(validate_path("file.txt", false).is_ok());
assert!(validate_path("/absolute/path", true).is_ok());
assert!(validate_path("../etc/passwd", false).is_err());
assert!(validate_path("data/../secret", false).is_err());
assert!(validate_path("/absolute/path", false).is_err());
assert!(validate_path("", false).is_err());
assert!(validate_path("path\0with\0null", false).is_err());
}
#[test]
fn test_validate_glob_pattern() {
assert!(validate_glob_pattern("**/*.json").is_ok());
assert!(validate_glob_pattern("data/*.csv").is_ok());
assert!(validate_glob_pattern("*.txt").is_ok());
assert!(validate_glob_pattern("../secret/*.json").is_err());
assert!(validate_glob_pattern("").is_err());
}
#[test]
fn test_sanitize_path() {
assert_eq!(sanitize_path("data/file.json"), "data/file.json");
assert_eq!(sanitize_path("../data/file.json"), "data/file.json");
assert_eq!(sanitize_path("/absolute/path"), "absolute/path");
assert_eq!(sanitize_path("data//double//slash"), "data/double/slash");
assert_eq!(
sanitize_path("path\\with\\backslash"),
"path/with/backslash"
);
}
#[test]
fn test_validate_url() {
assert!(validate_url("https://api.example.com/data").is_ok());
assert!(validate_url("http://localhost:8080/api").is_ok());
assert!(validate_url("file:///etc/passwd").is_err());
assert!(validate_url("javascript:alert(1)").is_err());
assert!(validate_url("data:text/html,<script>").is_err());
assert!(validate_url("https://user:pass@example.com").is_err());
assert!(validate_url("").is_err());
}
}
pub fn validate_openapi_file_size(file_size: u64) -> ValidationResult<()> {
if file_size > MAX_OPENAPI_FILE_SIZE {
return Err(ValidationError::TooLong {
field: "OpenAPI file size",
max: MAX_OPENAPI_FILE_SIZE as usize,
actual: file_size as usize,
});
}
Ok(())
}