llzk 0.5.0

Rust bindings to the LLZK C API.
//! Implementation of `!array.type` type.

use crate::utils::IsA;
use llzk_sys::{
    llzkArray_ArrayTypeGetDimensionSizesAt, llzkArray_ArrayTypeGetDimensionSizesCount,
    llzkArray_ArrayTypeGetElementType, llzkArray_ArrayTypeGetWithDims,
    llzkArray_ArrayTypeGetWithShape, llzkTypeIsA_Array_ArrayType,
};
use melior::ir::{Attribute, Type, TypeLike};
use mlir_sys::MlirType;

/// Represents the `!array.type` type.
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub struct ArrayType<'c> {
    r#type: Type<'c>,
}

impl<'c> ArrayType<'c> {
    unsafe fn from_raw(raw: MlirType) -> Self {
        Self {
            r#type: unsafe { Type::from_raw(raw) },
        }
    }

    /// Creates a new type with the given element type and dimensions.
    ///
    /// It can only be instantiated with the following element types:
    ///   - Any LLZK type other than an ArrayType (i.e., arrays cannot be nested)
    ///   - IndexType
    ///   - Unsigned integers of 1 bit (aka booleans)
    ///
    /// The dimensions of the array are specified using a list of attributes, one
    /// per dimension. Each attribute must be one of the following:
    ///   - IntegerAttribute (with IndexType), specifying a fixed dimension size
    ///   - SymbolRefAttribute, specifying a dimension size defined by a struct parameter or
    ///     global constant
    ///   - AffineMapAttribute, specifying a dimension size computed from surrounding loop
    ///     induction variables
    pub fn new(element_type: Type<'c>, dims: &[Attribute<'c>]) -> Self {
        unsafe {
            Self::from_raw(llzkArray_ArrayTypeGetWithDims(
                element_type.to_raw(),
                dims.len() as _,
                dims.as_ptr() as *const _,
            ))
        }
    }

    /// Creates a new type with the given element type and dimensions as integers.
    pub fn new_with_dims(element_type: Type<'c>, dims: &[i64]) -> Self {
        unsafe {
            Self::from_raw(llzkArray_ArrayTypeGetWithShape(
                element_type.to_raw(),
                dims.len() as _,
                dims.as_ptr() as *const _,
            ))
        }
    }

    /// Returns the element type of the array.
    pub fn element_type(&self) -> Type<'c> {
        unsafe { Type::from_raw(llzkArray_ArrayTypeGetElementType(self.to_raw())) }
    }

    /// Returns the number of dimensions of the array.
    pub fn num_dims(&self) -> isize {
        unsafe { llzkArray_ArrayTypeGetDimensionSizesCount(self.to_raw()) }
    }

    /// Returns the Attribute specifying the size of dimension `idx`.
    pub fn dim(&self, idx: isize) -> Attribute<'c> {
        unsafe { Attribute::from_raw(llzkArray_ArrayTypeGetDimensionSizesAt(self.to_raw(), idx)) }
    }

    /// Returns the Attributes specifying the sizes of all dimensions.
    #[inline]
    pub fn dims(&self) -> Vec<Attribute<'c>> {
        (0..self.num_dims()).map(|idx| self.dim(idx)).collect()
    }
}

impl<'c> TypeLike<'c> for ArrayType<'c> {
    fn to_raw(&self) -> MlirType {
        self.r#type.to_raw()
    }
}

impl<'c> TryFrom<Type<'c>> for ArrayType<'c> {
    type Error = melior::Error;

    fn try_from(t: Type<'c>) -> Result<Self, Self::Error> {
        if unsafe { llzkTypeIsA_Array_ArrayType(t.to_raw()) } {
            Ok(unsafe { Self::from_raw(t.to_raw()) })
        } else {
            Err(Self::Error::TypeExpected("llzk array", t.to_string()))
        }
    }
}

impl<'c> std::fmt::Display for ArrayType<'c> {
    fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        std::fmt::Display::fmt(&self.r#type, formatter)
    }
}

impl<'c> From<ArrayType<'c>> for Type<'c> {
    fn from(t: ArrayType<'c>) -> Type<'c> {
        t.r#type
    }
}

/// Return `true` iff the given [Type] is an [ArrayType].
#[inline]
pub fn is_array_type(t: Type) -> bool {
    t.isa::<ArrayType>()
}