1use syn::parse::{Parse, ParseStream};
2use syn::punctuated::Punctuated;
3use syn::{Meta, Path, Token};
4
5use crate::cfg_attr::CfgAttr;
6
7#[derive(Debug)]
8pub struct DeriveTrait {
9 pub name: String,
11 pub path: Option<String>,
13}
14
15impl DeriveTrait {
16 pub fn has_derive(&self, pkg: &str, name: &str) -> bool {
17 if self.name != name {
18 return false;
19 }
20 match &self.path {
21 None => true,
22 Some(path) => path == pkg,
23 }
24 }
25
26 pub fn has_any_derive(&self, pkg: &[&str], name: &str) -> bool {
27 if self.name != name {
28 return false;
29 }
30 match &self.path {
31 None => true,
32 Some(path) => pkg.contains(&path.as_str()),
33 }
34 }
35}
36
37impl From<Path> for DeriveTrait {
38 fn from(value: Path) -> Self {
39 let name = value.segments.last().as_ref().unwrap().ident.to_string();
40 let mut path = None;
41 if value.segments.len() > 1 {
42 path = value.segments.first().map(|s| s.ident.to_string());
43 }
44 DeriveTrait { name, path }
45 }
46}
47
48#[derive(Debug, Default)]
52pub struct DeriveParser {
53 derives: Vec<DeriveTrait>,
54}
55
56impl DeriveParser {
57 pub fn has_derive(&self, pkg: &str, name: &str) -> bool {
58 self.derives.iter().any(|d| d.has_derive(pkg, name))
59 }
60
61 pub fn has_any_derive(&self, pkg: &[&str], name: &str) -> bool {
62 self.derives.iter().any(|d| d.has_any_derive(pkg, name))
63 }
64
65 pub(crate) fn update(&mut self, other: Derive) {
66 for path in other.inner {
67 self.derives.push(path.into());
68 }
69 }
70}
71
72impl DeriveParser {
73 const ATTRIBUTE: &'static str = "derive";
74
75 pub fn from_attributes(attrs: &[syn::Attribute]) -> Self {
76 let mut result = Self::default();
77 for attr in attrs {
78 let Some(ident) = attr.path().get_ident() else {
79 continue;
80 };
81 if ident == Self::ATTRIBUTE {
82 result.update(attr.parse_args().unwrap());
83 } else if ident == "cfg_attr" {
84 let cfg: CfgAttr = attr.parse_args().unwrap();
85 for attr in cfg.attrs {
86 let Some(ident) = attr.path().get_ident() else {
87 continue;
88 };
89 if ident == Self::ATTRIBUTE {
90 let Meta::List(attrs) = attr else {
91 panic!("Expected a list of attributes")
92 };
93 result.update(attrs.parse_args().unwrap());
94 }
95 }
96 }
97 }
98 result
99 }
100}
101
102pub(crate) struct Derive {
104 inner: Punctuated<Path, Token![,]>,
105}
106
107impl Parse for Derive {
108 fn parse(input: ParseStream) -> syn::Result<Self> {
109 Ok(Derive {
110 inner: input.parse_terminated(Path::parse_mod_style, Token![,])?,
111 })
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::repr::Repr;
119
120 #[test]
121 fn test_repr() {
122 let q = quote::quote! {
123 #[derive(sqlx::Type)]
124 #[repr(u8)]
125 pub enum Foo {
126 Bar,
127 Baz,
128 }
129 };
130 let item = syn::parse2::<syn::ItemEnum>(q).unwrap();
131 let derive = DeriveParser::from_attributes(&item.attrs);
132 let repr = Repr::from_attributes(&item.attrs).unwrap();
133 assert!(derive.has_any_derive(&["ormlite", "sqlx"], "Type"));
134 assert_eq!(repr, "u8");
135 }
136
137 #[test]
140 fn test_attributes() {
141 let code = r#"/// Json-serializable representation of query results
143#[derive(Debug, Serialize, Deserialize, Clone, sqlx::Type, ormlite::Model)]
144#[repr(u8)]
145#[ormlite(table = "result")]
146#[deprecated]
147pub struct QuerySet {
148 pub headers: Vec<String>,
149 pub rows: Vec<Vec<Value>>,
150}"#;
151 let file: syn::File = syn::parse_str(code).unwrap();
152 let syn::Item::Struct(item) = file.items.first().unwrap() else {
153 panic!("expected struct");
154 };
155 let attr = DeriveParser::from_attributes(&item.attrs);
156 let repr = Repr::from_attributes(&item.attrs).unwrap();
157 assert_eq!(repr, "u8");
158 assert!(attr.has_derive("ormlite", "Model"));
159 assert!(attr.has_any_derive(&["ormlite", "sqlx"], "Type"));
160 assert!(!attr.has_derive("ormlite", "ManualType"));
161 }
162
163 #[test]
164 fn test_cfg_attr() {
165 let code = r#"
167#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
168#[cfg_attr(
169 target_arch = "wasm32",
170 derive(tsify::Tsify),
171 tsify(into_wasm_abi, from_wasm_abi)
172)]
173#[cfg_attr(
174 not(target_arch = "wasm32"),
175 derive(
176 sqlx::Type,
177 strum::IntoStaticStr,
178 strum::EnumString,
179 ),
180 strum(serialize_all = "snake_case")
181)]
182#[serde(rename_all = "snake_case")]
183pub enum Privacy {
184 Private,
185 Team,
186 Public,
187}
188"#;
189 let file: syn::File = syn::parse_str(code).unwrap();
190 let syn::Item::Enum(item) = file.items.first().unwrap() else {
191 panic!()
192 };
193 let attr = DeriveParser::from_attributes(&item.attrs);
194 assert!(attr.has_any_derive(&["ormlite", "sqlx"], "Type"));
195 }
196
197 #[test]
198 fn test_cfg_attr2() {
199 let code = r#"
200#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
201#[cfg_attr(
202 target_arch = "wasm32",
203 derive(tsify::Tsify),
204 tsify(into_wasm_abi, from_wasm_abi)
205)]
206#[cfg_attr(
207 not(target_arch = "wasm32"),
208 derive(ormlite::types::ManualType, strum::IntoStaticStr, strum::EnumString),
209 strum(serialize_all = "snake_case")
210)]
211#[serde(rename_all = "snake_case")]
212pub enum Privacy {
213 Private,
214 Team,
215 Public,
216}
217"#;
218 let file: syn::File = syn::parse_str(code).unwrap();
219 let syn::Item::Enum(item) = file.items.first().unwrap() else {
220 panic!()
221 };
222 let attr = DeriveParser::from_attributes(&item.attrs);
223 assert_eq!(attr.has_derive("ormlite", "ManualType"), true);
224 }
225}