use crate::{
attributes::array::ArrayAttribute,
error::Error,
symbol_lookup::SymbolLookupResult,
symbol_ref::{SymbolRefAttrLike, SymbolRefAttribute},
utils::{FromRaw, IsA},
};
use llzk_sys::{
llzkStruct_StructTypeGetNameRef, llzkStruct_StructTypeGetParams,
llzkStruct_StructTypeGetWithArrayAttr, llzkTypeIsA_Struct_StructType,
};
use melior::{
Context,
ir::{
Attribute, AttributeLike as _, Module, Type, TypeLike, attribute::FlatSymbolRefAttribute,
operation::OperationLike,
},
};
use mlir_sys::{MlirLogicalResult, MlirType};
#[derive(Copy, Clone, Debug)]
pub struct StructType<'c> {
t: Type<'c>,
}
impl<'c> StructType<'c> {
pub fn new(name: impl SymbolRefAttrLike<'c>, params: &[Attribute<'c>]) -> Self {
unsafe {
Self::from_raw(llzkStruct_StructTypeGetWithArrayAttr(
name.to_raw(),
ArrayAttribute::new(name.context().to_ref(), params).to_raw(),
))
}
}
pub fn from_str(context: &'c Context, name: &str) -> Self {
Self::new(FlatSymbolRefAttribute::new(context, name), &[])
}
pub fn from_str_params(context: &'c Context, name: &str, params: &[&str]) -> Self {
let params: Vec<Attribute> = params
.iter()
.map(|param| FlatSymbolRefAttribute::new(context, param).into())
.collect();
Self::new(FlatSymbolRefAttribute::new(context, name), ¶ms)
}
pub fn name(&self) -> SymbolRefAttribute<'c> {
SymbolRefAttribute::try_from(unsafe {
Attribute::from_raw(llzkStruct_StructTypeGetNameRef(self.to_raw()))
})
.expect("struct type must be constructed from SymbolRefAttribute")
}
pub fn params(&self) -> Option<ArrayAttribute<'c>> {
unsafe { Attribute::from_option_raw(llzkStruct_StructTypeGetParams(self.to_raw())) }.map(
|a| {
ArrayAttribute::try_from(a)
.expect("struct type's params must be an array attribute")
},
)
}
pub fn params_vec(&self) -> Vec<Attribute<'c>> {
self.params().into_iter().flatten().collect()
}
fn lookup_definition_impl<O>(
&self,
o: O,
f: unsafe extern "C" fn(
MlirType,
O,
*mut llzk_sys::LlzkSymbolLookupResult,
) -> MlirLogicalResult,
) -> Result<SymbolLookupResult<'c>, Error> {
let mut lookup = SymbolLookupResult::new();
let result = unsafe { f(self.to_raw(), o, lookup.as_raw_mut()) };
(result.value != 0)
.then_some(lookup)
.ok_or_else(|| Error::SymbolNotFound(self.name().to_string()))
}
pub fn lookup_definition<'o>(
&self,
root: &impl OperationLike<'c, 'o>,
) -> Result<SymbolLookupResult<'c>, Error>
where
'c: 'o,
{
self.lookup_definition_impl(root.to_raw(), llzk_sys::llzkStructStructTypeGetDefinition)
}
pub fn lookup_definition_from_module(
&self,
root: &Module<'c>,
) -> Result<SymbolLookupResult<'c>, Error> {
self.lookup_definition_impl(
root.to_raw(),
llzk_sys::llzkStructStructTypeGetDefinitionFromModule,
)
}
}
impl<'c> FromRaw<MlirType> for StructType<'c> {
unsafe fn from_raw(t: MlirType) -> Self {
Self {
t: unsafe { Type::from_raw(t) },
}
}
}
impl<'c> TypeLike<'c> for StructType<'c> {
fn to_raw(&self) -> MlirType {
self.t.to_raw()
}
}
impl<'c> TryFrom<Type<'c>> for StructType<'c> {
type Error = melior::Error;
fn try_from(t: Type<'c>) -> Result<Self, Self::Error> {
if unsafe { llzkTypeIsA_Struct_StructType(t.to_raw()) } {
Ok(unsafe { Self::from_raw(t.to_raw()) })
} else {
Err(Self::Error::TypeExpected("llzk struct", t.to_string()))
}
}
}
impl<'c> std::fmt::Display for StructType<'c> {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
std::fmt::Display::fmt(&self.t, formatter)
}
}
impl<'c> From<StructType<'c>> for Type<'c> {
fn from(s: StructType<'c>) -> Type<'c> {
s.t
}
}
#[inline]
pub fn is_struct_type(t: Type) -> bool {
t.isa::<StructType>()
}