ormlite_macro/
lib.rs

1#![allow(unused)]
2#![allow(non_snake_case)]
3
4use codegen::insert::impl_Insert;
5use convert_case::{Case, Casing};
6use ormlite_attr::InsertMeta;
7use proc_macro::TokenStream;
8use std::borrow::Borrow;
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::env;
12use std::env::var;
13use std::ops::Deref;
14use std::sync::OnceLock;
15use syn::DataEnum;
16
17use quote::quote;
18use syn::{Data, DeriveInput, parse_macro_input};
19
20use codegen::into_arguments::impl_IntoArguments;
21use ormlite_attr::DeriveInputExt;
22use ormlite_attr::ModelMeta;
23use ormlite_attr::TableMeta;
24use ormlite_attr::schema_from_filepaths;
25use ormlite_core::config::get_var_model_folders;
26
27use crate::codegen::common::OrmliteCodegen;
28use crate::codegen::from_row::{impl_FromRow, impl_from_row_using_aliases};
29use crate::codegen::insert::impl_InsertModel;
30use crate::codegen::insert_model::struct_InsertModel;
31use crate::codegen::join_description::static_join_descriptions;
32use crate::codegen::meta::{impl_JoinMeta, impl_TableMeta};
33use crate::codegen::model::impl_Model;
34use crate::codegen::model_builder::{impl_ModelBuilder, struct_ModelBuilder};
35
36mod codegen;
37mod util;
38
39/// Mapping from StructName -> ModelMeta
40pub(crate) type MetadataCache = HashMap<String, ModelMeta>;
41
42static TABLES: OnceLock<MetadataCache> = OnceLock::new();
43
44fn get_tables() -> &'static MetadataCache {
45    TABLES.get_or_init(|| load_metadata_cache())
46}
47
48fn load_metadata_cache() -> MetadataCache {
49    let mut tables = HashMap::new();
50    let paths = get_var_model_folders();
51    let paths = paths.iter().map(|p| p.as_path()).collect::<Vec<_>>();
52    let schema = schema_from_filepaths(&paths).expect("Failed to preload models");
53    for meta in schema.tables {
54        let name = meta.ident.to_string();
55        tables.insert(name, meta);
56    }
57    tables
58}
59
60/// For a given struct, determine what codegen to use.
61fn get_databases(table_meta: &TableMeta) -> Vec<Box<dyn OrmliteCodegen>> {
62    let mut databases: Vec<Box<dyn OrmliteCodegen>> = Vec::new();
63    let dbs = table_meta.databases.clone();
64    if dbs.is_empty() {
65        #[cfg(feature = "default-sqlite")]
66        databases.push(Box::new(codegen::sqlite::SqliteBackend {}));
67        #[cfg(feature = "default-postgres")]
68        databases.push(Box::new(codegen::postgres::PostgresBackend));
69        #[cfg(feature = "default-mysql")]
70        databases.push(Box::new(codegen::mysql::MysqlBackend {}));
71    } else {
72        for db in dbs {
73            match db.as_str() {
74                #[cfg(feature = "sqlite")]
75                "sqlite" => databases.push(Box::new(codegen::sqlite::SqliteBackend {})),
76                #[cfg(feature = "postgres")]
77                "postgres" => databases.push(Box::new(codegen::postgres::PostgresBackend)),
78                #[cfg(feature = "mysql")]
79                "mysql" => databases.push(Box::new(codegen::mysql::MysqlBackend {})),
80                "sqlite" | "postgres" | "mysql" => {
81                    panic!("Database {} is not enabled. Enable it with features = [\"{}\"]", db, db)
82                }
83                _ => panic!("Unknown database: {}", db),
84            }
85        }
86    }
87    if databases.is_empty() {
88        let mut count = 0;
89        #[cfg(feature = "sqlite")]
90        {
91            count += 1;
92        }
93        #[cfg(feature = "postgres")]
94        {
95            count += 1;
96        }
97        #[cfg(feature = "mysql")]
98        {
99            count += 1;
100        }
101        if count > 1 {
102            panic!(
103                "You have more than one database configured using features, but no database is specified for this model. \
104            Specify a database for the model like this:\n\n#[ormlite(database = \"<db>\")]\n\nOr you can enable \
105            a default database feature:\n\n # Cargo.toml\normlite = {{ features = [\"default-<db>\"] }}"
106            );
107        }
108    }
109    if databases.is_empty() {
110        #[cfg(feature = "sqlite")]
111        databases.push(Box::new(codegen::sqlite::SqliteBackend {}));
112        #[cfg(feature = "postgres")]
113        databases.push(Box::new(codegen::postgres::PostgresBackend));
114        #[cfg(feature = "mysql")]
115        databases.push(Box::new(codegen::mysql::MysqlBackend {}));
116    }
117    if databases.is_empty() {
118        panic!(
119            r#"No database is enabled. Enable one of these features for the ormlite crate: postgres, mysql, sqlite"#
120        );
121    }
122    databases
123}
124
125/// Derive macro for `#[derive(Model)]` It additionally generates FromRow for the struct, since
126/// Model requires FromRow.
127#[proc_macro_derive(Model, attributes(ormlite))]
128pub fn expand_ormlite_model(input: TokenStream) -> TokenStream {
129    let ast = parse_macro_input!(input as DeriveInput);
130    let meta = ModelMeta::from_derive(&ast);
131    let mut databases = get_databases(&meta.table);
132    let tables = get_tables();
133    let first = databases.remove(0);
134
135    let primary = {
136        let db = first.as_ref();
137        let impl_TableMeta = impl_TableMeta(&meta.table, Some(meta.pkey.name.as_str()));
138        let impl_JoinMeta = impl_JoinMeta(&meta);
139        let static_join_descriptions = static_join_descriptions(&meta.table, &tables);
140        let impl_Model = impl_Model(db, &meta, tables);
141        let impl_FromRow = impl_FromRow(db, &meta.table, &tables);
142        let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta.table, &tables);
143
144        let struct_ModelBuilder = struct_ModelBuilder(&ast, &meta);
145        let impl_ModelBuilder = impl_ModelBuilder(db, &meta);
146
147        let struct_InsertModel = struct_InsertModel(&ast, &meta);
148        let impl_InsertModel = impl_InsertModel(db, &meta);
149
150        quote! {
151            #impl_TableMeta
152            #impl_JoinMeta
153
154            #static_join_descriptions
155            #impl_Model
156            #impl_FromRow
157            #impl_from_row_using_aliases
158
159            #struct_ModelBuilder
160            #impl_ModelBuilder
161
162            #struct_InsertModel
163            #impl_InsertModel
164        }
165    };
166
167    let rest = databases.iter().map(|db| {
168        let impl_Model = impl_Model(db.as_ref(), &meta, tables);
169        quote! {
170            #impl_Model
171        }
172    });
173
174    TokenStream::from(quote! {
175        #primary
176        #(#rest)*
177    })
178}
179
180#[proc_macro_derive(Insert, attributes(ormlite))]
181pub fn expand_ormlite_insert(input: TokenStream) -> TokenStream {
182    let ast = parse_macro_input!(input as DeriveInput);
183    let mut meta = InsertMeta::from_derive(&ast);
184    let mut databases = get_databases(&meta.table);
185    let tables = get_tables();
186    if meta.name.is_none() {
187        if let Some(m) = tables.get(meta.returns.as_ref()) {
188            meta.table.name = m.name.clone();
189        }
190    }
191    let first = databases.remove(0);
192    TokenStream::from(impl_Insert(first.as_ref(), &meta.table, &meta.ident, &meta.returns))
193}
194
195#[proc_macro_derive(FromRow, attributes(ormlite))]
196pub fn expand_derive_fromrow(input: TokenStream) -> TokenStream {
197    let ast = parse_macro_input!(input as DeriveInput);
198    let meta = TableMeta::from_derive(&ast);
199
200    let databases = get_databases(&meta);
201    let tables = get_tables();
202
203    let expanded = databases.iter().map(|db| {
204        let db = db.as_ref();
205        let impl_FromRow = impl_FromRow(db, &meta, &tables);
206        let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta, &tables);
207        quote! {
208            #impl_FromRow
209            #impl_from_row_using_aliases
210        }
211    });
212
213    TokenStream::from(quote! {
214        #(#expanded)*
215    })
216}
217
218#[proc_macro_derive(TableMeta, attributes(ormlite))]
219pub fn expand_derive_table_meta(input: TokenStream) -> TokenStream {
220    let ast = parse_macro_input!(input as DeriveInput);
221    let Data::Struct(data) = &ast.data else {
222        panic!("Only structs can derive Model");
223    };
224
225    let table_meta = TableMeta::from_derive(&ast);
226    let databases = get_databases(&table_meta);
227    let impl_TableMeta = impl_TableMeta(&table_meta, table_meta.pkey.as_deref());
228    TokenStream::from(impl_TableMeta)
229}
230
231#[proc_macro_derive(IntoArguments, attributes(ormlite))]
232pub fn expand_derive_into_arguments(input: TokenStream) -> TokenStream {
233    let ast = parse_macro_input!(input as DeriveInput);
234    let Data::Struct(data) = &ast.data else {
235        panic!("Only structs can derive Model");
236    };
237
238    let meta = TableMeta::from_derive(&ast);
239    let databases = get_databases(&meta);
240
241    let expanded = databases.iter().map(|db| {
242        let impl_IntoArguments = impl_IntoArguments(db.as_ref(), &meta);
243        impl_IntoArguments
244    });
245    TokenStream::from(quote! {
246        #(#expanded)*
247    })
248}
249
250/// This is a no-op marker trait that allows the migration tool to know when a user has
251/// manually implemented a type.
252///
253/// This is useful for having data that's a string in the database, but a strum::EnumString in code.
254#[proc_macro_derive(ManualType)]
255pub fn expand_derive_manual_type(input: TokenStream) -> TokenStream {
256    TokenStream::new()
257}
258
259#[proc_macro_derive(Enum)]
260pub fn derive_ormlite_enum(input: TokenStream) -> TokenStream {
261    let input = parse_macro_input!(input as DeriveInput);
262
263    let enum_name = input.ident;
264
265    let variants = match input.data {
266        Data::Enum(DataEnum { variants, .. }) => variants,
267        _ => panic!("#[derive(OrmliteEnum)] is only supported on enums"),
268    };
269
270    // Collect variant names and strings into vectors
271    let variant_names: Vec<_> = variants.iter().map(|v| &v.ident).collect();
272    let variant_strings: Vec<_> = variant_names
273        .iter()
274        .map(|v| v.to_string().to_case(Case::Snake))
275        .collect();
276
277    let placeholder = quote! {
278        impl std::fmt::Display for #enum_name {
279            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280                match self {
281                    #(Self::#variant_names => write!(f, "{}", #variant_strings)),*
282                }
283            }
284        }
285
286        impl std::str::FromStr for #enum_name {
287            type Err = String;
288            fn from_str(s: &str) -> Result<Self, <Self as std::str::FromStr>::Err> {
289                match s {
290                    #(#variant_strings => Ok(Self::#variant_names)),*,
291                    _ => Err(format!("Invalid {} value: {}", stringify!(#enum_name), s))
292                }
293            }
294        }
295
296        impl std::convert::TryFrom<&str> for #enum_name {
297            type Error = String;
298            fn try_from(value: &str) -> Result<Self, Self::Error> {
299                <Self as std::str::FromStr>::from_str(value)
300            }
301        }
302
303        impl sqlx::Decode<'_, sqlx::Postgres> for #enum_name {
304            fn decode(
305                value: sqlx::postgres::PgValueRef<'_>,
306            ) -> Result<Self, sqlx::error::BoxDynError> {
307                let s = value.as_str()?;
308                <Self as std::str::FromStr>::from_str(s).map_err(|e| sqlx::error::BoxDynError::from(
309                    std::io::Error::new(std::io::ErrorKind::InvalidData, e)
310                ))
311            }
312        }
313
314        impl sqlx::Encode<'_, sqlx::Postgres> for #enum_name {
315            fn encode_by_ref(
316                &self,
317                buf: &mut sqlx::postgres::PgArgumentBuffer
318            ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
319                let s = self.to_string();
320                <String as sqlx::Encode<sqlx::Postgres>>::encode(s, buf)
321            }
322        }
323
324        impl sqlx::Type<sqlx::Postgres> for #enum_name {
325            fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
326                sqlx::postgres::PgTypeInfo::with_name("VARCHAR")
327            }
328
329            fn compatible(ty: &<sqlx::Postgres as sqlx::Database>::TypeInfo) -> bool {
330                ty.to_string() == "VARCHAR"
331            }
332        }
333    };
334
335    placeholder.into()
336}