stargate_grpc_derive/
lib.rs1use proc_macro::TokenStream;
139
140use darling::util::Override;
141use darling::{ast, util, FromDeriveInput, FromField};
142use quote::quote;
143use syn::__private::TokenStream2;
144
145#[derive(Debug, FromField)]
146#[darling(attributes(stargate))]
147struct UdtField {
148 ident: Option<syn::Ident>,
149 ty: syn::Type,
150 #[darling(default)]
151 default: Option<Override<String>>,
152 #[darling(default)]
153 cql_type: Option<String>,
154 #[darling(default)]
155 skip: bool,
156 #[darling(default)]
157 name: Option<String>,
158}
159
160#[derive(Debug, FromDeriveInput)]
161struct Udt {
162 ident: syn::Ident,
163 data: ast::Data<util::Ignored, UdtField>,
164}
165
166fn get_fields(udt: ast::Data<util::Ignored, UdtField>) -> Vec<UdtField> {
167 match udt {
168 ast::Data::Struct(s) => s.fields,
169 _ => panic!("Deriving IntoValue allowed only on structs"),
170 }
171}
172
173fn field_idents(fields: &[UdtField]) -> Vec<&syn::Ident> {
174 fields.iter().map(|f| f.ident.as_ref().unwrap()).collect()
175}
176
177fn field_names(fields: &[UdtField]) -> Vec<String> {
179 fields
180 .iter()
181 .map(|f| {
182 f.name
183 .clone()
184 .unwrap_or_else(|| f.ident.as_ref().unwrap().to_string())
185 })
186 .collect()
187}
188
189fn token_stream(s: &str) -> proc_macro2::TokenStream {
190 s.parse().unwrap()
191}
192
193fn convert_to_value(obj: &syn::Ident, field: &UdtField) -> TokenStream2 {
195 let field_ident = field.ident.as_ref().unwrap();
196 match &field.cql_type {
197 Some(t) => {
198 let cql_type = token_stream(t.as_str());
199 quote! { stargate_grpc::Value::of_type(#cql_type, #obj.#field_ident) }
200 }
201 None => {
202 quote! { stargate_grpc::Value::from(#obj.#field_ident) }
203 }
204 }
205}
206
207fn convert_to_values(obj: &syn::Ident, fields: &[UdtField]) -> Vec<TokenStream2> {
209 fields.iter().map(|f| convert_to_value(obj, f)).collect()
210}
211
212#[proc_macro_derive(IntoValue, attributes(stargate))]
214pub fn derive_into_value(tokens: TokenStream) -> TokenStream {
215 let parsed = syn::parse(tokens).unwrap();
216 let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
217 let udt_type = udt.ident;
218
219 let obj = syn::Ident::new("obj", proc_macro2::Span::mixed_site());
220 let fields: Vec<_> = get_fields(udt.data)
221 .into_iter()
222 .filter(|f| !f.skip)
223 .collect();
224 let remote_field_names = field_names(&fields);
225 let field_values: Vec<_> = convert_to_values(&obj, &fields);
226
227 let result = quote! {
228 impl stargate_grpc::into_value::IntoValue<stargate_grpc::types::Udt> for #udt_type {
229 fn into_value(self) -> stargate_grpc::Value {
230 let #obj = self;
231 let mut fields = std::collections::HashMap::new();
232 #(fields.insert(#remote_field_names.to_string(), #field_values));*;
233 stargate_grpc::Value::raw_udt(fields)
234 }
235 }
236 impl stargate_grpc::into_value::DefaultCqlType for #udt_type {
237 type C = stargate_grpc::types::Udt;
238 }
239 };
240 result.into()
241}
242
243#[proc_macro_derive(IntoValues, attributes(stargate))]
245pub fn derive_into_values(tokens: TokenStream) -> TokenStream {
246 let parsed = syn::parse(tokens).unwrap();
247 let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
248 let udt_type = udt.ident;
249
250 let obj = syn::Ident::new("obj", proc_macro2::Span::mixed_site());
251 let fields: Vec<_> = get_fields(udt.data)
252 .into_iter()
253 .filter(|f| !f.skip)
254 .collect();
255 let field_names = field_names(&fields);
256 let field_values: Vec<_> = convert_to_values(&obj, &fields);
257
258 let result = quote! {
259 impl std::convert::From<#udt_type> for stargate_grpc::proto::Values {
260 fn from(#obj: #udt_type) -> Self {
261 stargate_grpc::proto::Values {
262 value_names: vec![#(#field_names.to_string()),*],
263 values: vec![#(#field_values),*]
264 }
265 }
266 }
267 };
268 result.into()
269}
270
271fn convert_from_hashmap_value(hashmap: &syn::Ident, field: &UdtField) -> TokenStream2 {
275 let field_name = field
276 .name
277 .clone()
278 .unwrap_or_else(|| field.ident.as_ref().unwrap().to_string());
279 let field_type = &field.ty;
280
281 let default_expr = match &field.default {
282 None => quote! { Err(ConversionError::field_not_found::<_, Self>(&#hashmap, #field_name)) },
283 Some(Override::Inherit) => quote! { Ok(std::default::Default::default()) },
284 Some(Override::Explicit(s)) => {
285 let path = token_stream(s);
286 quote! { Ok(#path) }
287 }
288 };
289
290 quote! {
291 match #hashmap.remove(#field_name) {
292 Some(value) => {
293 let maybe_value: Option<#field_type> = value.try_into()?;
294 match maybe_value {
295 Some(v) => Ok(v),
296 None => #default_expr
297 }
298 }
299 None => #default_expr
300 }
301 }
302}
303
304#[proc_macro_derive(TryFromValue, attributes(stargate))]
306pub fn derive_try_from_value(tokens: TokenStream) -> TokenStream {
307 let parsed = syn::parse(tokens).unwrap();
308 let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
309 let ident = udt.ident;
310 let fields = get_fields(udt.data);
311 let field_idents = field_idents(&fields);
312 let udt_hashmap = syn::Ident::new("fields", proc_macro2::Span::mixed_site());
313 let field_values = fields
314 .iter()
315 .map(|field| convert_from_hashmap_value(&udt_hashmap, field));
316
317 let result = quote! {
318
319 impl stargate_grpc::from_value::TryFromValue for #ident {
320 fn try_from(value: stargate_grpc::Value) ->
321 Result<Self, stargate_grpc::error::ConversionError>
322 {
323 use stargate_grpc::Value;
324 use stargate_grpc::error::ConversionError;
325 use stargate_grpc::proto::*;
326 match value.inner {
327 Some(value::Inner::Udt(UdtValue { mut #udt_hashmap })) => {
328 Ok(#ident {
329 #(#field_idents: #field_values?),*
330 })
331 }
332 other => Err(ConversionError::incompatible::<_, Self>(other))
333 }
334 }
335 }
336
337 impl std::convert::TryFrom<stargate_grpc::Value> for #ident {
338 type Error = stargate_grpc::error::ConversionError;
339 fn try_from(value: stargate_grpc::Value) ->
340 Result<Self, stargate_grpc::error::ConversionError>
341 {
342 <#ident as stargate_grpc::from_value::TryFromValue>::try_from(value)
343 }
344 }
345 };
346
347 result.into()
348}
349
350#[proc_macro_derive(TryFromRow, attributes(stargate))]
352pub fn derive_try_from_typed_row(tokens: TokenStream) -> TokenStream {
353 let parsed = syn::parse(tokens).unwrap();
354 let udt: Udt = Udt::from_derive_input(&parsed).unwrap();
355 let ident = udt.ident;
356 let fields = get_fields(udt.data);
357 let field_idents = field_idents(&fields);
358 let field_names = field_names(&fields);
359 let indexes = 0..field_idents.len();
360
361 let result = quote! {
362 impl stargate_grpc::result::ColumnPositions for #ident {
363 fn field_to_column_pos(
364 column_positions: std::collections::HashMap<String, usize>
365 ) -> Result<Vec<usize>, stargate_grpc::result::MapperError>
366 {
367 use stargate_grpc::result::MapperError;
368 let mut result = Vec::new();
369 #(
370 result.push(
371 *column_positions
372 .get(#field_names)
373 .ok_or_else(|| MapperError::ColumnNotFound(#field_names))?
374 );
375 )*
376 Ok(result)
377 }
378 }
379
380 impl stargate_grpc::result::TryFromRow for #ident {
381 fn try_unpack(
382 mut row: stargate_grpc::Row,
383 column_positions: &[usize]
384 ) -> Result<Self, stargate_grpc::error::ConversionError>
385 {
386 Ok(#ident {
387 #(#field_idents: row.values[column_positions[#indexes]].take().try_into()?),*
388 })
389 }
390 }
391 };
392
393 result.into()
394}