ormlite_attr/
derive.rs

1use syn::parse::{Parse, ParseStream};
2use syn::punctuated::Punctuated;
3use syn::{Meta, Path, Token};
4
5use crate::cfg_attr::CfgAttr;
6
7#[derive(Debug)]
8pub struct DeriveTrait {
9    /// The derived trait
10    pub name: String,
11    /// The path to the derived trait
12    pub path: Option<String>,
13}
14
15impl DeriveTrait {
16    pub fn has_derive(&self, pkg: &str, name: &str) -> bool {
17        if self.name != name {
18            return false;
19        }
20        match &self.path {
21            None => true,
22            Some(path) => path == pkg,
23        }
24    }
25
26    pub fn has_any_derive(&self, pkg: &[&str], name: &str) -> bool {
27        if self.name != name {
28            return false;
29        }
30        match &self.path {
31            None => true,
32            Some(path) => pkg.contains(&path.as_str()),
33        }
34    }
35}
36
37impl From<Path> for DeriveTrait {
38    fn from(value: Path) -> Self {
39        let name = value.segments.last().as_ref().unwrap().ident.to_string();
40        let mut path = None;
41        if value.segments.len() > 1 {
42            path = value.segments.first().map(|s| s.ident.to_string());
43        }
44        DeriveTrait { name, path }
45    }
46}
47
48/// Existing libraries like `structmeta` and `darling` do not parse derives, as they are
49/// built assuming the data comes from proc-macro, whereas in ormlite we do both proc-macro
50/// as well as static codebase analysis
51#[derive(Debug, Default)]
52pub struct DeriveParser {
53    derives: Vec<DeriveTrait>,
54}
55
56impl DeriveParser {
57    pub fn has_derive(&self, pkg: &str, name: &str) -> bool {
58        self.derives.iter().any(|d| d.has_derive(pkg, name))
59    }
60
61    pub fn has_any_derive(&self, pkg: &[&str], name: &str) -> bool {
62        self.derives.iter().any(|d| d.has_any_derive(pkg, name))
63    }
64
65    pub(crate) fn update(&mut self, other: Derive) {
66        for path in other.inner {
67            self.derives.push(path.into());
68        }
69    }
70}
71
72impl DeriveParser {
73    const ATTRIBUTE: &'static str = "derive";
74
75    pub fn from_attributes(attrs: &[syn::Attribute]) -> Self {
76        let mut result = Self::default();
77        for attr in attrs {
78            let Some(ident) = attr.path().get_ident() else {
79                continue;
80            };
81            if ident == Self::ATTRIBUTE {
82                result.update(attr.parse_args().unwrap());
83            } else if ident == "cfg_attr" {
84                let cfg: CfgAttr = attr.parse_args().unwrap();
85                for attr in cfg.attrs {
86                    let Some(ident) = attr.path().get_ident() else {
87                        continue;
88                    };
89                    if ident == Self::ATTRIBUTE {
90                        let Meta::List(attrs) = attr else {
91                            panic!("Expected a list of attributes")
92                        };
93                        result.update(attrs.parse_args().unwrap());
94                    }
95                }
96            }
97        }
98        result
99    }
100}
101
102/// Parses `#[derive(...)]`
103pub(crate) struct Derive {
104    inner: Punctuated<Path, Token![,]>,
105}
106
107impl Parse for Derive {
108    fn parse(input: ParseStream) -> syn::Result<Self> {
109        Ok(Derive {
110            inner: input.parse_terminated(Path::parse_mod_style, Token![,])?,
111        })
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::repr::Repr;
119
120    #[test]
121    fn test_repr() {
122        let q = quote::quote! {
123            #[derive(sqlx::Type)]
124            #[repr(u8)]
125            pub enum Foo {
126                Bar,
127                Baz,
128            }
129        };
130        let item = syn::parse2::<syn::ItemEnum>(q).unwrap();
131        let derive = DeriveParser::from_attributes(&item.attrs);
132        let repr = Repr::from_attributes(&item.attrs).unwrap();
133        assert!(derive.has_any_derive(&["ormlite", "sqlx"], "Type"));
134        assert_eq!(repr, "u8");
135    }
136
137    /// The attributes on this are sort of nonsense, but we want to test the dynamic attribute parsing
138    /// in ormlite_attr::Attribute
139    #[test]
140    fn test_attributes() {
141        // the doc string is the regression test
142        let code = r#"/// Json-serializable representation of query results
143#[derive(Debug, Serialize, Deserialize, Clone, sqlx::Type, ormlite::Model)]
144#[repr(u8)]
145#[ormlite(table = "result")]
146#[deprecated]
147pub struct QuerySet {
148    pub headers: Vec<String>,
149    pub rows: Vec<Vec<Value>>,
150}"#;
151        let file: syn::File = syn::parse_str(code).unwrap();
152        let syn::Item::Struct(item) = file.items.first().unwrap() else {
153            panic!("expected struct");
154        };
155        let attr = DeriveParser::from_attributes(&item.attrs);
156        let repr = Repr::from_attributes(&item.attrs).unwrap();
157        assert_eq!(repr, "u8");
158        assert!(attr.has_derive("ormlite", "Model"));
159        assert!(attr.has_any_derive(&["ormlite", "sqlx"], "Type"));
160        assert!(!attr.has_derive("ormlite", "ManualType"));
161    }
162
163    #[test]
164    fn test_cfg_attr() {
165        // the doc string is the regression test
166        let code = r#"
167#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
168#[cfg_attr(
169    target_arch = "wasm32",
170    derive(tsify::Tsify),
171    tsify(into_wasm_abi, from_wasm_abi)
172)]
173#[cfg_attr(
174    not(target_arch = "wasm32"),
175    derive(
176        sqlx::Type,
177        strum::IntoStaticStr,
178        strum::EnumString,
179    ),
180    strum(serialize_all = "snake_case")
181)]
182#[serde(rename_all = "snake_case")]
183pub enum Privacy {
184    Private,
185    Team,
186    Public,
187}
188"#;
189        let file: syn::File = syn::parse_str(code).unwrap();
190        let syn::Item::Enum(item) = file.items.first().unwrap() else {
191            panic!()
192        };
193        let attr = DeriveParser::from_attributes(&item.attrs);
194        assert!(attr.has_any_derive(&["ormlite", "sqlx"], "Type"));
195    }
196
197    #[test]
198    fn test_cfg_attr2() {
199        let code = r#"
200#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
201#[cfg_attr(
202    target_arch = "wasm32",
203    derive(tsify::Tsify),
204    tsify(into_wasm_abi, from_wasm_abi)
205)]
206#[cfg_attr(
207    not(target_arch = "wasm32"),
208    derive(ormlite::types::ManualType, strum::IntoStaticStr, strum::EnumString),
209    strum(serialize_all = "snake_case")
210)]
211#[serde(rename_all = "snake_case")]
212pub enum Privacy {
213    Private,
214    Team,
215    Public,
216}
217"#;
218        let file: syn::File = syn::parse_str(code).unwrap();
219        let syn::Item::Enum(item) = file.items.first().unwrap() else {
220            panic!()
221        };
222        let attr = DeriveParser::from_attributes(&item.attrs);
223        assert_eq!(attr.has_derive("ormlite", "ManualType"), true);
224    }
225}