Skip to main content

plexor_core/erasure/
neuron.rs

1// Copyright 2025 Alecks Gates
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Type-erased versions of neuron traits.
8
9use crate::codec::{Codec, CodecName};
10use crate::neuron::{Neuron, NeuronError};
11use std::any::{Any, TypeId};
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15/// Type-erased neuron that can be stored in collections with other type-erased neurons
16pub trait NeuronErased: Send + Sync + 'static {
17    fn name(&self) -> String;
18    fn name_without_codec(&self) -> String;
19    fn schema(&self) -> String;
20    fn payload_type_id(&self) -> TypeId;
21    fn codec_type_id(&self) -> TypeId;
22    fn clone_to_box(&self) -> Box<dyn NeuronErased + Send + Sync + 'static>;
23    fn clone_to_arc(&self) -> Arc<dyn NeuronErased + Send + Sync + 'static>;
24    fn encode_any(&self, data: &dyn std::any::Any) -> Result<Vec<u8>, NeuronError>;
25    fn decode_any(&self, data: &[u8]) -> Result<Box<dyn std::any::Any + Send>, NeuronError>;
26    fn as_any(&self) -> &dyn Any;
27}
28
29/// Wrapper that implements NeuronErased for any concrete Neuron
30pub struct NeuronErasedWrapper<T: 'static, C: 'static> {
31    neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>,
32    _phantom: PhantomData<(T, C)>,
33}
34
35impl<T, C> NeuronErasedWrapper<T, C>
36where
37    T: Send + Sync + 'static,
38    C: Codec<T> + CodecName + Send + Sync + 'static,
39{
40    pub fn new(neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>) -> Self {
41        Self {
42            neuron,
43            _phantom: PhantomData,
44        }
45    }
46
47    /// Create a type-erased neuron from a correctly typed neuron
48    pub fn from_typed_neuron(
49        neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>,
50    ) -> Arc<dyn NeuronErased + Send + Sync + 'static> {
51        Arc::new(Self::new(neuron))
52    }
53
54    /// Get the underlying typed neuron
55    pub fn get_typed_neuron(&self) -> Arc<dyn Neuron<T, C> + Send + Sync + 'static> {
56        self.neuron.clone()
57    }
58}
59
60impl<T, C> NeuronErased for NeuronErasedWrapper<T, C>
61where
62    T: Send + Sync + 'static,
63    C: Codec<T> + CodecName + Send + Sync + 'static,
64{
65    fn name(&self) -> String {
66        self.neuron.name()
67    }
68
69    fn name_without_codec(&self) -> String {
70        self.neuron.name_without_codec()
71    }
72
73    fn schema(&self) -> String {
74        self.neuron.schema()
75    }
76
77    fn payload_type_id(&self) -> TypeId {
78        TypeId::of::<T>()
79    }
80
81    fn codec_type_id(&self) -> TypeId {
82        TypeId::of::<C>()
83    }
84
85    fn clone_to_box(&self) -> Box<dyn NeuronErased + Send + Sync + 'static> {
86        Box::new(NeuronErasedWrapper {
87            neuron: self.neuron.clone(),
88            _phantom: PhantomData,
89        })
90    }
91
92    fn clone_to_arc(&self) -> Arc<dyn NeuronErased + Send + Sync + 'static> {
93        Arc::new(NeuronErasedWrapper {
94            neuron: self.neuron.clone(),
95            _phantom: PhantomData,
96        })
97    }
98
99    fn encode_any(&self, data: &dyn std::any::Any) -> Result<Vec<u8>, NeuronError> {
100        if let Some(typed_data) = data.downcast_ref::<T>() {
101            self.neuron.encode(typed_data)
102        } else {
103            Err(NeuronError::Encode {
104                neuron_name: self.name(),
105                message: "Type mismatch in encode_any".to_string(),
106            })
107        }
108    }
109
110    fn decode_any(&self, data: &[u8]) -> Result<Box<dyn std::any::Any + Send>, NeuronError> {
111        let decoded = self.neuron.decode(data)?;
112        Ok(Box::new(decoded) as Box<dyn std::any::Any + Send>)
113    }
114
115    fn as_any(&self) -> &dyn Any {
116        self
117    }
118}
119
120impl<T, C> Clone for NeuronErasedWrapper<T, C>
121where
122    T: Send + Sync + 'static,
123    C: Codec<T> + CodecName + Send + Sync + 'static,
124{
125    fn clone(&self) -> Self {
126        Self {
127            neuron: self.neuron.clone(),
128            _phantom: PhantomData,
129        }
130    }
131}
132
133/// Convenience function to create a type-erased neuron from a correctly typed neuron
134pub fn erase_neuron<T, C>(
135    neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>,
136) -> Arc<dyn NeuronErased + Send + Sync + 'static>
137where
138    T: Send + Sync + 'static,
139    C: Codec<T> + CodecName + Send + Sync + 'static,
140{
141    NeuronErasedWrapper::from_typed_neuron(neuron)
142}