use std::fmt::{Arguments, Display};
use arrayvec::ArrayString;
use serde::{Deserialize, Serialize};
use sqlx_core::type_info::TypeInfo;
#[derive(Debug, Clone, Copy, Deserialize)]
#[serde(from = "ExaDataType")]
pub struct ExaTypeInfo {
pub(crate) name: DataTypeName,
pub(crate) data_type: ExaDataType,
}
impl ExaTypeInfo {
#[doc(hidden)]
#[allow(clippy::must_use_candidate)]
pub fn __type_feature_gate(&self) -> Option<&'static str> {
match self.data_type {
ExaDataType::Date
| ExaDataType::Timestamp
| ExaDataType::TimestampWithLocalTimeZone => Some("time"),
ExaDataType::Decimal(decimal)
if decimal.scale > 0 || decimal.precision > Some(Decimal::MAX_64BIT_PRECISION) =>
{
Some("bigdecimal")
}
_ => None,
}
}
}
impl Serialize for ExaTypeInfo {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.data_type.serialize(serializer)
}
}
impl From<ExaDataType> for ExaTypeInfo {
fn from(data_type: ExaDataType) -> Self {
let name = data_type.full_name();
Self { name, data_type }
}
}
impl PartialEq for ExaTypeInfo {
fn eq(&self, other: &Self) -> bool {
self.data_type == other.data_type
}
}
impl Display for ExaTypeInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name)
}
}
impl TypeInfo for ExaTypeInfo {
fn is_null(&self) -> bool {
false
}
fn name(&self) -> &str {
self.name.as_ref()
}
fn type_compatible(&self, other: &Self) -> bool
where
Self: Sized,
{
self.data_type.compatible(&other.data_type)
}
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "UPPERCASE")]
#[serde(tag = "type")]
pub enum ExaDataType {
Boolean,
#[serde(rename_all = "camelCase")]
Char { size: u32, character_set: Charset },
Date,
Decimal(Decimal),
Double,
#[serde(rename_all = "camelCase")]
Geometry { srid: u16 },
#[serde(rename = "INTERVAL DAY TO SECOND")]
#[serde(rename_all = "camelCase")]
IntervalDayToSecond { precision: u32, fraction: u32 },
#[serde(rename = "INTERVAL YEAR TO MONTH")]
#[serde(rename_all = "camelCase")]
IntervalYearToMonth { precision: u32 },
Timestamp,
#[serde(rename = "TIMESTAMP WITH LOCAL TIME ZONE")]
TimestampWithLocalTimeZone,
#[serde(rename_all = "camelCase")]
Varchar { size: u32, character_set: Charset },
HashType { size: Option<u16> },
}
impl ExaDataType {
const BOOLEAN: &'static str = "BOOLEAN";
const CHAR: &'static str = "CHAR";
const DATE: &'static str = "DATE";
const DECIMAL: &'static str = "DECIMAL";
const DOUBLE: &'static str = "DOUBLE PRECISION";
const GEOMETRY: &'static str = "GEOMETRY";
const INTERVAL_DAY_TO_SECOND: &'static str = "INTERVAL DAY TO SECOND";
const INTERVAL_YEAR_TO_MONTH: &'static str = "INTERVAL YEAR TO MONTH";
const TIMESTAMP: &'static str = "TIMESTAMP";
const TIMESTAMP_WITH_LOCAL_TIME_ZONE: &'static str = "TIMESTAMP WITH LOCAL TIME ZONE";
const VARCHAR: &'static str = "VARCHAR";
const HASHTYPE: &'static str = "HASHTYPE";
#[allow(dead_code, reason = "used by optional dependency")]
pub(crate) const INTERVAL_DTS_MAX_FRACTION: u32 = 3;
#[allow(dead_code, reason = "used by optional dependency")]
pub(crate) const INTERVAL_DTS_MAX_PRECISION: u32 = 9;
pub(crate) const INTERVAL_YTM_MAX_PRECISION: u32 = 9;
pub(crate) const VARCHAR_MAX_LEN: u32 = 2_000_000;
#[cfg_attr(not(test), expect(dead_code))]
pub(crate) const CHAR_MAX_LEN: u32 = 2_000;
#[cfg_attr(not(test), expect(dead_code))]
pub(crate) const HASHTYPE_MAX_LEN: u16 = 2048;
pub fn compatible(&self, other: &Self) -> bool {
match (self, other) {
(Self::HashType { size: Some(s1) }, Self::HashType { size: Some(s2) }) => s1 == s2,
(Self::Boolean, Self::Boolean)
| (
Self::Char { .. } | Self::Varchar { .. },
Self::Char { .. } | Self::Varchar { .. },
)
| (Self::Date, Self::Date)
| (Self::Double, Self::Double)
| (Self::Geometry { .. }, Self::Geometry { .. })
| (Self::IntervalDayToSecond { .. }, Self::IntervalDayToSecond { .. })
| (Self::IntervalYearToMonth { .. }, Self::IntervalYearToMonth { .. })
| (Self::Timestamp, Self::Timestamp)
| (Self::TimestampWithLocalTimeZone, Self::TimestampWithLocalTimeZone)
| (Self::HashType { .. }, Self::HashType { .. }) => true,
(Self::Decimal(d1), Self::Decimal(d2)) => d1.compatible(*d2),
_ => false,
}
}
fn full_name(&self) -> DataTypeName {
match self {
Self::Boolean => Self::BOOLEAN.into(),
Self::Date => Self::DATE.into(),
Self::Double => Self::DOUBLE.into(),
Self::Timestamp => Self::TIMESTAMP.into(),
Self::TimestampWithLocalTimeZone => Self::TIMESTAMP_WITH_LOCAL_TIME_ZONE.into(),
Self::Char {
size,
character_set,
}
| Self::Varchar {
size,
character_set,
} => format_args!("{}({}) {}", self.as_ref(), size, character_set).into(),
Self::Decimal(d) => match d.precision {
Some(p) => format_args!("{}({}, {})", self.as_ref(), p, d.scale).into(),
None => format_args!("{}(*, {})", self.as_ref(), d.scale).into(),
},
Self::Geometry { srid } => format_args!("{}({srid})", self.as_ref()).into(),
Self::IntervalDayToSecond {
precision,
fraction,
} => format_args!("INTERVAL DAY({precision}) TO SECOND({fraction})").into(),
Self::IntervalYearToMonth { precision } => {
format_args!("INTERVAL YEAR({precision}) TO MONTH").into()
}
Self::HashType { size } => match size {
Some(s) => format_args!("{}({} BYTE)", self.as_ref(), s / 2).into(),
None => format_args!("{}", self.as_ref()).into(),
},
}
}
}
impl AsRef<str> for ExaDataType {
fn as_ref(&self) -> &str {
match self {
Self::Boolean => Self::BOOLEAN,
Self::Char { .. } => Self::CHAR,
Self::Date => Self::DATE,
Self::Decimal(_) => Self::DECIMAL,
Self::Double => Self::DOUBLE,
Self::Geometry { .. } => Self::GEOMETRY,
Self::IntervalDayToSecond { .. } => Self::INTERVAL_DAY_TO_SECOND,
Self::IntervalYearToMonth { .. } => Self::INTERVAL_YEAR_TO_MONTH,
Self::Timestamp => Self::TIMESTAMP,
Self::TimestampWithLocalTimeZone => Self::TIMESTAMP_WITH_LOCAL_TIME_ZONE,
Self::Varchar { .. } => Self::VARCHAR,
Self::HashType { .. } => Self::HASHTYPE,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum DataTypeName {
Static(&'static str),
Inline(ArrayString<30>),
}
impl AsRef<str> for DataTypeName {
fn as_ref(&self) -> &str {
match self {
DataTypeName::Static(s) => s,
DataTypeName::Inline(s) => s.as_str(),
}
}
}
impl Display for DataTypeName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_ref())
}
}
impl From<&'static str> for DataTypeName {
fn from(value: &'static str) -> Self {
Self::Static(value)
}
}
impl From<Arguments<'_>> for DataTypeName {
fn from(value: Arguments<'_>) -> Self {
Self::Inline(ArrayString::try_from(value).expect("inline data type name too large"))
}
}
#[derive(Debug, Copy, Clone, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct Decimal {
pub(crate) precision: Option<u8>,
pub(crate) scale: u8,
}
impl Decimal {
pub(crate) const MAX_8BIT_PRECISION: u8 = 3;
pub(crate) const MAX_16BIT_PRECISION: u8 = 5;
pub(crate) const MAX_32BIT_PRECISION: u8 = 10;
pub(crate) const MAX_64BIT_PRECISION: u8 = 20;
pub(crate) const MAX_PRECISION: u8 = 36;
#[allow(dead_code)]
pub(crate) const MAX_SCALE: u8 = 36;
#[rustfmt::skip] fn compatible(self, dec: Decimal) -> bool {
let (precision, scale) = match dec.precision {
Some(precision) => (precision, dec.scale),
None => return true,
};
let self_diff = self.precision.map_or(Decimal::MAX_PRECISION, |p| p - self.scale);
let other_diff = precision - scale;
self.scale >= scale && self_diff >= other_diff
}
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq)]
#[serde(rename_all = "UPPERCASE")]
pub enum Charset {
Utf8,
Ascii,
}
impl AsRef<str> for Charset {
fn as_ref(&self) -> &str {
match self {
Charset::Utf8 => "UTF8",
Charset::Ascii => "ASCII",
}
}
}
impl Display for Charset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_ref())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_boolean_name() {
let data_type = ExaDataType::Boolean;
assert_eq!(data_type.full_name().as_ref(), "BOOLEAN");
}
#[test]
fn test_max_char_name() {
let data_type = ExaDataType::Char {
size: ExaDataType::CHAR_MAX_LEN,
character_set: Charset::Ascii,
};
assert_eq!(
data_type.full_name().as_ref(),
format!("CHAR({}) ASCII", ExaDataType::CHAR_MAX_LEN)
);
}
#[test]
fn test_date_name() {
let data_type = ExaDataType::Date;
assert_eq!(data_type.full_name().as_ref(), "DATE");
}
#[test]
fn test_max_decimal_name() {
let decimal = Decimal {
precision: Some(Decimal::MAX_PRECISION),
scale: Decimal::MAX_SCALE,
};
let data_type = ExaDataType::Decimal(decimal);
assert_eq!(
data_type.full_name().as_ref(),
format!(
"DECIMAL({}, {})",
Decimal::MAX_PRECISION,
Decimal::MAX_SCALE
)
);
}
#[test]
fn test_double_name() {
let data_type = ExaDataType::Double;
assert_eq!(data_type.full_name().as_ref(), "DOUBLE PRECISION");
}
#[test]
fn test_max_geometry_name() {
let data_type = ExaDataType::Geometry { srid: u16::MAX };
assert_eq!(
data_type.full_name().as_ref(),
format!("GEOMETRY({})", u16::MAX)
);
}
#[test]
fn test_max_interval_day_name() {
let data_type = ExaDataType::IntervalDayToSecond {
precision: ExaDataType::INTERVAL_DTS_MAX_PRECISION,
fraction: ExaDataType::INTERVAL_DTS_MAX_FRACTION,
};
assert_eq!(
data_type.full_name().as_ref(),
format!(
"INTERVAL DAY({}) TO SECOND({})",
ExaDataType::INTERVAL_DTS_MAX_PRECISION,
ExaDataType::INTERVAL_DTS_MAX_FRACTION
)
);
}
#[test]
fn test_max_interval_year_name() {
let data_type = ExaDataType::IntervalYearToMonth {
precision: ExaDataType::INTERVAL_YTM_MAX_PRECISION,
};
assert_eq!(
data_type.full_name().as_ref(),
format!(
"INTERVAL YEAR({}) TO MONTH",
ExaDataType::INTERVAL_YTM_MAX_PRECISION,
)
);
}
#[test]
fn test_timestamp_name() {
let data_type = ExaDataType::Timestamp;
assert_eq!(data_type.full_name().as_ref(), "TIMESTAMP");
}
#[test]
fn test_timestamp_with_tz_name() {
let data_type = ExaDataType::TimestampWithLocalTimeZone;
assert_eq!(
data_type.full_name().as_ref(),
"TIMESTAMP WITH LOCAL TIME ZONE"
);
}
#[test]
fn test_max_varchar_name() {
let data_type = ExaDataType::Varchar {
size: ExaDataType::VARCHAR_MAX_LEN,
character_set: Charset::Ascii,
};
assert_eq!(
data_type.full_name().as_ref(),
format!("VARCHAR({}) ASCII", ExaDataType::VARCHAR_MAX_LEN)
);
}
#[test]
fn test_max_hashbyte_name() {
let data_type = ExaDataType::HashType {
size: Some(ExaDataType::HASHTYPE_MAX_LEN),
};
assert_eq!(
data_type.full_name().as_ref(),
format!("HASHTYPE({} BYTE)", ExaDataType::HASHTYPE_MAX_LEN / 2)
);
}
}