use std::{
marker::PhantomData,
ops::{Add, Mul, Neg, Sub},
};
use melior::{Context, ContextRef, ir::Attribute};
use mlir_sys::{
MlirAffineExpr, MlirAffineMap, mlirAffineAddExprGet, mlirAffineConstantExprGet,
mlirAffineDimExprGet, mlirAffineExprGetContext, mlirAffineMapAttrGet, mlirAffineMapGet,
mlirAffineMulExprGet, mlirAffineSymbolExprGet,
};
#[derive(Debug, Clone, Copy)]
pub struct AffineMap<'ctx> {
raw: MlirAffineMap,
_context: PhantomData<&'ctx Context>,
}
impl<'ctx> AffineMap<'ctx> {
pub fn new(
context: &'ctx Context,
dims: usize,
symbols: usize,
exprs: &[AffineExpr<'ctx>],
) -> Self {
let mut exprs = exprs
.iter()
.map(|expr| unsafe { expr.to_raw() })
.collect::<Vec<_>>();
Self {
raw: unsafe {
mlirAffineMapGet(
context.to_raw(),
dims as isize,
symbols as isize,
exprs.len() as isize,
exprs.as_mut_ptr(),
)
},
_context: PhantomData,
}
}
pub unsafe fn to_raw(&self) -> MlirAffineMap {
self.raw
}
}
impl<'ctx> From<AffineMap<'ctx>> for Attribute<'ctx> {
fn from(value: AffineMap<'ctx>) -> Self {
unsafe { Attribute::from_raw(mlirAffineMapAttrGet(value.to_raw())) }
}
}
impl<'ctx> From<AffineMap<'ctx>> for melior::ir::AffineMap<'ctx> {
fn from(value: AffineMap<'ctx>) -> Self {
unsafe { Self::from_raw(value.to_raw()) }
}
}
#[derive(Debug, Clone, Copy)]
pub struct AffineExpr<'ctx> {
raw: MlirAffineExpr,
_context: PhantomData<&'ctx Context>,
}
impl<'ctx> AffineExpr<'ctx> {
pub fn constant(context: &'ctx Context, constant: i64) -> Self {
Self {
raw: unsafe { mlirAffineConstantExprGet(context.to_raw(), constant) },
_context: PhantomData,
}
}
pub fn symbol(context: &'ctx Context, position: usize) -> Self {
Self {
raw: unsafe { mlirAffineSymbolExprGet(context.to_raw(), position as isize) },
_context: PhantomData,
}
}
pub fn dimension(context: &'ctx Context, position: usize) -> Self {
Self {
raw: unsafe { mlirAffineDimExprGet(context.to_raw(), position as isize) },
_context: PhantomData,
}
}
pub unsafe fn to_raw(&self) -> MlirAffineExpr {
self.raw
}
pub fn context(&self) -> ContextRef<'ctx> {
unsafe { ContextRef::from_raw(mlirAffineExprGetContext(self.to_raw())) }
}
}
impl Add for AffineExpr<'_> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self {
raw: unsafe { mlirAffineAddExprGet(self.to_raw(), rhs.to_raw()) },
_context: PhantomData,
}
}
}
impl Sub for AffineExpr<'_> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
self + (-rhs)
}
}
impl Mul for AffineExpr<'_> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
Self {
raw: unsafe { mlirAffineMulExprGet(self.to_raw(), rhs.to_raw()) },
_context: PhantomData,
}
}
}
impl Neg for AffineExpr<'_> {
type Output = Self;
fn neg(self) -> Self::Output {
let context = self.context();
self * Self::constant(unsafe { context.to_ref() }, -1)
}
}