use syn::{
punctuated::{Iter, Punctuated},
spanned::Spanned,
token, ExprRange, Field, Ident, Token,
};
use quote::ToTokens;
use crate::{
dispatch::{Blueprint, BlueprintsMap},
error::Diagnostic,
utils::UniqueHashId,
};
use super::{ComparablePats, PredicateType, WhereClause, WherePredicate};
mod boilerplate;
mod parse;
mod to_tokens;
pub type PunctuatedParameters = Punctuated<PatFieldKind, Token![,]>;
#[derive(Default, Debug)]
pub struct PenumExpr {
pub pattern: Vec<PatFrag>,
pub clause: Option<WhereClause>,
}
#[derive(Debug)]
pub struct PatFrag {
pub ident: Option<Ident>,
pub group: PatComposite,
}
#[derive(Debug)]
pub enum PatComposite {
Named {
parameters: PunctuatedParameters,
delimiter: token::Brace,
},
Unnamed {
parameters: PunctuatedParameters,
delimiter: token::Paren,
},
Unit,
Inferred,
}
#[derive(Debug)]
pub enum PatFieldKind {
Infer,
Field(Field),
Variadic(Token![..]),
Range(ExprRange),
Nothing,
}
impl PenumExpr {
pub fn pattern_to_string(&self) -> String {
self.pattern
.iter()
.map(|s| s.to_token_stream().to_string())
.reduce(|acc, s| {
acc.is_empty()
.then(|| s.clone())
.unwrap_or_else(|| format!("{acc} | {s}"))
})
.unwrap()
}
pub fn get_comparable_patterns(&self) -> ComparablePats {
self.into()
}
pub fn has_predicates(&self) -> bool {
matches!(&self.clause, Some(wc) if !wc.predicates.is_empty())
}
pub fn has_clause(&self) -> bool {
self.clause.is_some()
}
pub fn get_blueprints_map(&self, error: &mut Diagnostic) -> Option<BlueprintsMap> {
let Some(clause) = self.clause.as_ref() else {
return None;
};
let mut polymap = BlueprintsMap::default();
for pred in clause.predicates.iter() {
if let WherePredicate::Type(pred_ty) = pred {
let mut blueprints = Vec::<Blueprint>::default();
for param_bound in pred_ty.bounds.iter() {
if let Some(trait_bound) = param_bound.get_dispatchable_trait_bound() {
match Blueprint::try_from(trait_bound) {
Ok(blueprint) => blueprints.push(blueprint),
Err(err) => error.extend(trait_bound.span(), err),
}
}
}
if blueprints.is_empty() {
return None;
}
let ty = UniqueHashId(pred_ty.bounded_ty.clone());
if let Some(entry) = polymap.get_mut(&ty) {
entry.append(&mut blueprints);
} else {
polymap.insert(ty, blueprints);
}
}
}
(!polymap.is_empty()).then_some(polymap)
}
pub fn find_predicate(
&self,
f: impl Fn(&PredicateType) -> Option<&PredicateType>,
) -> Option<&PredicateType> {
if self.has_predicates() {
unsafe { self.clause.as_ref().unwrap_unchecked() }
.predicates
.iter()
.find_map(|pred| match pred {
WherePredicate::Type(pred_ty) => f(pred_ty),
_ => None,
})
} else {
None
}
}
}
impl PatFieldKind {
pub fn is_field(&self) -> bool {
matches!(self, PatFieldKind::Field(_))
}
pub fn is_variadic(&self) -> bool {
matches!(self, PatFieldKind::Variadic(_))
}
pub fn is_range(&self) -> bool {
matches!(self, PatFieldKind::Range(_))
}
pub fn is_infer(&self) -> bool {
matches!(self, PatFieldKind::Infer)
}
pub fn get_field(&self) -> Option<&Field> {
match self {
PatFieldKind::Field(field) => Some(field),
_ => None,
}
}
}
impl PatComposite {
pub fn len(&self) -> usize {
match self {
PatComposite::Named { parameters, .. } => parameters.len(),
PatComposite::Unnamed { parameters, .. } => parameters.len(),
_ => 0,
}
}
pub fn iter(&self) -> Iter<'_, PatFieldKind> {
thread_local! {static EMPTY_SLICE_ITER: Punctuated<PatFieldKind, ()> = Punctuated::new();}
match self {
PatComposite::Named { parameters, .. } => parameters.iter(),
PatComposite::Unnamed { parameters, .. } => parameters.iter(),
_ => EMPTY_SLICE_ITER.with(|f| unsafe { std::mem::transmute(f.iter()) }),
}
}
pub fn is_unit(&self) -> bool {
matches!(self, PatComposite::Unit)
}
pub fn has_variadic(&self) -> bool {
match self {
PatComposite::Named { parameters, .. } => parameters.iter().any(|fk| fk.is_variadic()),
PatComposite::Unnamed { parameters, .. } => {
parameters.iter().any(|fk| fk.is_variadic())
}
_ => false,
}
}
pub fn get_variadic_position(&self) -> Option<usize> {
match self {
PatComposite::Named { parameters, .. } => parameters
.iter()
.enumerate()
.find_map(|(pos, fk)| fk.is_variadic().then_some(pos)),
PatComposite::Unnamed { parameters, .. } => parameters
.iter()
.enumerate()
.find_map(|(pos, fk)| fk.is_variadic().then_some(pos)),
_ => None,
}
}
pub fn has_last_variadic(&self) -> bool {
match self {
PatComposite::Named { parameters, .. } => {
matches!(parameters.iter().last().take(), Some(val) if val.is_variadic())
}
PatComposite::Unnamed { parameters, .. } => {
matches!(parameters.iter().last().take(), Some(val) if val.is_variadic())
}
_ => false,
}
}
pub fn count_with(&self, mut f: impl FnMut(&PatFieldKind) -> bool) -> usize {
match self {
PatComposite::Named { parameters, .. } => {
parameters
.iter()
.fold(0, |acc, fk| if f(fk) { acc + 1 } else { acc })
}
PatComposite::Unnamed { parameters, .. } => {
parameters
.iter()
.fold(0, |acc, fk| if f(fk) { acc + 1 } else { acc })
}
_ => 0,
}
}
}