use crate::{error::DynamicError, scheme_value::SchemeValue};
use alloc::{boxed::Box, string::String, vec, vec::Vec};
use any_fn::AnyFn;
use bitvec::bitvec;
use core::any::TypeId;
use stak_vm::{Cons, Error, Heap, Memory, Number, PrimitiveSet, Type, Value};
use winter_maybe_async::maybe_async;
const MAXIMUM_ARGUMENT_COUNT: usize = 16;
type ArgumentVec<T> = heapless::Vec<T, MAXIMUM_ARGUMENT_COUNT>;
type SchemeType<H> = (
TypeId,
Box<dyn Fn(&Memory<H>, Value) -> Result<Option<any_fn::Value>, DynamicError>>,
Box<dyn Fn(&mut Memory<H>, any_fn::Value) -> Result<Value, DynamicError>>,
);
pub struct DynamicPrimitiveSet<'a, 'b, H> {
functions: &'a mut [(&'a str, AnyFn<'b>)],
types: Vec<SchemeType<H>>,
values: Vec<Option<any_fn::Value>>,
}
impl<'a, 'b, H: Heap> DynamicPrimitiveSet<'a, 'b, H> {
pub fn new(functions: &'a mut [(&'a str, AnyFn<'b>)]) -> Self {
let mut set = Self {
functions,
types: vec![],
values: vec![],
};
set.register_type::<bool>();
set.register_type::<i8>();
set.register_type::<u8>();
set.register_type::<i16>();
set.register_type::<u16>();
set.register_type::<i32>();
set.register_type::<u32>();
set.register_type::<i64>();
set.register_type::<u64>();
set.register_type::<f32>();
set.register_type::<f64>();
set.register_type::<isize>();
set.register_type::<usize>();
set.register_type::<String>();
set
}
pub fn register_type<T: SchemeValue<H> + 'static>(&mut self) {
self.types.push((
TypeId::of::<T>(),
Box::new(|memory, value| Ok(T::from_scheme(memory, value)?.map(any_fn::value))),
Box::new(|memory, value| T::into_scheme(value.downcast()?, memory)),
));
}
fn collect_garbages(&mut self, memory: &Memory<H>) -> Result<(), Error> {
let mut marks = bitvec![0; self.values.len()];
for index in 0..(memory.allocation_index() / 2) {
let cons = Cons::new((memory.allocation_start() + 2 * index) as _);
if memory.cdr(cons)?.tag() != Type::Foreign as _ {
continue;
}
marks.set(memory.car(cons)?.assume_number().to_i64() as _, true);
}
for (index, mark) in marks.into_iter().enumerate() {
if !mark {
self.values[index] = None;
}
}
Ok(())
}
fn find_free(&self) -> Option<usize> {
self.values.iter().position(Option::is_none)
}
fn allocate(&mut self, memory: &Memory<H>) -> Result<usize, Error> {
Ok(if let Some(index) = self.find_free() {
index
} else if let Some(index) = {
self.collect_garbages(memory)?;
self.find_free()
} {
index
} else {
self.values.push(None);
self.values.len() - 1
})
}
fn convert_from_scheme(
&self,
memory: &Memory<H>,
value: Value,
type_id: TypeId,
) -> Result<Option<any_fn::Value>, DynamicError> {
for (id, from, _) in &self.types {
if type_id == *id {
return from(memory, value);
}
}
Ok(None)
}
fn convert_into_scheme(
&mut self,
memory: &mut Memory<H>,
value: any_fn::Value,
) -> Result<Value, DynamicError> {
for (id, _, into) in &self.types {
if value.type_id()? == *id {
return into(memory, value);
}
}
let index = self.allocate(memory)?;
self.values[index] = Some(value);
Ok(memory
.allocate(
Number::from_i64(index as _).into(),
memory.null()?.set_tag(Type::Foreign as _).into(),
)?
.into())
}
}
impl<H: Heap> PrimitiveSet<H> for DynamicPrimitiveSet<'_, '_, H> {
type Error = DynamicError;
#[maybe_async]
fn operate(&mut self, memory: &mut Memory<H>, primitive: usize) -> Result<(), Self::Error> {
if primitive == 0 {
memory.set_register(memory.null()?);
for (name, _) in self.functions.iter().rev() {
let list = memory.cons(memory.null()?.into(), memory.register())?;
memory.set_register(list);
let string = memory.build_raw_string(name)?;
memory.set_car(memory.register(), string.into())?;
}
memory.push(memory.register().into())?;
Ok(())
} else {
let primitive = primitive - 1;
let (_, function) = self
.functions
.get(primitive)
.ok_or(Error::IllegalPrimitive)?;
let mut arguments = (0..function.arity())
.map(|_| memory.pop())
.collect::<Result<ArgumentVec<_>, _>>()?;
arguments.reverse();
let cloned_arguments = {
arguments
.iter()
.enumerate()
.map(|(index, &value)| {
self.convert_from_scheme(memory, value, function.parameter_types()[index])
})
.collect::<Result<ArgumentVec<_>, _>>()?
};
let mut copied_arguments = ArgumentVec::new();
for &value in &arguments {
let value =
if value.is_cons() && memory.cdr_value(value)?.tag() == Type::Foreign as _ {
Some(
self.values
.get(memory.car_value(value)?.assume_number().to_i64() as usize)
.ok_or(DynamicError::ValueIndex)?
.as_ref()
.ok_or(DynamicError::ValueIndex)?,
)
} else {
None
};
copied_arguments
.push(value)
.map_err(|_| Error::ArgumentCount)?;
}
let value = self
.functions
.get_mut(primitive)
.ok_or(Error::IllegalPrimitive)?
.1
.call(
copied_arguments
.into_iter()
.enumerate()
.map(|(index, value)| {
cloned_arguments[index]
.as_ref()
.map_or_else(|| value.ok_or(DynamicError::ForeignValueExpected), Ok)
})
.collect::<Result<ArgumentVec<_>, DynamicError>>()?
.as_slice(),
)?;
let value = self.convert_into_scheme(memory, value)?;
memory.push(value)?;
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use any_fn::{Ref, r#fn, value};
use winter_maybe_async::maybe_await;
const HEAP_SIZE: usize = 1 << 8;
struct Foo {
bar: usize,
}
impl Foo {
const fn new(bar: usize) -> Self {
Self { bar }
}
const fn bar(&self) -> usize {
self.bar
}
fn baz(&mut self, value: usize) {
self.bar += value;
}
}
#[test]
fn create() {
let mut functions = [
("make-foo", r#fn(Foo::new)),
("foo-bar", r#fn::<(Ref<_>,), _>(Foo::bar)),
("foo-baz", r#fn(Foo::baz)),
];
DynamicPrimitiveSet::<[Value; 0]>::new(&mut functions);
}
#[test]
fn allocate_two() {
let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
let mut memory = Memory::new([Default::default(); HEAP_SIZE]).unwrap();
let index = primitive_set.allocate(&memory).unwrap();
primitive_set.values[index] = Some(value(42usize));
assert_eq!(index, 0);
assert_eq!(primitive_set.find_free(), None);
let cons = memory
.allocate(
Number::from_i64(index as _).into(),
memory.null().unwrap().set_tag(Type::Foreign as _).into(),
)
.unwrap();
memory.push(cons.into()).unwrap();
let index = primitive_set.allocate(&memory).unwrap();
primitive_set.values[index] = Some(value(42usize));
assert_eq!(index, 1);
assert_eq!(primitive_set.find_free(), None);
}
mod garbage_collection {
use super::*;
#[test]
fn collect_none() {
let mut primitive_set = DynamicPrimitiveSet::new(&mut []);
primitive_set
.collect_garbages(&Memory::new([Default::default(); HEAP_SIZE]).unwrap())
.unwrap();
}
#[tokio::test]
async fn collect_one() {
let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
let mut memory = Memory::new([Default::default(); HEAP_SIZE]).unwrap();
maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
assert_eq!(primitive_set.find_free(), None);
memory.pop().unwrap();
memory.collect_garbages(None).unwrap();
primitive_set.collect_garbages(&memory).unwrap();
assert_eq!(primitive_set.find_free(), Some(0));
}
#[tokio::test]
async fn keep_one() {
let mut functions = [("make-foo", r#fn(|| Foo { bar: 42 }))];
let mut primitive_set = DynamicPrimitiveSet::new(&mut functions);
let mut memory = Memory::new([Default::default(); HEAP_SIZE]).unwrap();
maybe_await!(primitive_set.operate(&mut memory, 1)).unwrap();
assert_eq!(primitive_set.find_free(), None);
primitive_set.collect_garbages(&memory).unwrap();
assert_eq!(primitive_set.find_free(), None);
}
}
}