onebot-api-macros 0.1.2

Proc macros for onebot-api
Documentation
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, Fields, Ident, Meta, MetaNameValue, Type, TypePath};

fn to_snake_case(s: &str) -> String {
	let mut result = String::new();
	for (i, c) in s.chars().enumerate() {
		if c.is_uppercase() {
			if i > 0 {
				result.push('_');
			}
			result.push(c.to_lowercase().next().unwrap());
		} else {
			result.push(c);
		}
	}
	result
}

fn is_primitive_type(ty: &Type) -> bool {
	if let Type::Path(TypePath { qself: None, path }) = ty {
		if let Some(ident) = path.get_ident() {
			let name = ident.to_string();
			return matches!(
				name.as_str(),
				"i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64"
					| "u128" | "f32" | "f64" | "bool" | "char"
			);
		}
	}
	false
}

fn is_string_type(ty: &Type) -> bool {
	if let Type::Path(TypePath { qself: None, path }) = ty {
		if let Some(ident) = path.get_ident() {
			return ident == "String";
		}
	}
	false
}

fn extract_box_inner(ty: &Type) -> Option<&Type> {
	if let Type::Path(TypePath { qself: None, path }) = ty {
		if let Some(segment) = path.segments.first() {
			if segment.ident == "Box" {
				if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
					if args.args.len() == 1 {
						if let syn::GenericArgument::Type(inner) = &args.args[0] {
							return Some(inner);
						}
					}
				}
			}
		}
	}
	None
}

