use crate::types::{SqlTypeName, TypeModifier, Value};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VarcharMode {
Reject,
Truncate,
}
#[derive(Debug, Clone)]
pub enum ParametricError {
VarcharOverflow { actual: usize, max: u32 },
DecimalPrecisionOverflow { precision: u8, actual_digits: usize },
DecimalScaleOverflow {
scale: u8,
actual_fraction_digits: usize,
},
NotADecimal(String),
BadModifier(String),
}
impl std::fmt::Display for ParametricError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VarcharOverflow { actual, max } => {
write!(f, "string of length {actual} exceeds VARCHAR({max})")
}
Self::DecimalPrecisionOverflow {
precision,
actual_digits,
} => {
write!(
f,
"decimal with {actual_digits} digits exceeds DECIMAL precision {precision}"
)
}
Self::DecimalScaleOverflow {
scale,
actual_fraction_digits,
} => {
write!(
f,
"decimal with {actual_fraction_digits} fractional digits exceeds DECIMAL scale {scale}"
)
}
Self::NotADecimal(input) => write!(f, "`{input}` is not a valid decimal literal"),
Self::BadModifier(reason) => write!(f, "bad parametric modifier: {reason}"),
}
}
}
impl std::error::Error for ParametricError {}
pub fn validate_varchar(
value: &Value,
max_len: u32,
mode: VarcharMode,
) -> Result<Value, ParametricError> {
let s = match value {
Value::Text(s) => s,
other => {
return validate_varchar(&Value::text(other.display_string()), max_len, mode);
}
};
let len = s.chars().count();
if (len as u32) <= max_len {
return Ok(value.clone());
}
match mode {
VarcharMode::Reject => Err(ParametricError::VarcharOverflow {
actual: len,
max: max_len,
}),
VarcharMode::Truncate => {
let truncated: String = s.chars().take(max_len as usize).collect();
Ok(Value::text(truncated))
}
}
}
pub fn validate_decimal(value: &Value, precision: u8, scale: u8) -> Result<Value, ParametricError> {
let s = value.display_string();
let trimmed = s.trim();
let body = trimmed.strip_prefix('-').unwrap_or(trimmed);
let (whole, frac) = match body.split_once('.') {
Some((w, f)) => (w, f),
None => (body, ""),
};
if whole.is_empty() && frac.is_empty() {
return Err(ParametricError::NotADecimal(s));
}
if !whole.bytes().all(|b| b.is_ascii_digit()) || !frac.bytes().all(|b| b.is_ascii_digit()) {
return Err(ParametricError::NotADecimal(s));
}
let total_digits = whole.len() + frac.len();
let frac_digits = frac.len();
if total_digits > precision as usize {
return Err(ParametricError::DecimalPrecisionOverflow {
precision,
actual_digits: total_digits,
});
}
if frac_digits > scale as usize {
return Err(ParametricError::DecimalScaleOverflow {
scale,
actual_fraction_digits: frac_digits,
});
}
Ok(value.clone())
}
pub fn parse_varchar_modifier(sql_type: &SqlTypeName) -> Result<u32, ParametricError> {
if sql_type.modifiers.is_empty() {
return Ok(u32::MAX);
}
if sql_type.modifiers.len() > 1 {
return Err(ParametricError::BadModifier(format!(
"VARCHAR expects 1 modifier, got {}",
sql_type.modifiers.len()
)));
}
match &sql_type.modifiers[0] {
TypeModifier::Number(n) => Ok(*n),
other => Err(ParametricError::BadModifier(format!(
"VARCHAR length must be a number, got {other:?}"
))),
}
}
pub fn parse_decimal_modifier(sql_type: &SqlTypeName) -> Result<(u8, u8), ParametricError> {
let mods = &sql_type.modifiers;
if mods.is_empty() {
return Ok((38, 0));
}
if mods.len() > 2 {
return Err(ParametricError::BadModifier(format!(
"DECIMAL expects (p) or (p,s), got {} modifiers",
mods.len()
)));
}
let precision = match &mods[0] {
TypeModifier::Number(n) => u8::try_from(*n).map_err(|_| {
ParametricError::BadModifier(format!("DECIMAL precision {n} out of u8 range"))
})?,
other => {
return Err(ParametricError::BadModifier(format!(
"DECIMAL precision must be a number, got {other:?}"
)))
}
};
let scale = if let Some(s_mod) = mods.get(1) {
match s_mod {
TypeModifier::Number(n) => u8::try_from(*n).map_err(|_| {
ParametricError::BadModifier(format!("DECIMAL scale {n} out of u8 range"))
})?,
other => {
return Err(ParametricError::BadModifier(format!(
"DECIMAL scale must be a number, got {other:?}"
)))
}
}
} else {
0
};
if scale > precision {
return Err(ParametricError::BadModifier(format!(
"DECIMAL scale {scale} cannot exceed precision {precision}"
)));
}
Ok((precision, scale))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_varchar_rejects_or_truncates_by_mode() {
let value = Value::text("abcdef");
let err = validate_varchar(&value, 3, VarcharMode::Reject).unwrap_err();
assert_eq!(err.to_string(), "string of length 6 exceeds VARCHAR(3)");
let truncated = validate_varchar(&value, 3, VarcharMode::Truncate).unwrap();
assert_eq!(truncated, Value::text("abc"));
let non_text = validate_varchar(&Value::Integer(123), 3, VarcharMode::Reject).unwrap();
assert_eq!(non_text, Value::text("123"));
}
#[test]
fn validate_decimal_checks_precision_scale_and_syntax() {
assert_eq!(
validate_decimal(&Value::Decimal(12_3400), 6, 4).unwrap(),
Value::Decimal(12_3400)
);
assert!(matches!(
validate_decimal(&Value::Decimal(12_3400), 5, 4),
Err(ParametricError::DecimalPrecisionOverflow { .. })
));
assert!(matches!(
validate_decimal(&Value::Decimal(12_3400), 6, 2),
Err(ParametricError::DecimalScaleOverflow { .. })
));
assert!(matches!(
validate_decimal(&Value::text("not decimal"), 10, 2),
Err(ParametricError::NotADecimal(_))
));
}
#[test]
fn varchar_modifier_parses_defaults_and_errors() {
assert_eq!(
parse_varchar_modifier(&SqlTypeName::new("varchar")).unwrap(),
u32::MAX
);
assert_eq!(
parse_varchar_modifier(
&SqlTypeName::new("varchar").with_modifiers(vec![TypeModifier::Number(42)])
)
.unwrap(),
42
);
assert!(matches!(
parse_varchar_modifier(
&SqlTypeName::new("varchar")
.with_modifiers(vec![TypeModifier::Number(1), TypeModifier::Number(2)])
),
Err(ParametricError::BadModifier(_))
));
assert!(matches!(
parse_varchar_modifier(
&SqlTypeName::new("varchar")
.with_modifiers(vec![TypeModifier::Ident("x".to_string())])
),
Err(ParametricError::BadModifier(_))
));
}
#[test]
fn decimal_modifier_parses_defaults_and_errors() {
assert_eq!(
parse_decimal_modifier(&SqlTypeName::new("decimal")).unwrap(),
(38, 0)
);
assert_eq!(
parse_decimal_modifier(
&SqlTypeName::new("decimal").with_modifiers(vec![TypeModifier::Number(10)])
)
.unwrap(),
(10, 0)
);
assert_eq!(
parse_decimal_modifier(
&SqlTypeName::new("decimal")
.with_modifiers(vec![TypeModifier::Number(10), TypeModifier::Number(2)])
)
.unwrap(),
(10, 2)
);
for modifiers in [
vec![
TypeModifier::Number(1),
TypeModifier::Number(2),
TypeModifier::Number(3),
],
vec![TypeModifier::Ident("p".to_string())],
vec![TypeModifier::Number(300)],
vec![
TypeModifier::Number(10),
TypeModifier::Ident("s".to_string()),
],
vec![TypeModifier::Number(10), TypeModifier::Number(300)],
vec![TypeModifier::Number(2), TypeModifier::Number(3)],
] {
let ty = SqlTypeName::new("decimal").with_modifiers(modifiers);
assert!(matches!(
parse_decimal_modifier(&ty),
Err(ParametricError::BadModifier(_))
));
}
}
}