#![allow(clippy::expect_used)]
use crate::error::TypeError;
use crate::value::SqlValue;
pub trait ToSql {
fn to_sql(&self) -> Result<SqlValue, TypeError>;
fn sql_type(&self) -> &'static str;
fn decimal_param_info(&self) -> Option<DecimalParamInfo> {
None
}
}
#[derive(Debug, Clone, Copy)]
pub struct DecimalParamInfo {
pub precision: u8,
pub scale: u8,
}
impl ToSql for bool {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Bool(*self))
}
fn sql_type(&self) -> &'static str {
"BIT"
}
}
impl ToSql for u8 {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::TinyInt(*self))
}
fn sql_type(&self) -> &'static str {
"TINYINT"
}
}
impl ToSql for i16 {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::SmallInt(*self))
}
fn sql_type(&self) -> &'static str {
"SMALLINT"
}
}
impl ToSql for i32 {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Int(*self))
}
fn sql_type(&self) -> &'static str {
"INT"
}
}
impl ToSql for i64 {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::BigInt(*self))
}
fn sql_type(&self) -> &'static str {
"BIGINT"
}
}
impl ToSql for f32 {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Float(*self))
}
fn sql_type(&self) -> &'static str {
"REAL"
}
}
impl ToSql for f64 {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Double(*self))
}
fn sql_type(&self) -> &'static str {
"FLOAT"
}
}
impl ToSql for str {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::String(self.to_owned()))
}
fn sql_type(&self) -> &'static str {
"NVARCHAR"
}
}
impl ToSql for String {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::String(self.clone()))
}
fn sql_type(&self) -> &'static str {
"NVARCHAR"
}
}
impl ToSql for [u8] {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Binary(bytes::Bytes::copy_from_slice(self)))
}
fn sql_type(&self) -> &'static str {
"VARBINARY"
}
}
impl ToSql for Vec<u8> {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Binary(bytes::Bytes::copy_from_slice(self)))
}
fn sql_type(&self) -> &'static str {
"VARBINARY"
}
}
pub trait SqlTyped {
const SQL_TYPE: &'static str;
}
impl SqlTyped for bool {
const SQL_TYPE: &'static str = "BIT";
}
impl SqlTyped for u8 {
const SQL_TYPE: &'static str = "TINYINT";
}
impl SqlTyped for i16 {
const SQL_TYPE: &'static str = "SMALLINT";
}
impl SqlTyped for i32 {
const SQL_TYPE: &'static str = "INT";
}
impl SqlTyped for i64 {
const SQL_TYPE: &'static str = "BIGINT";
}
impl SqlTyped for f32 {
const SQL_TYPE: &'static str = "REAL";
}
impl SqlTyped for f64 {
const SQL_TYPE: &'static str = "FLOAT";
}
impl SqlTyped for String {
const SQL_TYPE: &'static str = "NVARCHAR";
}
impl SqlTyped for Vec<u8> {
const SQL_TYPE: &'static str = "VARBINARY";
}
#[cfg(feature = "uuid")]
impl SqlTyped for uuid::Uuid {
const SQL_TYPE: &'static str = "UNIQUEIDENTIFIER";
}
#[cfg(feature = "chrono")]
impl SqlTyped for chrono::NaiveDate {
const SQL_TYPE: &'static str = "DATE";
}
#[derive(Debug, Clone, Copy)]
pub struct TypedNull {
sql_type: &'static str,
}
impl ToSql for TypedNull {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Null)
}
fn sql_type(&self) -> &'static str {
self.sql_type
}
}
#[must_use]
pub fn null<T: SqlTyped>() -> TypedNull {
TypedNull {
sql_type: T::SQL_TYPE,
}
}
impl<T: ToSql> ToSql for Option<T> {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
match self {
Some(v) => v.to_sql(),
None => Ok(SqlValue::Null),
}
}
fn sql_type(&self) -> &'static str {
match self {
Some(v) => v.sql_type(),
None => "NULL",
}
}
fn decimal_param_info(&self) -> Option<DecimalParamInfo> {
self.as_ref().and_then(ToSql::decimal_param_info)
}
}
impl<T: ToSql + ?Sized> ToSql for &T {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
(*self).to_sql()
}
fn sql_type(&self) -> &'static str {
(*self).sql_type()
}
fn decimal_param_info(&self) -> Option<DecimalParamInfo> {
(*self).decimal_param_info()
}
}
#[cfg(feature = "uuid")]
impl ToSql for uuid::Uuid {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Uuid(*self))
}
fn sql_type(&self) -> &'static str {
"UNIQUEIDENTIFIER"
}
}
#[cfg(feature = "decimal")]
impl ToSql for rust_decimal::Decimal {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Decimal(*self))
}
fn sql_type(&self) -> &'static str {
"DECIMAL"
}
}
#[cfg(feature = "decimal")]
#[derive(Debug, Clone, Copy)]
pub struct Numeric {
value: rust_decimal::Decimal,
precision: u8,
scale: u8,
}
#[cfg(feature = "decimal")]
#[must_use]
pub fn numeric(value: rust_decimal::Decimal, precision: u8, scale: u8) -> Numeric {
Numeric {
value,
precision,
scale,
}
}
#[cfg(feature = "decimal")]
impl ToSql for Numeric {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
let mut value = self.value;
value.rescale(u32::from(self.scale));
let mantissa = value.mantissa().unsigned_abs();
let digits = if mantissa == 0 {
0
} else {
mantissa.ilog10() + 1
};
if digits > u32::from(self.precision) {
return Err(TypeError::InvalidDecimal(format!(
"value has {digits} significant digits, which exceeds the declared precision {}",
self.precision
)));
}
Ok(SqlValue::Decimal(value))
}
fn sql_type(&self) -> &'static str {
"DECIMAL"
}
fn decimal_param_info(&self) -> Option<DecimalParamInfo> {
Some(DecimalParamInfo {
precision: self.precision,
scale: self.scale,
})
}
}
#[cfg(feature = "decimal")]
impl ToSql for crate::value::Money {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Money(self.0))
}
fn sql_type(&self) -> &'static str {
"MONEY"
}
}
#[cfg(feature = "decimal")]
impl ToSql for crate::value::SmallMoney {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::SmallMoney(self.0))
}
fn sql_type(&self) -> &'static str {
"SMALLMONEY"
}
}
#[cfg(feature = "chrono")]
impl ToSql for chrono::NaiveDate {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Date(*self))
}
fn sql_type(&self) -> &'static str {
"DATE"
}
}
#[cfg(feature = "chrono")]
impl ToSql for chrono::NaiveTime {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Time(*self))
}
fn sql_type(&self) -> &'static str {
"TIME"
}
}
#[cfg(feature = "chrono")]
impl ToSql for chrono::NaiveDateTime {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::DateTime(*self))
}
fn sql_type(&self) -> &'static str {
"DATETIME2"
}
}
#[cfg(feature = "chrono")]
impl ToSql for crate::value::SmallDateTime {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::SmallDateTime(self.0))
}
fn sql_type(&self) -> &'static str {
"SMALLDATETIME"
}
}
#[cfg(feature = "chrono")]
impl ToSql for chrono::DateTime<chrono::FixedOffset> {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::DateTimeOffset(*self))
}
fn sql_type(&self) -> &'static str {
"DATETIMEOFFSET"
}
}
#[cfg(feature = "chrono")]
impl ToSql for chrono::DateTime<chrono::Utc> {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
let fixed = self.with_timezone(&chrono::FixedOffset::east_opt(0).expect("valid offset"));
Ok(SqlValue::DateTimeOffset(fixed))
}
fn sql_type(&self) -> &'static str {
"DATETIMEOFFSET"
}
}
#[cfg(feature = "json")]
impl ToSql for serde_json::Value {
fn to_sql(&self) -> Result<SqlValue, TypeError> {
Ok(SqlValue::Json(self.clone()))
}
fn sql_type(&self) -> &'static str {
"NVARCHAR(MAX)"
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_to_sql_i32() {
let value: i32 = 42;
assert_eq!(value.to_sql().unwrap(), SqlValue::Int(42));
assert_eq!(value.sql_type(), "INT");
}
#[test]
fn test_typed_null_carries_type() {
assert_eq!(null::<i32>().to_sql().unwrap(), SqlValue::Null);
assert_eq!(null::<i32>().sql_type(), 42i32.sql_type());
assert_eq!(null::<i64>().sql_type(), "BIGINT");
assert_eq!(null::<Vec<u8>>().sql_type(), "VARBINARY");
assert_eq!(null::<String>().sql_type(), "NVARCHAR");
}
#[test]
fn test_to_sql_string() {
let value = "hello".to_string();
assert_eq!(
value.to_sql().unwrap(),
SqlValue::String("hello".to_string())
);
assert_eq!(value.sql_type(), "NVARCHAR");
}
#[test]
fn test_to_sql_option() {
let some: Option<i32> = Some(42);
assert_eq!(some.to_sql().unwrap(), SqlValue::Int(42));
let none: Option<i32> = None;
assert_eq!(none.to_sql().unwrap(), SqlValue::Null);
}
#[cfg(feature = "decimal")]
#[test]
fn test_numeric_precision_validation() {
use rust_decimal::Decimal;
assert!(numeric(Decimal::new(1_234_567, 2), 18, 4).to_sql().is_ok());
assert!(
numeric(Decimal::new(123_456, 0), 4, 0).to_sql().is_err(),
"value exceeding the declared precision must error"
);
let rounded = numeric(Decimal::new(12_999, 3), 18, 2).to_sql().unwrap();
assert_eq!(rounded, SqlValue::Decimal(Decimal::new(1_300, 2)));
assert!(numeric(Decimal::ZERO, 1, 0).to_sql().is_ok());
}
}