1#![allow(clippy::expect_used)]
29
30use std::collections::HashMap;
31use std::fmt;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::{Arc, RwLock};
34use std::time::{Duration, Instant};
35
36use super::{Brick, BrickAssertion, BrickBudget, BrickError, BrickResult, BrickVerification};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct WorkerId(pub u64);
41
42impl WorkerId {
43 #[must_use]
45 pub const fn new(id: u64) -> Self {
46 Self(id)
47 }
48
49 #[must_use]
51 pub const fn value(&self) -> u64 {
52 self.0
53 }
54}
55
56impl fmt::Display for WorkerId {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 write!(f, "worker-{}", self.0)
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum Backend {
65 Cpu,
67 Gpu,
69 Remote,
71 Simd,
73}
74
75impl Backend {
76 #[must_use]
78 pub fn is_available(&self) -> bool {
79 match self {
80 Self::Cpu | Self::Simd => true,
81 Self::Gpu => cfg!(feature = "gpu"),
82 Self::Remote => false,
84 }
85 }
86
87 #[must_use]
89 pub const fn performance_estimate(&self) -> u32 {
90 match self {
91 Self::Gpu => 100,
92 Self::Simd => 50,
93 Self::Cpu => 10,
94 Self::Remote => 5, }
96 }
97}
98
99impl Default for Backend {
100 fn default() -> Self {
101 Self::Cpu
102 }
103}
104
105#[derive(Debug, Clone, Default)]
107pub struct BrickInput {
108 pub data: Vec<f32>,
110 pub shape: Vec<usize>,
112 pub metadata: HashMap<String, String>,
114}
115
116impl BrickInput {
117 #[must_use]
119 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
120 Self {
121 data,
122 shape,
123 metadata: HashMap::new(),
124 }
125 }
126
127 #[must_use]
129 pub fn size_bytes(&self) -> usize {
130 self.data.len() * std::mem::size_of::<f32>()
131 }
132
133 #[must_use]
135 pub fn element_count(&self) -> usize {
136 self.data.len()
137 }
138
139 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
141 self.metadata.insert(key.into(), value.into());
142 self
143 }
144}
145
146#[derive(Debug, Clone, Default)]
148pub struct BrickOutput {
149 pub data: Vec<f32>,
151 pub shape: Vec<usize>,
153 pub metrics: ExecutionMetrics,
155}
156
157impl BrickOutput {
158 #[must_use]
160 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
161 Self {
162 data,
163 shape,
164 metrics: ExecutionMetrics::default(),
165 }
166 }
167
168 #[must_use]
170 pub fn size_bytes(&self) -> usize {
171 self.data.len() * std::mem::size_of::<f32>()
172 }
173}
174
175#[derive(Debug, Clone, Default)]
177pub struct ExecutionMetrics {
178 pub execution_time: Duration,
180 pub backend: Backend,
182 pub worker_id: Option<WorkerId>,
184 pub transfer_time: Option<Duration>,
186}
187
188impl ExecutionMetrics {
189 #[must_use]
191 pub fn new(execution_time: Duration, backend: Backend) -> Self {
192 Self {
193 execution_time,
194 backend,
195 worker_id: None,
196 transfer_time: None,
197 }
198 }
199}
200
201#[derive(Debug)]
208pub struct DistributedBrick<B: Brick> {
209 inner: B,
210 backend: Backend,
211 data_dependencies: Vec<String>,
212 preferred_worker: Option<WorkerId>,
213}
214
215impl<B: Brick> DistributedBrick<B> {
216 #[must_use]
218 pub fn new(inner: B) -> Self {
219 Self {
220 inner,
221 backend: Backend::default(),
222 data_dependencies: Vec::new(),
223 preferred_worker: None,
224 }
225 }
226
227 #[must_use]
229 pub fn with_backend(mut self, backend: Backend) -> Self {
230 self.backend = backend;
231 self
232 }
233
234 #[must_use]
236 pub fn with_data_dependencies(mut self, deps: Vec<String>) -> Self {
237 self.data_dependencies = deps;
238 self
239 }
240
241 #[must_use]
243 pub fn with_preferred_worker(mut self, worker: WorkerId) -> Self {
244 self.preferred_worker = Some(worker);
245 self
246 }
247
248 #[must_use]
250 pub fn inner(&self) -> &B {
251 &self.inner
252 }
253
254 pub fn inner_mut(&mut self) -> &mut B {
256 &mut self.inner
257 }
258
259 #[must_use]
261 pub fn backend(&self) -> Backend {
262 self.backend
263 }
264
265 #[must_use]
267 pub fn data_dependencies(&self) -> &[String] {
268 &self.data_dependencies
269 }
270
271 #[must_use]
273 pub fn preferred_worker(&self) -> Option<WorkerId> {
274 self.preferred_worker
275 }
276
277 #[must_use]
279 pub fn to_task_spec(&self) -> TaskSpec {
280 TaskSpec {
281 brick_name: self.inner.brick_name().to_string(),
282 backend: self.backend,
283 data_dependencies: self.data_dependencies.clone(),
284 preferred_worker: self.preferred_worker,
285 }
286 }
287}
288
289impl<B: Brick> Brick for DistributedBrick<B> {
290 fn brick_name(&self) -> &'static str {
291 self.inner.brick_name()
292 }
293
294 fn assertions(&self) -> &[BrickAssertion] {
295 self.inner.assertions()
296 }
297
298 fn budget(&self) -> BrickBudget {
299 self.inner.budget()
300 }
301
302 fn verify(&self) -> BrickVerification {
303 self.inner.verify()
304 }
305
306 fn to_html(&self) -> String {
307 self.inner.to_html()
308 }
309
310 fn to_css(&self) -> String {
311 self.inner.to_css()
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct TaskSpec {
318 pub brick_name: String,
320 pub backend: Backend,
322 pub data_dependencies: Vec<String>,
324 pub preferred_worker: Option<WorkerId>,
326}
327
328#[derive(Debug, Clone)]
330pub struct DataLocation {
331 pub key: String,
333 pub workers: Vec<WorkerId>,
335 pub size_bytes: usize,
337 pub last_access: Instant,
339}
340
341#[derive(Debug)]
345pub struct BrickDataTracker {
346 locations: RwLock<HashMap<String, DataLocation>>,
348}
349
350impl Default for BrickDataTracker {
351 fn default() -> Self {
352 Self::new()
353 }
354}
355
356impl BrickDataTracker {
357 #[must_use]
359 pub fn new() -> Self {
360 Self {
361 locations: RwLock::new(HashMap::new()),
362 }
363 }
364
365 pub fn track_data(&self, key: &str, worker_id: WorkerId, size_bytes: usize) {
367 let mut locations = self.locations.write().expect("lock poisoned");
368 locations
369 .entry(key.to_string())
370 .and_modify(|loc| {
371 if !loc.workers.contains(&worker_id) {
372 loc.workers.push(worker_id);
373 }
374 loc.last_access = Instant::now();
375 })
376 .or_insert_with(|| DataLocation {
377 key: key.to_string(),
378 workers: vec![worker_id],
379 size_bytes,
380 last_access: Instant::now(),
381 });
382 }
383
384 pub fn track_weights(&self, brick_name: &str, worker_id: WorkerId) {
386 let key = format!("{}_weights", brick_name);
387 self.track_data(&key, worker_id, 0);
388 }
389
390 pub fn remove_data(&self, key: &str, worker_id: WorkerId) {
392 let mut locations = self.locations.write().expect("lock poisoned");
393 if let Some(loc) = locations.get_mut(key) {
394 loc.workers.retain(|w| *w != worker_id);
395 }
396 }
397
398 #[must_use]
400 pub fn get_workers_for_data(&self, key: &str) -> Vec<WorkerId> {
401 let locations = self.locations.read().expect("lock poisoned");
402 locations
403 .get(key)
404 .map_or(Vec::new(), |loc| loc.workers.clone())
405 }
406
407 pub fn calculate_affinity(&self, dependencies: &[String]) -> HashMap<WorkerId, f64> {
409 let locations = self.locations.read().expect("lock poisoned");
410 let mut affinity: HashMap<WorkerId, f64> = HashMap::new();
411
412 for dep in dependencies {
413 if let Some(loc) = locations.get(dep) {
414 let score_per_worker = 1.0 / loc.workers.len() as f64;
415 for worker in &loc.workers {
416 *affinity.entry(*worker).or_insert(0.0) += score_per_worker;
417 }
418 }
419 }
420
421 if !affinity.is_empty() {
423 let max_score = affinity.values().cloned().fold(0.0_f64, f64::max);
424 if max_score > 0.0 {
425 for score in affinity.values_mut() {
426 *score /= max_score;
427 }
428 }
429 }
430
431 affinity
432 }
433
434 #[must_use]
436 pub fn find_best_worker(&self, brick: &dyn Brick) -> Option<WorkerId> {
437 let weights_key = format!("{}_weights", brick.brick_name());
439 let workers = self.get_workers_for_data(&weights_key);
440 workers.first().copied()
441 }
442
443 #[must_use]
445 pub fn find_best_worker_for_distributed<B: Brick>(
446 &self,
447 brick: &DistributedBrick<B>,
448 ) -> Option<WorkerId> {
449 if let Some(preferred) = brick.preferred_worker() {
451 return Some(preferred);
452 }
453
454 let affinity = self.calculate_affinity(brick.data_dependencies());
456 affinity
457 .into_iter()
458 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
459 .map(|(worker, _)| worker)
460 }
461
462 #[must_use]
464 pub fn total_data_size(&self) -> usize {
465 let locations = self.locations.read().expect("lock poisoned");
466 locations.values().map(|loc| loc.size_bytes).sum()
467 }
468}
469
470#[derive(Debug)]
472pub struct BackendSelector {
473 gpu_threshold: usize,
475 simd_threshold: usize,
477 cpu_max_threshold: usize,
479}
480
481impl Default for BackendSelector {
482 fn default() -> Self {
483 Self::new()
484 }
485}
486
487impl BackendSelector {
488 #[must_use]
490 pub fn new() -> Self {
491 Self {
492 gpu_threshold: 1_000_000, simd_threshold: 10_000, cpu_max_threshold: 100_000_000, }
496 }
497
498 #[must_use]
500 pub fn with_gpu_threshold(mut self, threshold: usize) -> Self {
501 self.gpu_threshold = threshold;
502 self
503 }
504
505 #[must_use]
507 pub fn with_simd_threshold(mut self, threshold: usize) -> Self {
508 self.simd_threshold = threshold;
509 self
510 }
511
512 #[must_use]
514 pub fn with_cpu_max_threshold(mut self, threshold: usize) -> Self {
515 self.cpu_max_threshold = threshold;
516 self
517 }
518
519 #[must_use]
521 pub fn select(&self, element_count: usize, gpu_available: bool) -> Backend {
522 if element_count > self.cpu_max_threshold && Backend::Remote.is_available() {
524 return Backend::Remote;
525 }
526
527 if element_count >= self.gpu_threshold && gpu_available {
529 return Backend::Gpu;
530 }
531
532 if element_count >= self.simd_threshold {
534 return Backend::Simd;
535 }
536
537 Backend::Cpu
539 }
540
541 #[must_use]
543 pub fn select_for_brick(
544 &self,
545 _brick_complexity: u32,
546 input_size: usize,
547 gpu_available: bool,
548 ) -> Backend {
549 self.select(input_size, gpu_available)
551 }
552}
553
554#[derive(Debug)]
558pub struct MultiBrickExecutor {
559 selector: BackendSelector,
560 gpu_available: bool,
561 data_tracker: Arc<BrickDataTracker>,
562}
563
564impl MultiBrickExecutor {
565 #[must_use]
567 pub fn new(data_tracker: Arc<BrickDataTracker>) -> Self {
568 Self {
569 selector: BackendSelector::new(),
570 gpu_available: cfg!(feature = "gpu"),
571 data_tracker,
572 }
573 }
574
575 #[must_use]
577 pub fn with_selector(mut self, selector: BackendSelector) -> Self {
578 self.selector = selector;
579 self
580 }
581
582 #[must_use]
584 pub fn with_gpu_available(mut self, available: bool) -> Self {
585 self.gpu_available = available;
586 self
587 }
588
589 pub fn execute(&self, brick: &dyn Brick, input: BrickInput) -> BrickResult<BrickOutput> {
591 let start = Instant::now();
592
593 let backend = self
595 .selector
596 .select(input.element_count(), self.gpu_available);
597
598 let (output_data, output_shape) = match backend {
600 Backend::Cpu => self.execute_cpu(brick, &input)?,
601 Backend::Simd => self.execute_simd(brick, &input)?,
602 Backend::Gpu => self.execute_gpu(brick, &input)?,
603 Backend::Remote => self.execute_remote(brick, &input)?,
604 };
605
606 let execution_time = start.elapsed();
607
608 let mut output = BrickOutput::new(output_data, output_shape);
610 output.metrics = ExecutionMetrics::new(execution_time, backend);
611
612 Ok(output)
613 }
614
615 pub fn execute_distributed<B: Brick>(
617 &self,
618 brick: &DistributedBrick<B>,
619 input: BrickInput,
620 ) -> BrickResult<BrickOutput> {
621 let start = Instant::now();
622
623 let backend = brick.backend();
625
626 let worker_id = self.data_tracker.find_best_worker_for_distributed(brick);
628
629 let (output_data, output_shape) = match backend {
631 Backend::Cpu => self.execute_cpu(brick.inner(), &input)?,
632 Backend::Simd => self.execute_simd(brick.inner(), &input)?,
633 Backend::Gpu => self.execute_gpu(brick.inner(), &input)?,
634 Backend::Remote => self.execute_remote(brick.inner(), &input)?,
635 };
636
637 let execution_time = start.elapsed();
638
639 let mut output = BrickOutput::new(output_data, output_shape);
641 output.metrics = ExecutionMetrics {
642 execution_time,
643 backend,
644 worker_id,
645 transfer_time: None,
646 };
647
648 Ok(output)
649 }
650
651 fn execute_cpu(
652 &self,
653 _brick: &dyn Brick,
654 input: &BrickInput,
655 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
656 Ok((input.data.clone(), input.shape.clone()))
658 }
659
660 fn execute_simd(
661 &self,
662 _brick: &dyn Brick,
663 input: &BrickInput,
664 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
665 Ok((input.data.clone(), input.shape.clone()))
667 }
668
669 fn execute_gpu(
670 &self,
671 _brick: &dyn Brick,
672 input: &BrickInput,
673 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
674 if !self.gpu_available {
676 return Err(BrickError::HtmlGenerationFailed {
677 reason: "GPU not available".into(),
678 });
679 }
680 Ok((input.data.clone(), input.shape.clone()))
681 }
682
683 fn execute_remote(
684 &self,
685 _brick: &dyn Brick,
686 input: &BrickInput,
687 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
688 if !Backend::Remote.is_available() {
690 return Err(BrickError::HtmlGenerationFailed {
691 reason: "Distributed execution not available".into(),
692 });
693 }
694 Ok((input.data.clone(), input.shape.clone()))
695 }
696
697 #[must_use]
699 pub fn data_tracker(&self) -> &Arc<BrickDataTracker> {
700 &self.data_tracker
701 }
702}
703
704#[derive(Debug, Clone)]
706pub enum BrickMessage {
707 WeightUpdate {
709 brick_name: String,
711 weights: Vec<u8>,
713 version: u64,
715 },
716 StateChange {
718 brick_name: String,
720 event: String,
722 },
723 ExecutionRequest {
725 brick_name: String,
727 input_key: String,
729 request_id: u64,
731 },
732 ExecutionResult {
734 request_id: u64,
736 output_key: String,
738 success: bool,
740 },
741}
742
743#[derive(Debug)]
745pub struct Subscription {
746 topic: String,
747 messages: Arc<RwLock<Vec<BrickMessage>>>,
748}
749
750impl Subscription {
751 #[must_use]
753 pub fn drain(&self) -> Vec<BrickMessage> {
754 let mut messages = self.messages.write().expect("lock poisoned");
755 std::mem::take(&mut *messages)
756 }
757
758 #[must_use]
760 pub fn has_messages(&self) -> bool {
761 let messages = self.messages.read().expect("lock poisoned");
762 !messages.is_empty()
763 }
764
765 #[must_use]
767 pub fn topic(&self) -> &str {
768 &self.topic
769 }
770}
771
772#[derive(Debug, Clone)]
778pub struct WorkStealingTask {
779 pub id: u64,
781 pub spec: TaskSpec,
783 pub input_key: String,
785 pub priority: u32,
787 pub created_at: Instant,
789}
790
791impl WorkStealingTask {
792 #[must_use]
794 pub fn new(id: u64, spec: TaskSpec, input_key: String) -> Self {
795 Self {
796 id,
797 spec,
798 input_key,
799 priority: 0,
800 created_at: Instant::now(),
801 }
802 }
803
804 #[must_use]
806 pub fn with_priority(mut self, priority: u32) -> Self {
807 self.priority = priority;
808 self
809 }
810
811 #[must_use]
813 pub fn age(&self) -> Duration {
814 self.created_at.elapsed()
815 }
816}
817
818#[derive(Debug)]
820pub struct WorkerQueue {
821 worker_id: WorkerId,
823 local_queue: RwLock<Vec<WorkStealingTask>>,
825 completed_count: AtomicU64,
827 stolen_count: AtomicU64,
829}
830
831impl WorkerQueue {
832 #[must_use]
834 pub fn new(worker_id: WorkerId) -> Self {
835 Self {
836 worker_id,
837 local_queue: RwLock::new(Vec::new()),
838 completed_count: AtomicU64::new(0),
839 stolen_count: AtomicU64::new(0),
840 }
841 }
842
843 pub fn push(&self, task: WorkStealingTask) {
845 let mut queue = self.local_queue.write().expect("lock poisoned");
846 queue.push(task);
847 queue.sort_by(|a, b| b.priority.cmp(&a.priority));
849 }
850
851 pub fn pop(&self) -> Option<WorkStealingTask> {
853 let mut queue = self.local_queue.write().expect("lock poisoned");
854 if queue.is_empty() {
855 return None;
856 }
857 Some(queue.remove(0)) }
859
860 pub fn steal(&self) -> Option<WorkStealingTask> {
862 let mut queue = self.local_queue.write().expect("lock poisoned");
863 if queue.is_empty() {
864 return None;
865 }
866 self.stolen_count.fetch_add(1, Ordering::Relaxed);
867 queue.pop() }
869
870 #[must_use]
872 pub fn is_empty(&self) -> bool {
873 let queue = self.local_queue.read().expect("lock poisoned");
874 queue.is_empty()
875 }
876
877 #[must_use]
879 pub fn len(&self) -> usize {
880 let queue = self.local_queue.read().expect("lock poisoned");
881 queue.len()
882 }
883
884 pub fn mark_completed(&self) {
886 self.completed_count.fetch_add(1, Ordering::Relaxed);
887 }
888
889 #[must_use]
891 pub fn worker_id(&self) -> WorkerId {
892 self.worker_id
893 }
894
895 #[must_use]
897 pub fn completed_count(&self) -> u64 {
898 self.completed_count.load(Ordering::Relaxed)
899 }
900
901 #[must_use]
903 pub fn stolen_count(&self) -> u64 {
904 self.stolen_count.load(Ordering::Relaxed)
905 }
906}
907
908#[derive(Debug)]
920pub struct WorkStealingScheduler {
921 queues: RwLock<HashMap<WorkerId, Arc<WorkerQueue>>>,
923 data_tracker: Arc<BrickDataTracker>,
925 task_counter: AtomicU64,
927 submitted_count: AtomicU64,
929}
930
931impl WorkStealingScheduler {
932 #[must_use]
934 pub fn new(data_tracker: Arc<BrickDataTracker>) -> Self {
935 Self {
936 queues: RwLock::new(HashMap::new()),
937 data_tracker,
938 task_counter: AtomicU64::new(0),
939 submitted_count: AtomicU64::new(0),
940 }
941 }
942
943 pub fn register_worker(&self, worker_id: WorkerId) -> Arc<WorkerQueue> {
945 let queue = Arc::new(WorkerQueue::new(worker_id));
946 let mut queues = self.queues.write().expect("lock poisoned");
947 queues.insert(worker_id, Arc::clone(&queue));
948 queue
949 }
950
951 pub fn unregister_worker(&self, worker_id: WorkerId) {
953 let mut queues = self.queues.write().expect("lock poisoned");
954 queues.remove(&worker_id);
955 }
956
957 pub fn submit(&self, spec: TaskSpec, input_key: String) -> u64 {
959 let task_id = self.task_counter.fetch_add(1, Ordering::SeqCst);
960 let task = WorkStealingTask::new(task_id, spec.clone(), input_key);
961
962 let target_worker = self.find_best_worker_for_task(&spec);
964
965 let queues = self.queues.read().expect("lock poisoned");
966 if let Some(queue) = target_worker.and_then(|w| queues.get(&w)) {
967 queue.push(task);
968 } else if let Some((_, queue)) = queues.iter().next() {
969 queue.push(task);
971 }
972
973 self.submitted_count.fetch_add(1, Ordering::Relaxed);
974 task_id
975 }
976
977 pub fn submit_priority(&self, spec: TaskSpec, input_key: String, priority: u32) -> u64 {
979 let task_id = self.task_counter.fetch_add(1, Ordering::SeqCst);
980 let task = WorkStealingTask::new(task_id, spec.clone(), input_key).with_priority(priority);
981
982 let target_worker = self.find_best_worker_for_task(&spec);
983
984 let queues = self.queues.read().expect("lock poisoned");
985 if let Some(queue) = target_worker.and_then(|w| queues.get(&w)) {
986 queue.push(task);
987 } else if let Some((_, queue)) = queues.iter().next() {
988 queue.push(task);
989 }
990
991 self.submitted_count.fetch_add(1, Ordering::Relaxed);
992 task_id
993 }
994
995 pub fn get_work(&self, worker_id: WorkerId) -> Option<WorkStealingTask> {
997 let queues = self.queues.read().expect("lock poisoned");
998
999 if let Some(queue) = queues.get(&worker_id) {
1001 if let Some(task) = queue.pop() {
1002 return Some(task);
1003 }
1004 }
1005
1006 self.try_steal(worker_id, &queues)
1008 }
1009
1010 fn try_steal(
1012 &self,
1013 stealer_id: WorkerId,
1014 queues: &HashMap<WorkerId, Arc<WorkerQueue>>,
1015 ) -> Option<WorkStealingTask> {
1016 let mut candidates: Vec<_> = queues
1018 .iter()
1019 .filter(|(id, q)| **id != stealer_id && !q.is_empty())
1020 .collect();
1021
1022 if candidates.is_empty() {
1023 return None;
1024 }
1025
1026 candidates.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
1028
1029 for (_, queue) in candidates {
1031 if let Some(task) = queue.steal() {
1032 return Some(task);
1033 }
1034 }
1035
1036 None
1037 }
1038
1039 fn find_best_worker_for_task(&self, spec: &TaskSpec) -> Option<WorkerId> {
1041 if let Some(preferred) = spec.preferred_worker {
1043 return Some(preferred);
1044 }
1045
1046 let affinity = self
1048 .data_tracker
1049 .calculate_affinity(&spec.data_dependencies);
1050 affinity
1051 .into_iter()
1052 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1053 .map(|(worker, _)| worker)
1054 }
1055
1056 #[must_use]
1058 pub fn stats(&self) -> SchedulerStats {
1059 let queues = self.queues.read().expect("lock poisoned");
1060
1061 let worker_stats: Vec<_> = queues
1062 .values()
1063 .map(|q| WorkerStats {
1064 worker_id: q.worker_id(),
1065 queue_length: q.len(),
1066 completed: q.completed_count(),
1067 stolen_from: q.stolen_count(),
1068 })
1069 .collect();
1070
1071 let total_pending: usize = worker_stats.iter().map(|s| s.queue_length).sum();
1072 let total_completed: u64 = worker_stats.iter().map(|s| s.completed).sum();
1073 let total_stolen: u64 = worker_stats.iter().map(|s| s.stolen_from).sum();
1074
1075 SchedulerStats {
1076 worker_count: queues.len(),
1077 total_submitted: self.submitted_count.load(Ordering::Relaxed),
1078 total_pending,
1079 total_completed,
1080 total_stolen,
1081 workers: worker_stats,
1082 }
1083 }
1084
1085 #[must_use]
1087 pub fn data_tracker(&self) -> &Arc<BrickDataTracker> {
1088 &self.data_tracker
1089 }
1090}
1091
1092#[derive(Debug, Clone)]
1094pub struct WorkerStats {
1095 pub worker_id: WorkerId,
1097 pub queue_length: usize,
1099 pub completed: u64,
1101 pub stolen_from: u64,
1103}
1104
1105#[derive(Debug, Clone)]
1107pub struct SchedulerStats {
1108 pub worker_count: usize,
1110 pub total_submitted: u64,
1112 pub total_pending: usize,
1114 pub total_completed: u64,
1116 pub total_stolen: u64,
1118 pub workers: Vec<WorkerStats>,
1120}
1121
1122#[derive(Debug)]
1130pub struct BrickCoordinator {
1131 subscriptions: RwLock<HashMap<String, Vec<Arc<RwLock<Vec<BrickMessage>>>>>>,
1133 message_counter: AtomicU64,
1135}
1136
1137impl Default for BrickCoordinator {
1138 fn default() -> Self {
1139 Self::new()
1140 }
1141}
1142
1143impl BrickCoordinator {
1144 #[must_use]
1146 pub fn new() -> Self {
1147 Self {
1148 subscriptions: RwLock::new(HashMap::new()),
1149 message_counter: AtomicU64::new(0),
1150 }
1151 }
1152
1153 #[must_use]
1155 pub fn subscribe(&self, topic: &str) -> Subscription {
1156 let messages = Arc::new(RwLock::new(Vec::new()));
1157 {
1158 let mut subs = self.subscriptions.write().expect("lock poisoned");
1159 subs.entry(topic.to_string())
1160 .or_default()
1161 .push(Arc::clone(&messages));
1162 }
1163 Subscription {
1164 topic: topic.to_string(),
1165 messages,
1166 }
1167 }
1168
1169 #[must_use]
1171 pub fn subscribe_brick(&self, brick_name: &str) -> Subscription {
1172 let topic = format!("brick/{}/events", brick_name);
1173 self.subscribe(&topic)
1174 }
1175
1176 pub fn publish(&self, topic: &str, message: BrickMessage) {
1178 let subs = self.subscriptions.read().expect("lock poisoned");
1179 if let Some(subscribers) = subs.get(topic) {
1180 for sub in subscribers {
1181 let mut messages = sub.write().expect("lock poisoned");
1182 messages.push(message.clone());
1183 }
1184 }
1185 }
1186
1187 pub fn broadcast_weights(&self, brick_name: &str, weights: Vec<u8>) {
1189 let topic = format!("brick/{}/weights", brick_name);
1190 let version = self.message_counter.fetch_add(1, Ordering::SeqCst);
1191 self.publish(
1192 &topic,
1193 BrickMessage::WeightUpdate {
1194 brick_name: brick_name.to_string(),
1195 weights,
1196 version,
1197 },
1198 );
1199 }
1200
1201 pub fn broadcast_state_change(&self, brick_name: &str, event: &str) {
1203 let topic = format!("brick/{}/events", brick_name);
1204 self.publish(
1205 &topic,
1206 BrickMessage::StateChange {
1207 brick_name: brick_name.to_string(),
1208 event: event.to_string(),
1209 },
1210 );
1211 }
1212
1213 #[must_use]
1215 pub fn next_request_id(&self) -> u64 {
1216 self.message_counter.fetch_add(1, Ordering::SeqCst)
1217 }
1218}
1219
1220#[cfg(test)]
1221mod tests {
1222 use super::*;
1223
1224 struct TestBrick {
1225 name: &'static str,
1226 }
1227
1228 impl Brick for TestBrick {
1229 fn brick_name(&self) -> &'static str {
1230 self.name
1231 }
1232
1233 fn assertions(&self) -> &[BrickAssertion] {
1234 &[BrickAssertion::TextVisible]
1235 }
1236
1237 fn budget(&self) -> BrickBudget {
1238 BrickBudget::uniform(16)
1239 }
1240
1241 fn verify(&self) -> BrickVerification {
1242 BrickVerification {
1243 passed: vec![BrickAssertion::TextVisible],
1244 failed: vec![],
1245 verification_time: Duration::from_micros(100),
1246 }
1247 }
1248
1249 fn to_html(&self) -> String {
1250 format!("<div>{}</div>", self.name)
1251 }
1252
1253 fn to_css(&self) -> String {
1254 ".test { }".into()
1255 }
1256 }
1257
1258 #[test]
1259 fn test_worker_id() {
1260 let id = WorkerId::new(42);
1261 assert_eq!(id.value(), 42);
1262 assert_eq!(format!("{id}"), "worker-42");
1263 }
1264
1265 #[test]
1266 fn test_backend_availability() {
1267 assert!(Backend::Cpu.is_available());
1268 assert!(Backend::Simd.is_available());
1269 }
1271
1272 #[test]
1273 fn test_backend_performance() {
1274 assert!(Backend::Gpu.performance_estimate() > Backend::Simd.performance_estimate());
1275 assert!(Backend::Simd.performance_estimate() > Backend::Cpu.performance_estimate());
1276 }
1277
1278 #[test]
1279 fn test_distributed_brick_creation() {
1280 let inner = TestBrick { name: "Test" };
1281 let distributed = DistributedBrick::new(inner)
1282 .with_backend(Backend::Gpu)
1283 .with_data_dependencies(vec!["weights".into(), "biases".into()])
1284 .with_preferred_worker(WorkerId::new(1));
1285
1286 assert_eq!(distributed.backend(), Backend::Gpu);
1287 assert_eq!(distributed.data_dependencies().len(), 2);
1288 assert_eq!(distributed.preferred_worker(), Some(WorkerId::new(1)));
1289 assert_eq!(distributed.brick_name(), "Test");
1290 }
1291
1292 #[test]
1293 fn test_distributed_brick_implements_brick() {
1294 let inner = TestBrick { name: "Test" };
1295 let distributed = DistributedBrick::new(inner);
1296
1297 assert!(distributed.verify().is_valid());
1299 assert_eq!(distributed.budget().total_ms, 16);
1300 }
1301
1302 #[test]
1303 fn test_task_spec() {
1304 let inner = TestBrick { name: "TestTask" };
1305 let distributed = DistributedBrick::new(inner)
1306 .with_backend(Backend::Simd)
1307 .with_data_dependencies(vec!["model".into()]);
1308
1309 let spec = distributed.to_task_spec();
1310 assert_eq!(spec.brick_name, "TestTask");
1311 assert_eq!(spec.backend, Backend::Simd);
1312 assert_eq!(spec.data_dependencies, vec!["model"]);
1313 }
1314
1315 #[test]
1316 fn test_brick_input_output() {
1317 let input = BrickInput::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1318 assert_eq!(input.element_count(), 4);
1319 assert_eq!(input.size_bytes(), 16);
1320
1321 let output = BrickOutput::new(vec![5.0, 6.0], vec![2]);
1322 assert_eq!(output.size_bytes(), 8);
1323 }
1324
1325 #[test]
1326 fn test_data_tracker() {
1327 let tracker = BrickDataTracker::new();
1328
1329 tracker.track_data("model_weights", WorkerId::new(1), 1024);
1331 tracker.track_data("model_weights", WorkerId::new(2), 1024);
1332 tracker.track_data("biases", WorkerId::new(1), 256);
1333
1334 let workers = tracker.get_workers_for_data("model_weights");
1336 assert_eq!(workers.len(), 2);
1337
1338 let affinity = tracker.calculate_affinity(&["model_weights".into(), "biases".into()]);
1340 assert!(affinity.get(&WorkerId::new(1)).unwrap_or(&0.0) > &0.0);
1341 }
1342
1343 #[test]
1344 fn test_data_tracker_find_best_worker() {
1345 let tracker = BrickDataTracker::new();
1346
1347 let brick = TestBrick { name: "MelBrick" };
1348 tracker.track_weights("MelBrick", WorkerId::new(5));
1349
1350 let best = tracker.find_best_worker(&brick);
1351 assert_eq!(best, Some(WorkerId::new(5)));
1352 }
1353
1354 #[test]
1355 fn test_backend_selector() {
1356 let selector = BackendSelector::new()
1357 .with_gpu_threshold(1000)
1358 .with_simd_threshold(100);
1359
1360 assert_eq!(selector.select(50, true), Backend::Cpu);
1362
1363 assert_eq!(selector.select(500, true), Backend::Simd);
1365
1366 assert_eq!(selector.select(5000, true), Backend::Gpu);
1368
1369 assert_eq!(selector.select(5000, false), Backend::Simd);
1371 }
1372
1373 #[test]
1374 fn test_multi_executor() {
1375 let tracker = Arc::new(BrickDataTracker::new());
1376 let executor = MultiBrickExecutor::new(tracker);
1377
1378 let brick = TestBrick { name: "Test" };
1379 let input = BrickInput::new(vec![1.0, 2.0, 3.0], vec![3]);
1380
1381 let result = executor.execute(&brick, input);
1382 assert!(result.is_ok());
1383
1384 let output = result.expect("execution should succeed");
1385 assert_eq!(output.data.len(), 3);
1386 assert!(output.metrics.execution_time >= Duration::ZERO);
1387 }
1388
1389 #[test]
1390 fn test_brick_coordinator() {
1391 let coordinator = BrickCoordinator::new();
1392
1393 let sub = coordinator.subscribe_brick("MyBrick");
1395
1396 coordinator.broadcast_state_change("MyBrick", "loaded");
1398
1399 assert!(sub.has_messages());
1401 let messages = sub.drain();
1402 assert_eq!(messages.len(), 1);
1403 matches!(&messages[0], BrickMessage::StateChange { brick_name, .. } if brick_name == "MyBrick");
1404 }
1405
1406 #[test]
1407 fn test_coordinator_weight_broadcast() {
1408 let coordinator = BrickCoordinator::new();
1409
1410 let sub = coordinator.subscribe("brick/Encoder/weights");
1411 coordinator.broadcast_weights("Encoder", vec![1, 2, 3, 4]);
1412
1413 let messages = sub.drain();
1414 assert_eq!(messages.len(), 1);
1415 match &messages[0] {
1416 BrickMessage::WeightUpdate {
1417 brick_name,
1418 weights,
1419 version,
1420 } => {
1421 assert_eq!(brick_name, "Encoder");
1422 assert_eq!(weights, &vec![1, 2, 3, 4]);
1423 assert_eq!(*version, 0);
1424 }
1425 _ => panic!("Expected WeightUpdate message"),
1426 }
1427 }
1428
1429 #[test]
1430 fn test_subscription_topic() {
1431 let coordinator = BrickCoordinator::new();
1432 let sub = coordinator.subscribe("my/topic");
1433 assert_eq!(sub.topic(), "my/topic");
1434 }
1435
1436 #[test]
1437 fn test_execution_metrics() {
1438 let metrics = ExecutionMetrics::new(Duration::from_millis(50), Backend::Gpu);
1439 assert_eq!(metrics.execution_time, Duration::from_millis(50));
1440 assert_eq!(metrics.backend, Backend::Gpu);
1441 assert!(metrics.worker_id.is_none());
1442 }
1443
1444 #[test]
1449 fn test_work_stealing_task() {
1450 let spec = TaskSpec {
1451 brick_name: "TestBrick".into(),
1452 backend: Backend::Cpu,
1453 data_dependencies: vec![],
1454 preferred_worker: None,
1455 };
1456 let task = WorkStealingTask::new(1, spec, "input_key".into()).with_priority(10);
1457
1458 assert_eq!(task.id, 1);
1459 assert_eq!(task.priority, 10);
1460 assert_eq!(task.input_key, "input_key");
1461 assert!(task.age() >= Duration::ZERO);
1462 }
1463
1464 #[test]
1465 fn test_worker_queue_basic() {
1466 let queue = WorkerQueue::new(WorkerId::new(1));
1467
1468 assert!(queue.is_empty());
1469 assert_eq!(queue.len(), 0);
1470
1471 let spec = TaskSpec {
1472 brick_name: "Test".into(),
1473 backend: Backend::Cpu,
1474 data_dependencies: vec![],
1475 preferred_worker: None,
1476 };
1477 let task = WorkStealingTask::new(1, spec, "key".into());
1478 queue.push(task);
1479
1480 assert!(!queue.is_empty());
1481 assert_eq!(queue.len(), 1);
1482
1483 let popped = queue.pop();
1484 assert!(popped.is_some());
1485 assert!(queue.is_empty());
1486 }
1487
1488 #[test]
1489 fn test_worker_queue_priority_ordering() {
1490 let queue = WorkerQueue::new(WorkerId::new(1));
1491
1492 for i in 0..5 {
1494 let spec = TaskSpec {
1495 brick_name: format!("Task{}", i),
1496 backend: Backend::Cpu,
1497 data_dependencies: vec![],
1498 preferred_worker: None,
1499 };
1500 let task = WorkStealingTask::new(i as u64, spec, "key".into()).with_priority(i);
1501 queue.push(task);
1502 }
1503
1504 let task = queue.pop().unwrap();
1506 assert_eq!(task.priority, 4);
1507
1508 let task = queue.pop().unwrap();
1509 assert_eq!(task.priority, 3);
1510 }
1511
1512 #[test]
1513 fn test_worker_queue_steal() {
1514 let queue = WorkerQueue::new(WorkerId::new(1));
1515
1516 for i in 0..3 {
1518 let spec = TaskSpec {
1519 brick_name: format!("Task{}", i),
1520 backend: Backend::Cpu,
1521 data_dependencies: vec![],
1522 preferred_worker: None,
1523 };
1524 let task = WorkStealingTask::new(i as u64, spec, "key".into()).with_priority(i);
1525 queue.push(task);
1526 }
1527
1528 let stolen = queue.steal().unwrap();
1530 assert_eq!(stolen.priority, 0);
1531 assert_eq!(queue.stolen_count(), 1);
1532
1533 assert_eq!(queue.len(), 2);
1535 }
1536
1537 #[test]
1538 fn test_work_stealing_scheduler_basic() {
1539 let tracker = Arc::new(BrickDataTracker::new());
1540 let scheduler = WorkStealingScheduler::new(tracker);
1541
1542 let _q1 = scheduler.register_worker(WorkerId::new(1));
1544 let _q2 = scheduler.register_worker(WorkerId::new(2));
1545
1546 let stats = scheduler.stats();
1547 assert_eq!(stats.worker_count, 2);
1548 assert_eq!(stats.total_submitted, 0);
1549 }
1550
1551 #[test]
1552 fn test_work_stealing_scheduler_submit() {
1553 let tracker = Arc::new(BrickDataTracker::new());
1554 let scheduler = WorkStealingScheduler::new(tracker);
1555
1556 scheduler.register_worker(WorkerId::new(1));
1557
1558 let spec = TaskSpec {
1559 brick_name: "Test".into(),
1560 backend: Backend::Cpu,
1561 data_dependencies: vec![],
1562 preferred_worker: None,
1563 };
1564
1565 let task_id = scheduler.submit(spec, "input".into());
1566 assert_eq!(task_id, 0);
1567
1568 let stats = scheduler.stats();
1569 assert_eq!(stats.total_submitted, 1);
1570 assert_eq!(stats.total_pending, 1);
1571 }
1572
1573 #[test]
1574 fn test_work_stealing_scheduler_get_work() {
1575 let tracker = Arc::new(BrickDataTracker::new());
1576 let scheduler = WorkStealingScheduler::new(tracker);
1577
1578 scheduler.register_worker(WorkerId::new(1));
1579 scheduler.register_worker(WorkerId::new(2));
1580
1581 let spec = TaskSpec {
1583 brick_name: "Test".into(),
1584 backend: Backend::Cpu,
1585 data_dependencies: vec![],
1586 preferred_worker: Some(WorkerId::new(1)),
1587 };
1588 scheduler.submit(spec, "input".into());
1589
1590 let task = scheduler.get_work(WorkerId::new(1));
1592 assert!(task.is_some());
1593
1594 let task = scheduler.get_work(WorkerId::new(2));
1596 assert!(task.is_none());
1597 }
1598
1599 #[test]
1600 fn test_work_stealing_scheduler_steal() {
1601 let tracker = Arc::new(BrickDataTracker::new());
1602 let scheduler = WorkStealingScheduler::new(tracker);
1603
1604 scheduler.register_worker(WorkerId::new(1));
1605 scheduler.register_worker(WorkerId::new(2));
1606
1607 for i in 0..3 {
1609 let spec = TaskSpec {
1610 brick_name: format!("Task{}", i),
1611 backend: Backend::Cpu,
1612 data_dependencies: vec![],
1613 preferred_worker: Some(WorkerId::new(1)),
1614 };
1615 scheduler.submit(spec, format!("input{}", i));
1616 }
1617
1618 let stolen = scheduler.get_work(WorkerId::new(2));
1620 assert!(stolen.is_some());
1621
1622 let stats = scheduler.stats();
1623 assert_eq!(stats.total_stolen, 1);
1624 assert_eq!(stats.total_pending, 2); }
1626
1627 #[test]
1628 fn test_work_stealing_scheduler_locality() {
1629 let tracker = Arc::new(BrickDataTracker::new());
1630
1631 tracker.track_data("model_weights", WorkerId::new(1), 1024);
1633
1634 let scheduler = WorkStealingScheduler::new(Arc::clone(&tracker));
1635 scheduler.register_worker(WorkerId::new(1));
1636 scheduler.register_worker(WorkerId::new(2));
1637
1638 let spec = TaskSpec {
1640 brick_name: "MelBrick".into(),
1641 backend: Backend::Cpu,
1642 data_dependencies: vec!["model_weights".into()],
1643 preferred_worker: None,
1644 };
1645 scheduler.submit(spec, "audio_input".into());
1646
1647 let task = scheduler.get_work(WorkerId::new(1));
1649 assert!(task.is_some());
1650 assert_eq!(task.unwrap().spec.brick_name, "MelBrick");
1651 }
1652
1653 #[test]
1654 fn test_scheduler_stats() {
1655 let tracker = Arc::new(BrickDataTracker::new());
1656 let scheduler = WorkStealingScheduler::new(tracker);
1657
1658 scheduler.register_worker(WorkerId::new(1));
1659 scheduler.register_worker(WorkerId::new(2));
1660
1661 for i in 0..5 {
1663 let spec = TaskSpec {
1664 brick_name: format!("Task{}", i),
1665 backend: Backend::Cpu,
1666 data_dependencies: vec![],
1667 preferred_worker: if i % 2 == 0 {
1668 Some(WorkerId::new(1))
1669 } else {
1670 Some(WorkerId::new(2))
1671 },
1672 };
1673 scheduler.submit(spec, format!("input{}", i));
1674 }
1675
1676 let stats = scheduler.stats();
1677 assert_eq!(stats.worker_count, 2);
1678 assert_eq!(stats.total_submitted, 5);
1679 assert_eq!(stats.total_pending, 5);
1680 assert_eq!(stats.workers.len(), 2);
1681 }
1682}