use proc_macro2::TokenStream;
use quote::quote;
use syn::{GenericArgument, PathArguments, Type};
pub fn infer_sql_type(ty: &Type) -> TokenStream {
let inner_ty = unwrap_option_type(ty);
let type_str = type_to_string(inner_ty);
match type_str.as_str() {
"bool" => quote! { sqlmodel_core::SqlType::Boolean },
"i8" => quote! { sqlmodel_core::SqlType::TinyInt },
"i16" => quote! { sqlmodel_core::SqlType::SmallInt },
"i32" => quote! { sqlmodel_core::SqlType::Integer },
"i64" => quote! { sqlmodel_core::SqlType::BigInt },
"u8" => quote! { sqlmodel_core::SqlType::SmallInt },
"u16" => quote! { sqlmodel_core::SqlType::Integer },
"u32" => quote! { sqlmodel_core::SqlType::BigInt },
"u64" => quote! { sqlmodel_core::SqlType::BigInt },
"f32" => quote! { sqlmodel_core::SqlType::Real },
"f64" => quote! { sqlmodel_core::SqlType::Double },
"String" | "&str" | "str" => quote! { sqlmodel_core::SqlType::Text },
"char" => quote! { sqlmodel_core::SqlType::Char(1) },
"Vec<u8>" | "&[u8]" | "[u8]" => quote! { sqlmodel_core::SqlType::Blob },
"Uuid" | "uuid::Uuid" => quote! { sqlmodel_core::SqlType::Uuid },
"NaiveDate" | "chrono::NaiveDate" => quote! { sqlmodel_core::SqlType::Date },
"NaiveTime" | "chrono::NaiveTime" => quote! { sqlmodel_core::SqlType::Time },
"NaiveDateTime" | "chrono::NaiveDateTime" => quote! { sqlmodel_core::SqlType::DateTime },
"DateTime<Utc>" | "chrono::DateTime<Utc>" | "DateTime<chrono::Utc>" => {
quote! { sqlmodel_core::SqlType::TimestampTz }
}
"DateTime<Local>" | "chrono::DateTime<Local>" | "DateTime<chrono::Local>" => {
quote! { sqlmodel_core::SqlType::TimestampTz }
}
"time::Date" => quote! { sqlmodel_core::SqlType::Date },
"time::Time" => quote! { sqlmodel_core::SqlType::Time },
"time::PrimitiveDateTime" | "PrimitiveDateTime" => {
quote! { sqlmodel_core::SqlType::DateTime }
}
"time::OffsetDateTime" | "OffsetDateTime" => {
quote! { sqlmodel_core::SqlType::TimestampTz }
}
"serde_json::Value" | "Value" => quote! { sqlmodel_core::SqlType::Json },
"rust_decimal::Decimal" | "Decimal" => {
quote! { sqlmodel_core::SqlType::Numeric { precision: 38, scale: 18 } }
}
"bytes::Bytes" | "Bytes" | "bytes::BytesMut" | "BytesMut" => {
quote! { sqlmodel_core::SqlType::Blob }
}
_ => quote! { sqlmodel_core::SqlType::Text },
}
}
pub fn parse_sql_type_attr(sql_type: &str) -> TokenStream {
let sql_type_upper = sql_type.to_uppercase();
let trimmed = sql_type_upper.trim();
if let Some(rest) = trimmed.strip_prefix("VARCHAR(") {
if let Some(len_str) = rest.strip_suffix(')') {
if let Ok(len) = len_str.trim().parse::<u32>() {
return quote! { sqlmodel_core::SqlType::VarChar(#len) };
}
}
}
if let Some(rest) = trimmed.strip_prefix("CHAR(") {
if let Some(len_str) = rest.strip_suffix(')') {
if let Ok(len) = len_str.trim().parse::<u32>() {
return quote! { sqlmodel_core::SqlType::Char(#len) };
}
}
}
if let Some(rest) = trimmed.strip_prefix("NUMERIC(") {
if let Some(params_str) = rest.strip_suffix(')') {
if let Some((p_str, s_str)) = params_str.split_once(',') {
if let (Ok(p), Ok(s)) = (p_str.trim().parse::<u8>(), s_str.trim().parse::<u8>()) {
return quote! { sqlmodel_core::SqlType::Numeric { precision: #p, scale: #s } };
}
}
}
}
if let Some(rest) = trimmed.strip_prefix("DECIMAL(") {
if let Some(params_str) = rest.strip_suffix(')') {
if let Some((p_str, s_str)) = params_str.split_once(',') {
if let (Ok(p), Ok(s)) = (p_str.trim().parse::<u8>(), s_str.trim().parse::<u8>()) {
return quote! { sqlmodel_core::SqlType::Decimal { precision: #p, scale: #s } };
}
}
}
}
if let Some(rest) = trimmed.strip_prefix("BINARY(") {
if let Some(len_str) = rest.strip_suffix(')') {
if let Ok(len) = len_str.trim().parse::<u32>() {
return quote! { sqlmodel_core::SqlType::Binary(#len) };
}
}
}
if let Some(rest) = trimmed.strip_prefix("VARBINARY(") {
if let Some(len_str) = rest.strip_suffix(')') {
if let Ok(len) = len_str.trim().parse::<u32>() {
return quote! { sqlmodel_core::SqlType::VarBinary(#len) };
}
}
}
match trimmed {
"TINYINT" => quote! { sqlmodel_core::SqlType::TinyInt },
"SMALLINT" | "INT2" => quote! { sqlmodel_core::SqlType::SmallInt },
"INTEGER" | "INT" | "INT4" => quote! { sqlmodel_core::SqlType::Integer },
"BIGINT" | "INT8" => quote! { sqlmodel_core::SqlType::BigInt },
"REAL" | "FLOAT4" => quote! { sqlmodel_core::SqlType::Real },
"DOUBLE" | "DOUBLE PRECISION" | "FLOAT8" | "FLOAT" => {
quote! { sqlmodel_core::SqlType::Double }
}
"NUMERIC" => quote! { sqlmodel_core::SqlType::Numeric { precision: 38, scale: 18 } },
"DECIMAL" => quote! { sqlmodel_core::SqlType::Decimal { precision: 38, scale: 18 } },
"BOOLEAN" | "BOOL" => quote! { sqlmodel_core::SqlType::Boolean },
"TEXT" => quote! { sqlmodel_core::SqlType::Text },
"VARCHAR" => quote! { sqlmodel_core::SqlType::VarChar(255) }, "CHAR" => quote! { sqlmodel_core::SqlType::Char(1) },
"BLOB" | "BYTEA" => quote! { sqlmodel_core::SqlType::Blob },
"BINARY" => quote! { sqlmodel_core::SqlType::Binary(255) },
"VARBINARY" => quote! { sqlmodel_core::SqlType::VarBinary(255) },
"DATE" => quote! { sqlmodel_core::SqlType::Date },
"TIME" => quote! { sqlmodel_core::SqlType::Time },
"DATETIME" => quote! { sqlmodel_core::SqlType::DateTime },
"TIMESTAMP" => quote! { sqlmodel_core::SqlType::Timestamp },
"TIMESTAMPTZ" | "TIMESTAMP WITH TIME ZONE" => {
quote! { sqlmodel_core::SqlType::TimestampTz }
}
"UUID" => quote! { sqlmodel_core::SqlType::Uuid },
"JSON" => quote! { sqlmodel_core::SqlType::Json },
"JSONB" => quote! { sqlmodel_core::SqlType::JsonB },
_ => {
let custom = sql_type; quote! { sqlmodel_core::SqlType::Custom(#custom) }
}
}
}
fn unwrap_option_type(ty: &Type) -> &Type {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(GenericArgument::Type(inner)) = args.args.first() {
return inner;
}
}
}
}
}
ty
}
fn type_to_string(ty: &Type) -> String {
use quote::ToTokens;
ty.to_token_stream().to_string().replace(' ', "")
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_infer_primitives() {
let ty: Type = parse_quote!(i32);
let result = infer_sql_type(&ty).to_string();
assert!(result.contains("Integer"));
let ty: Type = parse_quote!(i64);
let result = infer_sql_type(&ty).to_string();
assert!(result.contains("BigInt"));
let ty: Type = parse_quote!(bool);
let result = infer_sql_type(&ty).to_string();
assert!(result.contains("Boolean"));
}
#[test]
fn test_infer_string() {
let ty: Type = parse_quote!(String);
let result = infer_sql_type(&ty).to_string();
assert!(result.contains("Text"));
}
#[test]
fn test_infer_option() {
let ty: Type = parse_quote!(Option<i32>);
let result = infer_sql_type(&ty).to_string();
assert!(result.contains("Integer"));
}
#[test]
fn test_parse_sql_type_varchar() {
let result = parse_sql_type_attr("VARCHAR(100)").to_string();
assert!(result.contains("VarChar"));
assert!(result.contains("100"));
}
#[test]
fn test_parse_sql_type_numeric() {
let result = parse_sql_type_attr("NUMERIC(10, 2)").to_string();
assert!(result.contains("Numeric"));
assert!(result.contains("10"));
assert!(result.contains('2'));
}
}