chorus_macros/
lib.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, FieldsNamed};
8
9#[proc_macro_derive(WebSocketEvent)]
10pub fn websocket_event_macro_derive(input: TokenStream) -> TokenStream {
11    let ast: syn::DeriveInput = syn::parse(input).unwrap();
12
13    let name = &ast.ident;
14
15    quote! {
16        impl WebSocketEvent for #name {}
17    }
18    .into()
19}
20
21#[proc_macro_derive(Updateable)]
22pub fn updateable_macro_derive(input: TokenStream) -> TokenStream {
23    let ast: syn::DeriveInput = syn::parse(input).unwrap();
24
25    let name = &ast.ident;
26    // No need for macro hygiene, we're only using this in chorus
27    quote! {
28        impl Updateable for #name {
29            fn id(&self) -> Snowflake {
30                self.id
31            }
32        }
33    }
34    .into()
35}
36
37#[proc_macro_derive(JsonField)]
38pub fn jsonfield_macro_derive(input: TokenStream) -> TokenStream {
39    let ast: syn::DeriveInput = syn::parse(input).unwrap();
40
41    let name = &ast.ident;
42    // No need for macro hygiene, we're only using this in chorus
43    quote! {
44        impl JsonField for #name {
45            fn get_json(&self) -> String {
46                self.json.clone()
47            }
48            fn set_json(&mut self, json: String) {
49                self.json = json;
50            }
51        }
52    }
53    .into()
54}
55
56#[proc_macro_derive(SourceUrlField)]
57pub fn source_url_macro_derive(input: TokenStream) -> TokenStream {
58    let ast: syn::DeriveInput = syn::parse(input).unwrap();
59
60    let name = &ast.ident;
61    // No need for macro hygiene, we're only using this in chorus
62    quote! {
63        impl SourceUrlField for #name {
64            fn get_source_url(&self) -> String {
65                self.source_url.clone()
66            }
67            fn set_source_url(&mut self, url: String) {
68                self.source_url = url;
69            }
70        }
71    }
72    .into()
73}
74
75#[proc_macro_attribute]
76pub fn observe_option(_args: TokenStream, input: TokenStream) -> TokenStream {
77    input
78}
79
80#[proc_macro_attribute]
81pub fn observe_option_vec(_args: TokenStream, input: TokenStream) -> TokenStream {
82    input
83}
84
85#[proc_macro_attribute]
86pub fn observe(_args: TokenStream, input: TokenStream) -> TokenStream {
87    input
88}
89
90#[proc_macro_attribute]
91pub fn observe_vec(_args: TokenStream, input: TokenStream) -> TokenStream {
92    input
93}
94
95#[proc_macro_derive(
96    Composite,
97    attributes(observe_option_vec, observe_option, observe, observe_vec)
98)]
99pub fn composite_derive(input: TokenStream) -> TokenStream {
100    let input = parse_macro_input!(input as DeriveInput);
101
102    let process_field = |field: &Field| {
103        let field_name = &field.ident;
104        let attrs = &field.attrs;
105
106        let observe_option = attrs
107            .iter()
108            .any(|attr| attr.path().is_ident("observe_option"));
109        let observe_option_vec = attrs
110            .iter()
111            .any(|attr| attr.path().is_ident("observe_option_vec"));
112        let observe = attrs.iter().any(|attr| attr.path().is_ident("observe"));
113        let observe_vec = attrs.iter().any(|attr| attr.path().is_ident("observe_vec"));
114
115        match (observe_option, observe_option_vec, observe, observe_vec) {
116            (true, _, _, _) => quote! {
117                #field_name: Self::option_observe_fn(self.#field_name, gateway).await
118            },
119            (_, true, _, _) => quote! {
120                #field_name: Self::option_vec_observe_fn(self.#field_name, gateway).await
121            },
122            (_, _, true, _) => quote! {
123                #field_name: Self::value_observe_fn(self.#field_name, gateway).await
124            },
125            (_, _, _, true) => quote! {
126                #field_name: Self::vec_observe_fn(self.#field_name, gateway).await
127            },
128            _ => quote! {
129                #field_name: self.#field_name
130            },
131        }
132    };
133
134    match &input.data {
135        Data::Struct(data) => match &data.fields {
136            Fields::Named(FieldsNamed { named, .. }) => {
137                let field_exprs = named.iter().map(process_field);
138
139                let ident = &input.ident;
140                let expanded = quote! {
141                    #[async_trait::async_trait(?Send)]
142                    impl<T: Updateable + Clone + Debug> Composite<T> for #ident {
143                        async fn watch_whole(self, gateway: &GatewayHandle) -> Self {
144                            Self {
145                                #(#field_exprs,)*
146                            }
147                        }
148                    }
149                };
150
151                TokenStream::from(expanded)
152            }
153            _ => panic!("Composite derive macro only supports named fields"),
154        },
155        _ => panic!("Composite derive macro only supports structs"),
156    }
157}
158
159#[proc_macro_derive(SqlxBitFlags)]
160pub fn sqlx_bitflag_derive(input: TokenStream) -> TokenStream {
161    let ast: syn::DeriveInput = syn::parse(input).unwrap();
162
163    let name = &ast.ident;
164
165    quote!{
166        #[cfg(feature = "sqlx")]
167        impl sqlx::Type<sqlx::Postgres> for #name {
168            fn type_info() -> sqlx::postgres::PgTypeInfo {
169                <sqlx_pg_uint::PgU64 as sqlx::Type<sqlx::Postgres>>::type_info()
170            }
171        }
172
173        #[cfg(feature = "sqlx")]
174        impl<'q> sqlx::Encode<'q, sqlx::Postgres> for #name {
175            fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
176                <sqlx_pg_uint::PgU64 as sqlx::Encode<sqlx::Postgres>>::encode_by_ref(&self.bits().into(), buf)
177            }
178        }
179
180        #[cfg(feature = "sqlx")]
181        impl<'q> sqlx::Decode<'q, sqlx::Postgres> for #name {
182            fn decode(value: <sqlx::Postgres as sqlx::Database>::ValueRef<'q>) -> Result<Self, sqlx::error::BoxDynError> {
183                <sqlx_pg_uint::PgU64 as sqlx::Decode<sqlx::Postgres>>::decode(value).map(|v| Self::from_bits_truncate(v.to_uint()))
184            }
185        }
186    }
187    .into()
188}
189
190#[proc_macro_derive(SerdeBitFlags)]
191pub fn serde_bitflag_derive(input: TokenStream) -> TokenStream {
192    let ast: syn::DeriveInput = syn::parse(input).unwrap();
193
194    let name = &ast.ident;
195
196    quote! {
197        impl std::str::FromStr for #name {
198            type Err = std::num::ParseIntError;
199
200            fn from_str(s: &str) -> Result<#name, Self::Err> {
201                s.parse::<u64>().map(#name::from_bits).map(|f| f.unwrap_or(#name::empty()))
202            }
203        }
204
205        impl serde::Serialize for #name {
206            fn serialize<S: serde::ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
207                serializer.serialize_str(&self.bits().to_string())
208            }
209        }
210
211        impl<'de> serde::Deserialize<'de> for #name {
212            fn deserialize<D>(deserializer: D) -> Result<#name, D::Error> where D: serde::de::Deserializer<'de> + Sized {
213                // let s = String::deserialize(deserializer)?.parse::<u64>().map_err(serde::de::Error::custom)?;
214                let s = crate::types::serde::string_or_u64(deserializer)?;
215
216                // Note: while truncating may not be ideal, it's better than a panic if there are
217                // extra flags
218                Ok(Self::from_bits_truncate(s))
219            }
220        }
221    }
222    .into()
223}