use std::convert::TryFrom;
use thiserror::Error;
use crate::AddressSpace;
#[repr(u64)]
pub enum TypeId {
Half = 0,
BFloat,
Float,
Double,
X86Fp80,
Fp128,
PpcFp128,
Void,
Label,
Metadata,
X86Mmx,
X86Amx,
Token,
Integer,
Function,
Pointer,
Struct,
Array,
FixedVector,
ScalableVector,
}
#[allow(missing_docs)]
#[derive(Clone, Debug, PartialEq)]
pub enum Type {
Half,
BFloat,
Float,
Double,
Metadata,
X86Fp80,
Fp128,
PpcFp128,
Void,
Label,
X86Mmx,
X86Amx,
Token,
Integer(IntegerType),
Function(FunctionType),
Pointer(PointerType),
OpaquePointer(AddressSpace),
Struct(StructType),
Array(ArrayType),
FixedVector(VectorType),
ScalableVector(VectorType),
}
impl Type {
pub fn is_floating(&self) -> bool {
matches!(
self,
Type::Half
| Type::BFloat
| Type::Float
| Type::Double
| Type::X86Fp80
| Type::Fp128
| Type::PpcFp128
)
}
pub fn is_pointee(&self) -> bool {
!matches!(
self,
Type::Void | Type::Label | Type::Metadata | Type::Token | Type::X86Amx
)
}
pub fn is_array_element(&self) -> bool {
!matches!(
self,
Type::Void
| Type::Label
| Type::Metadata
| Type::Function(_)
| Type::Token
| Type::X86Amx
| Type::ScalableVector(_)
)
}
pub fn is_struct_element(&self) -> bool {
!matches!(
self,
Type::Void | Type::Label | Type::Metadata | Type::Function(_) | Type::Token
)
}
pub fn is_vector_element(&self) -> bool {
self.is_floating() || matches!(self, Type::Integer(_) | Type::Pointer(_))
}
fn is_first_class(&self) -> bool {
!matches!(self, Type::Function(_) | Type::Void)
}
pub fn is_argument(&self) -> bool {
self.is_first_class()
}
pub fn is_return(&self) -> bool {
!matches!(self, Type::Function(_) | Type::Label | Type::Metadata)
}
pub fn new_struct(
name: Option<String>,
fields: Vec<Type>,
is_packed: bool,
) -> Result<Self, StructTypeError> {
let inner = StructType::new(name, fields, is_packed)?;
Ok(Type::Struct(inner))
}
pub fn new_integer(bit_width: u32) -> Result<Self, IntegerTypeError> {
let inner = IntegerType::try_from(bit_width)?;
Ok(Type::Integer(inner))
}
pub fn new_pointer(
pointee: Type,
address_space: AddressSpace,
) -> Result<Self, PointerTypeError> {
let inner = PointerType::new(pointee, address_space)?;
Ok(Type::Pointer(inner))
}
pub fn new_array(num_elements: u64, element_type: Type) -> Result<Self, ArrayTypeError> {
let inner = ArrayType::new(num_elements, element_type)?;
Ok(Type::Array(inner))
}
pub fn new_scalable_vector(
num_elements: u64,
element_type: Type,
) -> Result<Self, VectorTypeError> {
let inner = VectorType::new(num_elements, element_type)?;
Ok(Type::ScalableVector(inner))
}
pub fn new_vector(num_elements: u64, element_type: Type) -> Result<Self, VectorTypeError> {
let inner = VectorType::new(num_elements, element_type)?;
Ok(Type::FixedVector(inner))
}
pub fn new_function(
return_type: Type,
param_types: Vec<Type>,
is_vararg: bool,
) -> Result<Self, FunctionTypeError> {
let inner = FunctionType::new(return_type, param_types, is_vararg)?;
Ok(Type::Function(inner))
}
}
#[derive(Debug, Error)]
pub enum StructTypeError {
#[error("invalid structure element type: {0:?}")]
BadElement(Type),
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq)]
pub struct StructType {
pub name: Option<String>,
pub fields: Vec<Type>,
is_packed: bool,
}
impl StructType {
pub fn new(
name: Option<String>,
fields: Vec<Type>,
is_packed: bool,
) -> Result<Self, StructTypeError> {
if let Some(bad) = fields.iter().find(|t| !t.is_struct_element()) {
Err(StructTypeError::BadElement(bad.clone()))
} else {
Ok(Self {
name,
fields,
is_packed,
})
}
}
}
#[derive(Debug, Error)]
pub enum IntegerTypeError {
#[error(
"specified bit width is invalid (not in [{}, {}])",
IntegerType::MIN_INT_BITS,
IntegerType::MAX_INT_BITS
)]
BadWidth,
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq)]
pub struct IntegerType {
bit_width: u32,
}
impl IntegerType {
pub const MIN_INT_BITS: u32 = 1;
pub const MAX_INT_BITS: u32 = (1 << 24) - 1;
pub fn bit_width(&self) -> u32 {
self.bit_width
}
pub fn byte_width(&self) -> u32 {
(self.bit_width + 7) / 8
}
}
impl TryFrom<u32> for IntegerType {
type Error = IntegerTypeError;
fn try_from(value: u32) -> Result<Self, Self::Error> {
if (IntegerType::MIN_INT_BITS..=IntegerType::MAX_INT_BITS).contains(&value) {
Ok(Self { bit_width: value })
} else {
Err(Self::Error::BadWidth)
}
}
}
#[derive(Debug, Error)]
pub enum PointerTypeError {
#[error("invalid pointee type: {0:?}")]
BadPointee(Type),
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq)]
pub struct PointerType {
pointee: Box<Type>,
address_space: AddressSpace,
}
impl PointerType {
pub fn new(pointee: Type, address_space: AddressSpace) -> Result<Self, PointerTypeError> {
if pointee.is_pointee() {
Ok(Self {
pointee: Box::new(pointee),
address_space,
})
} else {
Err(PointerTypeError::BadPointee(pointee))
}
}
pub fn pointee(&self) -> &Type {
self.pointee.as_ref()
}
}
#[derive(Debug, Error)]
pub enum ArrayTypeError {
#[error("invalid array element type: {0:?}")]
BadElement(Type),
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq)]
pub struct ArrayType {
num_elements: u64,
element_type: Box<Type>,
}
impl ArrayType {
pub fn new(num_elements: u64, element_type: Type) -> Result<Self, ArrayTypeError> {
if element_type.is_array_element() {
Ok(Self {
num_elements,
element_type: Box::new(element_type),
})
} else {
Err(ArrayTypeError::BadElement(element_type))
}
}
pub fn element(&self) -> &Type {
self.element_type.as_ref()
}
}
#[derive(Debug, Error)]
pub enum VectorTypeError {
#[error("invalid vector element type: {0:?}")]
BadElement(Type),
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq)]
pub struct VectorType {
num_elements: u64,
element_type: Box<Type>,
}
impl VectorType {
pub fn new(num_elements: u64, element_type: Type) -> Result<Self, VectorTypeError> {
if element_type.is_vector_element() {
Ok(Self {
num_elements,
element_type: Box::new(element_type),
})
} else {
Err(VectorTypeError::BadElement(element_type))
}
}
pub fn element(&self) -> &Type {
self.element_type.as_ref()
}
}
#[derive(Debug, Error)]
pub enum FunctionTypeError {
#[error("invalid function return type: {0:?}")]
BadReturn(Type),
#[error("invalid function parameter type: {0:?}")]
BadParameter(Type),
}
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq)]
pub struct FunctionType {
return_type: Box<Type>,
param_types: Vec<Type>,
is_vararg: bool,
}
impl FunctionType {
pub fn new(
return_type: Type,
param_types: Vec<Type>,
is_vararg: bool,
) -> Result<Self, FunctionTypeError> {
if !return_type.is_return() {
Err(FunctionTypeError::BadReturn(return_type))
} else if let Some(bad) = param_types.iter().find(|ty| !ty.is_argument()) {
Err(FunctionTypeError::BadParameter(bad.clone()))
} else {
Ok(FunctionType {
return_type: Box::new(return_type),
param_types,
is_vararg,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_integer_type() {
{
assert!(IntegerType::try_from(0).is_err());
assert!(IntegerType::try_from(IntegerType::MAX_INT_BITS + 1).is_err());
}
{
let ty = IntegerType::try_from(IntegerType::MIN_INT_BITS).unwrap();
assert_eq!(ty.bit_width(), 1);
assert_eq!(ty.byte_width(), 1);
let ty = IntegerType::try_from(IntegerType::MAX_INT_BITS).unwrap();
assert_eq!(ty.bit_width(), IntegerType::MAX_INT_BITS);
assert_eq!(ty.byte_width(), 2097152);
let ty = IntegerType::try_from(31).unwrap();
assert_eq!(ty.bit_width(), 31);
assert_eq!(ty.byte_width(), 4);
let ty = IntegerType::try_from(32).unwrap();
assert_eq!(ty.bit_width(), 32);
assert_eq!(ty.byte_width(), 4);
for i in 1..=8 {
let ty = IntegerType::try_from(i).unwrap();
assert_eq!(ty.bit_width(), i);
assert_eq!(ty.byte_width(), 1);
}
}
}
#[test]
fn test_pointer_type() {
{
assert!(PointerType::new(Type::Void, AddressSpace::default()).is_err());
assert!(PointerType::new(Type::Label, AddressSpace::default()).is_err());
assert!(PointerType::new(Type::Metadata, AddressSpace::default()).is_err());
assert!(PointerType::new(Type::Token, AddressSpace::default()).is_err());
assert!(PointerType::new(Type::X86Amx, AddressSpace::default()).is_err());
}
{
let ty = PointerType::new(Type::Double, AddressSpace::default()).unwrap();
assert_eq!(ty.pointee(), &Type::Double);
let ty =
PointerType::new(Type::new_integer(32).unwrap(), AddressSpace::default()).unwrap();
assert_eq!(ty.pointee(), &Type::new_integer(32).unwrap());
}
}
}