use std::hash::{Hash, Hasher};
mod list_fold;
use std::str::FromStr;
use itertools::Itertools;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use strum_macros::{EnumIter, EnumString, IntoStaticStr};
use crate::extension::prelude::{either_type, option_type, USIZE_T};
use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE};
use crate::ops::constant::{maybe_hash_values, TryHash, ValueName};
use crate::ops::{OpName, Value};
use crate::types::{TypeName, TypeRowRV};
use crate::{
extension::{
simple_op::{MakeExtensionOp, OpLoadError},
ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, TypeDefBound,
},
ops::constant::CustomConst,
ops::{custom::ExtensionOp, NamedOp},
types::{
type_param::{TypeArg, TypeParam},
CustomCheckFailure, CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeBound,
},
Extension,
};
pub const LIST_TYPENAME: TypeName = TypeName::new_inline("List");
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections");
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ListValue(Vec<Value>, Type);
impl ListValue {
pub fn new(typ: Type, contents: impl IntoIterator<Item = Value>) -> Self {
Self(contents.into_iter().collect_vec(), typ)
}
pub fn new_empty(typ: Type) -> Self {
Self(vec![], typ)
}
pub fn custom_type(&self) -> CustomType {
list_custom_type(self.1.clone())
}
}
impl TryHash for ListValue {
fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
maybe_hash_values(&self.0, &mut st) && {
self.1.hash(&mut st);
true
}
}
}
#[typetag::serde]
impl CustomConst for ListValue {
fn name(&self) -> ValueName {
ValueName::new_inline("list")
}
fn get_type(&self) -> Type {
self.custom_type().into()
}
fn validate(&self) -> Result<(), CustomCheckFailure> {
let typ = self.custom_type();
let error = || {
CustomCheckFailure::Message("List type check fail.".to_string())
};
EXTENSION
.get_type(&LIST_TYPENAME)
.unwrap()
.check_custom(&typ)
.map_err(|_| error())?;
let [TypeArg::Type { ty }] = typ.args() else {
return Err(error());
};
for v in &self.0 {
if v.get_type() != *ty {
return Err(error());
}
}
Ok(())
}
fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}
fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs))
.union(EXTENSION_ID.into())
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(non_camel_case_types)]
#[non_exhaustive]
pub enum ListOp {
pop,
push,
get,
set,
insert,
length,
}
impl ListOp {
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
pub fn with_type(self, element_type: Type) -> ListOpInst {
ListOpInst {
elem_type: element_type,
op: self,
}
}
fn compute_signature(self, list_type_def: &TypeDef) -> SignatureFunc {
use ListOp::*;
let e = Type::new_var_use(0, TypeBound::Any);
let l = self.list_type(list_type_def, 0);
match self {
pop => self
.list_polytype(vec![l.clone()], vec![l, Type::from(option_type(e))])
.into(),
push => self.list_polytype(vec![l.clone(), e], vec![l]).into(),
get => self
.list_polytype(vec![l, USIZE_T], vec![Type::from(option_type(e))])
.into(),
set => self
.list_polytype(
vec![l.clone(), USIZE_T, e.clone()],
vec![l, Type::from(either_type(e.clone(), e))],
)
.into(),
insert => self
.list_polytype(
vec![l.clone(), USIZE_T, e.clone()],
vec![l, either_type(e, Type::UNIT).into()],
)
.into(),
length => self.list_polytype(vec![l.clone()], vec![l, USIZE_T]).into(),
}
}
fn list_polytype(
self,
input: impl Into<TypeRowRV>,
output: impl Into<TypeRowRV>,
) -> PolyFuncTypeRV {
PolyFuncTypeRV::new(vec![Self::TP], FuncValueType::new(input, output))
}
fn list_type(self, list_type_def: &TypeDef, idx: usize) -> Type {
Type::new_extension(
list_type_def
.instantiate(vec![TypeArg::new_var_use(idx, Self::TP)])
.unwrap(),
)
}
}
impl MakeOpDef for ListOp {
fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> {
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension())
}
fn extension(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}
fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> {
let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap());
let def = extension.add_op(self.name(), self.description(), sig)?;
self.post_opdef(def);
Ok(())
}
fn signature(&self) -> SignatureFunc {
self.compute_signature(list_type_def())
}
fn description(&self) -> String {
use ListOp::*;
match self {
pop => "Pop from the back of list. Returns an optional value.",
push => "Push to the back of list",
get => "Lookup an element in a list by index. Panics if the index is out of bounds.",
set => "Replace the element at index `i` with value `v`.",
insert => "Insert an element at index `i`. Elements at higher indices are shifted one position to the right. Panics if the index is out of bounds.",
length => "Get the length of a list",
}
.into()
}
fn post_opdef(&self, def: &mut OpDef) {
list_fold::set_fold(self, def)
}
}
lazy_static! {
pub static ref EXTENSION: Extension = {
let mut extension = Extension::new(EXTENSION_ID, VERSION);
extension.add_type(
LIST_TYPENAME,
vec![ListOp::TP],
"Generic dynamically sized list of type T.".into(),
TypeDefBound::from_params(vec![0]),
)
.unwrap();
ListOp::load_all_ops(&mut extension).unwrap();
extension
};
pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
EXTENSION.to_owned(),
])
.unwrap();
}
impl MakeRegisteredOp for ListOp {
fn extension_id(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}
fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
&COLLECTIONS_REGISTRY
}
}
pub fn list_type_def() -> &'static TypeDef {
EXTENSION.get_type(&LIST_TYPENAME).unwrap()
}
pub fn list_custom_type(elem_type: Type) -> CustomType {
list_type_def()
.instantiate(vec![TypeArg::Type { ty: elem_type }])
.unwrap()
}
pub fn list_type(elem_type: Type) -> Type {
list_custom_type(elem_type).into()
}
#[derive(Debug, Clone, PartialEq)]
pub struct ListOpInst {
op: ListOp,
elem_type: Type,
}
impl NamedOp for ListOpInst {
fn name(&self) -> OpName {
let name: &str = self.op.into();
name.into()
}
}
impl MakeExtensionOp for ListOpInst {
fn from_extension_op(
ext_op: &ExtensionOp,
) -> Result<Self, crate::extension::simple_op::OpLoadError> {
let [TypeArg::Type { ty }] = ext_op.args() else {
return Err(SignatureError::InvalidTypeArgs.into());
};
let name = ext_op.def().name();
let Ok(op) = ListOp::from_str(name) else {
return Err(OpLoadError::NotMember(name.to_string()));
};
Ok(Self {
elem_type: ty.clone(),
op,
})
}
fn type_args(&self) -> Vec<TypeArg> {
vec![TypeArg::Type {
ty: self.elem_type.clone(),
}]
}
}
impl ListOpInst {
pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option<ExtensionOp> {
let registry = ExtensionRegistry::try_new(
elem_type_registry
.clone()
.into_iter()
.filter_map(|(_, ext)| (ext.name() != EXTENSION.name()).then_some(ext))
.chain(std::iter::once(EXTENSION.to_owned())),
)
.unwrap();
ExtensionOp::new(
registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(),
self.type_args(),
®istry,
)
.ok()
}
}
#[cfg(test)]
mod test {
use rstest::rstest;
use crate::extension::prelude::{
const_fail_tuple, const_none, const_ok_tuple, const_some_tuple,
};
use crate::ops::OpTrait;
use crate::PortIndex;
use crate::{
extension::{
prelude::{ConstUsize, QB_T, USIZE_T},
PRELUDE,
},
std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE},
types::TypeRow,
};
use super::*;
#[test]
fn test_extension() {
assert_eq!(&ListOp::push.extension_id(), EXTENSION.name());
assert_eq!(&ListOp::push.extension(), EXTENSION.name());
assert!(ListOp::pop.registry().contains(EXTENSION.name()));
for (_, op_def) in EXTENSION.operations() {
assert_eq!(op_def.extension(), &EXTENSION_ID);
}
}
#[test]
fn test_list() {
let list_def = list_type_def();
let list_type = list_def
.instantiate([TypeArg::Type { ty: USIZE_T }])
.unwrap();
assert!(list_def
.instantiate([TypeArg::BoundedNat { n: 3 }])
.is_err());
list_def.check_custom(&list_type).unwrap();
let list_value = ListValue(vec![ConstUsize::new(3).into()], USIZE_T);
list_value.validate().unwrap();
let wrong_list_value = ListValue(vec![ConstF64::new(1.2).into()], USIZE_T);
assert!(wrong_list_value.validate().is_err());
}
#[test]
fn test_list_ops() {
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()])
.unwrap();
let pop_op = ListOp::pop.with_type(QB_T);
let pop_ext = pop_op.clone().to_extension_op(®).unwrap();
assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op);
let pop_sig = pop_ext.dataflow_signature().unwrap();
let list_t = list_type(QB_T);
let both_row: TypeRow = vec![list_t.clone(), option_type(QB_T).into()].into();
let just_list_row: TypeRow = vec![list_t].into();
assert_eq!(pop_sig.input(), &just_list_row);
assert_eq!(pop_sig.output(), &both_row);
let push_op = ListOp::push.with_type(FLOAT64_TYPE);
let push_ext = push_op.clone().to_extension_op(®).unwrap();
assert_eq!(ListOpInst::from_extension_op(&push_ext).unwrap(), push_op);
let push_sig = push_ext.dataflow_signature().unwrap();
let list_t = list_type(FLOAT64_TYPE);
let both_row: TypeRow = vec![list_t.clone(), FLOAT64_TYPE].into();
let just_list_row: TypeRow = vec![list_t].into();
assert_eq!(push_sig.input(), &both_row);
assert_eq!(push_sig.output(), &just_list_row);
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum TestVal {
Idx(usize),
List(Vec<usize>),
Elem(usize),
Some(Vec<TestVal>),
None(TypeRow),
Ok(Vec<TestVal>, TypeRow),
Err(TypeRow, Vec<TestVal>),
}
impl TestVal {
fn to_value(&self) -> Value {
match self {
TestVal::Idx(i) => Value::extension(ConstUsize::new(*i as u64)),
TestVal::Elem(e) => Value::extension(ConstUsize::new(*e as u64)),
TestVal::List(l) => {
let elems = l
.iter()
.map(|&i| Value::extension(ConstUsize::new(i as u64)))
.collect();
Value::extension(ListValue(elems, USIZE_T))
}
TestVal::Some(l) => {
let elems = l.iter().map(TestVal::to_value);
const_some_tuple(elems)
}
TestVal::None(tr) => const_none(tr.clone()),
TestVal::Ok(l, tr) => {
let elems = l.iter().map(TestVal::to_value);
const_ok_tuple(elems, tr.clone())
}
TestVal::Err(tr, l) => {
let elems = l.iter().map(TestVal::to_value);
const_fail_tuple(elems, tr.clone())
}
}
}
}
#[rstest]
#[case::pop(ListOp::pop, &[TestVal::List(vec![77,88, 42])], &[TestVal::List(vec![77,88]), TestVal::Some(vec![TestVal::Elem(42)])])]
#[case::pop_empty(ListOp::pop, &[TestVal::List(vec![])], &[TestVal::List(vec![]), TestVal::None(vec![USIZE_T].into())])]
#[case::push(ListOp::push, &[TestVal::List(vec![77,88]), TestVal::Elem(42)], &[TestVal::List(vec![77,88,42])])]
#[case::set(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,42]), TestVal::Ok(vec![TestVal::Elem(88)], vec![USIZE_T].into())])]
#[case::set_invalid(ListOp::set, &[TestVal::List(vec![77,88,42]), TestVal::Idx(123), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(vec![USIZE_T].into(), vec![TestVal::Elem(99)])])]
#[case::get(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1)], &[TestVal::Some(vec![TestVal::Elem(88)])])]
#[case::get_invalid(ListOp::get, &[TestVal::List(vec![77,88,42]), TestVal::Idx(99)], &[TestVal::None(vec![USIZE_T].into())])]
#[case::insert(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(1), TestVal::Elem(99)], &[TestVal::List(vec![77,99,88,42]), TestVal::Ok(vec![], vec![USIZE_T].into())])]
#[case::insert_invalid(ListOp::insert, &[TestVal::List(vec![77,88,42]), TestVal::Idx(52), TestVal::Elem(99)], &[TestVal::List(vec![77,88,42]), TestVal::Err(Type::UNIT.into(), vec![TestVal::Elem(99)])])]
#[case::length(ListOp::length, &[TestVal::List(vec![77,88,42])], &[TestVal::Elem(3)])]
fn list_fold(#[case] op: ListOp, #[case] inputs: &[TestVal], #[case] outputs: &[TestVal]) {
let consts: Vec<_> = inputs
.iter()
.enumerate()
.map(|(i, x)| (i.into(), x.to_value()))
.collect();
let res = op
.with_type(USIZE_T)
.to_extension_op(&COLLECTIONS_REGISTRY)
.unwrap()
.constant_fold(&consts)
.unwrap();
for (i, expected) in outputs.iter().enumerate() {
let expected = expected.to_value();
let res_val = res
.iter()
.find(|(port, _)| port.index() == i)
.unwrap()
.1
.clone();
assert_eq!(res_val, expected);
}
}
}