Skip to main content

ormlite_macro/
lib.rs

1#![allow(unused)]
2#![allow(non_snake_case)]
3
4use codegen::insert::impl_Insert;
5use convert_case::{Case, Casing};
6use ormlite_attr::InsertMeta;
7use proc_macro::TokenStream;
8use std::borrow::Borrow;
9use std::cell::OnceCell;
10use std::collections::HashMap;
11use std::env;
12use std::env::var;
13use std::ops::Deref;
14use std::path::PathBuf;
15use std::str::FromStr;
16use std::sync::OnceLock;
17use syn::DataEnum;
18
19use quote::quote;
20use syn::{Data, DeriveInput, parse_macro_input};
21
22use codegen::into_arguments::impl_IntoArguments;
23use ormlite_attr::DeriveInputExt;
24use ormlite_attr::ModelMeta;
25use ormlite_attr::TableMeta;
26use ormlite_attr::schema_from_filepaths;
27
28use crate::codegen::common::OrmliteCodegen;
29use crate::codegen::from_row::{impl_FromRow, impl_from_row_using_aliases};
30use crate::codegen::insert::impl_InsertModel;
31use crate::codegen::insert_model::struct_InsertModel;
32use crate::codegen::join_description::static_join_descriptions;
33use crate::codegen::meta::{impl_JoinMeta, impl_TableMeta};
34use crate::codegen::model::impl_Model;
35use crate::codegen::model_builder::{impl_ModelBuilder, struct_ModelBuilder};
36
37mod codegen;
38mod placeholder;
39mod util;
40
41/// Mapping from StructName -> ModelMeta
42pub(crate) type MetadataCache = HashMap<String, ModelMeta>;
43
44static TABLES: OnceLock<MetadataCache> = OnceLock::new();
45
46const MODEL_FOLDERS: &str = ".";
47
48fn get_var_model_folders() -> Vec<PathBuf> {
49    let folders = var("MODEL_FOLDERS").unwrap_or_else(|_| MODEL_FOLDERS.to_string());
50    folders.split(',').map(|s| PathBuf::from_str(s).unwrap()).collect()
51}
52
53fn get_tables() -> &'static MetadataCache {
54    TABLES.get_or_init(|| load_metadata_cache())
55}
56
57fn load_metadata_cache() -> MetadataCache {
58    let mut tables = HashMap::new();
59    let paths = get_var_model_folders();
60    let paths = paths.iter().map(|p| p.as_path()).collect::<Vec<_>>();
61    let schema = schema_from_filepaths(&paths).expect("Failed to preload models");
62    for meta in schema.tables {
63        let name = meta.ident.to_string();
64        tables.insert(name, meta);
65    }
66    tables
67}
68
69/// For a given struct, determine what codegen to use.
70fn get_databases(table_meta: &TableMeta) -> Vec<Box<dyn OrmliteCodegen>> {
71    let mut databases: Vec<Box<dyn OrmliteCodegen>> = Vec::new();
72    let dbs = table_meta.databases.clone();
73    if dbs.is_empty() {
74        #[cfg(feature = "default-sqlite")]
75        databases.push(Box::new(codegen::sqlite::SqliteBackend {}));
76        #[cfg(feature = "default-postgres")]
77        databases.push(Box::new(codegen::postgres::PostgresBackend));
78        #[cfg(feature = "default-mysql")]
79        databases.push(Box::new(codegen::mysql::MysqlBackend {}));
80    } else {
81        for db in dbs {
82            match db.as_str() {
83                #[cfg(feature = "sqlite")]
84                "sqlite" => databases.push(Box::new(codegen::sqlite::SqliteBackend {})),
85                #[cfg(feature = "postgres")]
86                "postgres" => databases.push(Box::new(codegen::postgres::PostgresBackend)),
87                #[cfg(feature = "mysql")]
88                "mysql" => databases.push(Box::new(codegen::mysql::MysqlBackend {})),
89                "sqlite" | "postgres" | "mysql" => {
90                    panic!("Database {} is not enabled. Enable it with features = [\"{}\"]", db, db)
91                }
92                _ => panic!("Unknown database: {}", db),
93            }
94        }
95    }
96    if databases.is_empty() {
97        let mut count = 0;
98        #[cfg(feature = "sqlite")]
99        {
100            count += 1;
101        }
102        #[cfg(feature = "postgres")]
103        {
104            count += 1;
105        }
106        #[cfg(feature = "mysql")]
107        {
108            count += 1;
109        }
110        if count > 1 {
111            panic!(
112                "You have more than one database configured using features, but no database is specified for this model. \
113            Specify a database for the model like this:\n\n#[ormlite(database = \"<db>\")]\n\nOr you can enable \
114            a default database feature:\n\n # Cargo.toml\normlite = {{ features = [\"default-<db>\"] }}"
115            );
116        }
117    }
118    if databases.is_empty() {
119        #[cfg(feature = "sqlite")]
120        databases.push(Box::new(codegen::sqlite::SqliteBackend {}));
121        #[cfg(feature = "postgres")]
122        databases.push(Box::new(codegen::postgres::PostgresBackend));
123        #[cfg(feature = "mysql")]
124        databases.push(Box::new(codegen::mysql::MysqlBackend {}));
125    }
126    if databases.is_empty() {
127        panic!(
128            r#"No database is enabled. Enable one of these features for the ormlite crate: postgres, mysql, sqlite"#
129        );
130    }
131    databases
132}
133
134/// Derive macro for `#[derive(Model)]` It additionally generates FromRow for the struct, since
135/// Model requires FromRow.
136#[proc_macro_derive(Model, attributes(ormlite))]
137pub fn expand_ormlite_model(input: TokenStream) -> TokenStream {
138    let ast = parse_macro_input!(input as DeriveInput);
139    let meta = ModelMeta::from_derive(&ast);
140    let mut databases = get_databases(&meta.table);
141    let tables = get_tables();
142    let first = databases.remove(0);
143
144    let primary = {
145        let db = first.as_ref();
146        let impl_TableMeta = impl_TableMeta(&meta.table, Some(meta.pkey.name.as_str()));
147        let impl_JoinMeta = impl_JoinMeta(&meta);
148        let static_join_descriptions = static_join_descriptions(&meta.table, &tables);
149        let impl_Model = impl_Model(db, &meta, tables);
150        let impl_FromRow = impl_FromRow(db, &meta.table, &tables);
151        let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta.table, &tables);
152
153        let struct_ModelBuilder = struct_ModelBuilder(&ast, &meta);
154        let impl_ModelBuilder = impl_ModelBuilder(db, &meta);
155
156        let struct_InsertModel = struct_InsertModel(&ast, &meta);
157        let impl_InsertModel = impl_InsertModel(db, &meta);
158
159        quote! {
160            #impl_TableMeta
161            #impl_JoinMeta
162
163            #static_join_descriptions
164            #impl_Model
165            #impl_FromRow
166            #impl_from_row_using_aliases
167
168            #struct_ModelBuilder
169            #impl_ModelBuilder
170
171            #struct_InsertModel
172            #impl_InsertModel
173        }
174    };
175
176    let rest = databases.iter().map(|db| {
177        let impl_Model = impl_Model(db.as_ref(), &meta, tables);
178        quote! {
179            #impl_Model
180        }
181    });
182
183    TokenStream::from(quote! {
184        #primary
185        #(#rest)*
186    })
187}
188
189#[proc_macro_derive(Insert, attributes(ormlite))]
190pub fn expand_ormlite_insert(input: TokenStream) -> TokenStream {
191    let ast = parse_macro_input!(input as DeriveInput);
192    let mut meta = InsertMeta::from_derive(&ast);
193    let mut databases = get_databases(&meta.table);
194    let tables = get_tables();
195    if meta.name.is_none() {
196        if let Some(m) = tables.get(meta.returns.as_ref()) {
197            meta.table.name = m.name.clone();
198        }
199    }
200    let first = databases.remove(0);
201    TokenStream::from(impl_Insert(first.as_ref(), &meta.table, &meta.ident, &meta.returns))
202}
203
204#[proc_macro_derive(FromRow, attributes(ormlite))]
205pub fn expand_derive_fromrow(input: TokenStream) -> TokenStream {
206    let ast = parse_macro_input!(input as DeriveInput);
207    let meta = TableMeta::from_derive(&ast);
208
209    let databases = get_databases(&meta);
210    let tables = get_tables();
211
212    let expanded = databases.iter().map(|db| {
213        let db = db.as_ref();
214        let impl_FromRow = impl_FromRow(db, &meta, &tables);
215        let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta, &tables);
216        quote! {
217            #impl_FromRow
218            #impl_from_row_using_aliases
219        }
220    });
221
222    TokenStream::from(quote! {
223        #(#expanded)*
224    })
225}
226
227#[proc_macro_derive(TableMeta, attributes(ormlite))]
228pub fn expand_derive_table_meta(input: TokenStream) -> TokenStream {
229    let ast = parse_macro_input!(input as DeriveInput);
230    let Data::Struct(data) = &ast.data else {
231        panic!("Only structs can derive Model");
232    };
233
234    let table_meta = TableMeta::from_derive(&ast);
235    let databases = get_databases(&table_meta);
236    let impl_TableMeta = impl_TableMeta(&table_meta, table_meta.pkey.as_deref());
237    TokenStream::from(impl_TableMeta)
238}
239
240#[proc_macro_derive(IntoArguments, attributes(ormlite))]
241pub fn expand_derive_into_arguments(input: TokenStream) -> TokenStream {
242    let ast = parse_macro_input!(input as DeriveInput);
243    let Data::Struct(data) = &ast.data else {
244        panic!("Only structs can derive Model");
245    };
246
247    let meta = TableMeta::from_derive(&ast);
248    let databases = get_databases(&meta);
249
250    let expanded = databases.iter().map(|db| {
251        let impl_IntoArguments = impl_IntoArguments(db.as_ref(), &meta);
252        impl_IntoArguments
253    });
254    TokenStream::from(quote! {
255        #(#expanded)*
256    })
257}
258
259/// This is a no-op marker trait that allows the migration tool to know when a user has
260/// manually implemented a type.
261///
262/// This is useful for having data that's a string in the database, but a strum::EnumString in code.
263#[proc_macro_derive(ManualType)]
264pub fn expand_derive_manual_type(input: TokenStream) -> TokenStream {
265    TokenStream::new()
266}
267
268#[proc_macro_derive(Enum)]
269pub fn derive_ormlite_enum(input: TokenStream) -> TokenStream {
270    let input = parse_macro_input!(input as DeriveInput);
271
272    let enum_name = input.ident;
273
274    let variants = match input.data {
275        Data::Enum(DataEnum { variants, .. }) => variants,
276        _ => panic!("#[derive(OrmliteEnum)] is only supported on enums"),
277    };
278
279    // Collect variant names and strings into vectors
280    let variant_names: Vec<_> = variants.iter().map(|v| &v.ident).collect();
281    let variant_strings: Vec<_> = variant_names
282        .iter()
283        .map(|v| v.to_string().to_case(Case::Snake))
284        .collect();
285
286    let placeholder = quote! {
287        impl std::fmt::Display for #enum_name {
288            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289                match self {
290                    #(Self::#variant_names => write!(f, "{}", #variant_strings)),*
291                }
292            }
293        }
294
295        impl std::str::FromStr for #enum_name {
296            type Err = String;
297            fn from_str(s: &str) -> Result<Self, <Self as std::str::FromStr>::Err> {
298                match s {
299                    #(#variant_strings => Ok(Self::#variant_names)),*,
300                    _ => Err(format!("Invalid {} value: {}", stringify!(#enum_name), s))
301                }
302            }
303        }
304
305        impl std::convert::TryFrom<&str> for #enum_name {
306            type Error = String;
307            fn try_from(value: &str) -> Result<Self, Self::Error> {
308                <Self as std::str::FromStr>::from_str(value)
309            }
310        }
311
312        impl sqlx::Decode<'_, sqlx::Postgres> for #enum_name {
313            fn decode(
314                value: sqlx::postgres::PgValueRef<'_>,
315            ) -> Result<Self, sqlx::error::BoxDynError> {
316                let s = value.as_str()?;
317                <Self as std::str::FromStr>::from_str(s).map_err(|e| sqlx::error::BoxDynError::from(
318                    std::io::Error::new(std::io::ErrorKind::InvalidData, e)
319                ))
320            }
321        }
322
323        impl sqlx::Encode<'_, sqlx::Postgres> for #enum_name {
324            fn encode_by_ref(
325                &self,
326                buf: &mut sqlx::postgres::PgArgumentBuffer
327            ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
328                let s = self.to_string();
329                <String as sqlx::Encode<sqlx::Postgres>>::encode(s, buf)
330            }
331        }
332
333        impl sqlx::Type<sqlx::Postgres> for #enum_name {
334            fn type_info() -> <sqlx::Postgres as sqlx::Database>::TypeInfo {
335                sqlx::postgres::PgTypeInfo::with_name("VARCHAR")
336            }
337
338            fn compatible(ty: &<sqlx::Postgres as sqlx::Database>::TypeInfo) -> bool {
339                ty.to_string() == "VARCHAR"
340            }
341        }
342    };
343
344    placeholder.into()
345}