use crate::backpressure::{BackpressureConfig, BackpressureQueue};
use crate::codec::{Codec, CodecName};
use crate::logging::LogTrace;
use crate::neuron::Neuron;
use crate::payload::{Payload, PayloadRaw};
use crate::reactant::{ErrorReactant, Reactant, ReactantRaw};
use crate::synapse::SynapseError;
use futures_util::future::join_all;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use parking_lot::RwLock;
use thiserror::Error;
use tracing::Instrument;
#[derive(Error, Debug)]
pub enum DendriteError {
#[error("Failed to acquire read lock on reactants for neuron '{neuron_name}'")]
ReactantsReadLock { neuron_name: String },
#[error("Failed to acquire write lock on reactants for neuron '{neuron_name}'")]
ReactantsWriteLock { neuron_name: String },
#[error("Failed to acquire read lock on raw reactants for neuron '{neuron_name}'")]
RawReactantsReadLock { neuron_name: String },
#[error("Failed to acquire write lock on raw reactants for neuron '{neuron_name}'")]
RawReactantsWriteLock { neuron_name: String },
#[error("Failed to acquire read lock on error reactants for neuron '{neuron_name}'")]
ErrorReactantsReadLock { neuron_name: String },
#[error("Failed to acquire write lock on error reactants for neuron '{neuron_name}'")]
ErrorReactantsWriteLock { neuron_name: String },
#[error("Internal error: {0}")]
Other(String),
}
pub struct Dendrite<T, C>
where
C: Codec<T> + CodecName + Send + Sync + 'static,
T: Sync + Send + 'static,
{
_neuron: Arc<dyn Neuron<T, C> + Send + Sync>,
reactants: RwLock<Vec<Arc<dyn Reactant<T, C> + Send + Sync>>>,
error_reactants: RwLock<Vec<Arc<dyn ErrorReactant<T, C> + Send + Sync>>>,
_codec_marker: PhantomData<fn() -> &'static ()>,
_phantom_t: PhantomData<T>,
}
impl<T, C> Dendrite<T, C>
where
C: Codec<T> + CodecName + Send + Sync + 'static,
T: Sync + Send + 'static,
{
#[must_use]
pub fn new(
neuron: Arc<dyn Neuron<T, C> + Send + Sync>,
reactants: Vec<Arc<dyn Reactant<T, C> + Send + Sync>>,
error_reactants: Vec<Arc<dyn ErrorReactant<T, C> + Send + Sync>>,
) -> Self {
Self {
_neuron: neuron,
reactants: RwLock::new(reactants),
error_reactants: RwLock::new(error_reactants),
_codec_marker: PhantomData,
_phantom_t: PhantomData,
}
}
pub fn add_reactants(
&self,
reactants: Vec<Arc<dyn Reactant<T, C> + Send + Sync>>,
) -> Result<(), DendriteError> {
if !reactants.is_empty() {
let mut write_guard = self.reactants.write();
write_guard.extend(reactants);
}
Ok(())
}
pub fn add_error_reactants(
&self,
error_reactants: Vec<Arc<dyn ErrorReactant<T, C> + Send + Sync>>,
) -> Result<(), DendriteError> {
if !error_reactants.is_empty() {
let mut write_guard = self.error_reactants.write();
write_guard.extend(error_reactants);
}
Ok(())
}
pub fn transduce(
&self,
payload: Arc<Payload<T, C>>,
) -> Pin<Box<dyn Future<Output = Result<Vec<()>, DendriteError>> + Send + 'static>> {
tracing::debug!("Dendrite::transduce called");
let reactants_clone = {
let guard = self.reactants.read();
if guard.is_empty() {
tracing::debug!("Dendrite::transduce no reactants, returning empty vec");
return Box::pin(async move { Ok(vec![]) });
}
guard.clone()
};
let error_reactants_clone = {
let guard = self.error_reactants.read();
guard.clone()
};
let payload_clone = payload.clone();
tracing::debug!(
"Dendrite::transduce - Cloned {} reactants",
reactants_clone.len()
);
Box::pin(
async move {
tracing::debug!(
"Dendrite::transduce creating futures for {} reactants",
reactants_clone.len()
);
let futures = reactants_clone
.iter()
.map(|reactant| reactant.react(payload_clone.clone()))
.collect::<Vec<_>>();
tracing::debug!(
"Dendrite::transduce awaiting join_all of {} futures",
futures.len()
);
let results = join_all(futures).await;
tracing::debug!("Dendrite::transduce join_all completed");
let mut errors = Vec::new();
let successes: Vec<()> = results
.into_iter()
.filter_map(|r| match r {
Ok(_) => Some(()),
Err(e) => {
errors.push(e);
None
}
})
.collect();
if !errors.is_empty() && !error_reactants_clone.is_empty() {
let error_futures = errors.into_iter().flat_map(|err| {
let err_arc = Arc::new(err);
let p = payload_clone.clone();
error_reactants_clone
.iter()
.map(move |er| er.react_error(err_arc.clone(), p.clone()))
});
join_all(error_futures).await;
} else if !errors.is_empty() {
for e in errors {
tracing::error!("Reactant error: {e}");
}
}
Ok(successes)
}
.instrument(payload.span_debug("Dendrite::transduce")),
)
}
}
#[allow(clippy::type_complexity)]
pub struct DendriteDecoder<T, C>
where
C: Codec<T> + CodecName + Send + Sync + 'static,
T: Send + Sync + 'static,
{
reactants: Arc<RwLock<Vec<Arc<dyn Reactant<T, C> + Send + Sync>>>>,
raw_reactants: Arc<RwLock<Vec<Arc<dyn ReactantRaw<T, C> + Send + Sync>>>>,
error_reactants: Arc<RwLock<Vec<Arc<dyn ErrorReactant<T, C> + Send + Sync>>>>,
ingress_queue: RwLock<Option<Arc<BackpressureQueue<Arc<PayloadRaw<T, C>>>>>>,
_codec_marker: PhantomData<fn() -> &'static ()>,
_phantom_t: PhantomData<T>,
}
impl<T, C> DendriteDecoder<T, C>
where
C: Codec<T> + CodecName + Send + Sync + 'static,
T: Send + Sync + 'static,
{
#[must_use]
pub fn new(
neuron: Arc<dyn Neuron<T, C> + Send + Sync>,
reactants: Vec<Arc<dyn Reactant<T, C> + Send + Sync>>,
raw_reactants: Vec<Arc<dyn ReactantRaw<T, C> + Send + Sync>>,
error_reactants: Vec<Arc<dyn ErrorReactant<T, C> + Send + Sync>>,
backpressure: Option<BackpressureConfig>,
) -> Self {
let reactants = Arc::new(RwLock::new(reactants));
let raw_reactants = Arc::new(RwLock::new(raw_reactants));
let error_reactants = Arc::new(RwLock::new(error_reactants));
let reactants_clone = reactants.clone();
let raw_reactants_clone = raw_reactants.clone();
let error_reactants_clone = error_reactants.clone();
let ingress_queue = BackpressureQueue::new(
neuron.name(),
backpressure.unwrap_or_default(),
move |payload: Arc<PayloadRaw<T, C>>| {
Self::process_ingress(
reactants_clone.clone(),
raw_reactants_clone.clone(),
error_reactants_clone.clone(),
payload,
)
},
);
Self {
reactants,
raw_reactants,
error_reactants,
ingress_queue: RwLock::new(Some(Arc::new(ingress_queue))),
_codec_marker: PhantomData,
_phantom_t: PhantomData,
}
}
#[allow(clippy::type_complexity)]
async fn process_ingress(
rs: Arc<RwLock<Vec<Arc<dyn Reactant<T, C> + Send + Sync>>>>,
rrs: Arc<RwLock<Vec<Arc<dyn ReactantRaw<T, C> + Send + Sync>>>>,
ers: Arc<RwLock<Vec<Arc<dyn ErrorReactant<T, C> + Send + Sync>>>>,
payload: Arc<PayloadRaw<T, C>>,
) {
let span = payload.span_debug("DendriteDecoder::process_ingress");
async move {
let neuron = payload.neuron.clone();
let decoded_value = match neuron.decode(&payload.value) {
Ok(value) => value,
Err(_) => return,
};
let decoded_payload = Arc::new(Payload::from_parts(
Arc::new(decoded_value),
neuron.clone(),
payload.trace,
));
let reactants_vec: Vec<_> = {
let guard = rs.read();
guard.iter().cloned().collect()
};
let raw_reactants_vec: Vec<_> = {
let guard = rrs.read();
guard.iter().cloned().collect()
};
let error_reactants_vec: Vec<_> = {
let guard = ers.read();
guard.iter().cloned().collect()
};
if reactants_vec.is_empty() && raw_reactants_vec.is_empty() {
return;
}
let decoded_futures = reactants_vec
.iter()
.map(|reactant| reactant.react(decoded_payload.clone()));
let raw_futures = raw_reactants_vec
.iter()
.map(|raw_reactant| raw_reactant.react(payload.clone()));
let (decoded_results, raw_results) =
futures_util::future::join(join_all(decoded_futures), join_all(raw_futures)).await;
let mut errors = Vec::new();
for res in decoded_results {
if let Err(e) = res {
errors.push(e);
}
}
for res in raw_results {
if let Err(e) = res {
errors.push(e);
}
}
if !errors.is_empty() && !error_reactants_vec.is_empty() {
let error_futures = errors.into_iter().flat_map(|err| {
let err_arc = Arc::new(err);
let p = decoded_payload.clone();
error_reactants_vec
.iter()
.map(move |er| er.react_error(err_arc.clone(), p.clone()))
});
join_all(error_futures).await;
} else if !errors.is_empty() {
for e in errors {
tracing::error!("Reactant error: {e}");
}
}
}
.instrument(span)
.await
}
pub fn add_reactants(
&self,
reactants: Vec<Arc<dyn Reactant<T, C> + Send + Sync>>,
) -> Result<(), DendriteError> {
if !reactants.is_empty() {
let mut write_guard = self.reactants.write();
write_guard.extend(reactants);
}
Ok(())
}
pub fn add_raw_reactants(
&self,
raw_reactants: Vec<Arc<dyn ReactantRaw<T, C> + Send + Sync>>,
) -> Result<(), DendriteError> {
if !raw_reactants.is_empty() {
let mut write_guard = self.raw_reactants.write();
write_guard.extend(raw_reactants);
}
Ok(())
}
pub fn add_error_reactants(
&self,
error_reactants: Vec<Arc<dyn ErrorReactant<T, C> + Send + Sync>>,
) -> Result<(), DendriteError> {
if !error_reactants.is_empty() {
let mut write_guard = self.error_reactants.write();
write_guard.extend(error_reactants);
}
Ok(())
}
#[allow(clippy::type_complexity)]
pub fn transduce(
&self,
payload: Arc<PayloadRaw<T, C>>,
) -> Pin<Box<dyn Future<Output = Result<(Vec<()>, Vec<()>), DendriteError>> + Send + 'static>>
{
let queue_lock = self.ingress_queue.read();
let queue: Arc<BackpressureQueue<Arc<PayloadRaw<T, C>>>> = match &*queue_lock {
Some(q) => q.clone(),
None => {
return Box::pin(async move { Ok((vec![], vec![])) });
}
};
drop(queue_lock);
Box::pin(async move {
queue.push(payload).await.map_err(|e| match e {
SynapseError::QueueFull { neuron_name } => {
DendriteError::ReactantsWriteLock { neuron_name }
}
_ => DendriteError::ReactantsReadLock {
neuron_name: "unknown".to_string(),
},
})?;
Ok((vec![], vec![]))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::neuron::NeuronImpl;
use crate::test_utils::{
DebugCodec, DebugStruct, TokioMpscReactant, TokioMpscReactantRaw, test_namespace,
};
use std::thread;
use tokio::sync::mpsc::channel;
use uuid::Uuid;
#[tokio::test]
async fn test_dendrite_transduce() {
let ns = test_namespace();
let neuron_impl: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync + '_> =
Arc::new(neuron_impl);
let (tx, mut rx) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactant { sender: tx })];
let dendrite = Dendrite::new(neuron_arc.clone(), reactants, vec![]);
let debug_struct_val = DebugStruct {
foo: 42,
bar: "test_value".to_owned(),
};
let uuid = Uuid::now_v7();
let _ = dendrite
.transduce(Payload::with_correlation(
debug_struct_val.clone(),
neuron_arc.clone(),
Some(uuid),
))
.await;
assert_eq!(rx.len(), 1);
let p = rx.recv().await.unwrap();
assert_eq!(*p.value, debug_struct_val);
assert_eq!(p.correlation_id(), uuid);
}
#[tokio::test]
async fn test_dendrite_multiple_reactants() {
let ns = test_namespace();
let neuron_impl: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync + '_> =
Arc::new(neuron_impl);
let (tx1, mut rx1) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let (tx2, mut rx2) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> = vec![
Arc::new(TokioMpscReactant { sender: tx1 }),
Arc::new(TokioMpscReactant { sender: tx2 }),
];
let dendrite = Dendrite::new(neuron_arc.clone(), reactants, vec![]);
let debug_struct_val = DebugStruct {
foo: 100,
bar: "multi_test".to_owned(),
};
let uuid = Uuid::now_v7();
let payload_value = Arc::new(debug_struct_val.clone());
let _ = dendrite
.transduce(Payload::with_correlation(
debug_struct_val.clone(),
neuron_arc.clone(),
Some(uuid),
))
.await;
assert_eq!(rx1.len(), 1);
let p1 = rx1.recv().await.unwrap();
assert_eq!(p1.value, payload_value);
assert_eq!(p1.correlation_id(), uuid);
assert_eq!(rx2.len(), 1);
let p2 = rx2.recv().await.unwrap();
assert_eq!(p2.value, payload_value);
assert_eq!(p2.correlation_id(), uuid);
}
#[tokio::test]
async fn test_decoder_dendrite_transduce() {
let ns = test_namespace();
let neuron_impl_for_encoding: NeuronImpl<DebugStruct, DebugCodec> =
NeuronImpl::new(ns.clone());
let neuron_arc_for_dendrite: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync + '_> =
Arc::new(NeuronImpl::new(ns.clone()));
let (tx, mut rx) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let (tx_raw, mut rx_raw) = channel::<Arc<PayloadRaw<DebugStruct, DebugCodec>>>(1);
let reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactant { sender: tx })];
let raw_reactants: Vec<Arc<dyn ReactantRaw<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactantRaw { sender: tx_raw })];
let dendrite_decoder = DendriteDecoder::new(
neuron_arc_for_dendrite.clone(),
reactants,
raw_reactants,
vec![],
None,
);
let uuid = Uuid::now_v7();
let debug_struct_val = DebugStruct {
foo: 49,
bar: "foo_bar".to_owned(),
};
let encoded = neuron_impl_for_encoding
.encode(&debug_struct_val)
.expect("Encoding should succeed in test");
let _ = dendrite_decoder
.transduce(PayloadRaw::with_correlation(
encoded.clone(),
neuron_arc_for_dendrite.clone(),
Some(uuid),
))
.await;
let p = tokio::time::timeout(std::time::Duration::from_millis(100), rx.recv())
.await
.expect("Timeout waiting for decoded message")
.expect("Channel closed");
assert_eq!(*p.value, debug_struct_val);
assert_eq!(p.correlation_id(), uuid);
let p2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx_raw.recv())
.await
.expect("Timeout waiting for raw message")
.expect("Channel closed");
assert_eq!(p2.value.as_slice(), encoded.as_slice());
assert_eq!(p2.correlation_id(), uuid);
}
#[tokio::test]
async fn test_dendrite_add_reactants() {
let ns = test_namespace();
let neuron_impl: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync + '_> =
Arc::new(neuron_impl);
let (tx1, mut rx1) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let initial_reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactant { sender: tx1 })];
let dendrite = Dendrite::new(neuron_arc.clone(), initial_reactants, vec![]);
let (tx2, mut rx2) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let additional_reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactant { sender: tx2 })];
let _ = dendrite.add_reactants(additional_reactants);
let debug_struct_val = DebugStruct {
foo: 42,
bar: "test_add_reactants".to_owned(),
};
let uuid = Uuid::now_v7();
let payload =
Payload::with_correlation(debug_struct_val.clone(), neuron_arc.clone(), Some(uuid));
let _ = dendrite.transduce(payload.clone()).await;
assert_eq!(rx1.len(), 1);
let p1 = rx1.recv().await.unwrap();
assert_eq!(*p1.value, debug_struct_val);
assert_eq!(p1.correlation_id(), uuid);
assert_eq!(rx2.len(), 1);
let p2 = rx2.recv().await.unwrap();
assert_eq!(*p2.value, debug_struct_val);
assert_eq!(p2.correlation_id(), uuid);
}
#[tokio::test]
async fn test_dendrite_decoder_add_reactants() {
let ns = test_namespace();
let neuron_impl: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync + '_> =
Arc::new(neuron_impl);
let (tx1, mut rx1) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let initial_reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactant { sender: tx1 })];
let (tx_raw1, mut rx_raw1) = channel::<Arc<PayloadRaw<DebugStruct, DebugCodec>>>(1);
let initial_raw_reactants: Vec<Arc<dyn ReactantRaw<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactantRaw { sender: tx_raw1 })];
let dendrite_decoder = DendriteDecoder::new(
neuron_arc.clone(),
initial_reactants,
initial_raw_reactants,
vec![],
None,
);
let (tx2, mut rx2) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(1);
let additional_reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactant { sender: tx2 })];
let (tx_raw2, mut rx_raw2) = channel::<Arc<PayloadRaw<DebugStruct, DebugCodec>>>(1);
let additional_raw_reactants: Vec<Arc<dyn ReactantRaw<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactantRaw { sender: tx_raw2 })];
let _ = dendrite_decoder.add_reactants(additional_reactants);
let _ = dendrite_decoder.add_raw_reactants(additional_raw_reactants);
let debug_struct_val = DebugStruct {
foo: 42,
bar: "test_add_reactants_decoder".to_owned(),
};
let uuid = Uuid::now_v7();
let encoded = neuron_arc
.encode(&debug_struct_val)
.expect("Encoding should succeed in test");
let payload_raw =
PayloadRaw::with_correlation(encoded.clone(), neuron_arc.clone(), Some(uuid));
let _ = dendrite_decoder.transduce(payload_raw.clone()).await;
let p1 = tokio::time::timeout(std::time::Duration::from_millis(100), rx1.recv())
.await
.expect("Timeout rx1")
.expect("Closed rx1");
assert_eq!(*p1.value, debug_struct_val);
assert_eq!(p1.correlation_id(), uuid);
let p2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx2.recv())
.await
.expect("Timeout rx2")
.expect("Closed rx2");
assert_eq!(*p2.value, debug_struct_val);
assert_eq!(p2.correlation_id(), uuid);
let p_raw1 = tokio::time::timeout(std::time::Duration::from_millis(100), rx_raw1.recv())
.await
.expect("Timeout raw_rx1")
.expect("Closed raw_rx1");
assert_eq!(p_raw1.value.as_slice(), encoded.as_slice());
assert_eq!(p_raw1.correlation_id(), uuid);
let p_raw2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx_raw2.recv())
.await
.expect("Timeout raw_rx2")
.expect("Closed raw_rx2");
assert_eq!(p_raw2.value.as_slice(), encoded.as_slice());
assert_eq!(p_raw2.correlation_id(), uuid);
}
#[tokio::test]
async fn test_dendrite_concurrent_readers() {
let ns = test_namespace();
let neuron_impl: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync + '_> =
Arc::new(neuron_impl);
let (tx1, mut rx1) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
let initial_reactants: Vec<Arc<dyn Reactant<DebugStruct, DebugCodec> + Send + Sync>> =
vec![Arc::new(TokioMpscReactant { sender: tx1 })];
let dendrite = Arc::new(Dendrite::new(neuron_arc.clone(), initial_reactants, vec![]));
let debug_struct_val = DebugStruct {
foo: 42,
bar: "test_concurrent_readers".to_owned(),
};
let uuid = Uuid::now_v7();
let payload =
Payload::with_correlation(debug_struct_val.clone(), neuron_arc.clone(), Some(uuid));
let num_threads = 5;
let mut handles = vec![];
for _ in 0..num_threads {
let dendrite_clone = dendrite.clone();
let payload_clone = payload.clone();
let handle = thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let _ = dendrite_clone.transduce(payload_clone).await;
});
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
for _ in 0..num_threads {
let p = rx1.recv().await.unwrap();
assert_eq!(*p.value, debug_struct_val);
assert_eq!(p.correlation_id(), uuid);
}
}
}