use core::ptr::{DynMetadata, Pointee};
use smallvec::SmallVec;
use crate::{
Attribute, AttributeRef, AttributeRegistration, Op, OpFoldResult, OpOperand, Operation,
OperationRef, UnsafeIntrusiveEntityRef, ValueRef, attributes::AttributeValue,
};
pub trait Matcher<T: ?Sized> {
type Matched;
fn matches(&self, entity: &T) -> Option<Self::Matched>;
}
#[repr(transparent)]
pub struct MatchWith<F>(pub F);
impl<F, T: ?Sized, U> Matcher<T> for MatchWith<F>
where
F: Fn(&T) -> Option<U>,
{
type Matched = U;
#[inline(always)]
fn matches(&self, entity: &T) -> Option<Self::Matched> {
(self.0)(entity)
}
}
pub struct AndMatcher<A, B> {
a: A,
b: B,
}
impl<A, B> AndMatcher<A, B> {
pub const fn new(a: A, b: B) -> Self {
Self { a, b }
}
}
impl<T, A, B> Matcher<T> for AndMatcher<A, B>
where
A: Matcher<T>,
B: Matcher<T>,
{
type Matched = <B as Matcher<T>>::Matched;
#[inline]
fn matches(&self, entity: &T) -> Option<Self::Matched> {
self.a.matches(entity).and_then(|_| self.b.matches(entity))
}
}
pub struct ChainMatcher<A, B> {
a: A,
b: B,
}
impl<A, B> ChainMatcher<A, B> {
pub const fn new(a: A, b: B) -> Self {
Self { a, b }
}
}
impl<T, U, A, B> Matcher<T> for ChainMatcher<A, B>
where
A: Matcher<T, Matched = U>,
B: Matcher<U>,
{
type Matched = <B as Matcher<U>>::Matched;
#[inline]
fn matches(&self, entity: &T) -> Option<Self::Matched> {
self.a.matches(entity).and_then(|matched| self.b.matches(&matched))
}
}
pub struct OpTraitMatcher<Trait: ?Sized> {
_marker: core::marker::PhantomData<Trait>,
}
impl<Trait> Default for OpTraitMatcher<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<Trait> OpTraitMatcher<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
pub const fn new() -> Self {
Self {
_marker: core::marker::PhantomData,
}
}
}
impl<Trait> Matcher<Operation> for OpTraitMatcher<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
type Matched = UnsafeIntrusiveEntityRef<Trait>;
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
entity.as_operation_ref().as_trait_ref::<Trait>()
}
}
pub struct HasTraitMatcher<Trait: ?Sized> {
_marker: core::marker::PhantomData<Trait>,
}
impl<Trait> Default for HasTraitMatcher<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<Trait> HasTraitMatcher<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
pub const fn new() -> Self {
Self {
_marker: core::marker::PhantomData,
}
}
}
impl<Trait> Matcher<Operation> for HasTraitMatcher<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
type Matched = OperationRef;
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
if !entity.implements::<Trait>() {
return None;
}
Some(entity.as_operation_ref())
}
}
pub struct OpAttrMatcher<M> {
name: &'static str,
matcher: M,
}
impl<M> OpAttrMatcher<M>
where
M: Matcher<dyn Attribute>,
{
pub const fn new(name: &'static str, matcher: M) -> Self {
Self { name, matcher }
}
}
impl<M> Matcher<Operation> for OpAttrMatcher<M>
where
M: Matcher<dyn Attribute>,
{
type Matched = <M as Matcher<dyn Attribute>>::Matched;
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
entity
.get_attribute(self.name)
.and_then(|value| self.matcher.matches(&*value.borrow()))
}
}
pub type TypedOpAttrMatcher<A> = OpAttrMatcher<TypedAttrMatcher<A>>;
pub struct TypedAttrMatcher<A>(core::marker::PhantomData<A>);
impl<A: AttributeRegistration + Clone> Default for TypedAttrMatcher<A> {
#[inline(always)]
fn default() -> Self {
Self(core::marker::PhantomData)
}
}
impl<A: AttributeRegistration + Clone> Matcher<dyn Attribute> for TypedAttrMatcher<A> {
type Matched = A;
#[inline]
fn matches(&self, entity: &dyn Attribute) -> Option<Self::Matched> {
entity.downcast_ref::<A>().cloned()
}
}
struct AnyOpMatcher;
impl Matcher<Operation> for AnyOpMatcher {
type Matched = OperationRef;
#[inline(always)]
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
Some(entity.as_operation_ref())
}
}
struct OneOpMatcher<T>(core::marker::PhantomData<T>);
impl<T: Op> OneOpMatcher<T> {
pub const fn new() -> Self {
Self(core::marker::PhantomData)
}
}
impl<T: Op> Matcher<Operation> for OneOpMatcher<T> {
type Matched = UnsafeIntrusiveEntityRef<T>;
#[inline(always)]
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
entity.as_operation_ref().try_downcast_op::<T>().ok()
}
}
struct AnyValueMatcher;
impl Matcher<ValueRef> for AnyValueMatcher {
type Matched = ValueRef;
#[inline(always)]
fn matches(&self, entity: &ValueRef) -> Option<Self::Matched> {
Some(*entity)
}
}
struct ExactValueMatcher(ValueRef);
impl Matcher<ValueRef> for ExactValueMatcher {
type Matched = ValueRef;
#[inline(always)]
fn matches(&self, entity: &ValueRef) -> Option<Self::Matched> {
if ValueRef::ptr_eq(&self.0, entity) {
Some(*entity)
} else {
None
}
}
}
type ConstantOpMatcher = HasTraitMatcher<dyn crate::traits::ConstantLike>;
#[derive(Default)]
struct ConstantOpBinder;
impl Matcher<Operation> for ConstantOpBinder {
type Matched = AttributeRef;
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
use crate::traits::Foldable;
if !entity.implements::<dyn crate::traits::ConstantLike>() {
return None;
}
let mut out = SmallVec::default();
entity.fold(&mut out).unwrap_or_else(|| {
panic!("expected constant-like op '{}' to be foldable", entity.name())
});
let Some(OpFoldResult::Attribute(value)) = out.pop() else {
return None;
};
Some(value)
}
}
struct ImplementsConstantOpBinder<Trait: ?Sized + 'static>(
core::marker::PhantomData<&'static Trait>,
);
impl<Trait: ?Sized + 'static> ImplementsConstantOpBinder<Trait> {
pub const fn new() -> Self {
Self(core::marker::PhantomData)
}
}
impl<Trait> Matcher<Operation> for ImplementsConstantOpBinder<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
type Matched = UnsafeIntrusiveEntityRef<Trait>;
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
ConstantOpBinder.matches(entity).and_then(|attr| attr.as_trait_ref::<Trait>())
}
}
struct TypedConstantOpBinder<T>(core::marker::PhantomData<T>);
impl<T> TypedConstantOpBinder<T> {
pub const fn new() -> Self {
Self(core::marker::PhantomData)
}
}
impl<T: AttributeValue + Clone> Matcher<Operation> for TypedConstantOpBinder<T> {
type Matched = T;
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
ConstantOpBinder.matches(entity).and_then(|attr| {
let attr = attr.borrow();
attr.value().downcast_ref::<T>().cloned()
})
}
}
#[derive(Default)]
struct UnaryOpBinder;
impl Matcher<Operation> for UnaryOpBinder {
type Matched = OpOperand;
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
if !entity.implements::<dyn crate::traits::UnaryOp>() {
return None;
}
Some(entity.operands()[0].borrow().as_operand_ref())
}
}
#[derive(Default)]
struct BinaryOpBinder;
impl Matcher<Operation> for BinaryOpBinder {
type Matched = [OpOperand; 2];
fn matches(&self, entity: &Operation) -> Option<Self::Matched> {
if !entity.implements::<dyn crate::traits::BinaryOp>() {
return None;
}
let operands = entity.operands();
let lhs = operands[0].borrow().as_operand_ref();
let rhs = operands[1].borrow().as_operand_ref();
Some([lhs, rhs])
}
}
struct FoldResultBinder;
impl Matcher<OpOperand> for FoldResultBinder {
type Matched = OpFoldResult;
fn matches(&self, operand: &OpOperand) -> Option<Self::Matched> {
let operand = operand.borrow();
let maybe_constant = operand
.value()
.get_defining_op()
.and_then(|defining_op| constant().matches(&defining_op.borrow()));
if let Some(const_operand) = maybe_constant {
Some(OpFoldResult::Attribute(const_operand))
} else {
Some(OpFoldResult::Value(operand.as_value_ref()))
}
}
}
struct BinaryFoldResultBinder;
impl Matcher<[OpOperand; 2]> for BinaryFoldResultBinder {
type Matched = [OpFoldResult; 2];
fn matches(&self, operands: &[OpOperand; 2]) -> Option<Self::Matched> {
let binder = FoldResultBinder;
let lhs = binder.matches(&operands[0]).unwrap();
let rhs = binder.matches(&operands[1]).unwrap();
Some([lhs, rhs])
}
}
struct FoldableOperandBinder;
impl Matcher<OpOperand> for FoldableOperandBinder {
type Matched = AttributeRef;
fn matches(&self, operand: &OpOperand) -> Option<Self::Matched> {
let operand = operand.borrow();
let defining_op = operand.value().get_defining_op()?;
constant().matches(&defining_op.borrow())
}
}
struct TypedFoldableOperandBinder<T>(core::marker::PhantomData<T>);
impl<T> Default for TypedFoldableOperandBinder<T> {
fn default() -> Self {
Self(core::marker::PhantomData)
}
}
impl<T: AttributeRegistration> Matcher<OpOperand> for TypedFoldableOperandBinder<T> {
type Matched = UnsafeIntrusiveEntityRef<T>;
fn matches(&self, operand: &OpOperand) -> Option<Self::Matched> {
FoldableOperandBinder
.matches(operand)
.and_then(|value| value.try_downcast_attr::<T>().ok())
}
}
struct ImplementsFoldableOperandBinder<Trait: ?Sized + 'static>(
core::marker::PhantomData<&'static Trait>,
);
impl<Trait: ?Sized + 'static> Default for ImplementsFoldableOperandBinder<Trait> {
fn default() -> Self {
Self(core::marker::PhantomData)
}
}
impl<Trait> Matcher<OpOperand> for ImplementsFoldableOperandBinder<Trait>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
type Matched = UnsafeIntrusiveEntityRef<Trait>;
fn matches(&self, operand: &OpOperand) -> Option<Self::Matched> {
FoldableOperandBinder
.matches(operand)
.and_then(|attr| attr.as_trait_ref::<Trait>())
}
}
struct FoldableBinaryOpBinder;
impl Matcher<[OpOperand; 2]> for FoldableBinaryOpBinder {
type Matched = [AttributeRef; 2];
fn matches(&self, operands: &[OpOperand; 2]) -> Option<Self::Matched> {
let binder = FoldableOperandBinder;
let lhs = binder.matches(&operands[0])?;
let rhs = binder.matches(&operands[1])?;
Some([lhs, rhs])
}
}
pub const fn match_both<A, B>(
a: A,
b: B,
) -> impl Matcher<Operation, Matched = <B as Matcher<Operation>>::Matched>
where
A: Matcher<Operation>,
B: Matcher<Operation>,
{
AndMatcher::new(a, b)
}
pub const fn match_chain<T, A, B>(
a: A,
b: B,
) -> impl Matcher<Operation, Matched = <B as Matcher<T>>::Matched>
where
A: Matcher<Operation, Matched = T>,
B: Matcher<T>,
{
ChainMatcher::new(a, b)
}
pub const fn match_any() -> impl Matcher<Operation, Matched = OperationRef> {
AnyOpMatcher
}
pub const fn match_op<T: Op>() -> impl Matcher<Operation, Matched = UnsafeIntrusiveEntityRef<T>> {
OneOpMatcher::<T>::new()
}
pub const fn constant_like() -> impl Matcher<Operation, Matched = OperationRef> {
ConstantOpMatcher::new()
}
pub const fn constant() -> impl Matcher<Operation, Matched = AttributeRef> {
ConstantOpBinder
}
pub const fn constant_of<T: AttributeValue + Clone>() -> impl Matcher<Operation, Matched = T> {
TypedConstantOpBinder::<T>::new()
}
pub const fn constant_of_trait<Trait>()
-> impl Matcher<Operation, Matched = UnsafeIntrusiveEntityRef<Trait>>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
ImplementsConstantOpBinder::<Trait>::new()
}
pub const fn unary() -> impl Matcher<Operation, Matched = OpOperand> {
UnaryOpBinder
}
pub const fn unary_fold_result() -> impl Matcher<Operation, Matched = OpFoldResult> {
match_chain(UnaryOpBinder, FoldResultBinder)
}
pub const fn unary_foldable() -> impl Matcher<Operation, Matched = AttributeRef> {
match_chain(UnaryOpBinder, FoldableOperandBinder)
}
pub const fn binary() -> impl Matcher<Operation, Matched = [OpOperand; 2]> {
BinaryOpBinder
}
pub const fn binary_fold_results() -> impl Matcher<Operation, Matched = [OpFoldResult; 2]> {
match_chain(BinaryOpBinder, BinaryFoldResultBinder)
}
pub const fn binary_foldable() -> impl Matcher<Operation, Matched = [AttributeRef; 2]> {
match_chain(BinaryOpBinder, FoldableBinaryOpBinder)
}
pub const fn match_any_value() -> impl Matcher<ValueRef, Matched = ValueRef> {
AnyValueMatcher
}
pub const fn match_value(value: ValueRef) -> impl Matcher<ValueRef, Matched = ValueRef> {
ExactValueMatcher(value)
}
pub const fn foldable_operand() -> impl Matcher<OpOperand, Matched = AttributeRef> {
FoldableOperandBinder
}
pub const fn foldable_operand_of<T>()
-> impl Matcher<OpOperand, Matched = UnsafeIntrusiveEntityRef<T>>
where
T: AttributeRegistration,
{
TypedFoldableOperandBinder(core::marker::PhantomData)
}
pub const fn foldable_operand_of_trait<Trait>()
-> impl Matcher<OpOperand, Matched = UnsafeIntrusiveEntityRef<Trait>>
where
Trait: ?Sized + Pointee<Metadata = DynMetadata<Trait>> + 'static,
{
ImplementsFoldableOperandBinder(core::marker::PhantomData)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
dialects::{builtin::*, test::*},
testing::Test,
*,
};
#[test]
fn matcher_match_any_value() {
let mut test = Test::default();
let (lhs, rhs, sum) = setup("matcher_match_any_value", &mut test);
for value in [&lhs, &rhs, &sum] {
assert_eq!(match_any_value().matches(value).as_ref(), Some(value));
}
}
#[test]
fn matcher_match_value() {
let mut test = Test::default();
let (lhs, rhs, sum) = setup("matcher_match_value", &mut test);
for value in [&lhs, &rhs, &sum] {
assert_eq!(match_value(*value).matches(value).as_ref(), Some(value));
}
}
#[test]
fn matcher_match_any() {
let mut test = Test::default();
let (lhs, _rhs, sum) = setup("matcher_match_any", &mut test);
let lhs_op = lhs.borrow().get_defining_op().unwrap();
let sum_op = sum.borrow().get_defining_op().unwrap();
for op in [&lhs_op, &sum_op] {
assert_eq!(match_any().matches(&op.borrow()).as_ref(), Some(op));
}
}
#[test]
fn matcher_match_op() {
let mut test = Test::default();
let (lhs, rhs, sum) = setup("matcher_match_op", &mut test);
let lhs_op = lhs.borrow().get_defining_op().unwrap();
let sum_op = sum.borrow().get_defining_op().unwrap();
assert!(rhs.borrow().get_defining_op().is_none());
assert!(match_op::<Constant>().matches(&lhs_op.borrow()).is_some());
assert!(match_op::<Constant>().matches(&sum_op.borrow()).is_none());
assert!(match_op::<Add>().matches(&lhs_op.borrow()).is_none());
assert!(match_op::<Add>().matches(&sum_op.borrow()).is_some());
}
#[test]
fn matcher_match_both() {
let mut test = Test::default();
let (lhs, _rhs, _sum) = setup("matcher_match_both", &mut test);
let lhs_op = lhs.borrow().get_defining_op().unwrap();
assert!(
match_both(match_op::<Add>(), constant_of::<Immediate>())
.matches(&lhs_op.borrow())
.is_none()
);
assert!(
match_both(constant_like(), constant_of::<bool>())
.matches(&lhs_op.borrow())
.is_none()
);
assert!(
match_both(constant_like(), constant_of::<Immediate>())
.matches(&lhs_op.borrow())
.is_some()
);
}
#[test]
fn matcher_match_chain() {
let mut test = Test::default();
let (_, rhs, sum) = setup("matcher_match_chain", &mut test);
let sum_op = sum.borrow().get_defining_op().unwrap();
let [lhs_fr, rhs_fr] = binary_fold_results()
.matches(&sum_op.borrow())
.expect("expected to bind both operands of 'add'");
let lhs = match lhs_fr {
OpFoldResult::Attribute(attr) => attr.borrow().as_immediate().unwrap(),
OpFoldResult::Value(v) => panic!("expected immediate, got {v}"),
};
assert_eq!(lhs, Immediate::U32(1));
assert_eq!(rhs_fr, OpFoldResult::Value(rhs));
}
#[test]
fn matcher_constant_like() {
let mut test = Test::default();
let (lhs, _rhs, sum) = setup("matcher_constant_like", &mut test);
let lhs_op = lhs.borrow().get_defining_op().unwrap();
let sum_op = sum.borrow().get_defining_op().unwrap();
assert!(constant_like().matches(&lhs_op.borrow()).is_some());
assert!(constant_like().matches(&sum_op.borrow()).is_none());
}
#[test]
fn matcher_constant() {
let mut test = Test::default();
let (lhs, _rhs, sum) = setup("matcher_constant", &mut test);
let lhs_op = lhs.borrow().get_defining_op().unwrap();
let sum_op = sum.borrow().get_defining_op().unwrap();
assert!(constant().matches(&lhs_op.borrow()).is_some());
assert!(constant().matches(&sum_op.borrow()).is_none());
}
#[test]
fn matcher_constant_of() {
let mut test = Test::default();
let (lhs, _rhs, sum) = setup("matcher_constant_of", &mut test);
let lhs_op = lhs.borrow().get_defining_op().unwrap();
let sum_op = sum.borrow().get_defining_op().unwrap();
assert_eq!(constant_of::<Immediate>().matches(&lhs_op.borrow()), Some(Immediate::U32(1)));
assert!(constant_of::<Immediate>().matches(&sum_op.borrow()).is_none());
}
#[test]
fn matcher_foldable_operand_of_trait() {
let mut test = Test::new("matcher_foldable_operand_of_trait", &[Type::U32], &[Type::U32]);
let mut builder = test.function_builder();
let shift = builder.u32(1, SourceSpan::default()).unwrap();
let lhs = builder.current_block().borrow().arguments()[0].upcast();
let result = builder.shl(lhs, shift, SourceSpan::default()).unwrap();
builder.ret(Some(result), SourceSpan::default()).unwrap();
let shl_op = result.borrow().get_defining_op().unwrap();
let operand = {
let shl = shl_op.borrow();
let shl = shl.downcast_ref::<Shl>().unwrap();
shl.shift().as_operand_ref()
};
let matched =
foldable_operand_of_trait::<dyn crate::attributes::IntegerLikeAttr>().matches(&operand);
let matched = matched.expect("expected the shift operand to match as an integer immediate");
assert_eq!(matched.borrow().as_immediate(), Immediate::U32(1));
}
fn setup(name: &'static str, test: &mut Test) -> (ValueRef, ValueRef, ValueRef) {
test.with_function(name, &[Type::U32], &[Type::U32]);
let mut builder = test.function_builder();
let lhs = builder.u32(1, SourceSpan::default()).unwrap();
let block = builder.current_block();
let rhs = block.borrow().arguments()[0].upcast();
let sum = builder.add(lhs, rhs, SourceSpan::default()).unwrap();
builder.ret(Some(sum), SourceSpan::default()).unwrap();
(lhs, rhs, sum)
}
}