use crate::mlir::MlirError;
use crate::typechecker::Type;
#[cfg(feature = "mlir")]
use melior::{
ir::{
r#type::{FunctionType, IntegerType},
Type as MlirType,
},
Context,
};
#[cfg(feature = "mlir")]
pub struct TypeLowering<'ctx> {
context: &'ctx Context,
}
#[cfg(feature = "mlir")]
impl<'ctx> TypeLowering<'ctx> {
pub fn new(context: &'ctx Context) -> Self {
Self { context }
}
pub fn lower(&self, ty: &Type) -> Result<MlirType<'ctx>, MlirError> {
match ty {
Type::Void => {
Ok(MlirType::tuple(self.context, &[]))
}
Type::Bool => Ok(IntegerType::new(self.context, 1).into()),
Type::Int8 => Ok(IntegerType::new(self.context, 8).into()),
Type::Int16 => Ok(IntegerType::new(self.context, 16).into()),
Type::Int32 => Ok(IntegerType::new(self.context, 32).into()),
Type::Int64 => Ok(IntegerType::new(self.context, 64).into()),
Type::UInt8 => Ok(IntegerType::new(self.context, 8).into()),
Type::UInt16 => Ok(IntegerType::new(self.context, 16).into()),
Type::UInt32 => Ok(IntegerType::new(self.context, 32).into()),
Type::UInt64 => Ok(IntegerType::new(self.context, 64).into()),
Type::Float32 => Ok(MlirType::float32(self.context)),
Type::Float64 => Ok(MlirType::float64(self.context)),
Type::String => {
Ok(MlirType::index(self.context)) }
Type::Function {
params,
return_type,
} => {
let param_types: Result<Vec<_>, _> = params.iter().map(|p| self.lower(p)).collect();
let param_types = param_types?;
let ret_type = self.lower(return_type)?;
let results = if matches!(**return_type, Type::Void) {
vec![]
} else {
vec![ret_type]
};
Ok(FunctionType::new(self.context, ¶m_types, &results).into())
}
Type::Tuple(types) => {
let element_types: Result<Vec<_>, _> =
types.iter().map(|t| self.lower(t)).collect();
let element_types = element_types?;
Ok(MlirType::tuple(self.context, &element_types))
}
Type::Generic { name, args } => self.lower_generic(name, args),
Type::Var(id) => Err(MlirError::new(format!(
"cannot lower unresolved type variable ?{}",
id
))),
Type::Unknown => Err(MlirError::new(
"cannot lower unknown type (type inference incomplete)",
)),
Type::Any => Err(MlirError::new(
"cannot lower 'Any' type (requires concrete type)",
)),
Type::Error => Err(MlirError::new("cannot lower error type")),
}
}
fn lower_generic(&self, name: &str, args: &[Type]) -> Result<MlirType<'ctx>, MlirError> {
match name {
"List" | "Array" => {
if args.is_empty() {
return Err(MlirError::new("List type requires type parameter"));
}
let elem_type = self.lower(&args[0])?;
Ok(elem_type)
}
"Option" | "Maybe" => {
if args.is_empty() {
return Err(MlirError::new("Option type requires type parameter"));
}
let elem_type = self.lower(&args[0])?;
let flag_type = IntegerType::new(self.context, 1).into();
Ok(MlirType::tuple(self.context, &[flag_type, elem_type]))
}
"Result" => {
if args.len() != 2 {
return Err(MlirError::new("Result type requires two type parameters"));
}
let ok_type = self.lower(&args[0])?;
let err_type = self.lower(&args[1])?;
let tag_type = IntegerType::new(self.context, 1).into();
Ok(MlirType::tuple(
self.context,
&[tag_type, ok_type, err_type],
))
}
"Quoted" => {
Ok(MlirType::index(self.context)) }
"TypeInfo" => Ok(MlirType::index(self.context)),
_ => Err(MlirError::new(format!(
"unsupported generic type constructor: {}",
name
))),
}
}
}
#[cfg(not(feature = "mlir"))]
pub struct TypeLowering<'ctx> {
_phantom: std::marker::PhantomData<&'ctx ()>,
}
#[cfg(not(feature = "mlir"))]
impl<'ctx> TypeLowering<'ctx> {
pub fn new(_context: &'ctx ()) -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
pub fn lower(&self, ty: &Type) -> Result<(), MlirError> {
Err(MlirError::new(format!(
"MLIR feature not enabled, cannot lower type: {}",
ty
)))
}
}
#[cfg(all(test, feature = "mlir"))]
mod tests {
use super::*;
use crate::mlir::MlirContext;
#[test]
fn test_lower_primitives() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
assert!(lowering.lower(&Type::Bool).is_ok());
assert!(lowering.lower(&Type::Int8).is_ok());
assert!(lowering.lower(&Type::Int16).is_ok());
assert!(lowering.lower(&Type::Int32).is_ok());
assert!(lowering.lower(&Type::Int64).is_ok());
assert!(lowering.lower(&Type::UInt8).is_ok());
assert!(lowering.lower(&Type::UInt16).is_ok());
assert!(lowering.lower(&Type::UInt32).is_ok());
assert!(lowering.lower(&Type::UInt64).is_ok());
assert!(lowering.lower(&Type::Float32).is_ok());
assert!(lowering.lower(&Type::Float64).is_ok());
}
#[test]
fn test_lower_void() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
assert!(lowering.lower(&Type::Void).is_ok());
}
#[test]
fn test_lower_function() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
let func_type = Type::Function {
params: vec![Type::Int32, Type::Bool],
return_type: Box::new(Type::Float64),
};
assert!(lowering.lower(&func_type).is_ok());
}
#[test]
fn test_lower_tuple() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
let tuple_type = Type::Tuple(vec![Type::Int32, Type::Bool, Type::Float64]);
assert!(lowering.lower(&tuple_type).is_ok());
}
#[test]
fn test_lower_generic_option() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
let option_type = Type::Generic {
name: "Option".to_string(),
args: vec![Type::Int32],
};
assert!(lowering.lower(&option_type).is_ok());
}
#[test]
fn test_lower_generic_list() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
let list_type = Type::Generic {
name: "List".to_string(),
args: vec![Type::String],
};
assert!(lowering.lower(&list_type).is_ok());
}
#[test]
fn test_lower_error_unknown() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
assert!(lowering.lower(&Type::Unknown).is_err());
assert!(lowering.lower(&Type::Error).is_err());
assert!(lowering.lower(&Type::Any).is_err());
}
#[test]
fn test_lower_error_unresolved_var() {
let ctx = MlirContext::new();
let lowering = TypeLowering::new(ctx.context());
assert!(lowering.lower(&Type::Var(42)).is_err());
}
}