use crate::codec::{Codec, CodecName};
use crate::erasure::payload::PayloadErased;
use crate::erasure::reactant::{ErrorReactantErased, ReactantErased};
use crate::ganglion::{Ganglion, GanglionError, GanglionInternal};
use crate::neuron::Neuron;
use crate::utils::struct_name_of_type;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
pub struct Thalamus<G>
where
G: GanglionInternal + Ganglion + Send + Sync + 'static,
{
id: Uuid,
peers: Vec<Arc<Mutex<G>>>,
}
impl<G> Thalamus<G>
where
G: GanglionInternal + Ganglion + Send + Sync + 'static,
{
pub fn new(peers: Vec<Arc<Mutex<G>>>) -> Self {
Self {
id: Uuid::now_v7(),
peers,
}
}
}
impl<G> Ganglion for Thalamus<G>
where
G: GanglionInternal + Ganglion + Send + Sync + 'static,
{
fn capable<T, C>(&mut self, neuron: Arc<dyn Neuron<T, C> + Send + Sync>) -> bool
where
C: Codec<T> + CodecName + Send + Sync + 'static,
T: Send + Sync + 'static,
{
for peer in self.peers.iter() {
if let Some(mut p) = peer.try_lock().ok() {
if p.capable(neuron.clone()) {
return true;
}
}
}
false
}
fn adapt<T, C>(
&mut self,
neuron: Arc<dyn Neuron<T, C> + Send + Sync>,
) -> Pin<Box<dyn Future<Output = Result<(), GanglionError>> + Send + 'static>>
where
C: Codec<T> + CodecName + Send + Sync + 'static,
T: Send + Sync + 'static,
{
let peers = self.peers.clone();
Box::pin(async move {
for peer in peers.iter() {
let mut p = peer.lock().await;
p.adapt(neuron.clone()).await?;
}
Ok(())
})
}
}
impl<G> GanglionInternal for Thalamus<G>
where
G: GanglionInternal + Ganglion + Send + Sync + 'static,
{
fn transmit(
&mut self,
payload: Arc<dyn PayloadErased + Send + Sync + 'static>,
) -> Pin<Box<dyn Future<Output = Result<Vec<()>, GanglionError>> + Send + 'static>> {
let peers = self.peers.clone();
let neuron_name = payload.get_neuron_name();
Box::pin(async move {
if peers.is_empty() {
return Err(GanglionError::SynapseNotFound {
neuron_name,
ganglion_name: struct_name_of_type::<Self>().to_string(),
ganglion_id: Uuid::nil(),
});
}
let mut all_results = Vec::new();
let mut last_err = None;
let mut at_least_one_ok = false;
for peer in peers.iter() {
let future = {
let mut p = peer.lock().await;
p.transmit(payload.clone())
};
match future.await {
Ok(mut results) => {
all_results.append(&mut results);
at_least_one_ok = true;
}
Err(e) => {
last_err = Some(e);
}
}
}
if at_least_one_ok {
Ok(all_results)
} else {
Err(last_err.unwrap_or(GanglionError::SynapseNotFound {
neuron_name,
ganglion_name: struct_name_of_type::<Self>().to_string(),
ganglion_id: Uuid::nil(),
}))
}
})
}
fn react(
&mut self,
neuron_name: String,
reactants: Vec<Arc<dyn ReactantErased + Send + Sync + 'static>>,
error_reactants: Vec<Arc<dyn ErrorReactantErased + Send + Sync>>,
) -> Pin<Box<dyn Future<Output = Result<(), GanglionError>> + Send + 'static>> {
let peers = self.peers.clone();
Box::pin(async move {
if peers.is_empty() {
return Err(GanglionError::SynapseNotFound {
neuron_name,
ganglion_name: struct_name_of_type::<Self>().to_string(),
ganglion_id: Uuid::nil(),
});
}
let mut at_least_one_ok = false;
let mut last_err: Option<GanglionError> = None;
for peer in peers.iter() {
let future = {
let mut p = peer.lock().await;
p.react(neuron_name.clone(), reactants.clone(), error_reactants.clone())
};
match future.await {
Ok(()) => at_least_one_ok = true,
Err(e) => last_err = Some(e),
}
}
if at_least_one_ok {
Ok(())
} else {
Err(last_err.unwrap_or(GanglionError::SynapseNotFound {
neuron_name,
ganglion_name: struct_name_of_type::<Self>().to_string(),
ganglion_id: Uuid::nil(),
}))
}
})
}
fn react_many(
&mut self,
reactions: HashMap<
String,
(
Vec<Arc<dyn ReactantErased + Send + Sync + 'static>>,
Vec<Arc<dyn ErrorReactantErased + Send + Sync>>,
),
>,
) -> Pin<Box<dyn Future<Output = Result<(), GanglionError>> + Send + 'static>> {
let peers = self.peers.clone();
Box::pin(async move {
if peers.is_empty() {
return Err(GanglionError::SynapseNotFound {
neuron_name: "batch".to_string(),
ganglion_name: struct_name_of_type::<Self>().to_string(),
ganglion_id: Uuid::nil(),
});
}
let mut at_least_one_ok = false;
let mut last_err: Option<GanglionError> = None;
for peer in peers.iter() {
let future = {
let mut p = peer.lock().await;
p.react_many(reactions.clone())
};
match future.await {
Ok(()) => at_least_one_ok = true,
Err(e) => last_err = Some(e),
}
}
if at_least_one_ok {
Ok(())
} else {
Err(last_err.unwrap_or(GanglionError::SynapseNotFound {
neuron_name: "batch".to_string(),
ganglion_name: struct_name_of_type::<Self>().to_string(),
ganglion_id: Uuid::nil(),
}))
}
})
}
fn unique_id(&self) -> Uuid {
self.id
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::erasure::payload::erase_payload;
use crate::erasure::reactant::erase_reactant;
use crate::ganglion::GanglionInprocess;
use crate::logging::TraceContext;
use crate::neuron::NeuronImpl;
use crate::payload::Payload;
use crate::reactant::Reactant;
use crate::test_utils::{
DebugCodec, DebugStruct, ResponseCodec, ResponseStruct, TokioMpscReactant, test_namespace,
};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::mpsc::channel;
use uuid::Uuid;
#[tokio::test]
async fn test_thalamus_broadcast_basic() {
let ns = test_namespace();
let neuron: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let neuron_name = neuron.name();
let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync> = Arc::new(neuron);
let g1 = Arc::new(Mutex::new(GanglionInprocess::new()));
let g2 = Arc::new(Mutex::new(GanglionInprocess::new()));
let mut thalamus = Thalamus::new(vec![g1.clone(), g2.clone()]);
thalamus
.adapt::<DebugStruct, DebugCodec>(neuron_arc.clone())
.await
.unwrap();
let (tx, mut rx) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
let reactants = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
TokioMpscReactant::new(tx),
))];
thalamus
.react(neuron_name.clone(), reactants, vec![])
.await
.unwrap();
let correlation_id = Uuid::now_v7();
let span_id = correlation_id.as_u128() as u64;
let payload1 = Arc::new(Payload::from_parts(
Arc::new(DebugStruct {
foo: 1,
bar: "a".to_string(),
}),
neuron_arc.clone(),
TraceContext::from_parts(correlation_id, span_id, None),
));
let payload2 = Payload::new(
DebugStruct {
foo: 2,
bar: "b".to_string(),
},
neuron_arc.clone(),
);
thalamus.transmit(erase_payload(payload1)).await.unwrap();
thalamus.transmit(erase_payload(payload2)).await.unwrap();
let _m1 = rx.recv().await.expect("expected first message");
let _m2 = rx.recv().await.expect("expected second message");
let _m3 = rx.recv().await.expect("expected third message");
let _m4 = rx.recv().await.expect("expected fourth message");
}
#[tokio::test]
async fn test_thalamus_broadcast_work_distribution() {
let ns = test_namespace();
let neuron: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let neuron_name = neuron.name();
let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync> = Arc::new(neuron);
let g1 = Arc::new(Mutex::new(GanglionInprocess::new()));
let g2 = Arc::new(Mutex::new(GanglionInprocess::new()));
let g3 = Arc::new(Mutex::new(GanglionInprocess::new()));
let mut thalamus = Thalamus::new(vec![g1.clone(), g2.clone(), g3.clone()]);
thalamus
.adapt::<DebugStruct, DebugCodec>(neuron_arc.clone())
.await
.unwrap();
let (tx1, rx1) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
let (tx2, rx2) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
let (tx3, rx3) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
{
let mut g1_guard = g1.lock().await;
let reactants1 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
TokioMpscReactant::new(tx1),
))];
g1_guard
.react(neuron_name.clone(), reactants1, vec![])
.await
.unwrap();
}
{
let mut g2_guard = g2.lock().await;
let reactants2 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
TokioMpscReactant::new(tx2),
))];
g2_guard
.react(neuron_name.clone(), reactants2, vec![])
.await
.unwrap();
}
{
let mut g3_guard = g3.lock().await;
let reactants3 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
TokioMpscReactant::new(tx3),
))];
g3_guard
.react(neuron_name.clone(), reactants3, vec![])
.await
.unwrap();
}
for i in 0..2 {
let test_data = DebugStruct {
foo: i,
bar: format!("msg{i}"),
};
thalamus
.transmit(erase_payload(Payload::new(test_data, neuron_arc.clone())))
.await
.expect("Failed to transmit");
}
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
assert_eq!(rx1.len(), 2, "Ganglion 1 should receive 2 messages");
assert_eq!(rx2.len(), 2, "Ganglion 2 should receive 2 messages");
assert_eq!(rx3.len(), 2, "Ganglion 3 should receive 2 messages");
}
#[tokio::test]
async fn test_thalamus_work_distribution_with_responses() {
let ns = test_namespace();
let request_neuron: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
let request_neuron_name = request_neuron.name();
let request_neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync> =
Arc::new(request_neuron);
let g1 = Arc::new(Mutex::new(GanglionInprocess::new()));
let g2 = Arc::new(Mutex::new(GanglionInprocess::new()));
let g3 = Arc::new(Mutex::new(GanglionInprocess::new()));
let mut thalamus = Thalamus::new(vec![g1.clone(), g2.clone(), g3.clone()]);
let response_neuron: NeuronImpl<ResponseStruct, ResponseCodec> =
NeuronImpl::new(ns.clone());
let response_neuron_name = response_neuron.name();
let response_neuron_arc: Arc<dyn Neuron<ResponseStruct, ResponseCodec> + Send + Sync> =
Arc::new(response_neuron);
thalamus
.adapt::<DebugStruct, DebugCodec>(request_neuron_arc.clone())
.await
.unwrap();
thalamus
.adapt::<ResponseStruct, ResponseCodec>(response_neuron_arc.clone())
.await
.unwrap();
let (response_tx, mut response_rx) =
channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(20);
#[derive(Clone)]
struct ResponseCaptureReactant {
sender: tokio::sync::mpsc::Sender<Arc<Payload<ResponseStruct, ResponseCodec>>>,
}
impl ResponseCaptureReactant {
fn new(
sender: tokio::sync::mpsc::Sender<Arc<Payload<ResponseStruct, ResponseCodec>>>,
) -> Self {
Self { sender }
}
}
impl Reactant<ResponseStruct, ResponseCodec> for ResponseCaptureReactant {
fn react(
&self,
payload: Arc<Payload<ResponseStruct, ResponseCodec>>,
) -> Pin<
Box<
dyn Future<Output = Result<(), crate::reactant::ReactantError>>
+ Send
+ 'static,
>,
> {
let sender = self.sender.clone();
let payload_clone = payload.clone();
Box::pin(async move {
let _ = sender.try_send(payload_clone);
Ok(())
})
}
fn erase(self: Box<Self>) -> Arc<dyn ReactantErased + Send + Sync + 'static> {
erase_reactant(self)
}
}
let response_capture_reactant = ResponseCaptureReactant::new(response_tx.clone());
#[derive(Clone)]
struct ResponseGeneratingReactant {
ganglion_id: u32,
response_neuron: Arc<dyn Neuron<ResponseStruct, ResponseCodec> + Send + Sync>,
queue_sender: tokio::sync::mpsc::Sender<Arc<Payload<ResponseStruct, ResponseCodec>>>,
}
impl ResponseGeneratingReactant {
fn new(
ganglion_id: u32,
response_neuron: Arc<dyn Neuron<ResponseStruct, ResponseCodec> + Send + Sync>,
queue_sender: tokio::sync::mpsc::Sender<
Arc<Payload<ResponseStruct, ResponseCodec>>,
>,
) -> Self {
Self {
ganglion_id,
response_neuron,
queue_sender,
}
}
}
impl Reactant<DebugStruct, DebugCodec> for ResponseGeneratingReactant {
fn react(
&self,
payload: Arc<Payload<DebugStruct, DebugCodec>>,
) -> Pin<
Box<
dyn Future<Output = Result<(), crate::reactant::ReactantError>>
+ Send
+ 'static,
>,
> {
let ganglion_id = self.ganglion_id;
let response_neuron = self.response_neuron.clone();
let queue_sender = self.queue_sender.clone();
let original_value = payload.value.clone();
Box::pin(async move {
let response_payload = Payload::new(
ResponseStruct {
ganglion_id,
response_message: format!(
"response_from_ganglion_{}_for_{}",
ganglion_id, original_value.bar
),
},
response_neuron,
);
let _ = queue_sender.try_send(response_payload);
Ok(())
})
}
fn erase(self: Box<Self>) -> Arc<dyn ReactantErased + Send + Sync + 'static> {
erase_reactant(self)
}
}
let (queue1_tx, mut queue1_rx) = channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(10);
let (queue2_tx, mut queue2_rx) = channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(10);
let (queue3_tx, mut queue3_rx) = channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(10);
let thalamus_arc = Arc::new(Mutex::new(thalamus));
{
let mut thalamus_guard = thalamus_arc.lock().await;
let response_reactants = vec![erase_reactant::<ResponseStruct, ResponseCodec, _>(
Box::new(response_capture_reactant),
)];
let future =
thalamus_guard.react(response_neuron_name.clone(), response_reactants, vec![]);
drop(thalamus_guard);
future.await.unwrap();
}
let g1_clone = g1.clone();
tokio::spawn(async move {
while let Some(payload) = queue1_rx.recv().await {
let future = {
let mut ganglion_guard = g1_clone.lock().await;
ganglion_guard.transmit(erase_payload(payload))
};
let _ = future.await;
}
});
let g2_clone = g2.clone();
tokio::spawn(async move {
while let Some(payload) = queue2_rx.recv().await {
let future = {
let mut ganglion_guard = g2_clone.lock().await;
ganglion_guard.transmit(erase_payload(payload))
};
let _ = future.await;
}
});
let g3_clone = g3.clone();
tokio::spawn(async move {
while let Some(payload) = queue3_rx.recv().await {
let future = {
let mut ganglion_guard = g3_clone.lock().await;
ganglion_guard.transmit(erase_payload(payload))
};
let _ = future.await;
}
});
{
let mut g1_guard = g1.lock().await;
let reactants1 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
ResponseGeneratingReactant::new(1, response_neuron_arc.clone(), queue1_tx),
))];
let future = g1_guard.react(request_neuron_name.clone(), reactants1, vec![]);
drop(g1_guard);
future.await.unwrap();
}
{
let mut g2_guard = g2.lock().await;
let reactants2 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
ResponseGeneratingReactant::new(2, response_neuron_arc.clone(), queue2_tx),
))];
let future = g2_guard.react(request_neuron_name.clone(), reactants2, vec![]);
drop(g2_guard);
future.await.unwrap();
}
{
let mut g3_guard = g3.lock().await;
let reactants3 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
ResponseGeneratingReactant::new(3, response_neuron_arc.clone(), queue3_tx),
))];
let future = g3_guard.react(request_neuron_name.clone(), reactants3, vec![]);
drop(g3_guard);
future.await.unwrap();
}
{
for i in 0..2 {
let payload = Payload::new(
DebugStruct {
foo: i,
bar: format!("request_{i}"),
},
request_neuron_arc.clone(),
);
let future = {
let mut thalamus_guard = thalamus_arc.lock().await;
thalamus_guard.transmit(erase_payload(payload))
};
future.await.unwrap();
}
}
let mut count_g1 = 0;
let mut count_g2 = 0;
let mut count_g3 = 0;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let mut total_received = 0;
while total_received < 6 && !response_rx.is_empty() {
if let Ok(payload) = response_rx.try_recv() {
match payload.value.ganglion_id {
1 => count_g1 += 1,
2 => count_g2 += 1,
3 => count_g3 += 1,
_ => panic!(
"Unexpected ganglion ID in response: {}",
payload.value.ganglion_id
),
}
total_received += 1;
} else {
break;
}
}
assert_eq!(count_g1, 2, "Should receive 2 responses from ganglion 1");
assert_eq!(count_g2, 2, "Should receive 2 responses from ganglion 2");
assert_eq!(count_g3, 2, "Should receive 2 responses from ganglion 3");
assert_eq!(
count_g1 + count_g2 + count_g3,
6,
"Total responses should be 6"
);
}
}