use crate::{
NodeTypeMismatchError, NodesNotFoundError, RegistryError,
node::{AnyNode, Node},
stage::{Stage, StageShape, ValueStage, ValueWrapper},
};
use std::{any::TypeId, marker::PhantomData};
#[derive(Debug, Clone, Copy, Eq, Hash, Ord)]
#[repr(transparent)]
pub struct NodeId<S: Stage>(pub(crate) usize, pub(crate) PhantomData<S>);
impl<S: Stage> From<NodeId<S>> for usize {
fn from(value: NodeId<S>) -> Self {
value.0
}
}
impl<S0: Stage, S1: Stage> PartialEq<NodeId<S1>> for NodeId<S0> {
fn eq(&self, other: &NodeId<S1>) -> bool {
self.0 == other.0
}
}
impl<S0: Stage, S1: Stage> PartialOrd<NodeId<S1>> for NodeId<S0> {
fn partial_cmp(&self, other: &NodeId<S1>) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl<S: Stage> NodeId<S> {
pub fn stage_shape(&self) -> &'static StageShape {
Into::<NodeReflection>::into(self).shape
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct NodeReflection {
pub(crate) id: usize,
pub(crate) shape: &'static StageShape,
}
impl<S: Stage + 'static> From<NodeId<S>> for NodeReflection {
fn from(value: NodeId<S>) -> Self {
Self {
id: value.0,
shape: &S::SHAPE,
}
}
}
impl<S: Stage + 'static> From<&NodeId<S>> for NodeReflection {
fn from(value: &NodeId<S>) -> Self {
Self {
id: value.0,
shape: &S::SHAPE,
}
}
}
impl PartialOrd for NodeReflection {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.id.partial_cmp(&other.id)
}
}
impl Ord for NodeReflection {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.id.cmp(&other.id)
}
}
pub struct Registry(
pub(super) Vec<Option<Box<dyn AnyNode>>>,
#[cfg(feature = "tokio")]
pub(super) Vec<(
tokio::sync::watch::Sender<bool>,
tokio::sync::watch::Receiver<bool>,
)>,
);
impl Registry {
pub fn new() -> Self {
Self(
Vec::new(),
#[cfg(feature = "tokio")]
Vec::new(),
)
}
pub fn state<S: Stage + 'static>(&self, id: NodeId<S>) -> Result<&S::State, RegistryError> {
self.validate_node_type::<S>(&id)?;
match self.0.get(id.0) {
Some(Some(any_node)) => match any_node.as_any().downcast_ref::<Node<S>>() {
Some(node) => Ok(&node.state),
None => Err(NodeTypeMismatchError {
got: any_node.type_id(),
expected: TypeId::of::<Node<S>>(),
}
.into()),
},
None | Some(None) => {
Err(NodesNotFoundError::from(&[id.into()] as &[NodeReflection]).into())
}
}
}
pub fn state_mut<S: Stage + 'static>(
&mut self,
id: NodeId<S>,
) -> Result<&mut S::State, RegistryError> {
self.validate_node_type::<S>(&id)?;
match self.0.get_mut(id.0) {
Some(Some(any_node)) => {
let node_type_id = any_node.type_id();
match any_node.as_any_mut().downcast_mut::<Node<S>>() {
Some(node) => Ok(&mut node.state),
None => Err(NodeTypeMismatchError {
got: node_type_id,
expected: TypeId::of::<Node<S>>(),
}
.into()),
}
}
None | Some(None) => {
Err(NodesNotFoundError::from(&[id.into()] as &[NodeReflection]).into())
}
}
}
pub fn value<T: Send + Sync + Clone + 'static>(&mut self, value: T) -> NodeId<ValueStage<T>> {
let next = self.0.len();
self.0.push(Some(Box::new(Node::new(
ValueStage::<T>::new(),
ValueWrapper(Some(value)),
))));
#[cfg(feature = "tokio")]
self.1.push(tokio::sync::watch::channel(true));
NodeId(next, PhantomData)
}
pub fn register<S: Stage + Send + Sync + 'static>(&mut self, stage: S) -> NodeId<S>
where
S::State: Default,
{
let next = self.0.len();
self.0
.push(Some(Box::new(Node::new(stage, S::State::default()))));
#[cfg(feature = "tokio")]
self.1.push(tokio::sync::watch::channel(true));
NodeId(next, PhantomData)
}
pub fn register_with_state<S: Stage + Send + Sync + 'static>(
&mut self,
stage: S,
state: S::State,
) -> NodeId<S> {
let next = self.0.len();
self.0.push(Some(Box::new(Node::new(stage, state))));
#[cfg(feature = "tokio")]
self.1.push(tokio::sync::watch::channel(true));
NodeId(next, PhantomData)
}
pub fn validate_node_type<S: Stage + 'static>(
&self,
id: impl Into<NodeReflection>,
) -> Result<(), RegistryError> {
let id = id.into();
match self.0.get(id.id) {
Some(Some(node)) => match node.as_any().downcast_ref::<Node<S>>() {
Some(_) => Ok(()),
None => Err(NodeTypeMismatchError {
got: TypeId::of::<Node<S>>(),
expected: node.as_any().type_id(),
}
.into()),
},
None | Some(None) => {
Err(NodesNotFoundError::from(&[id.clone().into()] as &[NodeReflection]).into())
}
}
}
pub fn unregister<S: Stage + 'static>(
&mut self,
id: NodeReflection,
) -> Result<Option<Node<S>>, RegistryError> {
self.validate_node_type::<S>(id)?;
match self.0.get_mut(id.id) {
Some(maybe_node) => maybe_node
.take()
.map(|node| match node.into_any().downcast() {
Ok(node) => Ok(Some(*node)),
Err(node) => Err(NodeTypeMismatchError {
got: TypeId::of::<Node<S>>(),
expected: node.type_id(),
}
.into()),
})
.unwrap_or(Ok(None)),
None => Ok(None),
}
}
pub fn unregister_and_drop(
&mut self,
id: impl Into<NodeReflection>,
) -> Result<(), RegistryError> {
let id = id.into();
match self.0.get_mut(id.id).take().map(drop) {
Some(_) => Ok(()),
None => Err(NodesNotFoundError::from(&[id.into()] as &[NodeReflection]).into()),
}
}
pub fn get_node<S: Stage + 'static>(&self, id: NodeId<S>) -> Option<&Node<S>> {
match self.0.get(id.0) {
Some(Some(node)) => node.as_any().downcast_ref(),
Some(None) => None,
None => None,
}
}
pub fn get_node_any(&self, id: impl Into<NodeReflection>) -> Option<&Box<dyn AnyNode>> {
match self.0.get(id.into().id) {
Some(Some(node)) => Some(node),
Some(None) => None,
None => None,
}
}
pub fn get_node_mut<S: Stage + 'static>(&mut self, id: NodeId<S>) -> Option<&mut Node<S>> {
match self.0.get_mut(id.0) {
Some(Some(node)) => node.as_any_mut().downcast_mut(),
Some(None) => None,
None => None,
}
}
pub fn get_node_any_mut(
&mut self,
id: impl Into<NodeReflection>,
) -> Option<&mut Box<dyn AnyNode>> {
match self.0.get_mut(id.into().id) {
Some(Some(node)) => Some(node),
Some(None) => None,
None => None,
}
}
pub fn get2_nodes_any_mut(
&mut self,
id0: NodeReflection,
id1: NodeReflection,
) -> Result<(&mut Box<dyn AnyNode>, &mut Box<dyn AnyNode>), NodesNotFoundError> {
if id0.id == id1.id {
panic!(
"Attempted to borrow node id {:?} twice",
Into::<NodeReflection>::into(id0)
)
}
let first_id = std::cmp::min(id0.id, id1.id);
let second_id = std::cmp::max(id0.id, id1.id);
match (self.0.len() < id0.id, self.0.len() < id1.id) {
(true, true) => {
return Err(NodesNotFoundError::from(
&[id0.into(), id1.into()] as &[NodeReflection]
));
}
(true, false) => {
return Err(NodesNotFoundError::from(&[id0.into()] as &[NodeReflection]));
}
(false, true) => {
return Err(NodesNotFoundError::from(&[id1.into()] as &[NodeReflection]));
}
_ => (),
}
if let [first, .., second] = &mut self.0[first_id..=second_id] {
match (first, second) {
(None, None) => Err(NodesNotFoundError::from(
&[id0.into(), id1.into()] as &[NodeReflection]
)),
(None, Some(_)) => {
Err(NodesNotFoundError::from(&[id0.into()] as &[NodeReflection]))
}
(Some(_), None) => {
Err(NodesNotFoundError::from(&[id1.into()] as &[NodeReflection]))
}
(Some(first), Some(second)) => {
if first_id == id1.id {
Ok((first, second))
} else {
Ok((second, first))
}
}
}
} else {
unreachable!()
}
}
#[cfg(feature = "tokio")]
pub fn node_availability(
&self,
id: NodeReflection,
) -> Option<tokio::sync::watch::Receiver<bool>> {
match self.1.get(id.id) {
Some((_, rx)) => Some(rx.clone()),
None => None,
}
}
#[cfg(feature = "tokio")]
pub async fn take_node(&mut self, id: NodeReflection) -> Option<Box<dyn AnyNode>> {
match self.0.get_mut(id.id) {
Some(maybe_node) => {
match maybe_node.take() {
Some(node) => {
self.1.get_mut(id.id).unwrap().0.send(false).unwrap();
Some(node)
}
None => {
panic!("TODO: Handle missing node")
}
}
}
None => None,
}
}
pub fn replace_node(&mut self, id: NodeReflection, node: Box<dyn AnyNode>) {
*self.0.get_mut(id.id).unwrap() = Some(node);
#[cfg(feature = "tokio")]
self.1.get_mut(id.id).unwrap().0.send(true).unwrap();
}
pub fn get_inputs<S: Stage + 'static>(&self, node_id: NodeId<S>) -> Option<&S::Input> {
self.get_node(node_id).map(|node| &node.inputs)
}
pub fn get_inputs_mut<S: Stage + 'static>(
&mut self,
node_id: NodeId<S>,
) -> Option<&mut S::Input> {
self.get_node_mut(node_id).map(|node| &mut node.inputs)
}
pub fn get_outputs<S: Stage + 'static>(&self, node_id: NodeId<S>) -> Option<&S::Output> {
self.get_node(node_id).map(|node| &node.outputs)
}
pub fn get_outputs_mut<S: Stage + 'static>(
&mut self,
node_id: NodeId<S>,
) -> Option<&mut S::Output> {
self.get_node_mut(node_id).map(|node| &mut node.outputs)
}
}