use std::{collections::HashSet, ops::DerefMut, sync::Arc};
use fxhash::FxHashMap;
use itertools::Itertools;
use quote::quote;
use syn::{parse_quote, Attribute};
use crate::{
db::RirDatabase,
rir::{Field, Item},
symbol::{DefId, EnumRepr},
ty::{self, Ty, Visitor},
Context, IdentName,
};
mod serde;
pub use serde::SerdePlugin;
pub trait Plugin {
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
walk_item(self, cx, def_id, item)
}
fn on_field(&mut self, cx: &mut Context, def_id: DefId, f: Arc<Field>) {
walk_filed(self, cx, def_id, f)
}
fn on_emit(&mut self, _cx: &mut Context) {}
}
pub trait ClonePlugin: Plugin {
fn clone_box(&self) -> Box<dyn ClonePlugin>;
}
pub struct BoxClonePlugin(Box<dyn ClonePlugin>);
impl BoxClonePlugin {
pub fn new<P: ClonePlugin + 'static>(p: P) -> Self {
Self(Box::new(p))
}
}
impl Clone for BoxClonePlugin {
fn clone(&self) -> Self {
Self(self.0.clone_box())
}
}
impl Plugin for BoxClonePlugin {
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
self.0.on_item(cx, def_id, item)
}
fn on_field(&mut self, cx: &mut Context, def_id: DefId, f: Arc<Field>) {
self.0.on_field(cx, def_id, f)
}
fn on_emit(&mut self, cx: &mut Context) {
self.0.on_emit(cx)
}
}
impl<T> ClonePlugin for T
where
T: Plugin + Clone + 'static,
{
fn clone_box(&self) -> Box<dyn ClonePlugin> {
Box::new(self.clone())
}
}
impl<T> Plugin for &mut T
where
T: Plugin,
{
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
(*self).on_item(cx, def_id, item)
}
fn on_field(&mut self, cx: &mut Context, def_id: DefId, f: Arc<Field>) {
(*self).on_field(cx, def_id, f)
}
fn on_emit(&mut self, cx: &mut Context) {
(*self).on_emit(cx)
}
}
#[allow(clippy::single_match)]
pub fn walk_item<P: Plugin + ?Sized>(p: &mut P, cx: &mut Context, _def_id: DefId, item: Arc<Item>) {
match &*item {
Item::Message(s) => s
.fields
.iter()
.for_each(|f| p.on_field(cx, f.did, f.clone())),
_ => {}
}
}
pub fn walk_filed<P: Plugin + ?Sized>(
_p: &mut P,
_cx: &mut Context,
_def_id: DefId,
_field: Arc<Field>,
) {
}
pub struct BoxedPlugin;
impl Plugin for BoxedPlugin {
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
if let Item::Message(s) = &*item {
s.fields.iter().for_each(|f| {
if let ty::Path(p) = &f.ty.kind {
if cx.type_graph().is_nested(p.did, def_id) {
cx.with_adjust(f.did, |adj| adj.set_boxed())
}
}
})
}
walk_item(self, cx, def_id, item)
}
}
pub struct AutoDerivePlugin<F> {
can_derive: FxHashMap<DefId, CanDerive>,
predicate: F,
attrs: Vec<Attribute>,
}
impl<F> AutoDerivePlugin<F>
where
F: Fn(&Ty) -> PredicateResult,
{
pub fn new(attrs: Vec<Attribute>, f: F) -> Self {
Self {
can_derive: FxHashMap::default(),
predicate: f,
attrs,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum CanDerive {
Yes,
No,
Delay, }
pub enum PredicateResult {
No, GoOn, }
#[derive(Default)]
pub struct PathCollector {
paths: Vec<crate::rir::Path>,
}
impl super::ty::Visitor for PathCollector {
fn visit_path(&mut self, path: &crate::rir::Path) {
self.paths.push(path.clone())
}
}
impl<F> AutoDerivePlugin<F>
where
F: Fn(&Ty) -> PredicateResult,
{
fn can_derive(
&mut self,
cx: &Context,
def_id: DefId,
visiting: &mut HashSet<DefId>,
delayed: &mut HashSet<DefId>,
) -> CanDerive {
if let Some(b) = self.can_derive.get(&def_id) {
return *b;
}
if visiting.contains(&def_id) {
return CanDerive::Delay;
}
visiting.insert(def_id);
let item = cx.expect_item(def_id);
let deps = match &*item {
Item::Message(s) => s.fields.iter().map(|f| &f.ty).collect::<Vec<_>>(),
Item::Enum(e) => e
.variants
.iter()
.flat_map(|v| &v.fields)
.collect::<Vec<_>>(),
Item::Service(_) => return CanDerive::No,
Item::NewType(t) => vec![&t.ty],
Item::Const(_) => return CanDerive::No,
Item::Mod(_) => return CanDerive::No,
};
let can_derive = if deps
.iter()
.any(|t| matches!((self.predicate)(t), PredicateResult::No))
{
CanDerive::No
} else {
let paths = deps.iter().flat_map(|t| {
let mut visitor = PathCollector::default();
visitor.visit(t);
visitor.paths
});
let paths_can_derive = paths
.map(|p| (p.did, self.can_derive(cx, p.did, visiting, delayed)))
.collect::<Vec<_>>();
let delayed_count = paths_can_derive
.iter()
.filter(|(_, p)| *p == CanDerive::Delay)
.count();
if paths_can_derive.iter().any(|(_, p)| *p == CanDerive::No) {
delayed.iter().for_each(|def_id| {
self.can_derive.insert(*def_id, CanDerive::No);
});
CanDerive::No
} else if delayed_count > 0 {
delayed.insert(def_id);
CanDerive::Delay
} else {
CanDerive::Yes
}
};
self.can_derive.insert(def_id, can_derive);
visiting.remove(&def_id);
can_derive
}
}
impl<F> Plugin for AutoDerivePlugin<F>
where
F: Fn(&Ty) -> PredicateResult,
{
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
self.can_derive(cx, def_id, &mut HashSet::default(), &mut HashSet::default());
walk_item(self, cx, def_id, item)
}
fn on_emit(&mut self, cx: &mut Context) {
self.can_derive.iter().for_each(|(def_id, can_derive)| {
if !matches!(can_derive, CanDerive::No) {
cx.with_adjust(*def_id, |adj| adj.add_attrs(&self.attrs));
}
})
}
}
impl<T> Plugin for Box<T>
where
T: Plugin + ?Sized,
{
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
self.deref_mut().on_item(cx, def_id, item)
}
fn on_field(&mut self, cx: &mut Context, def_id: DefId, f: Arc<Field>) {
self.deref_mut().on_field(cx, def_id, f)
}
fn on_emit(&mut self, cx: &mut Context) {
self.deref_mut().on_emit(cx)
}
}
pub struct WithAttrsPlugin(pub Vec<syn::Attribute>);
impl Plugin for WithAttrsPlugin {
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
match &*item {
Item::Message(_) | Item::Enum(_) | Item::NewType(_) => {
cx.with_adjust(def_id, |adj| adj.add_attrs(&self.0))
}
_ => {}
}
walk_item(self, cx, def_id, item)
}
}
pub struct ImplDefaultPlugin;
impl Plugin for ImplDefaultPlugin {
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
match &*item {
Item::Message(_) | Item::NewType(_) => cx.with_adjust(def_id, |adj| {
adj.add_attrs(&[parse_quote!(#[derive(Default)])])
}),
Item::Enum(e) => {
if !e.variants.is_empty() {
cx.with_adjust(def_id, |adj| {
adj.add_attrs(&[
parse_quote!(#[derive(::pilota::derivative::Derivative)]),
parse_quote!(#[derivative(Default)]),
]);
});
if let Some(v) = e.variants.first() {
cx.with_adjust(v.did, |adj| {
adj.add_attrs(&[parse_quote!(#[derivative(Default)])]);
})
}
}
}
_ => {}
}
walk_item(self, cx, def_id, item)
}
}
pub struct EnumNumPlugin;
impl Plugin for EnumNumPlugin {
fn on_item(&mut self, cx: &mut Context, def_id: DefId, item: Arc<Item>) {
match &*item {
Item::Enum(e) if e.repr.is_some() => {
let name_str = &*cx.rust_name(def_id);
let name = name_str.as_syn_ident();
let num_ty = match e.repr {
Some(EnumRepr::I32) => quote!(i32),
None => return,
};
let variants = e
.variants
.iter()
.map(|v| {
let variant_name_str = cx.rust_name(v.did);
let variant_name = variant_name_str.as_syn_ident();
quote!(
#variant_name => ::std::result::Result::Ok(#name::#variant_name),
)
})
.collect_vec();
let nums = e
.variants
.iter()
.map(|v| {
let variant_name_str = cx.rust_name(v.did);
let variant_name = variant_name_str.as_syn_ident();
quote!(const #variant_name: #num_ty = #name::#variant_name as #num_ty;)
})
.collect_vec();
cx.with_adjust(def_id, |adj| {
adj.add_impl(quote! {
impl ::std::convert::From<#name> for #num_ty {
fn from(e: #name) -> Self {
e as _
}
}
impl ::std::convert::TryFrom<#num_ty> for #name {
type Error = ::pilota::EnumConvertError<#num_ty>;
#[allow(non_upper_case_globals)]
fn try_from(v: i32) -> Result<Self, ::pilota::EnumConvertError<#num_ty>> {
#(#nums)*
match v {
#(
#variants
)*
_ => ::std::result::Result::Err(::pilota::EnumConvertError::InvalidNum(v, #name_str)),
}
}
}
})
});
}
_ => {}
}
walk_item(self, cx, def_id, item)
}
}