use tvm_macros::Object;
use tvm_rt::{array::Array, DataType};
use crate::ir::relay::Constructor;
use crate::ir::span::Span;
use crate::ir::PrimExpr;
use crate::runtime::{string::String as TString, IsObject, IsObjectRef, Object, ObjectPtr};
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "Type"]
#[type_key = "Type"]
pub struct TypeNode {
pub base: Object,
pub span: Span,
}
impl TypeNode {
fn base<T: IsObject>(span: Span) -> Self {
TypeNode {
base: Object::base::<T>(),
span,
}
}
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "PrimType"]
#[type_key = "PrimType"]
pub struct PrimTypeNode {
pub base: TypeNode,
pub dtype: DataType,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "PointerType"]
#[type_key = "PointerType"]
pub struct PointerTypeNode {
pub base: TypeNode,
pub element_type: Type,
}
#[derive(PartialEq, Eq, Debug)]
pub enum TypeKind {
Type = 0,
ShapeVar = 1,
Constraint = 4,
AdtHandle = 5,
TypeData = 6,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TypeVar"]
#[type_key = "TypeVar"]
pub struct TypeVarNode {
pub base: TypeNode,
pub name_hint: TString,
pub kind: TypeKind,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "GlobalTypeVar"]
#[type_key = "GlobalTypeVar"]
pub struct GlobalTypeVarNode {
pub base: TypeNode,
pub name_hint: TString,
pub kind: TypeKind,
}
impl GlobalTypeVar {
pub fn new<S>(name_hint: S, kind: TypeKind, span: Span) -> GlobalTypeVar
where
S: Into<TString>,
{
let node = GlobalTypeVarNode {
base: TypeNode::base::<GlobalTypeVarNode>(span),
name_hint: name_hint.into(),
kind: kind,
};
ObjectPtr::new(node).into()
}
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TupleType"]
#[type_key = "TupleType"]
pub struct TupleTypeNode {
pub base: TypeNode,
pub fields: Array<Type>,
}
impl TupleType {
pub fn new(fields: Vec<Type>, span: Span) -> Self {
let node = TupleTypeNode {
base: TypeNode::base::<TupleTypeNode>(span),
fields: Array::from_vec(fields).unwrap(),
};
ObjectPtr::new(node).into()
}
pub fn empty() -> TupleType {
TupleType::new(vec![], Span::null())
}
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TypeConstraint"]
#[type_key = "TypeConstraint"]
pub struct TypeConstraintNode {
pub base: TypeNode,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "FuncType"]
#[type_key = "FuncType"]
pub struct FuncTypeNode {
pub base: TypeNode,
pub arg_types: Array<Type>,
pub ret_type: Type,
pub type_params: Array<TypeVar>,
pub type_constraints: Array<TypeConstraint>,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "IncompleteType"]
#[type_key = "IncompleteType"]
pub struct IncompleteTypeNode {
pub base: TypeNode,
pub kind: TypeKind,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "RefType"]
#[type_key = "relay.RefType"]
pub struct RelayRefTypeNode {
pub base: TypeNode,
pub value: Type,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "BaseTensorType"]
#[type_key = "relay.BaseTensorType"]
pub struct BaseTensorTypeNode {
pub base: TypeNode,
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TensorType"]
#[type_key = "relay.TensorType"]
pub struct TensorTypeNode {
pub base: TypeNode,
pub shape: Array<PrimExpr>,
pub dtype: DataType,
}
impl TensorType {
pub fn new(shape: Array<PrimExpr>, dtype: DataType, span: Span) -> TensorType {
let node = TensorTypeNode {
base: TypeNode::base::<TensorTypeNode>(span),
shape,
dtype,
};
ObjectPtr::new(node).into()
}
pub fn static_sh(shape: Vec<i32>, dtype: DataType, span: Span) -> TensorType {
let sh = Array::from_vec(shape.into_iter().map(Into::into).collect()).unwrap();
Self::new(sh, dtype, span)
}
}
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TypeData"]
#[type_key = "relay.TypeData"]
pub struct TypeDataNode {
pub base: TypeNode,
pub type_name: GlobalTypeVar,
pub type_vars: Array<TypeVar>,
pub constructors: Array<Constructor>,
}
impl TypeData {
pub fn new<TypeVars, Ctors>(
type_name: GlobalTypeVar,
type_vars: TypeVars,
constructors: Ctors,
span: Span,
) -> TypeData
where
TypeVars: IntoIterator<Item = TypeVar>,
Ctors: IntoIterator<Item = Constructor>,
{
use std::iter::FromIterator;
let type_data = TypeDataNode {
base: TypeNode::base::<TypeDataNode>(span),
type_name,
type_vars: Array::from_iter(type_vars),
constructors: Array::from_iter(constructors),
};
TypeData(Some(ObjectPtr::new(type_data)))
}
}