use super::types::{SqlTypeName, TypeModifier, Value};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VarcharMode {
Reject,
Truncate,
}
#[derive(Debug, Clone)]
pub enum ParametricError {
VarcharOverflow { actual: usize, max: u32 },
DecimalPrecisionOverflow { precision: u8, actual_digits: usize },
DecimalScaleOverflow {
scale: u8,
actual_fraction_digits: usize,
},
NotADecimal(String),
BadModifier(String),
}
impl std::fmt::Display for ParametricError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::VarcharOverflow { actual, max } => {
write!(f, "string of length {actual} exceeds VARCHAR({max})")
}
Self::DecimalPrecisionOverflow {
precision,
actual_digits,
} => {
write!(
f,
"decimal with {actual_digits} digits exceeds DECIMAL precision {precision}"
)
}
Self::DecimalScaleOverflow {
scale,
actual_fraction_digits,
} => {
write!(
f,
"decimal with {actual_fraction_digits} fractional digits exceeds DECIMAL scale {scale}"
)
}
Self::NotADecimal(input) => write!(f, "`{input}` is not a valid decimal literal"),
Self::BadModifier(reason) => write!(f, "bad parametric modifier: {reason}"),
}
}
}
impl std::error::Error for ParametricError {}
pub fn validate_varchar(
value: &Value,
max_len: u32,
mode: VarcharMode,
) -> Result<Value, ParametricError> {
let s = match value {
Value::Text(s) => s,
other => {
return validate_varchar(&Value::text(other.display_string()), max_len, mode);
}
};
let len = s.chars().count();
if (len as u32) <= max_len {
return Ok(value.clone());
}
match mode {
VarcharMode::Reject => Err(ParametricError::VarcharOverflow {
actual: len,
max: max_len,
}),
VarcharMode::Truncate => {
let truncated: String = s.chars().take(max_len as usize).collect();
Ok(Value::text(truncated))
}
}
}
pub fn validate_decimal(value: &Value, precision: u8, scale: u8) -> Result<Value, ParametricError> {
let s = value.display_string();
let trimmed = s.trim();
let body = trimmed.strip_prefix('-').unwrap_or(trimmed);
let (whole, frac) = match body.split_once('.') {
Some((w, f)) => (w, f),
None => (body, ""),
};
if whole.is_empty() && frac.is_empty() {
return Err(ParametricError::NotADecimal(s));
}
if !whole.bytes().all(|b| b.is_ascii_digit()) || !frac.bytes().all(|b| b.is_ascii_digit()) {
return Err(ParametricError::NotADecimal(s));
}
let total_digits = whole.len() + frac.len();
let frac_digits = frac.len();
if total_digits > precision as usize {
return Err(ParametricError::DecimalPrecisionOverflow {
precision,
actual_digits: total_digits,
});
}
if frac_digits > scale as usize {
return Err(ParametricError::DecimalScaleOverflow {
scale,
actual_fraction_digits: frac_digits,
});
}
Ok(value.clone())
}
pub fn parse_varchar_modifier(sql_type: &SqlTypeName) -> Result<u32, ParametricError> {
if sql_type.modifiers.is_empty() {
return Ok(u32::MAX);
}
if sql_type.modifiers.len() > 1 {
return Err(ParametricError::BadModifier(format!(
"VARCHAR expects 1 modifier, got {}",
sql_type.modifiers.len()
)));
}
match &sql_type.modifiers[0] {
TypeModifier::Number(n) => Ok(*n),
other => Err(ParametricError::BadModifier(format!(
"VARCHAR length must be a number, got {other:?}"
))),
}
}
pub fn parse_decimal_modifier(sql_type: &SqlTypeName) -> Result<(u8, u8), ParametricError> {
let mods = &sql_type.modifiers;
if mods.is_empty() {
return Ok((38, 0));
}
if mods.len() > 2 {
return Err(ParametricError::BadModifier(format!(
"DECIMAL expects (p) or (p,s), got {} modifiers",
mods.len()
)));
}
let precision = match &mods[0] {
TypeModifier::Number(n) => u8::try_from(*n).map_err(|_| {
ParametricError::BadModifier(format!("DECIMAL precision {n} out of u8 range"))
})?,
other => {
return Err(ParametricError::BadModifier(format!(
"DECIMAL precision must be a number, got {other:?}"
)))
}
};
let scale = if let Some(s_mod) = mods.get(1) {
match s_mod {
TypeModifier::Number(n) => u8::try_from(*n).map_err(|_| {
ParametricError::BadModifier(format!("DECIMAL scale {n} out of u8 range"))
})?,
other => {
return Err(ParametricError::BadModifier(format!(
"DECIMAL scale must be a number, got {other:?}"
)))
}
}
} else {
0
};
if scale > precision {
return Err(ParametricError::BadModifier(format!(
"DECIMAL scale {scale} cannot exceed precision {precision}"
)));
}
Ok((precision, scale))
}