1use crate::codec::{Codec, CodecName};
8use crate::erasure::payload::PayloadErased;
9use crate::erasure::reactant::{ErrorReactantErased, ReactantErased};
10use crate::ganglion::{Ganglion, GanglionError, GanglionInternal};
11use crate::neuron::Neuron;
12use crate::utils::struct_name_of_type;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::Mutex;
17use uuid::Uuid;
18
19pub struct Thalamus<G>
23where
24 G: GanglionInternal + Ganglion + Send + Sync + 'static,
25{
26 id: Uuid,
27 peers: Vec<Arc<Mutex<G>>>,
28 rr_idx: usize,
29}
30
31impl<G> Thalamus<G>
32where
33 G: GanglionInternal + Ganglion + Send + Sync + 'static,
34{
35 pub fn new(peers: Vec<Arc<Mutex<G>>>) -> Self {
36 Self {
37 id: Uuid::now_v7(),
38 peers,
39 rr_idx: 0,
40 }
41 }
42
43 fn next_index(&mut self) -> usize {
44 if self.peers.is_empty() {
45 return 0;
46 }
47 let idx = self.rr_idx % self.peers.len();
48 self.rr_idx = (idx + 1) % self.peers.len();
49 idx
50 }
51}
52
53impl<G> Ganglion for Thalamus<G>
54where
55 G: GanglionInternal + Ganglion + Send + Sync + 'static,
56{
57 fn capable<T, C>(&mut self, neuron: Arc<dyn Neuron<T, C> + Send + Sync>) -> bool
58 where
59 C: Codec<T> + CodecName + Send + Sync + 'static,
60 T: Send + Sync + 'static,
61 {
62 for peer in &self.peers {
63 if let Ok(mut guard) = peer.try_lock() {
65 if guard.capable::<T, C>(neuron.clone()) {
66 return true;
67 }
68 }
69 }
70 false
71 }
72
73 fn adapt<T, C>(
74 &mut self,
75 neuron: Arc<dyn Neuron<T, C> + Send + Sync>,
76 ) -> Pin<Box<dyn Future<Output = Result<(), GanglionError>> + Send + 'static>>
77 where
78 C: Codec<T> + CodecName + Send + Sync + 'static,
79 T: Send + Sync + 'static,
80 {
81 if !self.capable::<T, C>(neuron.clone()) {
83 return Box::pin(async move { Ok(()) });
84 }
85
86 let peers = self.peers.clone();
87 Box::pin(async move {
88 for peer in peers.iter() {
90 let mut p = peer.lock().await;
91 if p.capable::<T, C>(neuron.clone()) {
92 p.adapt::<T, C>(neuron.clone()).await?;
93 }
94 }
95 Ok(())
96 })
97 }
98}
99
100impl<G> GanglionInternal for Thalamus<G>
101where
102 G: GanglionInternal + Ganglion + Send + Sync + 'static,
103{
104 fn transmit(
105 &mut self,
106 payload: Arc<dyn PayloadErased + Send + Sync + 'static>,
107 ) -> Pin<Box<dyn Future<Output = Result<Vec<()>, GanglionError>> + Send + 'static>> {
108 let peers = self.peers.clone();
109 let start_idx = self.next_index();
110 let neuron_name = payload.get_neuron_name();
111 let ganglion_id = self.id;
112 let ganglion_name = struct_name_of_type::<Self>().to_string();
113
114 Box::pin(async move {
115 if peers.is_empty() {
116 return Err(GanglionError::SynapseNotFound {
117 neuron_name,
118 ganglion_name,
119 ganglion_id,
120 });
121 }
122
123 let mut last_err: Option<GanglionError> = None;
124 for off in 0..peers.len() {
126 let idx = (start_idx + off) % peers.len();
127 let peer = &peers[idx];
128 let future = {
129 let mut p = peer.lock().await;
130 p.transmit(payload.clone())
131 };
132 match future.await {
133 Ok(acks) => {
134 return Ok(acks);
136 }
137 Err(e) => {
138 last_err = Some(e);
140 continue;
141 }
142 }
143 }
144 Err(last_err.unwrap_or(GanglionError::Transmit {
145 neuron_name,
146 ganglion_name,
147 ganglion_id,
148 message: "No peers available or all transmissions failed".to_string(),
149 }))
150 })
151 }
152
153 fn react(
154 &mut self,
155 neuron_name: String,
156 reactants: Vec<Arc<dyn ReactantErased + Send + Sync + 'static>>,
157 error_reactants: Vec<Arc<dyn ErrorReactantErased + Send + Sync>>,
158 ) -> Pin<Box<dyn Future<Output = Result<(), GanglionError>> + Send + 'static>> {
159 let peers = self.peers.clone();
160 Box::pin(async move {
161 if peers.is_empty() {
162 return Err(GanglionError::SynapseNotFound {
163 neuron_name,
164 ganglion_name: struct_name_of_type::<Self>().to_string(),
165 ganglion_id: Uuid::nil(),
166 });
167 }
168
169 let mut at_least_one_ok = false;
170 let mut last_err: Option<GanglionError> = None;
171 for peer in peers.iter() {
172 let future = {
173 let mut p = peer.lock().await;
174 p.react(
175 neuron_name.clone(),
176 reactants.clone(),
177 error_reactants.clone(),
178 )
179 };
180 match future.await {
181 Ok(()) => at_least_one_ok = true,
182 Err(e) => last_err = Some(e),
183 }
184 }
185 if at_least_one_ok {
186 Ok(())
187 } else {
188 Err(last_err.unwrap_or(GanglionError::SynapseNotFound {
189 neuron_name,
190 ganglion_name: struct_name_of_type::<Self>().to_string(),
191 ganglion_id: Uuid::nil(),
192 }))
193 }
194 })
195 }
196
197 fn unique_id(&self) -> Uuid {
198 self.id
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::erasure::payload::erase_payload;
206 use crate::erasure::reactant::erase_reactant;
207 use crate::ganglion::GanglionInprocess;
208 use crate::logging::TraceContext;
209 use crate::neuron::NeuronImpl;
210 use crate::payload::Payload;
211 use crate::reactant::Reactant;
212 use crate::test_utils::{
213 DebugCodec, DebugStruct, ResponseCodec, ResponseStruct, TokioMpscReactant, test_namespace,
214 };
215 use std::sync::Arc;
216 use tokio::sync::Mutex;
217 use tokio::sync::mpsc::channel;
218 use uuid::Uuid;
219
220 #[tokio::test]
221 async fn test_thalamus_round_robin_basic() {
222 let ns = test_namespace();
223 let neuron: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
224 let neuron_name = neuron.name();
225 let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync> = Arc::new(neuron);
226
227 let g1 = Arc::new(Mutex::new(GanglionInprocess::new()));
228 let g2 = Arc::new(Mutex::new(GanglionInprocess::new()));
229
230 let mut thalamus = Thalamus::new(vec![g1.clone(), g2.clone()]);
231 thalamus
233 .adapt::<DebugStruct, DebugCodec>(neuron_arc.clone())
234 .await
235 .unwrap();
236
237 let (tx, mut rx) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
238 let reactants = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
239 TokioMpscReactant::new(tx),
240 ))];
241
242 thalamus
244 .react(neuron_name.clone(), reactants, vec![])
245 .await
246 .unwrap();
247
248 let correlation_id = Uuid::now_v7();
250 let span_id = correlation_id.as_u128() as u64;
251 let payload1 = Arc::new(Payload::from_parts(
252 Arc::new(DebugStruct {
253 foo: 1,
254 bar: "a".to_string(),
255 }),
256 neuron_arc.clone(),
257 TraceContext::from_parts(correlation_id, span_id, None),
258 ));
259 let payload2 = Payload::new(
260 DebugStruct {
261 foo: 2,
262 bar: "b".to_string(),
263 },
264 neuron_arc.clone(),
265 );
266
267 thalamus.transmit(erase_payload(payload1)).await.unwrap();
268 thalamus.transmit(erase_payload(payload2)).await.unwrap();
269
270 let _m1 = rx.recv().await.expect("expected first message");
272 let _m2 = rx.recv().await.expect("expected second message");
273 }
274
275 #[tokio::test]
276 async fn test_thalamus_even_work_distribution() {
277 let ns = test_namespace();
278 let neuron: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
279 let neuron_name = neuron.name();
280 let neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync> = Arc::new(neuron);
281
282 let g1 = Arc::new(Mutex::new(GanglionInprocess::new()));
283 let g2 = Arc::new(Mutex::new(GanglionInprocess::new()));
284 let g3 = Arc::new(Mutex::new(GanglionInprocess::new()));
285
286 let mut thalamus = Thalamus::new(vec![g1.clone(), g2.clone(), g3.clone()]);
287 thalamus
289 .adapt::<DebugStruct, DebugCodec>(neuron_arc.clone())
290 .await
291 .unwrap();
292
293 let (tx1, mut rx1) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
295 let (tx2, mut rx2) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
296 let (tx3, mut rx3) = channel::<Arc<Payload<DebugStruct, DebugCodec>>>(10);
297
298 {
300 let mut g1_guard = g1.lock().await;
301 let reactants1 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
302 TokioMpscReactant::new(tx1),
303 ))];
304 g1_guard
305 .react(neuron_name.clone(), reactants1, vec![])
306 .await
307 .unwrap();
308 }
309 {
310 let mut g2_guard = g2.lock().await;
311 let reactants2 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
312 TokioMpscReactant::new(tx2),
313 ))];
314 g2_guard
315 .react(neuron_name.clone(), reactants2, vec![])
316 .await
317 .unwrap();
318 }
319 {
320 let mut g3_guard = g3.lock().await;
321 let reactants3 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
322 TokioMpscReactant::new(tx3),
323 ))];
324 g3_guard
325 .react(neuron_name.clone(), reactants3, vec![])
326 .await
327 .unwrap();
328 }
329
330 for i in 0..6 {
332 let test_data = DebugStruct {
333 foo: i,
334 bar: format!("msg{i}"),
335 };
336
337 thalamus
338 .transmit(erase_payload(Payload::new(test_data, neuron_arc.clone())))
339 .await
340 .expect("Failed to transmit");
341 }
342
343 let mut count1 = 0;
345 let mut count2 = 0;
346 let mut count3 = 0;
347
348 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
350
351 while rx1.try_recv().is_ok() {
352 count1 += 1;
353 }
354 while rx2.try_recv().is_ok() {
355 count2 += 1;
356 }
357 while rx3.try_recv().is_ok() {
358 count3 += 1;
359 }
360
361 assert_eq!(count1, 2, "Ganglion 1 should receive 2 messages");
363 assert_eq!(count2, 2, "Ganglion 2 should receive 2 messages");
364 assert_eq!(count3, 2, "Ganglion 3 should receive 2 messages");
365 assert_eq!(count1 + count2 + count3, 6, "Total messages should be 6");
366 }
367
368 #[tokio::test]
369 async fn test_thalamus_work_distribution_with_responses() {
370 let ns = test_namespace();
371
372 let request_neuron: NeuronImpl<DebugStruct, DebugCodec> = NeuronImpl::new(ns.clone());
374 let request_neuron_name = request_neuron.name();
375 let request_neuron_arc: Arc<dyn Neuron<DebugStruct, DebugCodec> + Send + Sync> =
376 Arc::new(request_neuron);
377
378 let g1 = Arc::new(Mutex::new(GanglionInprocess::new()));
379 let g2 = Arc::new(Mutex::new(GanglionInprocess::new()));
380 let g3 = Arc::new(Mutex::new(GanglionInprocess::new()));
381
382 let mut thalamus = Thalamus::new(vec![g1.clone(), g2.clone(), g3.clone()]);
383
384 let response_neuron: NeuronImpl<ResponseStruct, ResponseCodec> =
386 NeuronImpl::new(ns.clone());
387 let response_neuron_name = response_neuron.name();
388 let response_neuron_arc: Arc<dyn Neuron<ResponseStruct, ResponseCodec> + Send + Sync> =
389 Arc::new(response_neuron);
390
391 thalamus
393 .adapt::<DebugStruct, DebugCodec>(request_neuron_arc.clone())
394 .await
395 .unwrap();
396 thalamus
397 .adapt::<ResponseStruct, ResponseCodec>(response_neuron_arc.clone())
398 .await
399 .unwrap();
400
401 let (response_tx, mut response_rx) =
403 channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(20);
404
405 #[derive(Clone)]
407 struct ResponseCaptureReactant {
408 sender: tokio::sync::mpsc::Sender<Arc<Payload<ResponseStruct, ResponseCodec>>>,
409 }
410
411 impl ResponseCaptureReactant {
412 fn new(
413 sender: tokio::sync::mpsc::Sender<Arc<Payload<ResponseStruct, ResponseCodec>>>,
414 ) -> Self {
415 Self { sender }
416 }
417 }
418
419 impl Reactant<ResponseStruct, ResponseCodec> for ResponseCaptureReactant {
420 fn react(
421 &self,
422 payload: Arc<Payload<ResponseStruct, ResponseCodec>>,
423 ) -> Pin<
424 Box<
425 dyn Future<Output = Result<(), crate::reactant::ReactantError>>
426 + Send
427 + 'static,
428 >,
429 > {
430 let sender = self.sender.clone();
431 let payload_clone = payload.clone();
432
433 Box::pin(async move {
434 let _ = sender.try_send(payload_clone);
436 Ok(())
437 })
438 }
439
440 fn erase(self: Box<Self>) -> Arc<dyn ReactantErased + Send + Sync + 'static> {
441 erase_reactant(self)
442 }
443 }
444
445 let response_capture_reactant = ResponseCaptureReactant::new(response_tx.clone());
446
447 #[derive(Clone)]
449 struct ResponseGeneratingReactant {
450 ganglion_id: u32,
451 response_neuron: Arc<dyn Neuron<ResponseStruct, ResponseCodec> + Send + Sync>,
452 queue_sender: tokio::sync::mpsc::Sender<Arc<Payload<ResponseStruct, ResponseCodec>>>,
453 }
454
455 impl ResponseGeneratingReactant {
456 fn new(
457 ganglion_id: u32,
458 response_neuron: Arc<dyn Neuron<ResponseStruct, ResponseCodec> + Send + Sync>,
459 queue_sender: tokio::sync::mpsc::Sender<
460 Arc<Payload<ResponseStruct, ResponseCodec>>,
461 >,
462 ) -> Self {
463 Self {
464 ganglion_id,
465 response_neuron,
466 queue_sender,
467 }
468 }
469 }
470
471 impl Reactant<DebugStruct, DebugCodec> for ResponseGeneratingReactant {
472 fn react(
473 &self,
474 payload: Arc<Payload<DebugStruct, DebugCodec>>,
475 ) -> Pin<
476 Box<
477 dyn Future<Output = Result<(), crate::reactant::ReactantError>>
478 + Send
479 + 'static,
480 >,
481 > {
482 let ganglion_id = self.ganglion_id;
483 let response_neuron = self.response_neuron.clone();
484 let queue_sender = self.queue_sender.clone();
485 let original_value = payload.value.clone();
486
487 Box::pin(async move {
488 let response_payload = Payload::new(
491 ResponseStruct {
492 ganglion_id,
493 response_message: format!(
494 "response_from_ganglion_{}_for_{}",
495 ganglion_id, original_value.bar
496 ),
497 },
498 response_neuron,
499 );
500
501 let _ = queue_sender.try_send(response_payload);
503 Ok(())
504 })
505 }
506
507 fn erase(self: Box<Self>) -> Arc<dyn ReactantErased + Send + Sync + 'static> {
508 erase_reactant(self)
509 }
510 }
511
512 let (queue1_tx, mut queue1_rx) = channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(10);
514 let (queue2_tx, mut queue2_rx) = channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(10);
515 let (queue3_tx, mut queue3_rx) = channel::<Arc<Payload<ResponseStruct, ResponseCodec>>>(10);
516
517 let thalamus_arc = Arc::new(Mutex::new(thalamus));
518
519 {
521 let mut thalamus_guard = thalamus_arc.lock().await;
522 let response_reactants = vec![erase_reactant::<ResponseStruct, ResponseCodec, _>(
523 Box::new(response_capture_reactant),
524 )];
525 let future =
526 thalamus_guard.react(response_neuron_name.clone(), response_reactants, vec![]);
527 drop(thalamus_guard);
528 future.await.unwrap();
529 }
530
531 let g1_clone = g1.clone();
533 tokio::spawn(async move {
534 while let Some(payload) = queue1_rx.recv().await {
535 let future = {
536 let mut ganglion_guard = g1_clone.lock().await;
537 ganglion_guard.transmit(erase_payload(payload))
538 };
539 let _ = future.await;
540 }
541 });
542
543 let g2_clone = g2.clone();
544 tokio::spawn(async move {
545 while let Some(payload) = queue2_rx.recv().await {
546 let future = {
547 let mut ganglion_guard = g2_clone.lock().await;
548 ganglion_guard.transmit(erase_payload(payload))
549 };
550 let _ = future.await;
551 }
552 });
553
554 let g3_clone = g3.clone();
555 tokio::spawn(async move {
556 while let Some(payload) = queue3_rx.recv().await {
557 let future = {
558 let mut ganglion_guard = g3_clone.lock().await;
559 ganglion_guard.transmit(erase_payload(payload))
560 };
561 let _ = future.await;
562 }
563 });
564
565 {
568 let mut g1_guard = g1.lock().await;
569 let reactants1 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
570 ResponseGeneratingReactant::new(1, response_neuron_arc.clone(), queue1_tx),
571 ))];
572 let future = g1_guard.react(request_neuron_name.clone(), reactants1, vec![]);
573 drop(g1_guard);
574 future.await.unwrap();
575 }
576 {
577 let mut g2_guard = g2.lock().await;
578 let reactants2 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
579 ResponseGeneratingReactant::new(2, response_neuron_arc.clone(), queue2_tx),
580 ))];
581 let future = g2_guard.react(request_neuron_name.clone(), reactants2, vec![]);
582 drop(g2_guard);
583 future.await.unwrap();
584 }
585 {
586 let mut g3_guard = g3.lock().await;
587 let reactants3 = vec![erase_reactant::<DebugStruct, DebugCodec, _>(Box::new(
588 ResponseGeneratingReactant::new(3, response_neuron_arc.clone(), queue3_tx),
589 ))];
590 let future = g3_guard.react(request_neuron_name.clone(), reactants3, vec![]);
591 drop(g3_guard);
592 future.await.unwrap();
593 }
594
595 {
597 for i in 0..6 {
598 let payload = Payload::new(
599 DebugStruct {
600 foo: i,
601 bar: format!("request_{i}"),
602 },
603 request_neuron_arc.clone(),
604 );
605
606 let future = {
607 let mut thalamus_guard = thalamus_arc.lock().await;
608 thalamus_guard.transmit(erase_payload(payload))
609 };
610 future.await.unwrap();
611 }
612 }
613
614 let mut count_g1 = 0;
616 let mut count_g2 = 0;
617 let mut count_g3 = 0;
618
619 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
621
622 let mut total_received = 0;
624 while total_received < 6 && !response_rx.is_empty() {
625 if let Ok(payload) = response_rx.try_recv() {
626 match payload.value.ganglion_id {
627 1 => count_g1 += 1,
628 2 => count_g2 += 1,
629 3 => count_g3 += 1,
630 _ => panic!(
631 "Unexpected ganglion ID in response: {}",
632 payload.value.ganglion_id
633 ),
634 }
635 total_received += 1;
636 } else {
637 break;
638 }
639 }
640
641 assert_eq!(count_g1, 2, "Should receive 2 responses from ganglion 1");
643 assert_eq!(count_g2, 2, "Should receive 2 responses from ganglion 2");
644 assert_eq!(count_g3, 2, "Should receive 2 responses from ganglion 3");
645 assert_eq!(
646 count_g1 + count_g2 + count_g3,
647 6,
648 "Total responses should be 6"
649 );
650 }
651}