1#![allow(clippy::expect_used)]
29
30use std::collections::HashMap;
31use std::fmt;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::{Arc, RwLock};
34use std::time::{Duration, Instant};
35
36use super::{Brick, BrickAssertion, BrickBudget, BrickError, BrickResult, BrickVerification};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct WorkerId(pub u64);
41
42impl WorkerId {
43 #[must_use]
45 pub const fn new(id: u64) -> Self {
46 Self(id)
47 }
48
49 #[must_use]
51 pub const fn value(&self) -> u64 {
52 self.0
53 }
54}
55
56impl fmt::Display for WorkerId {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 write!(f, "worker-{}", self.0)
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum Backend {
65 Cpu,
67 Gpu,
69 Remote,
71 Simd,
73}
74
75impl Backend {
76 #[must_use]
78 pub fn is_available(&self) -> bool {
79 match self {
80 Self::Cpu | Self::Simd => true,
81 Self::Gpu => cfg!(feature = "gpu"),
82 Self::Remote => false,
84 }
85 }
86
87 #[must_use]
89 pub const fn performance_estimate(&self) -> u32 {
90 match self {
91 Self::Gpu => 100,
92 Self::Simd => 50,
93 Self::Cpu => 10,
94 Self::Remote => 5, }
96 }
97}
98
99impl Default for Backend {
100 fn default() -> Self {
101 Self::Cpu
102 }
103}
104
105#[derive(Debug, Clone, Default)]
107pub struct BrickInput {
108 pub data: Vec<f32>,
110 pub shape: Vec<usize>,
112 pub metadata: HashMap<String, String>,
114}
115
116impl BrickInput {
117 #[must_use]
119 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
120 Self {
121 data,
122 shape,
123 metadata: HashMap::new(),
124 }
125 }
126
127 #[must_use]
129 pub fn size_bytes(&self) -> usize {
130 self.data.len() * std::mem::size_of::<f32>()
131 }
132
133 #[must_use]
135 pub fn element_count(&self) -> usize {
136 self.data.len()
137 }
138
139 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
141 self.metadata.insert(key.into(), value.into());
142 self
143 }
144}
145
146#[derive(Debug, Clone, Default)]
148pub struct BrickOutput {
149 pub data: Vec<f32>,
151 pub shape: Vec<usize>,
153 pub metrics: ExecutionMetrics,
155}
156
157impl BrickOutput {
158 #[must_use]
160 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
161 Self {
162 data,
163 shape,
164 metrics: ExecutionMetrics::default(),
165 }
166 }
167
168 #[must_use]
170 pub fn size_bytes(&self) -> usize {
171 self.data.len() * std::mem::size_of::<f32>()
172 }
173}
174
175#[derive(Debug, Clone, Default)]
177pub struct ExecutionMetrics {
178 pub execution_time: Duration,
180 pub backend: Backend,
182 pub worker_id: Option<WorkerId>,
184 pub transfer_time: Option<Duration>,
186}
187
188impl ExecutionMetrics {
189 #[must_use]
191 pub fn new(execution_time: Duration, backend: Backend) -> Self {
192 Self {
193 execution_time,
194 backend,
195 worker_id: None,
196 transfer_time: None,
197 }
198 }
199}
200
201#[derive(Debug)]
208pub struct DistributedBrick<B: Brick> {
209 inner: B,
210 backend: Backend,
211 data_dependencies: Vec<String>,
212 preferred_worker: Option<WorkerId>,
213}
214
215impl<B: Brick> DistributedBrick<B> {
216 #[must_use]
218 pub fn new(inner: B) -> Self {
219 Self {
220 inner,
221 backend: Backend::default(),
222 data_dependencies: Vec::new(),
223 preferred_worker: None,
224 }
225 }
226
227 #[must_use]
229 pub fn with_backend(mut self, backend: Backend) -> Self {
230 self.backend = backend;
231 self
232 }
233
234 #[must_use]
236 pub fn with_data_dependencies(mut self, deps: Vec<String>) -> Self {
237 self.data_dependencies = deps;
238 self
239 }
240
241 #[must_use]
243 pub fn with_preferred_worker(mut self, worker: WorkerId) -> Self {
244 self.preferred_worker = Some(worker);
245 self
246 }
247
248 #[must_use]
250 pub fn inner(&self) -> &B {
251 &self.inner
252 }
253
254 pub fn inner_mut(&mut self) -> &mut B {
256 &mut self.inner
257 }
258
259 #[must_use]
261 pub fn backend(&self) -> Backend {
262 self.backend
263 }
264
265 #[must_use]
267 pub fn data_dependencies(&self) -> &[String] {
268 &self.data_dependencies
269 }
270
271 #[must_use]
273 pub fn preferred_worker(&self) -> Option<WorkerId> {
274 self.preferred_worker
275 }
276
277 #[must_use]
279 pub fn to_task_spec(&self) -> TaskSpec {
280 TaskSpec {
281 brick_name: self.inner.brick_name().to_string(),
282 backend: self.backend,
283 data_dependencies: self.data_dependencies.clone(),
284 preferred_worker: self.preferred_worker,
285 }
286 }
287}
288
289impl<B: Brick> Brick for DistributedBrick<B> {
290 fn brick_name(&self) -> &'static str {
291 self.inner.brick_name()
292 }
293
294 fn assertions(&self) -> &[BrickAssertion] {
295 self.inner.assertions()
296 }
297
298 fn budget(&self) -> BrickBudget {
299 self.inner.budget()
300 }
301
302 fn verify(&self) -> BrickVerification {
303 self.inner.verify()
304 }
305
306 fn to_html(&self) -> String {
307 self.inner.to_html()
308 }
309
310 fn to_css(&self) -> String {
311 self.inner.to_css()
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct TaskSpec {
318 pub brick_name: String,
320 pub backend: Backend,
322 pub data_dependencies: Vec<String>,
324 pub preferred_worker: Option<WorkerId>,
326}
327
328#[derive(Debug, Clone)]
330pub struct DataLocation {
331 pub key: String,
333 pub workers: Vec<WorkerId>,
335 pub size_bytes: usize,
337 pub last_access: Instant,
339}
340
341#[derive(Debug)]
345pub struct BrickDataTracker {
346 locations: RwLock<HashMap<String, DataLocation>>,
348}
349
350impl Default for BrickDataTracker {
351 fn default() -> Self {
352 Self::new()
353 }
354}
355
356impl BrickDataTracker {
357 #[must_use]
359 pub fn new() -> Self {
360 Self {
361 locations: RwLock::new(HashMap::new()),
362 }
363 }
364
365 pub fn track_data(&self, key: &str, worker_id: WorkerId, size_bytes: usize) {
367 let mut locations = self.locations.write().expect("lock poisoned");
368 locations
369 .entry(key.to_string())
370 .and_modify(|loc| {
371 if !loc.workers.contains(&worker_id) {
372 loc.workers.push(worker_id);
373 }
374 loc.last_access = Instant::now();
375 })
376 .or_insert_with(|| DataLocation {
377 key: key.to_string(),
378 workers: vec![worker_id],
379 size_bytes,
380 last_access: Instant::now(),
381 });
382 }
383
384 pub fn track_weights(&self, brick_name: &str, worker_id: WorkerId) {
386 let key = format!("{}_weights", brick_name);
387 self.track_data(&key, worker_id, 0);
388 }
389
390 pub fn remove_data(&self, key: &str, worker_id: WorkerId) {
392 let mut locations = self.locations.write().expect("lock poisoned");
393 if let Some(loc) = locations.get_mut(key) {
394 loc.workers.retain(|w| *w != worker_id);
395 }
396 }
397
398 #[must_use]
400 pub fn get_workers_for_data(&self, key: &str) -> Vec<WorkerId> {
401 let locations = self.locations.read().expect("lock poisoned");
402 locations
403 .get(key)
404 .map_or(Vec::new(), |loc| loc.workers.clone())
405 }
406
407 pub fn calculate_affinity(&self, dependencies: &[String]) -> HashMap<WorkerId, f64> {
409 let locations = self.locations.read().expect("lock poisoned");
410 let mut affinity: HashMap<WorkerId, f64> = HashMap::new();
411
412 for dep in dependencies {
413 if let Some(loc) = locations.get(dep) {
414 let score_per_worker = 1.0 / loc.workers.len() as f64;
415 for worker in &loc.workers {
416 *affinity.entry(*worker).or_insert(0.0) += score_per_worker;
417 }
418 }
419 }
420
421 if !affinity.is_empty() {
423 let max_score = affinity.values().cloned().fold(0.0_f64, f64::max);
424 if max_score > 0.0 {
425 for score in affinity.values_mut() {
426 *score /= max_score;
427 }
428 }
429 }
430
431 affinity
432 }
433
434 #[must_use]
436 pub fn find_best_worker(&self, brick: &dyn Brick) -> Option<WorkerId> {
437 let weights_key = format!("{}_weights", brick.brick_name());
439 let workers = self.get_workers_for_data(&weights_key);
440 workers.first().copied()
441 }
442
443 #[must_use]
445 pub fn find_best_worker_for_distributed<B: Brick>(
446 &self,
447 brick: &DistributedBrick<B>,
448 ) -> Option<WorkerId> {
449 if let Some(preferred) = brick.preferred_worker() {
451 return Some(preferred);
452 }
453
454 let affinity = self.calculate_affinity(brick.data_dependencies());
456 affinity
457 .into_iter()
458 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
459 .map(|(worker, _)| worker)
460 }
461
462 #[must_use]
464 pub fn total_data_size(&self) -> usize {
465 let locations = self.locations.read().expect("lock poisoned");
466 locations.values().map(|loc| loc.size_bytes).sum()
467 }
468}
469
470#[derive(Debug)]
472pub struct BackendSelector {
473 gpu_threshold: usize,
475 simd_threshold: usize,
477 cpu_max_threshold: usize,
479}
480
481impl Default for BackendSelector {
482 fn default() -> Self {
483 Self::new()
484 }
485}
486
487impl BackendSelector {
488 #[must_use]
490 pub fn new() -> Self {
491 Self {
492 gpu_threshold: 1_000_000, simd_threshold: 10_000, cpu_max_threshold: 100_000_000, }
496 }
497
498 #[must_use]
500 pub fn with_gpu_threshold(mut self, threshold: usize) -> Self {
501 self.gpu_threshold = threshold;
502 self
503 }
504
505 #[must_use]
507 pub fn with_simd_threshold(mut self, threshold: usize) -> Self {
508 self.simd_threshold = threshold;
509 self
510 }
511
512 #[must_use]
514 pub fn with_cpu_max_threshold(mut self, threshold: usize) -> Self {
515 self.cpu_max_threshold = threshold;
516 self
517 }
518
519 #[must_use]
521 pub fn select(&self, element_count: usize, gpu_available: bool) -> Backend {
522 if element_count > self.cpu_max_threshold && Backend::Remote.is_available() {
524 return Backend::Remote;
525 }
526
527 if element_count >= self.gpu_threshold && gpu_available {
529 return Backend::Gpu;
530 }
531
532 if element_count >= self.simd_threshold {
534 return Backend::Simd;
535 }
536
537 Backend::Cpu
539 }
540
541 #[must_use]
543 pub fn select_for_brick(
544 &self,
545 _brick_complexity: u32,
546 input_size: usize,
547 gpu_available: bool,
548 ) -> Backend {
549 self.select(input_size, gpu_available)
551 }
552}
553
554#[derive(Debug)]
558pub struct MultiBrickExecutor {
559 selector: BackendSelector,
560 gpu_available: bool,
561 data_tracker: Arc<BrickDataTracker>,
562}
563
564impl MultiBrickExecutor {
565 #[must_use]
567 pub fn new(data_tracker: Arc<BrickDataTracker>) -> Self {
568 Self {
569 selector: BackendSelector::new(),
570 gpu_available: cfg!(feature = "gpu"),
571 data_tracker,
572 }
573 }
574
575 #[must_use]
577 pub fn with_selector(mut self, selector: BackendSelector) -> Self {
578 self.selector = selector;
579 self
580 }
581
582 #[must_use]
584 pub fn with_gpu_available(mut self, available: bool) -> Self {
585 self.gpu_available = available;
586 self
587 }
588
589 pub fn execute(&self, brick: &dyn Brick, input: BrickInput) -> BrickResult<BrickOutput> {
591 let start = Instant::now();
592
593 let backend = self
595 .selector
596 .select(input.element_count(), self.gpu_available);
597
598 let (output_data, output_shape) = match backend {
600 Backend::Cpu => self.execute_cpu(brick, &input)?,
601 Backend::Simd => self.execute_simd(brick, &input)?,
602 Backend::Gpu => self.execute_gpu(brick, &input)?,
603 Backend::Remote => self.execute_remote(brick, &input)?,
604 };
605
606 let execution_time = start.elapsed();
607
608 let mut output = BrickOutput::new(output_data, output_shape);
610 output.metrics = ExecutionMetrics::new(execution_time, backend);
611
612 Ok(output)
613 }
614
615 pub fn execute_distributed<B: Brick>(
617 &self,
618 brick: &DistributedBrick<B>,
619 input: BrickInput,
620 ) -> BrickResult<BrickOutput> {
621 let start = Instant::now();
622
623 let backend = brick.backend();
625
626 let worker_id = self.data_tracker.find_best_worker_for_distributed(brick);
628
629 let (output_data, output_shape) = match backend {
631 Backend::Cpu => self.execute_cpu(brick.inner(), &input)?,
632 Backend::Simd => self.execute_simd(brick.inner(), &input)?,
633 Backend::Gpu => self.execute_gpu(brick.inner(), &input)?,
634 Backend::Remote => self.execute_remote(brick.inner(), &input)?,
635 };
636
637 let execution_time = start.elapsed();
638
639 let mut output = BrickOutput::new(output_data, output_shape);
641 output.metrics = ExecutionMetrics {
642 execution_time,
643 backend,
644 worker_id,
645 transfer_time: None,
646 };
647
648 Ok(output)
649 }
650
651 fn execute_cpu(
652 &self,
653 _brick: &dyn Brick,
654 input: &BrickInput,
655 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
656 Ok((input.data.clone(), input.shape.clone()))
658 }
659
660 fn execute_simd(
661 &self,
662 _brick: &dyn Brick,
663 input: &BrickInput,
664 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
665 Ok((input.data.clone(), input.shape.clone()))
667 }
668
669 fn execute_gpu(
670 &self,
671 _brick: &dyn Brick,
672 input: &BrickInput,
673 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
674 if !self.gpu_available {
676 return Err(BrickError::HtmlGenerationFailed {
677 reason: "GPU not available".into(),
678 });
679 }
680 Ok((input.data.clone(), input.shape.clone()))
681 }
682
683 fn execute_remote(
684 &self,
685 _brick: &dyn Brick,
686 input: &BrickInput,
687 ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
688 if !Backend::Remote.is_available() {
690 return Err(BrickError::HtmlGenerationFailed {
691 reason: "Distributed execution not available".into(),
692 });
693 }
694 Ok((input.data.clone(), input.shape.clone()))
695 }
696
697 #[must_use]
699 pub fn data_tracker(&self) -> &Arc<BrickDataTracker> {
700 &self.data_tracker
701 }
702}
703
704#[derive(Debug, Clone)]
706pub enum BrickMessage {
707 WeightUpdate {
709 brick_name: String,
711 weights: Vec<u8>,
713 version: u64,
715 },
716 StateChange {
718 brick_name: String,
720 event: String,
722 },
723 ExecutionRequest {
725 brick_name: String,
727 input_key: String,
729 request_id: u64,
731 },
732 ExecutionResult {
734 request_id: u64,
736 output_key: String,
738 success: bool,
740 },
741}
742
743#[derive(Debug)]
745pub struct Subscription {
746 topic: String,
747 messages: Arc<RwLock<Vec<BrickMessage>>>,
748}
749
750impl Subscription {
751 #[must_use]
753 pub fn drain(&self) -> Vec<BrickMessage> {
754 let mut messages = self.messages.write().expect("lock poisoned");
755 std::mem::take(&mut *messages)
756 }
757
758 #[must_use]
760 pub fn has_messages(&self) -> bool {
761 let messages = self.messages.read().expect("lock poisoned");
762 !messages.is_empty()
763 }
764
765 #[must_use]
767 pub fn topic(&self) -> &str {
768 &self.topic
769 }
770}
771
772#[derive(Debug, Clone)]
778pub struct WorkStealingTask {
779 pub id: u64,
781 pub spec: TaskSpec,
783 pub input_key: String,
785 pub priority: u32,
787 pub created_at: Instant,
789}
790
791impl WorkStealingTask {
792 #[must_use]
794 pub fn new(id: u64, spec: TaskSpec, input_key: String) -> Self {
795 Self {
796 id,
797 spec,
798 input_key,
799 priority: 0,
800 created_at: Instant::now(),
801 }
802 }
803
804 #[must_use]
806 pub fn with_priority(mut self, priority: u32) -> Self {
807 self.priority = priority;
808 self
809 }
810
811 #[must_use]
813 pub fn age(&self) -> Duration {
814 self.created_at.elapsed()
815 }
816}
817
818#[derive(Debug)]
820pub struct WorkerQueue {
821 worker_id: WorkerId,
823 local_queue: RwLock<Vec<WorkStealingTask>>,
825 completed_count: AtomicU64,
827 stolen_count: AtomicU64,
829}
830
831impl WorkerQueue {
832 #[must_use]
834 pub fn new(worker_id: WorkerId) -> Self {
835 Self {
836 worker_id,
837 local_queue: RwLock::new(Vec::new()),
838 completed_count: AtomicU64::new(0),
839 stolen_count: AtomicU64::new(0),
840 }
841 }
842
843 pub fn push(&self, task: WorkStealingTask) {
845 let mut queue = self.local_queue.write().expect("lock poisoned");
846 queue.push(task);
847 queue.sort_by(|a, b| b.priority.cmp(&a.priority));
849 }
850
851 pub fn pop(&self) -> Option<WorkStealingTask> {
853 let mut queue = self.local_queue.write().expect("lock poisoned");
854 if queue.is_empty() {
855 return None;
856 }
857 Some(queue.remove(0)) }
859
860 pub fn steal(&self) -> Option<WorkStealingTask> {
862 let mut queue = self.local_queue.write().expect("lock poisoned");
863 if queue.is_empty() {
864 return None;
865 }
866 self.stolen_count.fetch_add(1, Ordering::Relaxed);
867 queue.pop() }
869
870 #[must_use]
872 pub fn is_empty(&self) -> bool {
873 let queue = self.local_queue.read().expect("lock poisoned");
874 queue.is_empty()
875 }
876
877 #[must_use]
879 pub fn len(&self) -> usize {
880 let queue = self.local_queue.read().expect("lock poisoned");
881 queue.len()
882 }
883
884 pub fn mark_completed(&self) {
886 self.completed_count.fetch_add(1, Ordering::Relaxed);
887 }
888
889 #[must_use]
891 pub fn worker_id(&self) -> WorkerId {
892 self.worker_id
893 }
894
895 #[must_use]
897 pub fn completed_count(&self) -> u64 {
898 self.completed_count.load(Ordering::Relaxed)
899 }
900
901 #[must_use]
903 pub fn stolen_count(&self) -> u64 {
904 self.stolen_count.load(Ordering::Relaxed)
905 }
906}
907
908#[derive(Debug)]
920pub struct WorkStealingScheduler {
921 queues: RwLock<HashMap<WorkerId, Arc<WorkerQueue>>>,
923 data_tracker: Arc<BrickDataTracker>,
925 task_counter: AtomicU64,
927 submitted_count: AtomicU64,
929}
930
931impl WorkStealingScheduler {
932 #[must_use]
934 pub fn new(data_tracker: Arc<BrickDataTracker>) -> Self {
935 Self {
936 queues: RwLock::new(HashMap::new()),
937 data_tracker,
938 task_counter: AtomicU64::new(0),
939 submitted_count: AtomicU64::new(0),
940 }
941 }
942
943 pub fn register_worker(&self, worker_id: WorkerId) -> Arc<WorkerQueue> {
945 let queue = Arc::new(WorkerQueue::new(worker_id));
946 let mut queues = self.queues.write().expect("lock poisoned");
947 queues.insert(worker_id, Arc::clone(&queue));
948 queue
949 }
950
951 pub fn unregister_worker(&self, worker_id: WorkerId) {
953 let mut queues = self.queues.write().expect("lock poisoned");
954 queues.remove(&worker_id);
955 }
956
957 pub fn submit(&self, spec: TaskSpec, input_key: String) -> u64 {
959 let task_id = self.task_counter.fetch_add(1, Ordering::SeqCst);
960 let task = WorkStealingTask::new(task_id, spec.clone(), input_key);
961
962 let target_worker = self.find_best_worker_for_task(&spec);
964
965 let queues = self.queues.read().expect("lock poisoned");
966 if let Some(queue) = target_worker.and_then(|w| queues.get(&w)) {
967 queue.push(task);
968 } else if let Some((_, queue)) = queues.iter().next() {
969 queue.push(task);
971 }
972
973 self.submitted_count.fetch_add(1, Ordering::Relaxed);
974 task_id
975 }
976
977 pub fn submit_priority(&self, spec: TaskSpec, input_key: String, priority: u32) -> u64 {
979 let task_id = self.task_counter.fetch_add(1, Ordering::SeqCst);
980 let task = WorkStealingTask::new(task_id, spec.clone(), input_key).with_priority(priority);
981
982 let target_worker = self.find_best_worker_for_task(&spec);
983
984 let queues = self.queues.read().expect("lock poisoned");
985 if let Some(queue) = target_worker.and_then(|w| queues.get(&w)) {
986 queue.push(task);
987 } else if let Some((_, queue)) = queues.iter().next() {
988 queue.push(task);
989 }
990
991 self.submitted_count.fetch_add(1, Ordering::Relaxed);
992 task_id
993 }
994
995 pub fn get_work(&self, worker_id: WorkerId) -> Option<WorkStealingTask> {
997 let queues = self.queues.read().expect("lock poisoned");
998
999 if let Some(queue) = queues.get(&worker_id) {
1001 if let Some(task) = queue.pop() {
1002 return Some(task);
1003 }
1004 }
1005
1006 self.try_steal(worker_id, &queues)
1008 }
1009
1010 fn try_steal(
1012 &self,
1013 stealer_id: WorkerId,
1014 queues: &HashMap<WorkerId, Arc<WorkerQueue>>,
1015 ) -> Option<WorkStealingTask> {
1016 let mut candidates: Vec<_> = queues
1018 .iter()
1019 .filter(|(id, q)| **id != stealer_id && !q.is_empty())
1020 .collect();
1021
1022 if candidates.is_empty() {
1023 return None;
1024 }
1025
1026 candidates.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
1028
1029 for (_, queue) in candidates {
1031 if let Some(task) = queue.steal() {
1032 return Some(task);
1033 }
1034 }
1035
1036 None
1037 }
1038
1039 fn find_best_worker_for_task(&self, spec: &TaskSpec) -> Option<WorkerId> {
1041 if let Some(preferred) = spec.preferred_worker {
1043 return Some(preferred);
1044 }
1045
1046 let affinity = self
1048 .data_tracker
1049 .calculate_affinity(&spec.data_dependencies);
1050 affinity
1051 .into_iter()
1052 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1053 .map(|(worker, _)| worker)
1054 }
1055
1056 #[must_use]
1058 pub fn stats(&self) -> SchedulerStats {
1059 let queues = self.queues.read().expect("lock poisoned");
1060
1061 let worker_stats: Vec<_> = queues
1062 .values()
1063 .map(|q| WorkerStats {
1064 worker_id: q.worker_id(),
1065 queue_length: q.len(),
1066 completed: q.completed_count(),
1067 stolen_from: q.stolen_count(),
1068 })
1069 .collect();
1070
1071 let total_pending: usize = worker_stats.iter().map(|s| s.queue_length).sum();
1072 let total_completed: u64 = worker_stats.iter().map(|s| s.completed).sum();
1073 let total_stolen: u64 = worker_stats.iter().map(|s| s.stolen_from).sum();
1074
1075 SchedulerStats {
1076 worker_count: queues.len(),
1077 total_submitted: self.submitted_count.load(Ordering::Relaxed),
1078 total_pending,
1079 total_completed,
1080 total_stolen,
1081 workers: worker_stats,
1082 }
1083 }
1084
1085 #[must_use]
1087 pub fn data_tracker(&self) -> &Arc<BrickDataTracker> {
1088 &self.data_tracker
1089 }
1090}
1091
1092#[derive(Debug, Clone)]
1094pub struct WorkerStats {
1095 pub worker_id: WorkerId,
1097 pub queue_length: usize,
1099 pub completed: u64,
1101 pub stolen_from: u64,
1103}
1104
1105#[derive(Debug, Clone)]
1107pub struct SchedulerStats {
1108 pub worker_count: usize,
1110 pub total_submitted: u64,
1112 pub total_pending: usize,
1114 pub total_completed: u64,
1116 pub total_stolen: u64,
1118 pub workers: Vec<WorkerStats>,
1120}
1121
1122#[derive(Debug)]
1130pub struct BrickCoordinator {
1131 subscriptions: RwLock<HashMap<String, Vec<Arc<RwLock<Vec<BrickMessage>>>>>>,
1133 message_counter: AtomicU64,
1135}
1136
1137impl Default for BrickCoordinator {
1138 fn default() -> Self {
1139 Self::new()
1140 }
1141}
1142
1143impl BrickCoordinator {
1144 #[must_use]
1146 pub fn new() -> Self {
1147 Self {
1148 subscriptions: RwLock::new(HashMap::new()),
1149 message_counter: AtomicU64::new(0),
1150 }
1151 }
1152
1153 #[must_use]
1155 pub fn subscribe(&self, topic: &str) -> Subscription {
1156 let messages = Arc::new(RwLock::new(Vec::new()));
1157 {
1158 let mut subs = self.subscriptions.write().expect("lock poisoned");
1159 subs.entry(topic.to_string())
1160 .or_default()
1161 .push(Arc::clone(&messages));
1162 }
1163 Subscription {
1164 topic: topic.to_string(),
1165 messages,
1166 }
1167 }
1168
1169 #[must_use]
1171 pub fn subscribe_brick(&self, brick_name: &str) -> Subscription {
1172 let topic = format!("brick/{}/events", brick_name);
1173 self.subscribe(&topic)
1174 }
1175
1176 pub fn publish(&self, topic: &str, message: BrickMessage) {
1178 let subs = self.subscriptions.read().expect("lock poisoned");
1179 if let Some(subscribers) = subs.get(topic) {
1180 for sub in subscribers {
1181 let mut messages = sub.write().expect("lock poisoned");
1182 messages.push(message.clone());
1183 }
1184 }
1185 }
1186
1187 pub fn broadcast_weights(&self, brick_name: &str, weights: Vec<u8>) {
1189 let topic = format!("brick/{}/weights", brick_name);
1190 let version = self.message_counter.fetch_add(1, Ordering::SeqCst);
1191 self.publish(
1192 &topic,
1193 BrickMessage::WeightUpdate {
1194 brick_name: brick_name.to_string(),
1195 weights,
1196 version,
1197 },
1198 );
1199 }
1200
1201 pub fn broadcast_state_change(&self, brick_name: &str, event: &str) {
1203 let topic = format!("brick/{}/events", brick_name);
1204 self.publish(
1205 &topic,
1206 BrickMessage::StateChange {
1207 brick_name: brick_name.to_string(),
1208 event: event.to_string(),
1209 },
1210 );
1211 }
1212
1213 #[must_use]
1215 pub fn next_request_id(&self) -> u64 {
1216 self.message_counter.fetch_add(1, Ordering::SeqCst)
1217 }
1218}
1219
1220#[cfg(test)]
1221#[allow(clippy::unwrap_used, clippy::expect_used)]
1222mod tests {
1223 use super::*;
1224
1225 struct TestBrick {
1226 name: &'static str,
1227 }
1228
1229 impl Brick for TestBrick {
1230 fn brick_name(&self) -> &'static str {
1231 self.name
1232 }
1233
1234 fn assertions(&self) -> &[BrickAssertion] {
1235 &[BrickAssertion::TextVisible]
1236 }
1237
1238 fn budget(&self) -> BrickBudget {
1239 BrickBudget::uniform(16)
1240 }
1241
1242 fn verify(&self) -> BrickVerification {
1243 BrickVerification {
1244 passed: vec![BrickAssertion::TextVisible],
1245 failed: vec![],
1246 verification_time: Duration::from_micros(100),
1247 }
1248 }
1249
1250 fn to_html(&self) -> String {
1251 format!("<div>{}</div>", self.name)
1252 }
1253
1254 fn to_css(&self) -> String {
1255 ".test { }".into()
1256 }
1257 }
1258
1259 #[test]
1260 fn test_worker_id() {
1261 let id = WorkerId::new(42);
1262 assert_eq!(id.value(), 42);
1263 assert_eq!(format!("{id}"), "worker-42");
1264 }
1265
1266 #[test]
1267 fn test_backend_availability() {
1268 assert!(Backend::Cpu.is_available());
1269 assert!(Backend::Simd.is_available());
1270 }
1272
1273 #[test]
1274 fn test_backend_performance() {
1275 assert!(Backend::Gpu.performance_estimate() > Backend::Simd.performance_estimate());
1276 assert!(Backend::Simd.performance_estimate() > Backend::Cpu.performance_estimate());
1277 }
1278
1279 #[test]
1280 fn test_distributed_brick_creation() {
1281 let inner = TestBrick { name: "Test" };
1282 let distributed = DistributedBrick::new(inner)
1283 .with_backend(Backend::Gpu)
1284 .with_data_dependencies(vec!["weights".into(), "biases".into()])
1285 .with_preferred_worker(WorkerId::new(1));
1286
1287 assert_eq!(distributed.backend(), Backend::Gpu);
1288 assert_eq!(distributed.data_dependencies().len(), 2);
1289 assert_eq!(distributed.preferred_worker(), Some(WorkerId::new(1)));
1290 assert_eq!(distributed.brick_name(), "Test");
1291 }
1292
1293 #[test]
1294 fn test_distributed_brick_implements_brick() {
1295 let inner = TestBrick { name: "Test" };
1296 let distributed = DistributedBrick::new(inner);
1297
1298 assert!(distributed.verify().is_valid());
1300 assert_eq!(distributed.budget().total_ms, 16);
1301 }
1302
1303 #[test]
1304 fn test_task_spec() {
1305 let inner = TestBrick { name: "TestTask" };
1306 let distributed = DistributedBrick::new(inner)
1307 .with_backend(Backend::Simd)
1308 .with_data_dependencies(vec!["model".into()]);
1309
1310 let spec = distributed.to_task_spec();
1311 assert_eq!(spec.brick_name, "TestTask");
1312 assert_eq!(spec.backend, Backend::Simd);
1313 assert_eq!(spec.data_dependencies, vec!["model"]);
1314 }
1315
1316 #[test]
1317 fn test_brick_input_output() {
1318 let input = BrickInput::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1319 assert_eq!(input.element_count(), 4);
1320 assert_eq!(input.size_bytes(), 16);
1321
1322 let output = BrickOutput::new(vec![5.0, 6.0], vec![2]);
1323 assert_eq!(output.size_bytes(), 8);
1324 }
1325
1326 #[test]
1327 fn test_data_tracker() {
1328 let tracker = BrickDataTracker::new();
1329
1330 tracker.track_data("model_weights", WorkerId::new(1), 1024);
1332 tracker.track_data("model_weights", WorkerId::new(2), 1024);
1333 tracker.track_data("biases", WorkerId::new(1), 256);
1334
1335 let workers = tracker.get_workers_for_data("model_weights");
1337 assert_eq!(workers.len(), 2);
1338
1339 let affinity = tracker.calculate_affinity(&["model_weights".into(), "biases".into()]);
1341 assert!(affinity.get(&WorkerId::new(1)).unwrap_or(&0.0) > &0.0);
1342 }
1343
1344 #[test]
1345 fn test_data_tracker_find_best_worker() {
1346 let tracker = BrickDataTracker::new();
1347
1348 let brick = TestBrick { name: "MelBrick" };
1349 tracker.track_weights("MelBrick", WorkerId::new(5));
1350
1351 let best = tracker.find_best_worker(&brick);
1352 assert_eq!(best, Some(WorkerId::new(5)));
1353 }
1354
1355 #[test]
1356 fn test_backend_selector() {
1357 let selector = BackendSelector::new()
1358 .with_gpu_threshold(1000)
1359 .with_simd_threshold(100);
1360
1361 assert_eq!(selector.select(50, true), Backend::Cpu);
1363
1364 assert_eq!(selector.select(500, true), Backend::Simd);
1366
1367 assert_eq!(selector.select(5000, true), Backend::Gpu);
1369
1370 assert_eq!(selector.select(5000, false), Backend::Simd);
1372 }
1373
1374 #[test]
1375 fn test_multi_executor() {
1376 let tracker = Arc::new(BrickDataTracker::new());
1377 let executor = MultiBrickExecutor::new(tracker);
1378
1379 let brick = TestBrick { name: "Test" };
1380 let input = BrickInput::new(vec![1.0, 2.0, 3.0], vec![3]);
1381
1382 let result = executor.execute(&brick, input);
1383 assert!(result.is_ok());
1384
1385 let output = result.expect("execution should succeed");
1386 assert_eq!(output.data.len(), 3);
1387 assert!(output.metrics.execution_time >= Duration::ZERO);
1388 }
1389
1390 #[test]
1391 fn test_brick_coordinator() {
1392 let coordinator = BrickCoordinator::new();
1393
1394 let sub = coordinator.subscribe_brick("MyBrick");
1396
1397 coordinator.broadcast_state_change("MyBrick", "loaded");
1399
1400 assert!(sub.has_messages());
1402 let messages = sub.drain();
1403 assert_eq!(messages.len(), 1);
1404 matches!(&messages[0], BrickMessage::StateChange { brick_name, .. } if brick_name == "MyBrick");
1405 }
1406
1407 #[test]
1408 fn test_coordinator_weight_broadcast() {
1409 let coordinator = BrickCoordinator::new();
1410
1411 let sub = coordinator.subscribe("brick/Encoder/weights");
1412 coordinator.broadcast_weights("Encoder", vec![1, 2, 3, 4]);
1413
1414 let messages = sub.drain();
1415 assert_eq!(messages.len(), 1);
1416 match &messages[0] {
1417 BrickMessage::WeightUpdate {
1418 brick_name,
1419 weights,
1420 version,
1421 } => {
1422 assert_eq!(brick_name, "Encoder");
1423 assert_eq!(weights, &vec![1, 2, 3, 4]);
1424 assert_eq!(*version, 0);
1425 }
1426 _ => panic!("Expected WeightUpdate message"),
1427 }
1428 }
1429
1430 #[test]
1431 fn test_subscription_topic() {
1432 let coordinator = BrickCoordinator::new();
1433 let sub = coordinator.subscribe("my/topic");
1434 assert_eq!(sub.topic(), "my/topic");
1435 }
1436
1437 #[test]
1438 fn test_execution_metrics() {
1439 let metrics = ExecutionMetrics::new(Duration::from_millis(50), Backend::Gpu);
1440 assert_eq!(metrics.execution_time, Duration::from_millis(50));
1441 assert_eq!(metrics.backend, Backend::Gpu);
1442 assert!(metrics.worker_id.is_none());
1443 }
1444
1445 #[test]
1450 fn test_work_stealing_task() {
1451 let spec = TaskSpec {
1452 brick_name: "TestBrick".into(),
1453 backend: Backend::Cpu,
1454 data_dependencies: vec![],
1455 preferred_worker: None,
1456 };
1457 let task = WorkStealingTask::new(1, spec, "input_key".into()).with_priority(10);
1458
1459 assert_eq!(task.id, 1);
1460 assert_eq!(task.priority, 10);
1461 assert_eq!(task.input_key, "input_key");
1462 assert!(task.age() >= Duration::ZERO);
1463 }
1464
1465 #[test]
1466 fn test_worker_queue_basic() {
1467 let queue = WorkerQueue::new(WorkerId::new(1));
1468
1469 assert!(queue.is_empty());
1470 assert_eq!(queue.len(), 0);
1471
1472 let spec = TaskSpec {
1473 brick_name: "Test".into(),
1474 backend: Backend::Cpu,
1475 data_dependencies: vec![],
1476 preferred_worker: None,
1477 };
1478 let task = WorkStealingTask::new(1, spec, "key".into());
1479 queue.push(task);
1480
1481 assert!(!queue.is_empty());
1482 assert_eq!(queue.len(), 1);
1483
1484 let popped = queue.pop();
1485 assert!(popped.is_some());
1486 assert!(queue.is_empty());
1487 }
1488
1489 #[test]
1490 fn test_worker_queue_priority_ordering() {
1491 let queue = WorkerQueue::new(WorkerId::new(1));
1492
1493 for i in 0..5 {
1495 let spec = TaskSpec {
1496 brick_name: format!("Task{}", i),
1497 backend: Backend::Cpu,
1498 data_dependencies: vec![],
1499 preferred_worker: None,
1500 };
1501 let task = WorkStealingTask::new(i as u64, spec, "key".into()).with_priority(i);
1502 queue.push(task);
1503 }
1504
1505 let task = queue.pop().unwrap();
1507 assert_eq!(task.priority, 4);
1508
1509 let task = queue.pop().unwrap();
1510 assert_eq!(task.priority, 3);
1511 }
1512
1513 #[test]
1514 fn test_worker_queue_steal() {
1515 let queue = WorkerQueue::new(WorkerId::new(1));
1516
1517 for i in 0..3 {
1519 let spec = TaskSpec {
1520 brick_name: format!("Task{}", i),
1521 backend: Backend::Cpu,
1522 data_dependencies: vec![],
1523 preferred_worker: None,
1524 };
1525 let task = WorkStealingTask::new(i as u64, spec, "key".into()).with_priority(i);
1526 queue.push(task);
1527 }
1528
1529 let stolen = queue.steal().unwrap();
1531 assert_eq!(stolen.priority, 0);
1532 assert_eq!(queue.stolen_count(), 1);
1533
1534 assert_eq!(queue.len(), 2);
1536 }
1537
1538 #[test]
1539 fn test_work_stealing_scheduler_basic() {
1540 let tracker = Arc::new(BrickDataTracker::new());
1541 let scheduler = WorkStealingScheduler::new(tracker);
1542
1543 let _q1 = scheduler.register_worker(WorkerId::new(1));
1545 let _q2 = scheduler.register_worker(WorkerId::new(2));
1546
1547 let stats = scheduler.stats();
1548 assert_eq!(stats.worker_count, 2);
1549 assert_eq!(stats.total_submitted, 0);
1550 }
1551
1552 #[test]
1553 fn test_work_stealing_scheduler_submit() {
1554 let tracker = Arc::new(BrickDataTracker::new());
1555 let scheduler = WorkStealingScheduler::new(tracker);
1556
1557 scheduler.register_worker(WorkerId::new(1));
1558
1559 let spec = TaskSpec {
1560 brick_name: "Test".into(),
1561 backend: Backend::Cpu,
1562 data_dependencies: vec![],
1563 preferred_worker: None,
1564 };
1565
1566 let task_id = scheduler.submit(spec, "input".into());
1567 assert_eq!(task_id, 0);
1568
1569 let stats = scheduler.stats();
1570 assert_eq!(stats.total_submitted, 1);
1571 assert_eq!(stats.total_pending, 1);
1572 }
1573
1574 #[test]
1575 fn test_work_stealing_scheduler_get_work() {
1576 let tracker = Arc::new(BrickDataTracker::new());
1577 let scheduler = WorkStealingScheduler::new(tracker);
1578
1579 scheduler.register_worker(WorkerId::new(1));
1580 scheduler.register_worker(WorkerId::new(2));
1581
1582 let spec = TaskSpec {
1584 brick_name: "Test".into(),
1585 backend: Backend::Cpu,
1586 data_dependencies: vec![],
1587 preferred_worker: Some(WorkerId::new(1)),
1588 };
1589 scheduler.submit(spec, "input".into());
1590
1591 let task = scheduler.get_work(WorkerId::new(1));
1593 assert!(task.is_some());
1594
1595 let task = scheduler.get_work(WorkerId::new(2));
1597 assert!(task.is_none());
1598 }
1599
1600 #[test]
1601 fn test_work_stealing_scheduler_steal() {
1602 let tracker = Arc::new(BrickDataTracker::new());
1603 let scheduler = WorkStealingScheduler::new(tracker);
1604
1605 scheduler.register_worker(WorkerId::new(1));
1606 scheduler.register_worker(WorkerId::new(2));
1607
1608 for i in 0..3 {
1610 let spec = TaskSpec {
1611 brick_name: format!("Task{}", i),
1612 backend: Backend::Cpu,
1613 data_dependencies: vec![],
1614 preferred_worker: Some(WorkerId::new(1)),
1615 };
1616 scheduler.submit(spec, format!("input{}", i));
1617 }
1618
1619 let stolen = scheduler.get_work(WorkerId::new(2));
1621 assert!(stolen.is_some());
1622
1623 let stats = scheduler.stats();
1624 assert_eq!(stats.total_stolen, 1);
1625 assert_eq!(stats.total_pending, 2); }
1627
1628 #[test]
1629 fn test_work_stealing_scheduler_locality() {
1630 let tracker = Arc::new(BrickDataTracker::new());
1631
1632 tracker.track_data("model_weights", WorkerId::new(1), 1024);
1634
1635 let scheduler = WorkStealingScheduler::new(Arc::clone(&tracker));
1636 scheduler.register_worker(WorkerId::new(1));
1637 scheduler.register_worker(WorkerId::new(2));
1638
1639 let spec = TaskSpec {
1641 brick_name: "MelBrick".into(),
1642 backend: Backend::Cpu,
1643 data_dependencies: vec!["model_weights".into()],
1644 preferred_worker: None,
1645 };
1646 scheduler.submit(spec, "audio_input".into());
1647
1648 let task = scheduler.get_work(WorkerId::new(1));
1650 assert!(task.is_some());
1651 assert_eq!(task.unwrap().spec.brick_name, "MelBrick");
1652 }
1653
1654 #[test]
1655 fn test_scheduler_stats() {
1656 let tracker = Arc::new(BrickDataTracker::new());
1657 let scheduler = WorkStealingScheduler::new(tracker);
1658
1659 scheduler.register_worker(WorkerId::new(1));
1660 scheduler.register_worker(WorkerId::new(2));
1661
1662 for i in 0..5 {
1664 let spec = TaskSpec {
1665 brick_name: format!("Task{}", i),
1666 backend: Backend::Cpu,
1667 data_dependencies: vec![],
1668 preferred_worker: if i % 2 == 0 {
1669 Some(WorkerId::new(1))
1670 } else {
1671 Some(WorkerId::new(2))
1672 },
1673 };
1674 scheduler.submit(spec, format!("input{}", i));
1675 }
1676
1677 let stats = scheduler.stats();
1678 assert_eq!(stats.worker_count, 2);
1679 assert_eq!(stats.total_submitted, 5);
1680 assert_eq!(stats.total_pending, 5);
1681 assert_eq!(stats.workers.len(), 2);
1682 }
1683
1684 #[test]
1689 fn test_worker_id_copy_clone() {
1690 let id = WorkerId::new(123);
1691 let cloned = id;
1692 assert_eq!(id, cloned);
1693 assert_eq!(id.0, 123);
1694 }
1695
1696 #[test]
1697 fn test_worker_id_hash() {
1698 use std::collections::HashSet;
1699 let mut set = HashSet::new();
1700 set.insert(WorkerId::new(1));
1701 set.insert(WorkerId::new(2));
1702 set.insert(WorkerId::new(1)); assert_eq!(set.len(), 2);
1704 }
1705
1706 #[test]
1707 fn test_backend_default() {
1708 let backend = Backend::default();
1709 assert_eq!(backend, Backend::Cpu);
1710 }
1711
1712 #[test]
1713 fn test_backend_remote_not_available() {
1714 assert!(!Backend::Remote.is_available());
1715 }
1716
1717 #[test]
1718 fn test_backend_performance_remote() {
1719 assert_eq!(Backend::Remote.performance_estimate(), 5);
1720 assert_eq!(Backend::Cpu.performance_estimate(), 10);
1721 }
1722
1723 #[test]
1724 fn test_brick_input_default() {
1725 let input = BrickInput::default();
1726 assert!(input.data.is_empty());
1727 assert!(input.shape.is_empty());
1728 assert!(input.metadata.is_empty());
1729 }
1730
1731 #[test]
1732 fn test_brick_input_with_metadata() {
1733 let input = BrickInput::new(vec![1.0], vec![1])
1734 .with_metadata("key1", "value1")
1735 .with_metadata("key2", "value2");
1736 assert_eq!(input.metadata.get("key1"), Some(&"value1".to_string()));
1737 assert_eq!(input.metadata.get("key2"), Some(&"value2".to_string()));
1738 }
1739
1740 #[test]
1741 fn test_brick_output_default() {
1742 let output = BrickOutput::default();
1743 assert!(output.data.is_empty());
1744 assert!(output.shape.is_empty());
1745 }
1746
1747 #[test]
1748 fn test_execution_metrics_default() {
1749 let metrics = ExecutionMetrics::default();
1750 assert_eq!(metrics.execution_time, Duration::ZERO);
1751 assert_eq!(metrics.backend, Backend::Cpu);
1752 assert!(metrics.worker_id.is_none());
1753 assert!(metrics.transfer_time.is_none());
1754 }
1755
1756 #[test]
1757 fn test_distributed_brick_inner() {
1758 let inner = TestBrick { name: "Inner" };
1759 let distributed = DistributedBrick::new(inner);
1760 assert_eq!(distributed.inner().brick_name(), "Inner");
1761 }
1762
1763 #[test]
1764 fn test_distributed_brick_inner_mut() {
1765 let inner = TestBrick { name: "Inner" };
1766 let mut distributed = DistributedBrick::new(inner);
1767 let _ = distributed.inner_mut();
1768 }
1770
1771 #[test]
1772 fn test_distributed_brick_to_html() {
1773 let inner = TestBrick { name: "Test" };
1774 let distributed = DistributedBrick::new(inner);
1775 assert_eq!(distributed.to_html(), "<div>Test</div>");
1776 }
1777
1778 #[test]
1779 fn test_distributed_brick_to_css() {
1780 let inner = TestBrick { name: "Test" };
1781 let distributed = DistributedBrick::new(inner);
1782 assert_eq!(distributed.to_css(), ".test { }");
1783 }
1784
1785 #[test]
1786 fn test_distributed_brick_assertions() {
1787 let inner = TestBrick { name: "Test" };
1788 let distributed = DistributedBrick::new(inner);
1789 assert_eq!(distributed.assertions().len(), 1);
1790 }
1791
1792 #[test]
1793 fn test_task_spec_clone() {
1794 let spec = TaskSpec {
1795 brick_name: "Test".into(),
1796 backend: Backend::Gpu,
1797 data_dependencies: vec!["dep1".into()],
1798 preferred_worker: Some(WorkerId::new(5)),
1799 };
1800 let cloned = spec.clone();
1801 assert_eq!(spec.brick_name, cloned.brick_name);
1802 assert_eq!(spec.backend, cloned.backend);
1803 }
1804
1805 #[test]
1806 fn test_brick_data_tracker_default() {
1807 let tracker = BrickDataTracker::default();
1808 assert_eq!(tracker.total_data_size(), 0);
1809 }
1810
1811 #[test]
1812 fn test_brick_data_tracker_remove_data() {
1813 let tracker = BrickDataTracker::new();
1814 tracker.track_data("data1", WorkerId::new(1), 100);
1815 tracker.track_data("data1", WorkerId::new(2), 100);
1816
1817 let workers = tracker.get_workers_for_data("data1");
1818 assert_eq!(workers.len(), 2);
1819
1820 tracker.remove_data("data1", WorkerId::new(1));
1821 let workers = tracker.get_workers_for_data("data1");
1822 assert_eq!(workers.len(), 1);
1823 assert_eq!(workers[0], WorkerId::new(2));
1824 }
1825
1826 #[test]
1827 fn test_brick_data_tracker_total_size() {
1828 let tracker = BrickDataTracker::new();
1829 tracker.track_data("data1", WorkerId::new(1), 100);
1830 tracker.track_data("data2", WorkerId::new(1), 200);
1831 assert_eq!(tracker.total_data_size(), 300);
1832 }
1833
1834 #[test]
1835 fn test_brick_data_tracker_get_nonexistent() {
1836 let tracker = BrickDataTracker::new();
1837 let workers = tracker.get_workers_for_data("nonexistent");
1838 assert!(workers.is_empty());
1839 }
1840
1841 #[test]
1842 fn test_brick_data_tracker_calculate_affinity_empty() {
1843 let tracker = BrickDataTracker::new();
1844 let affinity = tracker.calculate_affinity(&["nonexistent".into()]);
1845 assert!(affinity.is_empty());
1846 }
1847
1848 #[test]
1849 fn test_brick_data_tracker_find_best_worker_no_weights() {
1850 let tracker = BrickDataTracker::new();
1851 let brick = TestBrick { name: "NoBrick" };
1852 let best = tracker.find_best_worker(&brick);
1853 assert!(best.is_none());
1854 }
1855
1856 #[test]
1857 fn test_brick_data_tracker_find_best_worker_distributed_preferred() {
1858 let tracker = BrickDataTracker::new();
1859 let inner = TestBrick { name: "Test" };
1860 let distributed = DistributedBrick::new(inner).with_preferred_worker(WorkerId::new(42));
1861
1862 let best = tracker.find_best_worker_for_distributed(&distributed);
1863 assert_eq!(best, Some(WorkerId::new(42)));
1864 }
1865
1866 #[test]
1867 fn test_brick_data_tracker_find_best_worker_distributed_affinity() {
1868 let tracker = BrickDataTracker::new();
1869 tracker.track_data("dep1", WorkerId::new(5), 100);
1870
1871 let inner = TestBrick { name: "Test" };
1872 let distributed = DistributedBrick::new(inner).with_data_dependencies(vec!["dep1".into()]);
1873
1874 let best = tracker.find_best_worker_for_distributed(&distributed);
1875 assert_eq!(best, Some(WorkerId::new(5)));
1876 }
1877
1878 #[test]
1879 fn test_backend_selector_default() {
1880 let selector = BackendSelector::default();
1881 assert_eq!(selector.select(50, true), Backend::Cpu);
1883 }
1884
1885 #[test]
1886 fn test_backend_selector_cpu_max_threshold() {
1887 let selector = BackendSelector::new()
1888 .with_cpu_max_threshold(100)
1889 .with_simd_threshold(50);
1890 let backend = selector.select(200, false);
1893 assert_eq!(backend, Backend::Simd);
1894
1895 let backend = selector.select(10, false);
1897 assert_eq!(backend, Backend::Cpu);
1898 }
1899
1900 #[test]
1901 fn test_backend_selector_select_for_brick() {
1902 let selector = BackendSelector::new();
1903 let backend = selector.select_for_brick(50, 100, true);
1904 assert_eq!(backend, Backend::Cpu);
1905 }
1906
1907 #[test]
1908 fn test_multi_executor_with_selector() {
1909 let tracker = Arc::new(BrickDataTracker::new());
1910 let selector = BackendSelector::new().with_simd_threshold(1);
1911 let executor = MultiBrickExecutor::new(tracker).with_selector(selector);
1912
1913 let brick = TestBrick { name: "Test" };
1914 let input = BrickInput::new(vec![1.0, 2.0], vec![2]);
1915 let result = executor.execute(&brick, input);
1916 assert!(result.is_ok());
1917 assert_eq!(result.unwrap().metrics.backend, Backend::Simd);
1919 }
1920
1921 #[test]
1922 fn test_multi_executor_with_gpu_available() {
1923 let tracker = Arc::new(BrickDataTracker::new());
1924 let executor = MultiBrickExecutor::new(tracker).with_gpu_available(true);
1925 let _ = executor.data_tracker();
1926 }
1927
1928 #[test]
1929 fn test_multi_executor_execute_distributed() {
1930 let tracker = Arc::new(BrickDataTracker::new());
1931 let executor = MultiBrickExecutor::new(tracker);
1932
1933 let inner = TestBrick { name: "Test" };
1934 let distributed = DistributedBrick::new(inner).with_backend(Backend::Cpu);
1935 let input = BrickInput::new(vec![1.0], vec![1]);
1936
1937 let result = executor.execute_distributed(&distributed, input);
1938 assert!(result.is_ok());
1939 }
1940
1941 #[test]
1942 fn test_multi_executor_execute_simd() {
1943 let tracker = Arc::new(BrickDataTracker::new());
1944 let selector = BackendSelector::new().with_simd_threshold(1);
1945 let executor = MultiBrickExecutor::new(tracker).with_selector(selector);
1946
1947 let brick = TestBrick { name: "Test" };
1948 let input = BrickInput::new(vec![1.0, 2.0], vec![2]);
1949
1950 let result = executor.execute(&brick, input);
1951 assert!(result.is_ok());
1952 assert_eq!(result.unwrap().metrics.backend, Backend::Simd);
1953 }
1954
1955 #[test]
1956 fn test_multi_executor_execute_gpu_unavailable() {
1957 let tracker = Arc::new(BrickDataTracker::new());
1958 let inner = TestBrick { name: "Test" };
1959 let distributed = DistributedBrick::new(inner).with_backend(Backend::Gpu);
1960 let executor = MultiBrickExecutor::new(tracker).with_gpu_available(false);
1961 let input = BrickInput::new(vec![1.0], vec![1]);
1962
1963 let result = executor.execute_distributed(&distributed, input);
1964 assert!(result.is_err());
1965 }
1966
1967 #[test]
1968 fn test_multi_executor_execute_remote_unavailable() {
1969 let tracker = Arc::new(BrickDataTracker::new());
1970 let inner = TestBrick { name: "Test" };
1971 let distributed = DistributedBrick::new(inner).with_backend(Backend::Remote);
1972 let executor = MultiBrickExecutor::new(tracker);
1973 let input = BrickInput::new(vec![1.0], vec![1]);
1974
1975 let result = executor.execute_distributed(&distributed, input);
1976 assert!(result.is_err());
1977 }
1978
1979 #[test]
1980 fn test_subscription_drain_empty() {
1981 let coordinator = BrickCoordinator::new();
1982 let sub = coordinator.subscribe("test/topic");
1983 let messages = sub.drain();
1984 assert!(messages.is_empty());
1985 }
1986
1987 #[test]
1988 fn test_subscription_has_messages_false() {
1989 let coordinator = BrickCoordinator::new();
1990 let sub = coordinator.subscribe("test/topic");
1991 assert!(!sub.has_messages());
1992 }
1993
1994 #[test]
1995 fn test_brick_coordinator_default() {
1996 let coordinator = BrickCoordinator::default();
1997 let id = coordinator.next_request_id();
1998 assert_eq!(id, 0);
1999 }
2000
2001 #[test]
2002 fn test_brick_coordinator_next_request_id() {
2003 let coordinator = BrickCoordinator::new();
2004 assert_eq!(coordinator.next_request_id(), 0);
2005 assert_eq!(coordinator.next_request_id(), 1);
2006 assert_eq!(coordinator.next_request_id(), 2);
2007 }
2008
2009 #[test]
2010 fn test_brick_coordinator_publish_no_subscribers() {
2011 let coordinator = BrickCoordinator::new();
2012 coordinator.publish(
2014 "nonexistent/topic",
2015 BrickMessage::StateChange {
2016 brick_name: "Test".into(),
2017 event: "test".into(),
2018 },
2019 );
2020 }
2021
2022 #[test]
2023 fn test_brick_message_execution_request() {
2024 let msg = BrickMessage::ExecutionRequest {
2025 brick_name: "Test".into(),
2026 input_key: "key".into(),
2027 request_id: 42,
2028 };
2029 match msg {
2030 BrickMessage::ExecutionRequest {
2031 brick_name,
2032 input_key,
2033 request_id,
2034 } => {
2035 assert_eq!(brick_name, "Test");
2036 assert_eq!(input_key, "key");
2037 assert_eq!(request_id, 42);
2038 }
2039 _ => panic!("Wrong message type"),
2040 }
2041 }
2042
2043 #[test]
2044 fn test_brick_message_execution_result() {
2045 let msg = BrickMessage::ExecutionResult {
2046 request_id: 42,
2047 output_key: "out".into(),
2048 success: true,
2049 };
2050 match msg {
2051 BrickMessage::ExecutionResult {
2052 request_id,
2053 output_key,
2054 success,
2055 } => {
2056 assert_eq!(request_id, 42);
2057 assert_eq!(output_key, "out");
2058 assert!(success);
2059 }
2060 _ => panic!("Wrong message type"),
2061 }
2062 }
2063
2064 #[test]
2065 fn test_work_stealing_task_clone() {
2066 let spec = TaskSpec {
2067 brick_name: "Test".into(),
2068 backend: Backend::Cpu,
2069 data_dependencies: vec![],
2070 preferred_worker: None,
2071 };
2072 let task = WorkStealingTask::new(1, spec, "key".into());
2073 let cloned = task.clone();
2074 assert_eq!(task.id, cloned.id);
2075 }
2076
2077 #[test]
2078 fn test_worker_queue_worker_id() {
2079 let queue = WorkerQueue::new(WorkerId::new(42));
2080 assert_eq!(queue.worker_id(), WorkerId::new(42));
2081 }
2082
2083 #[test]
2084 fn test_worker_queue_completed_count() {
2085 let queue = WorkerQueue::new(WorkerId::new(1));
2086 assert_eq!(queue.completed_count(), 0);
2087 queue.mark_completed();
2088 assert_eq!(queue.completed_count(), 1);
2089 queue.mark_completed();
2090 assert_eq!(queue.completed_count(), 2);
2091 }
2092
2093 #[test]
2094 fn test_worker_queue_pop_empty() {
2095 let queue = WorkerQueue::new(WorkerId::new(1));
2096 assert!(queue.pop().is_none());
2097 }
2098
2099 #[test]
2100 fn test_worker_queue_steal_empty() {
2101 let queue = WorkerQueue::new(WorkerId::new(1));
2102 assert!(queue.steal().is_none());
2103 }
2104
2105 #[test]
2106 fn test_scheduler_unregister_worker() {
2107 let tracker = Arc::new(BrickDataTracker::new());
2108 let scheduler = WorkStealingScheduler::new(tracker);
2109
2110 scheduler.register_worker(WorkerId::new(1));
2111 assert_eq!(scheduler.stats().worker_count, 1);
2112
2113 scheduler.unregister_worker(WorkerId::new(1));
2114 assert_eq!(scheduler.stats().worker_count, 0);
2115 }
2116
2117 #[test]
2118 fn test_scheduler_submit_no_workers() {
2119 let tracker = Arc::new(BrickDataTracker::new());
2120 let scheduler = WorkStealingScheduler::new(tracker);
2121
2122 let spec = TaskSpec {
2123 brick_name: "Test".into(),
2124 backend: Backend::Cpu,
2125 data_dependencies: vec![],
2126 preferred_worker: None,
2127 };
2128
2129 let task_id = scheduler.submit(spec, "input".into());
2130 assert_eq!(task_id, 0);
2131 assert_eq!(scheduler.stats().total_submitted, 1);
2133 }
2134
2135 #[test]
2136 fn test_scheduler_submit_priority() {
2137 let tracker = Arc::new(BrickDataTracker::new());
2138 let scheduler = WorkStealingScheduler::new(tracker);
2139
2140 scheduler.register_worker(WorkerId::new(1));
2141
2142 let spec = TaskSpec {
2143 brick_name: "Test".into(),
2144 backend: Backend::Cpu,
2145 data_dependencies: vec![],
2146 preferred_worker: None,
2147 };
2148
2149 let task_id = scheduler.submit_priority(spec, "input".into(), 100);
2150 assert_eq!(task_id, 0);
2151
2152 let task = scheduler.get_work(WorkerId::new(1));
2153 assert!(task.is_some());
2154 assert_eq!(task.unwrap().priority, 100);
2155 }
2156
2157 #[test]
2158 fn test_scheduler_get_work_unregistered_worker() {
2159 let tracker = Arc::new(BrickDataTracker::new());
2160 let scheduler = WorkStealingScheduler::new(tracker);
2161
2162 let task = scheduler.get_work(WorkerId::new(999));
2164 assert!(task.is_none());
2165 }
2166
2167 #[test]
2168 fn test_scheduler_data_tracker_accessor() {
2169 let tracker = Arc::new(BrickDataTracker::new());
2170 let scheduler = WorkStealingScheduler::new(Arc::clone(&tracker));
2171
2172 let _ = scheduler.data_tracker();
2173 }
2174
2175 #[test]
2176 fn test_worker_stats_fields() {
2177 let stats = WorkerStats {
2178 worker_id: WorkerId::new(1),
2179 queue_length: 5,
2180 completed: 10,
2181 stolen_from: 2,
2182 };
2183 assert_eq!(stats.worker_id, WorkerId::new(1));
2184 assert_eq!(stats.queue_length, 5);
2185 assert_eq!(stats.completed, 10);
2186 assert_eq!(stats.stolen_from, 2);
2187 }
2188
2189 #[test]
2190 fn test_scheduler_stats_fields() {
2191 let stats = SchedulerStats {
2192 worker_count: 2,
2193 total_submitted: 10,
2194 total_pending: 5,
2195 total_completed: 4,
2196 total_stolen: 1,
2197 workers: vec![],
2198 };
2199 assert_eq!(stats.worker_count, 2);
2200 assert_eq!(stats.total_submitted, 10);
2201 assert_eq!(stats.total_pending, 5);
2202 assert_eq!(stats.total_completed, 4);
2203 assert_eq!(stats.total_stolen, 1);
2204 }
2205
2206 #[test]
2207 fn test_data_location_clone() {
2208 let loc = DataLocation {
2209 key: "test".into(),
2210 workers: vec![WorkerId::new(1)],
2211 size_bytes: 100,
2212 last_access: Instant::now(),
2213 };
2214 let cloned = loc.clone();
2215 assert_eq!(loc.key, cloned.key);
2216 }
2217
2218 #[test]
2219 fn test_track_data_updates_existing() {
2220 let tracker = BrickDataTracker::new();
2221 tracker.track_data("key", WorkerId::new(1), 100);
2222 tracker.track_data("key", WorkerId::new(1), 200); let workers = tracker.get_workers_for_data("key");
2225 assert_eq!(workers.len(), 1); }
2227}