use serde::Deserialize;
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum SchemaError {
#[error("empty {0} name")]
EmptyIdentifier(&'static str),
#[error("{kind} name too long: {len} > {max}")]
IdentifierTooLong {
kind: &'static str,
len: usize,
max: usize,
},
#[error("invalid {kind} name {name:?}: must match ^[A-Za-z_][A-Za-z0-9_]*$")]
InvalidIdentifier { kind: &'static str, name: String },
#[error("a table must declare at least one column")]
NoColumns,
#[error("too many columns: {count} > {max}")]
TooManyColumns { count: usize, max: usize },
#[error("column name {0:?} is reserved")]
ReservedColumn(String),
#[error("duplicate column name {0:?}")]
DuplicateColumn(String),
#[error("invalid DateTime64 precision: {precision} (must be 0..=9)")]
InvalidDateTime64Precision { precision: u8 },
}
#[derive(Debug, Clone, Copy)]
pub struct SchemaLimits {
pub max_columns: usize,
pub max_identifier_length: usize,
}
impl Default for SchemaLimits {
fn default() -> Self {
Self {
max_columns: 1024,
max_identifier_length: 128,
}
}
}
pub const DEFAULT_RESERVED_COLUMNS: &[&str] = &["attrs", "raw"];
fn is_valid_identifier(name: &str) -> bool {
let mut chars = name.chars();
match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
_ => return false,
}
chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
}
fn is_valid_timezone(tz: &str) -> bool {
!tz.is_empty()
&& tz.len() <= 64
&& tz
.chars()
.all(|c| c.is_ascii_alphanumeric() || matches!(c, '_' | '+' | '/' | '-'))
}
pub fn validate_identifier<'a>(
name: &'a str,
kind: &'static str,
limits: &SchemaLimits,
) -> Result<&'a str, SchemaError> {
if name.is_empty() {
return Err(SchemaError::EmptyIdentifier(kind));
}
if name.len() > limits.max_identifier_length {
return Err(SchemaError::IdentifierTooLong {
kind,
len: name.len(),
max: limits.max_identifier_length,
});
}
if !is_valid_identifier(name) {
return Err(SchemaError::InvalidIdentifier {
kind,
name: name.to_string(),
});
}
Ok(name)
}
pub fn quote_identifier(name: &str) -> String {
format!("`{}`", name.replace('`', "``"))
}
pub fn assert_column_count(count: usize, limits: &SchemaLimits) -> Result<(), SchemaError> {
if count < 1 {
return Err(SchemaError::NoColumns);
}
if count > limits.max_columns {
return Err(SchemaError::TooManyColumns {
count,
max: limits.max_columns,
});
}
Ok(())
}
pub fn assert_not_reserved(name: &str, reserved: &[&str]) -> Result<(), SchemaError> {
if reserved.contains(&name) {
return Err(SchemaError::ReservedColumn(name.to_string()));
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
pub enum ScalarType {
String,
#[serde(rename = "UUID")]
Uuid,
Bool,
Date,
DateTime,
DateTime64,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float32,
Float64,
#[serde(rename = "JSON")]
Json,
}
impl ScalarType {
fn ch_type(self) -> &'static str {
match self {
ScalarType::String => "String",
ScalarType::Uuid => "UUID",
ScalarType::Bool => "Bool",
ScalarType::Date => "Date",
ScalarType::DateTime => "DateTime",
ScalarType::DateTime64 => "DateTime64(3)",
ScalarType::Int8 => "Int8",
ScalarType::Int16 => "Int16",
ScalarType::Int32 => "Int32",
ScalarType::Int64 => "Int64",
ScalarType::UInt8 => "UInt8",
ScalarType::UInt16 => "UInt16",
ScalarType::UInt32 => "UInt32",
ScalarType::UInt64 => "UInt64",
ScalarType::Float32 => "Float32",
ScalarType::Float64 => "Float64",
ScalarType::Json => "JSON",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
pub enum StringOnly {
String,
}
fn default_dt64_precision() -> u8 {
3
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct DateTime64Spec {
#[serde(default = "default_dt64_precision")]
pub precision: u8,
#[serde(default)]
pub timezone: Option<String>,
}
impl DateTime64Spec {
pub fn validate(&self) -> Result<(), SchemaError> {
if self.precision > 9 {
return Err(SchemaError::InvalidDateTime64Precision {
precision: self.precision,
});
}
if let Some(tz) = &self.timezone {
if !is_valid_timezone(tz) {
return Err(SchemaError::InvalidIdentifier {
kind: "timezone",
name: tz.clone(),
});
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(untagged)]
pub enum ColumnTypeSpec {
Scalar(ScalarType),
DateTime64 {
datetime64: DateTime64Spec,
},
Nullable {
nullable: Box<ColumnTypeSpec>,
},
LowCardinality {
#[serde(rename = "lowCardinality")]
low_cardinality: Box<ColumnTypeSpec>,
},
Array {
array: StringOnly,
},
Map {
map: (StringOnly, StringOnly),
},
}
impl ColumnTypeSpec {
pub fn to_ch_type(&self) -> String {
match self {
ColumnTypeSpec::Scalar(s) => s.ch_type().to_string(),
ColumnTypeSpec::DateTime64 { datetime64 } => match &datetime64.timezone {
Some(tz) => format!("DateTime64({}, '{}')", datetime64.precision, tz),
None => format!("DateTime64({})", datetime64.precision),
},
ColumnTypeSpec::Nullable { nullable } => format!("Nullable({})", nullable.to_ch_type()),
ColumnTypeSpec::LowCardinality { low_cardinality } => {
format!("LowCardinality({})", low_cardinality.to_ch_type())
}
ColumnTypeSpec::Array { .. } => "Array(String)".to_string(),
ColumnTypeSpec::Map { .. } => "Map(String, String)".to_string(),
}
}
pub fn is_datetime64(&self) -> bool {
match self {
ColumnTypeSpec::Scalar(ScalarType::DateTime64) => true,
ColumnTypeSpec::DateTime64 { .. } => true,
ColumnTypeSpec::Nullable { nullable } => nullable.is_datetime64(),
ColumnTypeSpec::LowCardinality { low_cardinality } => low_cardinality.is_datetime64(),
_ => false,
}
}
pub fn validate(&self) -> Result<(), SchemaError> {
match self {
ColumnTypeSpec::DateTime64 { datetime64 } => datetime64.validate(),
ColumnTypeSpec::Nullable { nullable } => nullable.validate(),
ColumnTypeSpec::LowCardinality { low_cardinality } => low_cardinality.validate(),
_ => Ok(()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn limits() -> SchemaLimits {
SchemaLimits::default()
}
#[test]
fn accepts_safe_identifiers() {
for ok in ["a", "A", "_x", "org_id", "col1", "X_2_y"] {
assert_eq!(validate_identifier(ok, "column", &limits()).unwrap(), ok);
}
}
#[test]
fn rejects_injection_and_metacharacters() {
let attacks = [
"a; DROP TABLE x",
"a`,`b",
"a) ENGINE=Memory AS SELECT * FROM secrets --",
"a' OR '1'='1",
"a b",
"a.b",
"a-b",
"1col",
"",
"a\"b",
"a\nb",
"таблица",
"a/*x*/",
];
for bad in attacks {
assert!(
validate_identifier(bad, "column", &limits()).is_err(),
"should reject {bad:?}"
);
}
}
#[test]
fn enforces_length_bound() {
let lim = limits();
let too_long = "a".repeat(lim.max_identifier_length + 1);
assert!(validate_identifier(&too_long, "column", &lim).is_err());
let ok = "a".repeat(lim.max_identifier_length);
assert!(validate_identifier(&ok, "column", &lim).is_ok());
}
#[test]
fn quotes_and_escapes() {
assert_eq!(quote_identifier("org_id"), "`org_id`");
assert_eq!(quote_identifier("a`b"), "`a``b`");
}
#[test]
fn bounds_and_reserved() {
assert!(assert_column_count(0, &limits()).is_err());
assert!(assert_column_count(limits().max_columns + 1, &limits()).is_err());
assert!(assert_column_count(10, &limits()).is_ok());
assert!(assert_not_reserved("attrs", DEFAULT_RESERVED_COLUMNS).is_err());
assert!(assert_not_reserved("raw", DEFAULT_RESERVED_COLUMNS).is_err());
assert!(assert_not_reserved("user_col", DEFAULT_RESERVED_COLUMNS).is_ok());
}
#[test]
fn allowlist_builds_allowed_types() {
let s: ColumnTypeSpec = serde_json::from_str("\"DateTime64\"").unwrap();
assert_eq!(s.to_ch_type(), "DateTime64(3)");
assert!(s.is_datetime64());
let n: ColumnTypeSpec = serde_json::from_str(r#"{"nullable":"String"}"#).unwrap();
assert_eq!(n.to_ch_type(), "Nullable(String)");
let lc: ColumnTypeSpec =
serde_json::from_str(r#"{"lowCardinality":{"nullable":"String"}}"#).unwrap();
assert_eq!(lc.to_ch_type(), "LowCardinality(Nullable(String))");
let lcd: ColumnTypeSpec =
serde_json::from_str(r#"{"lowCardinality":"DateTime64"}"#).unwrap();
assert!(lcd.is_datetime64());
let a: ColumnTypeSpec = serde_json::from_str(r#"{"array":"String"}"#).unwrap();
assert_eq!(a.to_ch_type(), "Array(String)");
let m: ColumnTypeSpec = serde_json::from_str(r#"{"map":["String","String"]}"#).unwrap();
assert_eq!(m.to_ch_type(), "Map(String, String)");
}
#[test]
fn allowlist_rejects_disallowed_types() {
let bad = [
"\"Decimal(38, 10)\"",
"\"FixedString(16)\"",
"\"Enum8\"",
"\"Tuple\"",
"\"Nested\"",
r#"{"map":["String","Int32"]}"#,
r#"{"array":"Int32"}"#,
r#"{"array":{"nullable":"String"}}"#,
r#"{"wat":"String"}"#,
"42",
];
for b in bad {
assert!(
serde_json::from_str::<ColumnTypeSpec>(b).is_err(),
"should reject {b}"
);
}
}
#[test]
fn parametrised_datetime64_renders_and_validates() {
let utc: ColumnTypeSpec =
serde_json::from_str(r#"{"datetime64":{"precision":3,"timezone":"UTC"}}"#).unwrap();
assert_eq!(utc.to_ch_type(), "DateTime64(3, 'UTC')");
assert!(utc.is_datetime64());
assert!(utc.validate().is_ok());
let p6: ColumnTypeSpec = serde_json::from_str(r#"{"datetime64":{"precision":6}}"#).unwrap();
assert_eq!(p6.to_ch_type(), "DateTime64(6)");
assert!(p6.validate().is_ok());
let def: ColumnTypeSpec = serde_json::from_str(r#"{"datetime64":{}}"#).unwrap();
assert_eq!(def.to_ch_type(), "DateTime64(3)");
assert!(def.is_datetime64());
assert!(def.validate().is_ok());
let bare: ColumnTypeSpec = serde_json::from_str("\"DateTime64\"").unwrap();
assert!(matches!(
bare,
ColumnTypeSpec::Scalar(ScalarType::DateTime64)
));
let tz: ColumnTypeSpec =
serde_json::from_str(r#"{"datetime64":{"precision":9,"timezone":"America/New_York"}}"#)
.unwrap();
assert_eq!(tz.to_ch_type(), "DateTime64(9, 'America/New_York')");
assert!(tz.validate().is_ok());
}
#[test]
fn parametrised_datetime64_rejects_bad_params() {
let bad_tz: ColumnTypeSpec =
serde_json::from_str(r#"{"datetime64":{"precision":3,"timezone":"UTC'; DROP"}}"#)
.unwrap();
assert!(matches!(
bad_tz.validate(),
Err(SchemaError::InvalidIdentifier {
kind: "timezone",
..
})
));
let bad_p: ColumnTypeSpec =
serde_json::from_str(r#"{"datetime64":{"precision":12}}"#).unwrap();
assert!(matches!(
bad_p.validate(),
Err(SchemaError::InvalidDateTime64Precision { precision: 12 })
));
}
#[test]
fn parametrised_datetime64_is_datetime64_through_nullable() {
let n: ColumnTypeSpec =
serde_json::from_str(r#"{"nullable":{"datetime64":{"precision":3,"timezone":"UTC"}}}"#)
.unwrap();
assert!(n.is_datetime64());
assert_eq!(n.to_ch_type(), "Nullable(DateTime64(3, 'UTC'))");
assert!(n.validate().is_ok());
let bad: ColumnTypeSpec =
serde_json::from_str(r#"{"nullable":{"datetime64":{"precision":12}}}"#).unwrap();
assert!(bad.validate().is_err());
}
}