use std::{ffi::c_void, fmt::Display};
use mlir_sys::{
MlirOperation, MlirValue, MlirWalkOrder_MlirWalkPostOrder, MlirWalkOrder_MlirWalkPreOrder,
MlirWalkResult, MlirWalkResult_MlirWalkResultAdvance, MlirWalkResult_MlirWalkResultInterrupt,
MlirWalkResult_MlirWalkResultSkip, mlirOperationDump, mlirOperationGetAttribute,
mlirOperationGetAttributeByName, mlirOperationGetBlock, mlirOperationGetContext,
mlirOperationGetDiscardableAttribute, mlirOperationGetDiscardableAttributeByName,
mlirOperationGetFirstRegion, mlirOperationGetInherentAttributeByName, mlirOperationGetLocation,
mlirOperationGetName, mlirOperationGetNextInBlock, mlirOperationGetNumAttributes,
mlirOperationGetNumDiscardableAttributes, mlirOperationGetNumOperands,
mlirOperationGetNumRegions, mlirOperationGetNumResults, mlirOperationGetNumSuccessors,
mlirOperationGetOperand, mlirOperationGetParentOperation, mlirOperationGetRegion,
mlirOperationGetResult, mlirOperationGetSuccessor, mlirOperationGetTypeID,
mlirOperationHasInherentAttributeByName, mlirOperationHashValue,
mlirOperationImplementsInterface, mlirOperationImplementsInterfaceStatic,
mlirOperationIsBeforeInBlock, mlirOperationMoveAfter, mlirOperationMoveBefore,
mlirOperationPrintWithFlags, mlirOperationRemoveAttributeByName,
mlirOperationRemoveDiscardableAttributeByName, mlirOperationRemoveFromParent,
mlirOperationReplaceUsesOfWith, mlirOperationSetAttributeByName,
mlirOperationSetDiscardableAttributeByName, mlirOperationSetInherentAttributeByName,
mlirOperationSetLocation, mlirOperationSetOperand, mlirOperationSetOperands,
mlirOperationSetSuccessor, mlirOperationVerify, mlirOperationWalk, mlirOperationWriteBytecode,
mlirOperationWriteBytecodeWithConfig,
};
use crate::{
Context, ContextRef, Error, StringRef,
ir::{
Attribute, AttributeLike, Block, BlockRef, Identifier, Location, RegionRef, Value,
bytecode_writer_config::BytecodeWriterConfig, r#type::TypeId, value::ValueLike,
},
logical_result::LogicalResult,
};
use super::{
OperationPrintingFlags, OperationRef, OperationRefMut, OperationResult, collect_bytes_callback,
print_string_callback,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u32)]
pub enum WalkOrder {
PreOrder = MlirWalkOrder_MlirWalkPreOrder,
PostOrder = MlirWalkOrder_MlirWalkPostOrder,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u32)]
pub enum WalkResult {
Advance = MlirWalkResult_MlirWalkResultAdvance,
Interrupt = MlirWalkResult_MlirWalkResultInterrupt,
Skip = MlirWalkResult_MlirWalkResultSkip,
}
pub trait OperationLike<'c: 'a, 'a>: Display + 'a {
fn to_raw(&self) -> MlirOperation;
fn context(&self) -> ContextRef<'c> {
unsafe { ContextRef::from_raw(mlirOperationGetContext(self.to_raw())) }
}
fn name(&self) -> Identifier<'c> {
unsafe { Identifier::from_raw(mlirOperationGetName(self.to_raw())) }
}
fn block(&self) -> Option<BlockRef<'c, 'a>> {
unsafe { BlockRef::from_option_raw(mlirOperationGetBlock(self.to_raw())) }
}
fn operand_count(&self) -> usize {
unsafe { mlirOperationGetNumOperands(self.to_raw()) as usize }
}
fn operand(&self, index: usize) -> Result<Value<'c, 'a>, Error> {
if index < self.operand_count() {
Ok(unsafe { Value::from_raw(mlirOperationGetOperand(self.to_raw(), index as isize)) })
} else {
Err(Error::PositionOutOfBounds {
name: "operation operand",
value: self.to_string(),
index,
})
}
}
fn operands(&self) -> impl Iterator<Item = Value<'c, 'a>> {
(0..self.operand_count()).map(|index| self.operand(index).expect("valid operand index"))
}
fn result_count(&self) -> usize {
unsafe { mlirOperationGetNumResults(self.to_raw()) as usize }
}
fn result(&self, index: usize) -> Result<OperationResult<'c, 'a>, Error> {
if index < self.result_count() {
Ok(unsafe {
OperationResult::from_raw(mlirOperationGetResult(self.to_raw(), index as isize))
})
} else {
Err(Error::PositionOutOfBounds {
name: "operation result",
value: self.to_string(),
index,
})
}
}
fn results(&self) -> impl Iterator<Item = OperationResult<'c, 'a>> {
(0..self.result_count()).map(|index| self.result(index).expect("valid result index"))
}
fn region_count(&self) -> usize {
unsafe { mlirOperationGetNumRegions(self.to_raw()) as usize }
}
fn region(&self, index: usize) -> Result<RegionRef<'c, 'a>, Error> {
if index < self.region_count() {
Ok(unsafe {
RegionRef::from_raw(mlirOperationGetRegion(self.to_raw(), index as isize))
})
} else {
Err(Error::PositionOutOfBounds {
name: "region",
value: self.to_string(),
index,
})
}
}
fn regions(&self) -> impl Iterator<Item = RegionRef<'c, 'a>> {
(0..self.region_count()).map(move |index| self.region(index).expect("valid result index"))
}
fn location(&self) -> Location<'c> {
unsafe { Location::from_raw(mlirOperationGetLocation(self.to_raw())) }
}
fn successor_count(&self) -> usize {
unsafe { mlirOperationGetNumSuccessors(self.to_raw()) as usize }
}
fn successor(&self, index: usize) -> Result<BlockRef<'c, 'a>, Error> {
if index < self.successor_count() {
Ok(unsafe {
BlockRef::from_raw(mlirOperationGetSuccessor(self.to_raw(), index as isize))
})
} else {
Err(Error::PositionOutOfBounds {
name: "successor",
value: self.to_string(),
index,
})
}
}
fn successors(&self) -> impl Iterator<Item = BlockRef<'c, 'a>> {
(0..self.successor_count())
.map(|index| self.successor(index).expect("valid successor index"))
}
fn attribute_count(&self) -> usize {
unsafe { mlirOperationGetNumAttributes(self.to_raw()) as usize }
}
fn attribute_at(&self, index: usize) -> Result<(Identifier<'c>, Attribute<'c>), Error> {
if index < self.attribute_count() {
let named_attribute =
unsafe { mlirOperationGetAttribute(self.to_raw(), index as isize) };
Ok((
unsafe { Identifier::from_raw(named_attribute.name) },
unsafe { Attribute::from_raw(named_attribute.attribute) },
))
} else {
Err(Error::PositionOutOfBounds {
name: "attribute",
value: self.to_string(),
index,
})
}
}
fn attributes(&self) -> impl Iterator<Item = (Identifier<'c>, Attribute<'c>)> + '_ {
(0..self.attribute_count())
.map(|index| self.attribute_at(index).expect("valid attribute index"))
}
fn attribute(&self, name: &str) -> Result<Attribute<'c>, Error> {
unsafe {
Attribute::from_option_raw(mlirOperationGetAttributeByName(
self.to_raw(),
StringRef::new(name).to_raw(),
))
}
.ok_or_else(|| Error::AttributeNotFound(name.into()))
}
fn has_attribute(&self, name: &str) -> bool {
self.attribute(name).is_ok()
}
fn next_in_block(&self) -> Option<OperationRef<'c, 'a>> {
unsafe { OperationRef::from_option_raw(mlirOperationGetNextInBlock(self.to_raw())) }
}
fn next_in_block_mut(&self) -> Option<OperationRefMut<'c, 'a>> {
unsafe { OperationRefMut::from_option_raw(mlirOperationGetNextInBlock(self.to_raw())) }
}
fn previous_in_block(&self) -> Option<OperationRef<'c, 'a>> {
todo!("mlirOperationGetPrevInBlock is not exposed in the C API")
}
fn parent_operation(&self) -> Option<OperationRef<'c, 'a>> {
unsafe { OperationRef::from_option_raw(mlirOperationGetParentOperation(self.to_raw())) }
}
fn parent_operation_mut(&mut self) -> Option<OperationRefMut<'c, 'a>> {
unsafe { OperationRefMut::from_option_raw(mlirOperationGetParentOperation(self.to_raw())) }
}
fn verify(&self) -> bool {
unsafe { mlirOperationVerify(self.to_raw()) }
}
fn dump(&self) {
unsafe { mlirOperationDump(self.to_raw()) }
}
fn to_string_with_flags(&self, flags: OperationPrintingFlags) -> Result<String, Error> {
let mut data = (String::new(), Ok::<_, Error>(()));
unsafe {
mlirOperationPrintWithFlags(
self.to_raw(),
flags.to_raw(),
Some(print_string_callback),
&mut data as *mut _ as *mut _,
);
}
data.1?;
Ok(data.0)
}
fn write_bytecode(&self) -> Vec<u8> {
let mut bytes = Vec::new();
unsafe {
mlirOperationWriteBytecode(
self.to_raw(),
Some(collect_bytes_callback),
&mut bytes as *mut _ as *mut _,
);
}
bytes
}
fn write_bytecode_with_config(&self, config: &BytecodeWriterConfig) -> Result<Vec<u8>, Error> {
let mut bytes = Vec::new();
let result = LogicalResult::from_raw(unsafe {
mlirOperationWriteBytecodeWithConfig(
self.to_raw(),
config.to_raw(),
Some(collect_bytes_callback),
&mut bytes as *mut _ as *mut _,
)
});
if result.is_success() {
Ok(bytes)
} else {
Err(Error::WriteBytecode)
}
}
fn walk<F>(&self, order: WalkOrder, mut callback: F)
where
F: for<'x, 'y> FnMut(OperationRef<'x, 'y>) -> WalkResult,
{
unsafe extern "C" fn tramp<'c: 'a, 'a, F: FnMut(OperationRef<'c, 'a>) -> WalkResult>(
operation: MlirOperation,
data: *mut c_void,
) -> MlirWalkResult {
let callback: &mut F = unsafe { &mut *(data as *mut F) };
let operation = unsafe { OperationRef::from_raw(operation) };
(callback)(operation) as _
}
unsafe {
mlirOperationWalk(
self.to_raw(),
Some(tramp::<'c, 'a, F>),
&mut callback as *mut _ as *mut _,
order as _,
);
}
}
fn type_id(&self) -> Option<TypeId<'c>> {
let raw = unsafe { mlirOperationGetTypeID(self.to_raw()) };
if raw.ptr.is_null() {
None
} else {
Some(unsafe { TypeId::from_raw(raw) })
}
}
fn first_region(&self) -> Option<RegionRef<'c, 'a>> {
unsafe { RegionRef::from_option_raw(mlirOperationGetFirstRegion(self.to_raw())) }
}
fn discardable_attribute_count(&self) -> usize {
unsafe { mlirOperationGetNumDiscardableAttributes(self.to_raw()) as usize }
}
fn discardable_attribute_at(
&self,
index: usize,
) -> Result<(Identifier<'c>, Attribute<'c>), Error> {
if index < self.discardable_attribute_count() {
let named =
unsafe { mlirOperationGetDiscardableAttribute(self.to_raw(), index as isize) };
Ok((unsafe { Identifier::from_raw(named.name) }, unsafe {
Attribute::from_raw(named.attribute)
}))
} else {
Err(Error::PositionOutOfBounds {
name: "discardable attribute",
value: self.to_string(),
index,
})
}
}
fn discardable_attribute(&self, name: &str) -> Result<Attribute<'c>, Error> {
unsafe {
Attribute::from_option_raw(mlirOperationGetDiscardableAttributeByName(
self.to_raw(),
StringRef::new(name).to_raw(),
))
}
.ok_or_else(|| Error::AttributeNotFound(name.into()))
}
fn inherent_attribute(&self, name: &str) -> Result<Attribute<'c>, Error> {
unsafe {
Attribute::from_option_raw(mlirOperationGetInherentAttributeByName(
self.to_raw(),
StringRef::new(name).to_raw(),
))
}
.ok_or_else(|| Error::AttributeNotFound(name.into()))
}
fn has_inherent_attribute(&self, name: &str) -> bool {
unsafe {
mlirOperationHasInherentAttributeByName(self.to_raw(), StringRef::new(name).to_raw())
}
}
fn hash_value(&self) -> usize {
unsafe { mlirOperationHashValue(self.to_raw()) }
}
fn is_before_in_block(&self, other: OperationRef<'c, 'a>) -> bool {
unsafe { mlirOperationIsBeforeInBlock(self.to_raw(), other.to_raw()) }
}
fn implements_interface(&self, interface_type_id: TypeId<'c>) -> bool {
unsafe { mlirOperationImplementsInterface(self.to_raw(), interface_type_id.to_raw()) }
}
fn implements_interface_static(
name: &str,
context: &'c Context,
interface_type_id: TypeId<'c>,
) -> bool {
unsafe {
mlirOperationImplementsInterfaceStatic(
StringRef::new(name).to_raw(),
context.to_raw(),
interface_type_id.to_raw(),
)
}
}
}
pub trait OperationMutLike<'c: 'a, 'a>: OperationLike<'c, 'a> {
fn set_attribute(&mut self, name: &str, attribute: Attribute<'c>) {
unsafe {
mlirOperationSetAttributeByName(
self.to_raw(),
StringRef::new(name).to_raw(),
attribute.to_raw(),
)
}
}
fn remove_attribute(&mut self, name: &str) -> Result<(), Error> {
unsafe { mlirOperationRemoveAttributeByName(self.to_raw(), StringRef::new(name).to_raw()) }
.then_some(())
.ok_or_else(|| Error::AttributeNotFound(name.into()))
}
fn remove_from_parent(&mut self) {
unsafe { mlirOperationRemoveFromParent(self.to_raw()) }
}
fn move_after(&mut self, other: OperationRef<'c, 'a>) {
unsafe { mlirOperationMoveAfter(self.to_raw(), other.to_raw()) }
}
fn move_before(&mut self, other: OperationRef<'c, 'a>) {
unsafe { mlirOperationMoveBefore(self.to_raw(), other.to_raw()) }
}
fn set_operand(&mut self, index: usize, value: Value<'c, 'a>) {
unsafe { mlirOperationSetOperand(self.to_raw(), index as isize, value.to_raw()) }
}
fn set_operands(&mut self, values: &[Value<'c, 'a>]) {
unsafe {
mlirOperationSetOperands(
self.to_raw(),
values.len() as isize,
values.as_ptr() as *const MlirValue,
)
}
}
fn set_successor(&mut self, index: usize, block: &Block<'c>) {
unsafe { mlirOperationSetSuccessor(self.to_raw(), index as isize, block.to_raw()) }
}
fn set_discardable_attribute(&mut self, name: &str, attribute: Attribute<'c>) {
unsafe {
mlirOperationSetDiscardableAttributeByName(
self.to_raw(),
StringRef::new(name).to_raw(),
attribute.to_raw(),
)
}
}
fn remove_discardable_attribute(&mut self, name: &str) -> Result<(), Error> {
unsafe {
mlirOperationRemoveDiscardableAttributeByName(
self.to_raw(),
StringRef::new(name).to_raw(),
)
}
.then_some(())
.ok_or_else(|| Error::AttributeNotFound(name.into()))
}
fn set_location(&mut self, location: Location<'c>) {
unsafe { mlirOperationSetLocation(self.to_raw(), location.to_raw()) }
}
fn replace_uses_of_with(&mut self, of: Value<'c, 'a>, with: Value<'c, 'a>) {
unsafe { mlirOperationReplaceUsesOfWith(self.to_raw(), of.to_raw(), with.to_raw()) }
}
fn set_inherent_attribute(&mut self, name: &str, attribute: Attribute<'c>) {
unsafe {
mlirOperationSetInherentAttributeByName(
self.to_raw(),
StringRef::new(name).to_raw(),
attribute.to_raw(),
)
}
}
fn walk_mut<F>(&mut self, order: WalkOrder, mut callback: F)
where
F: for<'x, 'y> FnMut(OperationRefMut<'x, 'y>) -> WalkResult,
{
unsafe extern "C" fn tramp<'c: 'a, 'a, F: FnMut(OperationRefMut<'c, 'a>) -> WalkResult>(
operation: MlirOperation,
data: *mut c_void,
) -> MlirWalkResult {
let callback: &mut F = unsafe { &mut *(data as *mut F) };
let operation = unsafe { OperationRefMut::from_raw(operation) };
(callback)(operation) as _
}
unsafe {
mlirOperationWalk(
self.to_raw(),
Some(tramp::<'c, 'a, F>),
&mut callback as *mut _ as *mut _,
order as _,
);
}
}
}