elephantry_derive/
lib.rs

1#![warn(warnings)]
2
3mod composite;
4mod entity;
5mod r#enum;
6mod params;
7
8/**
9 * Impl [`FromSql`]/[`ToSql`] traits for [composite
10 * type](https://www.postgresql.org/docs/current/rowtypes.html).
11 *
12 * [`FromSql`]: trait.FromSql.html
13 * [`ToSql`]: trait.ToSql.html
14 */
15#[proc_macro_derive(Composite, attributes(elephantry))]
16pub fn composite_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
17    let ast = syn::parse(input).unwrap();
18
19    composite::impl_macro(&ast)
20        .unwrap_or_else(syn::Error::into_compile_error)
21        .into()
22}
23
24/**
25 * Impl [`Entity`] trait.
26 *
27 * [`Entity`]: trait.Entity.html
28 */
29#[proc_macro_derive(Entity, attributes(elephantry))]
30pub fn entity_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31    let ast: syn::DeriveInput = syn::parse(input).unwrap();
32
33    entity::impl_macro(&ast)
34        .unwrap_or_else(syn::Error::into_compile_error)
35        .into()
36}
37
38/**
39 * Impl [`FromSql`]/[`ToSql`] traits for [enum
40 * type](https://www.postgresql.org/docs/current/datatype-enum.html).
41 *
42 * [`FromSql`]: trait.FromSql.html
43 * [`ToSql`]: trait.ToSql.html
44 */
45#[proc_macro_derive(Enum, attributes(elephantry))]
46pub fn enum_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
47    let ast = syn::parse(input).unwrap();
48
49    r#enum::impl_macro(&ast)
50        .unwrap_or_else(syn::Error::into_compile_error)
51        .into()
52}
53
54pub(crate) fn check_type(ty: &syn::Type) -> syn::Result<()> {
55    let features = [
56        #[cfg(feature = "bit")]
57        "bit",
58        #[cfg(feature = "chrono")]
59        "chrono",
60        #[cfg(feature = "date")]
61        "date",
62        #[cfg(feature = "geo")]
63        "geo",
64        #[cfg(feature = "ltree")]
65        "ltree",
66        #[cfg(feature = "jiff")]
67        "jiff",
68        #[cfg(feature = "json")]
69        "json",
70        #[cfg(feature = "multirange")]
71        "multirange",
72        #[cfg(feature = "net")]
73        "net",
74        #[cfg(feature = "numeric")]
75        "numeric",
76        #[cfg(feature = "time")]
77        "time",
78        #[cfg(feature = "uuid")]
79        "uuid",
80        #[cfg(feature = "xml")]
81        "xml",
82    ];
83
84    let types = [
85        ("bit", "bit_vec::BitVec"),
86        ("bit", "elephantry::Bits"),
87        ("bit", "u8"),
88        ("chrono", "chrono::DateTime"),
89        ("chrono", "chrono::NaiveDate"),
90        ("chrono", "chrono::NaiveDateTime"),
91        ("chrono", "chrono::NaiveTime"),
92        ("date", "elephantry::Date"),
93        ("date", "elephantry::TimestampTz"),
94        ("date", "elephantry::Timestamp"),
95        ("date", "elephantry::Interval"),
96        ("geo", "elephantry::Box"),
97        ("geo", "elephantry::Circle"),
98        ("geo", "elephantry::Line"),
99        ("geo", "elephantry::Path"),
100        ("geo", "elephantry::Point"),
101        ("geo", "elephantry::Polygon"),
102        ("geo", "elephantry::Segment"),
103        ("jiff", "jiff::civil::Date"),
104        ("jiff", "jiff::civil::DateTime"),
105        ("jiff", "jiff::civil::Time"),
106        ("jiff", "jiff::Zoned"),
107        ("json", "serde_json::value::Value"),
108        ("json", "elephantry::Json"),
109        ("ltree", "elephantry::Lquery"),
110        ("ltree", "elephantry::Ltree"),
111        ("ltree", "elephantry::Ltxtquery"),
112        ("multirange", "elephantry::Multirange"),
113        ("net", "ipnetwork::IpNetwork"),
114        ("net", "elephantry::Cidr"),
115        ("net", "macaddr::MacAddr6"),
116        ("net", "elephantry::MacAddr"),
117        ("net", "macaddr::MacAddr8"),
118        ("net", "elephantry::MacAddr8"),
119        ("net", "std::net::IpAddr"),
120        ("numeric", "bigdecimal::BigDecimal"),
121        ("numeric", "elephantry::Numeric"),
122        ("time", "elephantry::Time"),
123        ("time", "elephantry::TimeTz"),
124        ("uuid", "uuid::Uuid"),
125        ("uuid", "elephantry::Uuid"),
126        ("xml", "xmltree::Element"),
127        ("xml", "elephantry::Xml"),
128    ];
129
130    for (feature, feature_ty) in &types {
131        if !features.contains(feature) && ty == &syn::parse_str(feature_ty).unwrap() {
132            return error(
133                ty,
134                &format!(
135                    "Enable '{feature}' feature to use the type `{feature_ty}` in this entity"
136                ),
137            );
138        }
139    }
140
141    check_u8_array(ty)
142}
143
144#[cfg(not(feature = "bit"))]
145fn check_u8_array(ty: &syn::Type) -> syn::Result<()> {
146    if let syn::Type::Array(array) = ty
147        && array.elem == syn::parse_str("u8")?
148    {
149        return error(
150            ty,
151            "Enable 'bit' feature to use the type `[u8]` in this entity",
152        );
153    }
154
155    Ok(())
156}
157
158#[cfg(feature = "bit")]
159fn check_u8_array(_: &syn::Type) -> syn::Result<()> {
160    Ok(())
161}
162
163pub(crate) fn error<R>(ast: &dyn quote::ToTokens, message: &str) -> syn::Result<R> {
164    Err(syn::Error::new_spanned(ast, message))
165}
166
167pub(crate) fn elephantry() -> proc_macro2::TokenStream {
168    match (
169        proc_macro_crate::crate_name("elephantry"),
170        std::env::var("CARGO_CRATE_NAME").as_deref(),
171    ) {
172        (Ok(proc_macro_crate::FoundCrate::Itself), Ok("elephantry")) => quote::quote!(crate),
173        (Ok(proc_macro_crate::FoundCrate::Name(name)), _) => {
174            let ident = proc_macro2::Ident::new(&name, proc_macro2::Span::call_site());
175            quote::quote!(::#ident)
176        }
177        _ => quote::quote!(::elephantry),
178    }
179}