1use proc_macro::TokenStream;
6use quote::quote;
7use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, FieldsNamed};
8
9#[proc_macro_derive(WebSocketEvent)]
10pub fn websocket_event_macro_derive(input: TokenStream) -> TokenStream {
11 let ast: syn::DeriveInput = syn::parse(input).unwrap();
12
13 let name = &ast.ident;
14
15 quote! {
16 impl WebSocketEvent for #name {}
17 }
18 .into()
19}
20
21#[proc_macro_derive(Updateable)]
22pub fn updateable_macro_derive(input: TokenStream) -> TokenStream {
23 let ast: syn::DeriveInput = syn::parse(input).unwrap();
24
25 let name = &ast.ident;
26 quote! {
28 impl Updateable for #name {
29 fn id(&self) -> Snowflake {
30 self.id
31 }
32 }
33 }
34 .into()
35}
36
37#[proc_macro_derive(JsonField)]
38pub fn jsonfield_macro_derive(input: TokenStream) -> TokenStream {
39 let ast: syn::DeriveInput = syn::parse(input).unwrap();
40
41 let name = &ast.ident;
42 quote! {
44 impl JsonField for #name {
45 fn get_json(&self) -> String {
46 self.json.clone()
47 }
48 fn set_json(&mut self, json: String) {
49 self.json = json;
50 }
51 }
52 }
53 .into()
54}
55
56#[proc_macro_derive(SourceUrlField)]
57pub fn source_url_macro_derive(input: TokenStream) -> TokenStream {
58 let ast: syn::DeriveInput = syn::parse(input).unwrap();
59
60 let name = &ast.ident;
61 quote! {
63 impl SourceUrlField for #name {
64 fn get_source_url(&self) -> String {
65 self.source_url.clone()
66 }
67 fn set_source_url(&mut self, url: String) {
68 self.source_url = url;
69 }
70 }
71 }
72 .into()
73}
74
75#[proc_macro_attribute]
76pub fn observe_option(_args: TokenStream, input: TokenStream) -> TokenStream {
77 input
78}
79
80#[proc_macro_attribute]
81pub fn observe_option_vec(_args: TokenStream, input: TokenStream) -> TokenStream {
82 input
83}
84
85#[proc_macro_attribute]
86pub fn observe(_args: TokenStream, input: TokenStream) -> TokenStream {
87 input
88}
89
90#[proc_macro_attribute]
91pub fn observe_vec(_args: TokenStream, input: TokenStream) -> TokenStream {
92 input
93}
94
95#[proc_macro_derive(
96 Composite,
97 attributes(observe_option_vec, observe_option, observe, observe_vec)
98)]
99pub fn composite_derive(input: TokenStream) -> TokenStream {
100 let input = parse_macro_input!(input as DeriveInput);
101
102 let process_field = |field: &Field| {
103 let field_name = &field.ident;
104 let attrs = &field.attrs;
105
106 let observe_option = attrs
107 .iter()
108 .any(|attr| attr.path().is_ident("observe_option"));
109 let observe_option_vec = attrs
110 .iter()
111 .any(|attr| attr.path().is_ident("observe_option_vec"));
112 let observe = attrs.iter().any(|attr| attr.path().is_ident("observe"));
113 let observe_vec = attrs.iter().any(|attr| attr.path().is_ident("observe_vec"));
114
115 match (observe_option, observe_option_vec, observe, observe_vec) {
116 (true, _, _, _) => quote! {
117 #field_name: Self::option_observe_fn(self.#field_name, gateway).await
118 },
119 (_, true, _, _) => quote! {
120 #field_name: Self::option_vec_observe_fn(self.#field_name, gateway).await
121 },
122 (_, _, true, _) => quote! {
123 #field_name: Self::value_observe_fn(self.#field_name, gateway).await
124 },
125 (_, _, _, true) => quote! {
126 #field_name: Self::vec_observe_fn(self.#field_name, gateway).await
127 },
128 _ => quote! {
129 #field_name: self.#field_name
130 },
131 }
132 };
133
134 match &input.data {
135 Data::Struct(data) => match &data.fields {
136 Fields::Named(FieldsNamed { named, .. }) => {
137 let field_exprs = named.iter().map(process_field);
138
139 let ident = &input.ident;
140 let expanded = quote! {
141 #[async_trait::async_trait(?Send)]
142 impl<T: Updateable + Clone + Debug> Composite<T> for #ident {
143 async fn watch_whole(self, gateway: &GatewayHandle) -> Self {
144 Self {
145 #(#field_exprs,)*
146 }
147 }
148 }
149 };
150
151 TokenStream::from(expanded)
152 }
153 _ => panic!("Composite derive macro only supports named fields"),
154 },
155 _ => panic!("Composite derive macro only supports structs"),
156 }
157}
158
159#[proc_macro_derive(SqlxBitFlags)]
160pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream {
161 let ast: syn::DeriveInput = syn::parse(input).unwrap();
162
163 let name = &ast.ident;
164
165 quote!{
166 #[cfg(feature = "sqlx")]
167 impl sqlx::Type<sqlx::Postgres> for #name {
168 fn type_info() -> sqlx::postgres::PgTypeInfo {
169 <sqlx_pg_uint::PgU64 as sqlx::Type<sqlx::Postgres>>::type_info()
170 }
171 }
172
173 #[cfg(feature = "sqlx")]
174 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for #name {
175 fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
176 <sqlx_pg_uint::PgU64 as sqlx::Encode<sqlx::Postgres>>::encode_by_ref(&self.bits().into(), buf)
177 }
178 }
179
180 #[cfg(feature = "sqlx")]
181 impl<'q> sqlx::Decode<'q, sqlx::Postgres> for #name {
182 fn decode(value: <sqlx::Postgres as sqlx::Database>::ValueRef<'q>) -> Result<Self, sqlx::error::BoxDynError> {
183 <sqlx_pg_uint::PgU64 as sqlx::Decode<sqlx::Postgres>>::decode(value).map(|v| Self::from_bits_truncate(v.to_uint()))
184 }
185 }
186 }
187 .into()
188}
189
190#[proc_macro_derive(SerdeBitFlags)]
191pub fn serde_bitflag_derive(input: TokenStream) -> TokenStream {
192 let ast: syn::DeriveInput = syn::parse(input).unwrap();
193
194 let name = &ast.ident;
195
196 quote! {
197 impl std::str::FromStr for #name {
198 type Err = std::num::ParseIntError;
199
200 fn from_str(s: &str) -> Result<#name, Self::Err> {
201 s.parse::<u64>().map(#name::from_bits).map(|f| f.unwrap_or(#name::empty()))
202 }
203 }
204
205 impl serde::Serialize for #name {
206 fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
207 serializer.serialize_str(&self.bits().to_string())
208 }
209 }
210
211 impl<'de> serde::Deserialize<'de> for #name {
212 fn deserialize<D>(deserializer: D) -> Result<#name, D::Error> where D: serde::de::Deserializer<'de> + Sized {
213 let s = crate::types::serde::string_or_u64(deserializer)?;
215
216 Ok(Self::from_bits_truncate(s))
219 }
220 }
221 }
222 .into()
223}