Skip to main content

plexor_core/erasure/
neuron.rs

1// Copyright 2025 Alecks Gates
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Type-erased versions of neuron traits.
8
9use crate::codec::{Codec, CodecName};
10use crate::erasure::error::{ErasureError, ErasureResult};
11use crate::neuron::{Neuron, NeuronError};
12use std::any::{Any, TypeId};
13use std::marker::PhantomData;
14use std::sync::Arc;
15
16/// Type-erased neuron that can be stored in collections with other type-erased neurons
17pub trait NeuronErased: Send + Sync + 'static {
18    fn name(&self) -> String;
19    fn name_without_codec(&self) -> String;
20    fn schema(&self) -> String;
21    fn payload_type_id(&self) -> TypeId;
22    fn codec_type_id(&self) -> TypeId;
23    fn clone_to_box(&self) -> Box<dyn NeuronErased + Send + Sync + 'static>;
24    fn clone_to_arc(&self) -> Arc<dyn NeuronErased + Send + Sync + 'static>;
25    fn encode_any(&self, data: &dyn std::any::Any) -> Result<Vec<u8>, NeuronError>;
26    fn decode_any(&self, data: &[u8]) -> Result<Box<dyn std::any::Any>, NeuronError>;
27    fn as_any(&self) -> &dyn Any;
28}
29
30/// Wrapper that implements NeuronErased for any concrete Neuron
31pub struct NeuronErasedWrapper<T, C> {
32    neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>,
33    _phantom: PhantomData<(T, C)>,
34}
35
36impl<T, C> NeuronErasedWrapper<T, C>
37where
38    T: Send + Sync + 'static,
39    C: Codec<T> + CodecName + Send + Sync + 'static,
40{
41    pub fn new(neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>) -> Self {
42        Self {
43            neuron,
44            _phantom: PhantomData,
45        }
46    }
47
48    /// Create a type-erased neuron from a correctly typed neuron
49    pub fn from_typed_neuron(
50        neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>,
51    ) -> Arc<dyn NeuronErased + Send + Sync + 'static> {
52        Arc::new(Self::new(neuron))
53    }
54
55    /// Get the underlying typed neuron if the types match
56    pub fn get_typed_neuron(&self) -> Arc<dyn Neuron<T, C> + Send + Sync + 'static> {
57        self.neuron.clone()
58    }
59}
60
61impl<T, C> NeuronErased for NeuronErasedWrapper<T, C>
62where
63    T: Send + Sync + 'static,
64    C: Codec<T> + CodecName + Send + Sync + 'static,
65{
66    fn name(&self) -> String {
67        self.neuron.name()
68    }
69
70    fn name_without_codec(&self) -> String {
71        self.neuron.name_without_codec()
72    }
73
74    fn schema(&self) -> String {
75        self.neuron.schema()
76    }
77
78    fn payload_type_id(&self) -> TypeId {
79        TypeId::of::<T>()
80    }
81
82    fn codec_type_id(&self) -> TypeId {
83        TypeId::of::<C>()
84    }
85
86    fn clone_to_box(&self) -> Box<dyn NeuronErased + Send + Sync + 'static> {
87        Box::new(NeuronErasedWrapper {
88            neuron: self.neuron.clone_to_arc(),
89            _phantom: PhantomData,
90        })
91    }
92
93    fn clone_to_arc(&self) -> Arc<dyn NeuronErased + Send + Sync + 'static> {
94        Arc::new(NeuronErasedWrapper {
95            neuron: self.neuron.clone_to_arc(),
96            _phantom: PhantomData,
97        })
98    }
99
100    fn encode_any(&self, data: &dyn std::any::Any) -> Result<Vec<u8>, NeuronError> {
101        // Try to downcast the Any to our specific type T
102        if let Some(typed_data) = data.downcast_ref::<T>() {
103            self.neuron.encode(typed_data)
104        } else {
105            Err(NeuronError::Encode {
106                neuron_name: self.neuron.name(),
107                message: format!(
108                    "Type mismatch: expected type with TypeId {:?}",
109                    TypeId::of::<T>()
110                ),
111            })
112        }
113    }
114
115    fn decode_any(&self, data: &[u8]) -> Result<Box<dyn std::any::Any>, NeuronError> {
116        let decoded = self.neuron.decode(data)?;
117        Ok(Box::new(decoded) as Box<dyn std::any::Any>)
118    }
119
120    fn as_any(&self) -> &dyn Any {
121        self
122    }
123}
124
125impl<T, C> NeuronErasedWrapper<T, C>
126where
127    T: Send + Sync + 'static,
128    C: Codec<T> + CodecName + Send + Sync + 'static,
129{
130    /// Convert this wrapper to a correctly typed neuron with different type parameters
131    /// Returns an error if the type parameters don't match
132    pub fn to_typed_neuron<U, D>(
133        &self,
134    ) -> ErasureResult<Arc<dyn Neuron<U, D> + Send + Sync + 'static>>
135    where
136        U: Send + Sync + 'static,
137        D: Codec<U> + CodecName + Send + Sync + 'static,
138    {
139        // Check if the type parameters match using safe downcasting
140        if let Some(wrapper) = (self as &dyn Any).downcast_ref::<NeuronErasedWrapper<U, D>>() {
141            Ok(wrapper.neuron.clone())
142        } else {
143            Err(ErasureError::NeuronTypeMismatch {
144                expected_payload_type: TypeId::of::<U>(),
145                expected_codec_type: TypeId::of::<D>(),
146                actual_payload_type: TypeId::of::<T>(),
147                actual_codec_type: TypeId::of::<C>(),
148            })
149        }
150    }
151}
152
153/// Convenience function to create a type-erased neuron from a correctly typed neuron
154pub fn erase_neuron<T, C>(
155    neuron: Arc<dyn Neuron<T, C> + Send + Sync + 'static>,
156) -> Arc<dyn NeuronErased + Send + Sync + 'static>
157where
158    T: Send + Sync + 'static,
159    C: Codec<T> + CodecName + Send + Sync + 'static,
160{
161    NeuronErasedWrapper::from_typed_neuron(neuron)
162}
163
164/// Convenience function to convert a type-erased neuron wrapper back to a correctly typed neuron
165/// This function works directly with the concrete wrapper type
166pub fn unerase_neuron<T, C>(
167    wrapper: &NeuronErasedWrapper<T, C>,
168) -> Arc<dyn Neuron<T, C> + Send + Sync + 'static>
169where
170    T: Send + Sync + 'static,
171    C: Codec<T> + CodecName + Send + Sync + 'static,
172{
173    wrapper.get_typed_neuron()
174}