use std::marker::PhantomData;
use std::str::FromStr;
use std::sync::{Arc, Weak};
use crate::Extension;
use crate::extension::simple_op::{
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
};
use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef};
use crate::ops::{ExtensionOp, NamedOp, OpName};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound};
use super::array_kind::ArrayKind;
pub type Direction = bool;
pub const INTO: Direction = true;
pub const FROM: Direction = false;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
pub struct GenericArrayConvertDef<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>(
PhantomData<AK>,
PhantomData<OtherAK>,
);
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
GenericArrayConvertDef<AK, DIR, OtherAK>
{
#[must_use]
pub fn new() -> Self {
GenericArrayConvertDef(PhantomData, PhantomData)
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> Default
for GenericArrayConvertDef<AK, DIR, OtherAK>
{
fn default() -> Self {
Self::new()
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> FromStr
for GenericArrayConvertDef<AK, DIR, OtherAK>
{
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let def = GenericArrayConvertDef::new();
if s == def.opdef_id() {
Ok(def)
} else {
Err(())
}
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
GenericArrayConvertDef<AK, DIR, OtherAK>
{
fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
let params = vec![TypeParam::max_nat_type(), TypeBound::Linear.into()];
let size = TypeArg::new_var_use(0, TypeParam::max_nat_type());
let element_ty = Type::new_var_use(1, TypeBound::Linear);
let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone())
.expect("Array type instantiation failed");
let other_ty =
OtherAK::ty_parametric(size, element_ty).expect("Array type instantiation failed");
let sig = match DIR {
INTO => FuncValueType::new([this_ty], [other_ty]),
FROM => FuncValueType::new([other_ty], [this_ty]),
};
PolyFuncTypeRV::new(params, sig).into()
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeOpDef
for GenericArrayConvertDef<AK, DIR, OtherAK>
{
fn opdef_id(&self) -> OpName {
match DIR {
INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
}
}
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
where
Self: Sized,
{
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
}
fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
self.signature_from_def(AK::type_def())
}
fn extension_ref(&self) -> Weak<Extension> {
Arc::downgrade(AK::extension())
}
fn extension(&self) -> ExtensionId {
AK::EXTENSION_ID
}
fn description(&self) -> String {
match DIR {
INTO => format!("Turns `{}` into `{}`", AK::TYPE_NAME, OtherAK::TYPE_NAME),
FROM => format!("Turns `{}` into `{}`", OtherAK::TYPE_NAME, AK::TYPE_NAME),
}
}
fn add_to_extension(
&self,
extension: &mut Extension,
extension_ref: &Weak<Extension>,
) -> Result<(), crate::extension::ExtensionBuildError> {
let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap());
let def = extension.add_op(self.opdef_id(), self.description(), sig, extension_ref)?;
self.post_opdef(def);
Ok(())
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct GenericArrayConvert<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> {
pub elem_ty: Type,
pub size: u64,
_kind: PhantomData<AK>,
_other_kind: PhantomData<OtherAK>,
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
GenericArrayConvert<AK, DIR, OtherAK>
{
#[must_use]
pub fn new(elem_ty: Type, size: u64) -> Self {
GenericArrayConvert {
elem_ty,
size,
_kind: PhantomData,
_other_kind: PhantomData,
}
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> NamedOp
for GenericArrayConvert<AK, DIR, OtherAK>
{
fn name(&self) -> OpName {
match DIR {
INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
}
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeExtensionOp
for GenericArrayConvert<AK, DIR, OtherAK>
{
fn op_id(&self) -> OpName {
GenericArrayConvertDef::<AK, DIR, OtherAK>::new().opdef_id()
}
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
where
Self: Sized,
{
let def = GenericArrayConvertDef::<AK, DIR, OtherAK>::from_def(ext_op.def())?;
def.instantiate(ext_op.args())
}
fn type_args(&self) -> Vec<TypeArg> {
vec![TypeArg::BoundedNat(self.size), self.elem_ty.clone().into()]
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeRegisteredOp
for GenericArrayConvert<AK, DIR, OtherAK>
{
fn extension_id(&self) -> ExtensionId {
AK::EXTENSION_ID
}
fn extension_ref(&self) -> Arc<Extension> {
AK::extension().clone()
}
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasDef
for GenericArrayConvert<AK, DIR, OtherAK>
{
type Def = GenericArrayConvertDef<AK, DIR, OtherAK>;
}
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasConcrete
for GenericArrayConvertDef<AK, DIR, OtherAK>
{
type Concrete = GenericArrayConvert<AK, DIR, OtherAK>;
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
match type_args {
[TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] => {
Ok(GenericArrayConvert::new(ty.clone(), *n))
}
_ => Err(SignatureError::InvalidTypeArgs.into()),
}
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use crate::extension::prelude::bool_t;
use crate::ops::{OpTrait, OpType};
use crate::std_extensions::collections::array::Array;
use crate::std_extensions::collections::borrow_array::BorrowArray;
use super::*;
#[rstest]
#[case(BorrowArray, Array)]
fn test_convert_from_def<AK: ArrayKind, OtherAK: ArrayKind>(
#[case] _kind: AK,
#[case] _other_kind: OtherAK,
) {
let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(bool_t(), 2);
let optype: OpType = op.clone().into();
let new_op: GenericArrayConvert<AK, FROM, OtherAK> = optype.cast().unwrap();
assert_eq!(new_op, op);
}
#[rstest]
#[case(BorrowArray, Array)]
fn test_convert_into_def<AK: ArrayKind, OtherAK: ArrayKind>(
#[case] _kind: AK,
#[case] _other_kind: OtherAK,
) {
let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(bool_t(), 2);
let optype: OpType = op.clone().into();
let new_op: GenericArrayConvert<AK, INTO, OtherAK> = optype.cast().unwrap();
assert_eq!(new_op, op);
}
#[rstest]
#[case(BorrowArray, Array)]
fn test_convert_from<AK: ArrayKind, OtherAK: ArrayKind>(
#[case] _kind: AK,
#[case] _other_kind: OtherAK,
) {
let size = 2;
let element_ty = bool_t();
let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(element_ty.clone(), size);
let optype: OpType = op.into();
let sig = optype.dataflow_signature().unwrap();
assert_eq!(
sig.io(),
(
&vec![OtherAK::ty(size, element_ty.clone())].into(),
&vec![AK::ty(size, element_ty.clone())].into(),
)
);
}
#[rstest]
#[case(BorrowArray, Array)]
fn test_convert_into<AK: ArrayKind, OtherAK: ArrayKind>(
#[case] _kind: AK,
#[case] _other_kind: OtherAK,
) {
let size = 2;
let element_ty = bool_t();
let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(element_ty.clone(), size);
let optype: OpType = op.into();
let sig = optype.dataflow_signature().unwrap();
assert_eq!(
sig.io(),
(
&vec![AK::ty(size, element_ty.clone())].into(),
&vec![OtherAK::ty(size, element_ty.clone())].into(),
)
);
}
}