use std::fmt;
use melior::{
Context, StringRef,
ir::{Attribute, AttributeLike, attribute::FlatSymbolRefAttribute},
};
use mlir_sys::{
MlirAttribute, mlirAttributeIsASymbolRef, mlirSymbolRefAttrGet,
mlirSymbolRefAttrGetLeafReference, mlirSymbolRefAttrGetNestedReference,
mlirSymbolRefAttrGetNumNestedReferences, mlirSymbolRefAttrGetRootReference,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct SymbolRefAttribute<'c> {
inner: Attribute<'c>,
}
impl<'c> SymbolRefAttribute<'c> {
pub fn new(ctx: &'c Context, root: StringRef, nested: &[impl SymbolRefAttrLike<'c>]) -> Self {
let raw_refs: Vec<_> = nested.iter().map(|r| r.to_raw()).collect();
Self {
inner: unsafe {
Attribute::from_raw(mlirSymbolRefAttrGet(
ctx.to_raw(),
root.to_raw(),
isize::try_from(raw_refs.len()).expect("nested reference count too large"),
raw_refs.as_ptr(),
))
},
}
}
pub fn new_from_str(ctx: &'c Context, root: &str, nested: &[&str]) -> Self {
let refs: Vec<_> = nested
.iter()
.map(|r| FlatSymbolRefAttribute::new(ctx, r))
.collect();
Self::new(ctx, StringRef::new(root), &refs)
}
pub fn root(&self) -> StringRef<'c> {
unsafe { StringRef::from_raw(mlirSymbolRefAttrGetRootReference(self.to_raw())) }
}
pub fn leaf(&self) -> StringRef<'c> {
unsafe { StringRef::from_raw(mlirSymbolRefAttrGetLeafReference(self.to_raw())) }
}
pub fn nested(&self) -> Vec<FlatSymbolRefAttribute<'c>> {
let nested_count = unsafe { mlirSymbolRefAttrGetNumNestedReferences(self.to_raw()) };
(0..nested_count)
.map(|i| {
unsafe {
Attribute::from_raw(mlirSymbolRefAttrGetNestedReference(self.to_raw(), i))
}
.try_into()
.expect("expected FlatSymbolRefAttribute")
})
.collect()
}
}
impl<'c> AttributeLike<'c> for SymbolRefAttribute<'c> {
fn to_raw(&self) -> MlirAttribute {
self.inner.to_raw()
}
}
impl<'c> TryFrom<Attribute<'c>> for SymbolRefAttribute<'c> {
type Error = melior::Error;
fn try_from(value: Attribute<'c>) -> Result<Self, Self::Error> {
if unsafe { mlirAttributeIsASymbolRef(value.to_raw()) } {
Ok(Self { inner: value })
} else {
Err(Self::Error::AttributeExpected(
"symbol ref attr",
value.to_string(),
))
}
}
}
impl<'c> From<SymbolRefAttribute<'c>> for Attribute<'c> {
fn from(sym: SymbolRefAttribute<'c>) -> Attribute<'c> {
sym.inner
}
}
impl<'c> fmt::Display for SymbolRefAttribute<'c> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
pub trait SymbolRefAttrLike<'c>: AttributeLike<'c> + fmt::Display + private::Sealed {}
impl<'c> SymbolRefAttrLike<'c> for SymbolRefAttribute<'c> {}
impl<'c> SymbolRefAttrLike<'c> for FlatSymbolRefAttribute<'c> {}
mod private {
use crate::symbol_ref::SymbolRefAttribute;
use melior::ir::attribute::FlatSymbolRefAttribute;
pub trait Sealed {}
impl<'c> Sealed for SymbolRefAttribute<'c> {}
impl<'c> Sealed for FlatSymbolRefAttribute<'c> {}
}