use crate::error::{Error, Result};
use std::alloc::Layout;
use std::any::TypeId;
use std::cell::{Cell, UnsafeCell};
use std::collections::HashMap;
use std::ptr::NonNull;
use std::{fmt, mem};
pub trait Node {
fn register(self: &Self, r: &mut Register);
fn process(self: &mut Self);
fn reset(self: &mut Self);
}
pub struct NodeWrapper {
node: Box<dyn Node>,
inputs: HashMap<Name, InputRef>,
outputs: HashMap<Name, OutputRef>,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub struct Name {
pub name: &'static str,
pub index: usize,
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}({})", self.name, self.index)
}
}
impl From<&'static str> for Name {
fn from(s: &'static str) -> Self {
Self { name: s, index: 0 }
}
}
impl From<(&'static str, usize)> for Name {
fn from((s, i): (&'static str, usize)) -> Self {
Self { name: s, index: i }
}
}
impl NodeWrapper {
pub fn new(node: impl Node + 'static) -> Self {
let boxed: Box<dyn Node> = Box::new(node);
let mut register = Register::new(boxed.as_ref());
boxed.register(&mut register);
Self {
node: boxed,
inputs: register.inputs,
outputs: register.outputs,
}
}
pub fn get_input(&self, name: Name) -> Result<&InputRef> {
self.inputs.get(&name).ok_or(Error::UnknownInput(name))
}
pub fn get_output(&self, name: Name) -> Result<&OutputRef> {
self.outputs.get(&name).ok_or(Error::UnknownOutput(name))
}
pub fn process(&mut self) {
self.node.as_mut().process();
}
pub fn reset(&mut self) {
for output in self.outputs.values() {
output.reset();
}
self.node.as_mut().reset();
}
}
pub struct Register {
node_start: usize,
node_stop: usize,
inputs: HashMap<Name, InputRef>,
outputs: HashMap<Name, OutputRef>,
}
impl Register {
pub fn new<N: ?Sized>(for_node: &N) -> Self {
let node_start = for_node as *const N as *const () as usize;
Self {
node_start,
node_stop: node_start + mem::size_of_val(for_node),
inputs: HashMap::default(),
outputs: HashMap::default(),
}
}
pub fn input<T: 'static>(&mut self, name: impl Into<Name>, input: &Input<T>) {
let pos = input as *const Input<T> as usize;
let name = name.into();
assert!(
pos >= self.node_start && pos + mem::size_of::<Input<T>>() <= self.node_stop,
"The given input {} is not part of the node",
name
);
let old_value = self.inputs.insert(name, input.into());
assert!(old_value.is_none(), "The name {} is already taken", name);
}
pub fn output<T: 'static>(&mut self, name: impl Into<Name>, output: &Output<T>) {
let pos = output as *const Output<T> as usize;
let name = name.into();
assert!(
pos >= self.node_start && pos + mem::size_of::<Output<T>>() <= self.node_stop,
"The given output {} is not part of the node",
name
);
let old_value = self.outputs.insert(name, output.into());
assert!(old_value.is_none(), "The name {} is already taken", name);
}
}
pub struct Input<T> {
ptr: Cell<Option<NonNull<UnsafeCell<Option<T>>>>>,
}
impl<T> Default for Input<T> {
fn default() -> Self {
Self {
ptr: Cell::new(None),
}
}
}
impl<T> Input<T> {
pub fn get(&self) -> Option<T>
where
T: Clone,
{
self.ptr
.get()
.and_then(|o| unsafe { (&*o.as_ref().get()).clone() })
}
}
pub struct InputRef {
type_id: TypeId,
ptr: NonNull<Cell<Option<NonNull<()>>>>,
}
impl InputRef {
pub fn type_id(&self) -> TypeId {
self.type_id
}
pub unsafe fn set_target(&self, target: NonNull<()>) {
let r = self.ptr.as_ref();
r.set(Some(target));
}
}
impl PartialEq for InputRef {
fn eq(&self, other: &Self) -> bool {
self.ptr.as_ptr() as usize == other.ptr.as_ptr() as usize
}
}
impl Eq for InputRef {}
impl<T: 'static> From<&Input<T>> for InputRef {
fn from(i: &Input<T>) -> Self {
Self {
type_id: TypeId::of::<T>(),
ptr: NonNull::from(&i.ptr).cast(),
}
}
}
pub struct Output<T> {
ptr: Cell<Option<NonNull<UnsafeCell<Option<T>>>>>,
}
impl<T> Default for Output<T> {
fn default() -> Self {
Self {
ptr: Cell::new(None),
}
}
}
impl<T> Output<T> {
pub fn set(&self, value: impl Into<Option<T>>) {
if let Some(cell) = self.ptr.get() {
let item = unsafe { &mut *cell.as_ref().get() };
*item = value.into();
}
}
pub fn is_used(&self) -> bool {
self.ptr.get().is_some()
}
}
pub struct OutputRef {
type_id: TypeId,
layout: Layout,
reset: fn(*mut ()),
drop_fn: Option<fn(*mut ())>,
ptr: NonNull<Cell<Option<NonNull<()>>>>,
}
impl OutputRef {
pub fn type_id(&self) -> TypeId {
self.type_id
}
pub fn drop_fn(&self) -> Option<fn(*mut ())> {
self.drop_fn
}
pub fn layout(&self) -> Layout {
self.layout
}
pub unsafe fn set_target(&self, target: NonNull<()>) {
let r = self.ptr.as_ref();
r.set(Some(target));
}
pub fn get_target(&self) -> Option<NonNull<()>> {
unsafe { self.ptr.as_ref().get() }
}
pub fn reset(&self) {
unsafe {
if let Some(ptr) = self.ptr.as_ref().get() {
(self.reset)(ptr.as_ptr());
}
}
}
}
impl PartialEq for OutputRef {
fn eq(&self, other: &Self) -> bool {
self.ptr.as_ptr() as usize == other.ptr.as_ptr() as usize
}
}
impl Eq for OutputRef {}
impl<T: 'static> From<&Output<T>> for OutputRef {
fn from(o: &Output<T>) -> Self {
Self {
type_id: TypeId::of::<T>(),
layout: Layout::new::<UnsafeCell<Option<T>>>(),
drop_fn: if std::mem::needs_drop::<UnsafeCell<Option<T>>>() {
Some(|ptr| unsafe { std::ptr::drop_in_place(ptr as *mut UnsafeCell<Option<T>>) })
} else {
None
},
reset: |ptr| unsafe {
*(&*(ptr as *mut UnsafeCell<Option<T>>)).get() = None;
},
ptr: NonNull::from(&o.ptr).cast(),
}
}
}