use proc_macro2::TokenStream;
use quote::{quote, quote_spanned};
use std::collections::HashSet;
use syn::{
braced,
parse::{Parse, ParseStream},
Error, File, FnArg, Ident, Item, LitBool, LitChar, Pat, ReturnType, Signature, Token,
TraitItem, Type,
};
pub(crate) enum Override {
Bool(LitBool),
Char(LitChar),
None,
}
pub(crate) struct DeriveDialectInput {
pub name: Ident,
pub base: Type,
pub preserve_type_id: bool,
pub overrides: Vec<(Ident, Override)>,
}
struct DialectMethod {
name: Ident,
signature: Signature,
}
impl Parse for DeriveDialectInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let name: Ident = input.parse()?;
input.parse::<Token![,]>()?;
let base: Type = input.parse()?;
let mut preserve_type_id = false;
let mut overrides = Vec::new();
while input.peek(Token![,]) {
input.parse::<Token![,]>()?;
if input.is_empty() {
break;
}
if input.peek(Ident) {
let ident: Ident = input.parse()?;
match ident.to_string().as_str() {
"preserve_type_id" => {
input.parse::<Token![=]>()?;
preserve_type_id = input.parse::<LitBool>()?.value();
}
"overrides" => {
input.parse::<Token![=]>()?;
let content;
braced!(content in input);
while !content.is_empty() {
let key: Ident = content.parse()?;
content.parse::<Token![=]>()?;
let value = if content.peek(LitBool) {
Override::Bool(content.parse()?)
} else if content.peek(LitChar) {
Override::Char(content.parse()?)
} else if content.peek(Ident) {
let ident: Ident = content.parse()?;
if ident == "None" {
Override::None
} else {
return Err(Error::new(
ident.span(),
format!("Expected `true`, `false`, a char, or `None`, found `{ident}`"),
));
}
} else {
return Err(
content.error("Expected `true`, `false`, a char, or `None`")
);
};
overrides.push((key, value));
if content.peek(Token![,]) {
content.parse::<Token![,]>()?;
}
}
}
other => {
return Err(Error::new(ident.span(), format!(
"Unknown argument `{other}`. Expected `preserve_type_id` or `overrides`."
)));
}
}
}
}
Ok(DeriveDialectInput {
name,
base,
preserve_type_id,
overrides,
})
}
}
pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream {
let err = |msg: String| {
Error::new(proc_macro2::Span::call_site(), msg)
.to_compile_error()
.into()
};
let source = match read_dialect_mod_file() {
Ok(s) => s,
Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
};
let file: File = match syn::parse_str(&source) {
Ok(f) => f,
Err(e) => return err(format!("Failed to parse source: {e}")),
};
let methods = match extract_dialect_methods(&file) {
Ok(m) => m,
Err(e) => return e.to_compile_error().into(),
};
let bool_names: HashSet<_> = methods
.iter()
.filter(|m| is_bool_method(&m.signature))
.map(|m| m.name.to_string())
.collect();
for (key, value) in &input.overrides {
let key_str = key.to_string();
let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
match value {
Override::Bool(_) if !bool_names.contains(&key_str) => {
return err(format!("Unknown boolean method `{key_str}`"));
}
Override::Char(_) | Override::None if key_str != "identifier_quote_style" => {
return err(format!(
"Char/None only valid for `identifier_quote_style`, not `{key_str}`"
));
}
_ => {}
}
}
generate_derived_dialect(&input, &methods).into()
}
fn generate_derived_dialect(input: &DeriveDialectInput, methods: &[DialectMethod]) -> TokenStream {
let name = &input.name;
let base = &input.base;
let find_override = |method_name: &str| {
input
.overrides
.iter()
.find(|(k, _)| k == method_name)
.map(|(_, v)| v)
};
let delegate = |method: &DialectMethod| {
let sig = &method.signature;
let method_name = &method.name;
let params = extract_param_names(sig);
quote_spanned! { method_name.span() => #sig { self.dialect.#method_name(#(#params),*) } }
};
let struct_def = quote_spanned! { name.span() =>
#[derive(Debug, Default)]
pub struct #name {
dialect: #base,
}
impl #name {
pub fn new() -> Self { Self::default() }
}
};
let type_id_body = if input.preserve_type_id {
quote! { Dialect::dialect(&self.dialect) }
} else {
quote! { ::core::any::TypeId::of::<#name>() }
};
let method_impls = methods.iter().map(|method| {
let method_name = &method.name;
match find_override(&method_name.to_string()) {
Some(Override::Bool(value)) => {
quote_spanned! { method_name.span() => fn #method_name(&self) -> bool { #value } }
}
Some(Override::Char(c)) => {
quote_spanned! { method_name.span() =>
fn identifier_quote_style(&self, _: &str) -> Option<char> { Some(#c) }
}
}
Some(Override::None) => {
quote_spanned! { method_name.span() =>
fn identifier_quote_style(&self, _: &str) -> Option<char> { None }
}
}
None => delegate(method),
}
});
quote! {
#struct_def
const _: () = {
use ::core::iter::Peekable;
use ::core::str::Chars;
use sqlparser::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart, Statement};
use sqlparser::dialect::{Dialect, Precedence};
use sqlparser::keywords::Keyword;
use sqlparser::parser::{Parser, ParserError};
impl Dialect for #name {
fn dialect(&self) -> ::core::any::TypeId { #type_id_body }
#(#method_impls)*
}
};
}
}
fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
sig.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Typed(pt) => match pt.pat.as_ref() {
Pat::Ident(pi) => Some(&pi.ident),
_ => None,
},
_ => None,
})
.collect()
}
fn read_dialect_mod_file() -> Result<String, String> {
let manifest_dir =
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?;
let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display()))
}
fn extract_dialect_methods(file: &File) -> Result<Vec<DialectMethod>, Error> {
let dialect_trait = file
.items
.iter()
.find_map(|item| match item {
Item::Trait(t) if t.ident == "Dialect" => Some(t),
_ => None,
})
.ok_or_else(|| Error::new(proc_macro2::Span::call_site(), "Dialect trait not found"))?;
let mut methods: Vec<_> = dialect_trait
.items
.iter()
.filter_map(|item| match item {
TraitItem::Fn(m) if m.sig.ident != "dialect" => Some(DialectMethod {
name: m.sig.ident.clone(),
signature: m.sig.clone(),
}),
_ => None,
})
.collect();
methods.sort_by_key(|m| m.name.to_string());
Ok(methods)
}
fn is_bool_method(sig: &Signature) -> bool {
sig.inputs.len() == 1
&& matches!(
sig.inputs.first(),
Some(FnArg::Receiver(r)) if r.reference.is_some() && r.mutability.is_none()
)
&& matches!(
&sig.output,
ReturnType::Type(_, ty) if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("bool"))
)
}