fn get_filter_param_type(ty: &Type) -> TokenStream2 {
	if let Some(inner) = extract_box_inner(ty) {
		return quote!(&#inner);
	}
	if is_string_type(ty) {
		quote!(&str)
	} else if is_primitive_type(ty) {
		quote!(#ty)
	} else {
		quote!(&#ty)
	}
}

fn get_field_access(ty: &Type, field_name: &Ident) -> TokenStream2 {
	if extract_box_inner(ty).is_some() {
		quote!(&*data.#field_name)
	} else if is_string_type(ty) {
		quote!(&data.#field_name)
	} else if is_primitive_type(ty) {
		quote!(data.#field_name)
	} else {
		quote!(&data.#field_name)
	}
}

fn parse_variants(attrs: &[syn::Attribute]) -> Option<Vec<Ident>> {
	for attr in attrs {
		if attr.path().is_ident("selector") {
			let nested = match attr
				.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)
			{
				Ok(n) => n,
				Err(_) => continue,
			};
			for meta in nested {
				if let Meta::List(list) = &meta {
					if list.path.is_ident("variants") {
						let idents = match list.parse_args_with(
							syn::punctuated::Punctuated::<Ident, syn::Token![,]>::parse_terminated,
						) {
							Ok(i) => i.into_iter().collect(),
							Err(_) => continue,
						};
						return Some(idents);
					}
				}
			}
		}
	}
	None
}

fn parse_through(attrs: &[syn::Attribute]) -> Option<Ident> {
	for attr in attrs {
		if attr.path().is_ident("selector") {
			let nested = match attr
				.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)
			{
				Ok(n) => n,
				Err(_) => continue,
			};
			for meta in nested {
				if let Meta::NameValue(MetaNameValue { path, value, .. }) = &meta {
					if path.is_ident("through") {
						if let syn::Expr::Lit(syn::ExprLit {
							lit: syn::Lit::Str(lit_str),
							..
						}) = value
						{
							let method_name = Ident::new(&lit_str.value(), lit_str.span());
							return Some(method_name);
						}
					}
				}
			}
		}
	}
	None
}

fn get_through_return_type(field_ty: &Type) -> Type {
	if let Some(inner) = extract_box_inner(field_ty) {
		inner.clone()
	} else {
		field_ty.clone()
	}
}

fn generate_struct_selector(
	name: &Ident,
	generics: &syn::Generics,
	data: &syn::DataStruct,
) -> TokenStream {
	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

	let selector_impl = quote! {
		#[cfg(feature = "selector")]
		impl #impl_generics #name #ty_generics #where_clause {
			pub const fn selector(&'_ self) -> crate::selector::Selector<'_, Self> {
				crate::selector::Selector { data: Some(self) }
			}
		}
	};

	let mut filter_methods = Vec::new();

	for field in &data.fields {
		let field_name = field.ident.as_ref().unwrap();
		let field_ty = &field.ty;
		let param_ty = get_filter_param_type(field_ty);
		let field_access = get_field_access(field_ty, field_name);

		let filter_name = format_ident!("filter_{}", field_name);
		let and_filter_name = format_ident!("and_filter_{}", field_name);
		let filter_async_name = format_ident!("filter_{}_async", field_name);
		let and_filter_async_name = format_ident!("and_filter_{}_async", field_name);

		filter_methods.push(quote! {
			pub fn #filter_name(&mut self, f: impl FnOnce(#param_ty) -> bool) {
				if let Some(data) = self.data
					&& !f(#field_access)
				{
					self.data = None
				}
			}

			pub fn #and_filter_name(mut self, f: impl FnOnce(#param_ty) -> bool) -> Self {
				self.#filter_name(f);
				self
			}

			pub async fn #filter_async_name(&mut self, f: impl AsyncFnOnce(#param_ty) -> bool) {
				if let Some(data) = self.data
					&& !f(#field_access).await
				{
					self.data = None
				}
			}

			pub async fn #and_filter_async_name(mut self, f: impl AsyncFnOnce(#param_ty) -> bool) -> Self {
				self.#filter_async_name(f).await;
				self
			}
		});

		if let Some(variants) = parse_variants(&field.attrs) {
			for variant in variants {
				let is_method = Ident::new(
					&format!("is_{}", variant),
					variant.span(),
				);
				let and_variant = format_ident!("and_{}", variant);
				let not_variant = format_ident!("not_{}", variant);
				let and_not_variant = format_ident!("and_not_{}", variant);

				filter_methods.push(quote! {
					pub const fn #variant(&mut self) {
						if let Some(data) = self.data
							&& !data.#field_name.#is_method()
						{
							self.data = None
						}
					}

					pub const fn #and_variant(mut self) -> Self {
						self.#variant();
						self
					}

					pub const fn #not_variant(&mut self) {
						if let Some(data) = self.data
							&& data.#field_name.#is_method()
						{
							self.data = None
						}
					}

					pub const fn #and_not_variant(mut self) -> Self {
						self.#not_variant();
						self
					}
				});
			}
		}

		if let Some(method_name) = parse_through(&field.attrs) {
			let return_ty = get_through_return_type(field_ty);
			let access = if extract_box_inner(field_ty).is_some() {
				quote!(self.data.map(|d| &*d.#field_name))
			} else {
				quote!(self.data.map(|d| &d.#field_name))
			};

			filter_methods.push(quote! {
				pub fn #method_name(&self) -> crate::selector::Selector<'a, #return_ty> {
					crate::selector::Selector {
						data: #access,
					}
				}
			});
		}
	}

	let selector_methods = quote! {
		#[cfg(feature = "selector")]
		impl<'a> crate::selector::Selector<'a, #name> {
			#(#filter_methods)*
		}
	};

	let expanded = quote! {
		#selector_impl
		#selector_methods
	};

	expanded.into()
}

fn generate_enum_selector(
	name: &Ident,
	generics: &syn::Generics,
	data: &syn::DataEnum,
) -> TokenStream {
	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

	let mut match_methods = Vec::new();
	let mut selector_variant_methods = Vec::new();

	for variant in &data.variants {
		let variant_name = &variant.ident;
		let snake_name_str = to_snake_case(&variant_name.to_string());
		let snake_name = Ident::new(&snake_name_str, variant_name.span());

		if let Fields::Unnamed(fields) = &variant.fields {
			if fields.unnamed.len() == 1 {
				let field_ty = &fields.unnamed[0].ty;

				let match_name = format_ident!("match_{}", snake_name);
				let on_name = format_ident!("on_{}", snake_name);
				let on_async_name = format_ident!("on_{}_async", snake_name);

				match_methods.push(quote! {
					pub const fn #match_name(&self) -> Option<&#field_ty> {
						if let Self::#variant_name(data) = self {
							Some(data)
						} else {
							None
						}
					}

					pub fn #on_name<T>(&self, handler: impl FnOnce(&#field_ty) -> T) -> Option<T> {
						if let Self::#variant_name(data) = self {
							Some(handler(data))
						} else {
							None
						}
					}

					pub async fn #on_async_name<T>(&self, handler: impl AsyncFnOnce(&#field_ty) -> T) -> Option<T> {
						if let Self::#variant_name(data) = self {
							Some(handler(data).await)
						} else {
							None
						}
					}
				});

				selector_variant_methods.push(quote! {
					pub fn #snake_name(&self) -> crate::selector::Selector<'a, #field_ty> {
						crate::selector::Selector {
							data: self.data.and_then(|d| d.#match_name()),
						}
					}
				});
			}
		}
	}

	let expanded = quote! {
		#[cfg(feature = "selector")]
		impl #impl_generics #name #ty_generics #where_clause {
			pub const fn selector(&'_ self) -> crate::selector::Selector<'_, Self> {
				crate::selector::Selector { data: Some(self) }
			}

			#(#match_methods)*
		}

		#[cfg(feature = "selector")]
		impl<'a> crate::selector::Selector<'a, #name> {
			#(#selector_variant_methods)*
		}
	};

	expanded.into()
}

pub fn derive_selector(input: TokenStream) -> TokenStream {
	let input = parse_macro_input!(input as DeriveInput);
	let name = &input.ident;
	let generics = &input.generics;

	match &input.data {
		Data::Struct(data) => generate_struct_selector(name, generics, data),
		Data::Enum(data) => generate_enum_selector(name, generics, data),
		Data::Union(_) => panic!("Selector cannot be derived for unions"),
	}
}