jugar_probar/brick/
distributed.rs

1//! DistributedBrick: Work-stealing and data locality (PROBAR-SPEC-009-P10)
2//!
3//! This module enables distributed brick execution with:
4//! - Work-stealing across nodes
5//! - Data locality awareness
6//! - Multi-backend dispatch (CPU/GPU/Remote/SIMD)
7//!
8//! # Architecture
9//!
10//! ```text
11//! ┌─────────────────────────────────────────────────────────────┐
12//! │                   DISTRIBUTED BRICK FLOW                     │
13//! ├─────────────────────────────────────────────────────────────┤
14//! │                                                              │
15//! │  1. DistributedBrick<B> wraps any Brick                     │
16//! │  2. BrickDataTracker tracks data locality                   │
17//! │  3. MultiBrickExecutor selects best backend                 │
18//! │  4. BrickCoordinator handles PUB/SUB coordination           │
19//! │                                                              │
20//! └─────────────────────────────────────────────────────────────┘
21//! ```
22//!
23//! # References
24//!
25//! - PROBAR-SPEC-009-P10: Distribution - Repartir Integration
26
27// Allow expect for RwLock - lock poisoning is truly exceptional
28#![allow(clippy::expect_used)]
29
30use std::collections::HashMap;
31use std::fmt;
32use std::sync::atomic::{AtomicU64, Ordering};
33use std::sync::{Arc, RwLock};
34use std::time::{Duration, Instant};
35
36use super::{Brick, BrickAssertion, BrickBudget, BrickError, BrickResult, BrickVerification};
37
38/// Unique identifier for a worker node
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct WorkerId(pub u64);
41
42impl WorkerId {
43    /// Create a new worker ID
44    #[must_use]
45    pub const fn new(id: u64) -> Self {
46        Self(id)
47    }
48
49    /// Get the underlying ID value
50    #[must_use]
51    pub const fn value(&self) -> u64 {
52        self.0
53    }
54}
55
56impl fmt::Display for WorkerId {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        write!(f, "worker-{}", self.0)
59    }
60}
61
62/// Execution backend for brick operations
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub enum Backend {
65    /// CPU execution with standard instructions
66    Cpu,
67    /// GPU execution via WebGPU/wgpu
68    Gpu,
69    /// Remote execution on another node
70    Remote,
71    /// CPU execution with SIMD acceleration
72    Simd,
73}
74
75impl Backend {
76    /// Check if backend is available on current system
77    #[must_use]
78    pub fn is_available(&self) -> bool {
79        match self {
80            Self::Cpu | Self::Simd => true,
81            Self::Gpu => cfg!(feature = "gpu"),
82            // Remote backend requires distributed feature (not yet implemented)
83            Self::Remote => false,
84        }
85    }
86
87    /// Get relative performance estimate (higher = faster)
88    #[must_use]
89    pub const fn performance_estimate(&self) -> u32 {
90        match self {
91            Self::Gpu => 100,
92            Self::Simd => 50,
93            Self::Cpu => 10,
94            Self::Remote => 5, // Network latency
95        }
96    }
97}
98
99impl Default for Backend {
100    fn default() -> Self {
101        Self::Cpu
102    }
103}
104
105/// Input data for brick execution
106#[derive(Debug, Clone, Default)]
107pub struct BrickInput {
108    /// Input tensor data
109    pub data: Vec<f32>,
110    /// Input shape dimensions
111    pub shape: Vec<usize>,
112    /// Additional metadata
113    pub metadata: HashMap<String, String>,
114}
115
116impl BrickInput {
117    /// Create new brick input
118    #[must_use]
119    pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
120        Self {
121            data,
122            shape,
123            metadata: HashMap::new(),
124        }
125    }
126
127    /// Get total size in bytes
128    #[must_use]
129    pub fn size_bytes(&self) -> usize {
130        self.data.len() * std::mem::size_of::<f32>()
131    }
132
133    /// Get total element count
134    #[must_use]
135    pub fn element_count(&self) -> usize {
136        self.data.len()
137    }
138
139    /// Add metadata
140    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
141        self.metadata.insert(key.into(), value.into());
142        self
143    }
144}
145
146/// Output data from brick execution
147#[derive(Debug, Clone, Default)]
148pub struct BrickOutput {
149    /// Output tensor data
150    pub data: Vec<f32>,
151    /// Output shape dimensions
152    pub shape: Vec<usize>,
153    /// Execution metrics
154    pub metrics: ExecutionMetrics,
155}
156
157impl BrickOutput {
158    /// Create new brick output
159    #[must_use]
160    pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
161        Self {
162            data,
163            shape,
164            metrics: ExecutionMetrics::default(),
165        }
166    }
167
168    /// Get total size in bytes
169    #[must_use]
170    pub fn size_bytes(&self) -> usize {
171        self.data.len() * std::mem::size_of::<f32>()
172    }
173}
174
175/// Metrics from brick execution
176#[derive(Debug, Clone, Default)]
177pub struct ExecutionMetrics {
178    /// Time to execute
179    pub execution_time: Duration,
180    /// Backend used
181    pub backend: Backend,
182    /// Worker that executed
183    pub worker_id: Option<WorkerId>,
184    /// Data transfer time (if remote)
185    pub transfer_time: Option<Duration>,
186}
187
188impl ExecutionMetrics {
189    /// Create new execution metrics
190    #[must_use]
191    pub fn new(execution_time: Duration, backend: Backend) -> Self {
192        Self {
193            execution_time,
194            backend,
195            worker_id: None,
196            transfer_time: None,
197        }
198    }
199}
200
201/// Distributed brick wrapper for multi-backend execution
202///
203/// Wraps any `Brick` to enable distributed execution with:
204/// - Backend selection (CPU/GPU/Remote/SIMD)
205/// - Data dependency tracking for locality
206/// - Work-stealing support
207#[derive(Debug)]
208pub struct DistributedBrick<B: Brick> {
209    inner: B,
210    backend: Backend,
211    data_dependencies: Vec<String>,
212    preferred_worker: Option<WorkerId>,
213}
214
215impl<B: Brick> DistributedBrick<B> {
216    /// Create a new distributed brick wrapper
217    #[must_use]
218    pub fn new(inner: B) -> Self {
219        Self {
220            inner,
221            backend: Backend::default(),
222            data_dependencies: Vec::new(),
223            preferred_worker: None,
224        }
225    }
226
227    /// Set the preferred execution backend
228    #[must_use]
229    pub fn with_backend(mut self, backend: Backend) -> Self {
230        self.backend = backend;
231        self
232    }
233
234    /// Add data dependencies for locality-aware scheduling
235    #[must_use]
236    pub fn with_data_dependencies(mut self, deps: Vec<String>) -> Self {
237        self.data_dependencies = deps;
238        self
239    }
240
241    /// Set preferred worker for execution
242    #[must_use]
243    pub fn with_preferred_worker(mut self, worker: WorkerId) -> Self {
244        self.preferred_worker = Some(worker);
245        self
246    }
247
248    /// Get the inner brick
249    #[must_use]
250    pub fn inner(&self) -> &B {
251        &self.inner
252    }
253
254    /// Get mutable access to inner brick
255    pub fn inner_mut(&mut self) -> &mut B {
256        &mut self.inner
257    }
258
259    /// Get current backend
260    #[must_use]
261    pub fn backend(&self) -> Backend {
262        self.backend
263    }
264
265    /// Get data dependencies
266    #[must_use]
267    pub fn data_dependencies(&self) -> &[String] {
268        &self.data_dependencies
269    }
270
271    /// Get preferred worker
272    #[must_use]
273    pub fn preferred_worker(&self) -> Option<WorkerId> {
274        self.preferred_worker
275    }
276
277    /// Convert to a task specification for distributed execution
278    #[must_use]
279    pub fn to_task_spec(&self) -> TaskSpec {
280        TaskSpec {
281            brick_name: self.inner.brick_name().to_string(),
282            backend: self.backend,
283            data_dependencies: self.data_dependencies.clone(),
284            preferred_worker: self.preferred_worker,
285        }
286    }
287}
288
289impl<B: Brick> Brick for DistributedBrick<B> {
290    fn brick_name(&self) -> &'static str {
291        self.inner.brick_name()
292    }
293
294    fn assertions(&self) -> &[BrickAssertion] {
295        self.inner.assertions()
296    }
297
298    fn budget(&self) -> BrickBudget {
299        self.inner.budget()
300    }
301
302    fn verify(&self) -> BrickVerification {
303        self.inner.verify()
304    }
305
306    fn to_html(&self) -> String {
307        self.inner.to_html()
308    }
309
310    fn to_css(&self) -> String {
311        self.inner.to_css()
312    }
313}
314
315/// Task specification for distributed execution
316#[derive(Debug, Clone)]
317pub struct TaskSpec {
318    /// Brick name for identification
319    pub brick_name: String,
320    /// Requested backend
321    pub backend: Backend,
322    /// Data dependencies
323    pub data_dependencies: Vec<String>,
324    /// Preferred worker
325    pub preferred_worker: Option<WorkerId>,
326}
327
328/// Data location entry for a specific piece of data
329#[derive(Debug, Clone)]
330pub struct DataLocation {
331    /// Data key/identifier
332    pub key: String,
333    /// Workers that have this data
334    pub workers: Vec<WorkerId>,
335    /// Size of data in bytes
336    pub size_bytes: usize,
337    /// Last access time
338    pub last_access: Instant,
339}
340
341/// Track where brick weights/data reside across workers
342///
343/// Used for locality-aware scheduling to minimize data movement.
344#[derive(Debug)]
345pub struct BrickDataTracker {
346    /// Map from data key to location info
347    locations: RwLock<HashMap<String, DataLocation>>,
348}
349
350impl Default for BrickDataTracker {
351    fn default() -> Self {
352        Self::new()
353    }
354}
355
356impl BrickDataTracker {
357    /// Create a new data tracker
358    #[must_use]
359    pub fn new() -> Self {
360        Self {
361            locations: RwLock::new(HashMap::new()),
362        }
363    }
364
365    /// Register that a worker has certain data
366    pub fn track_data(&self, key: &str, worker_id: WorkerId, size_bytes: usize) {
367        let mut locations = self.locations.write().expect("lock poisoned");
368        locations
369            .entry(key.to_string())
370            .and_modify(|loc| {
371                if !loc.workers.contains(&worker_id) {
372                    loc.workers.push(worker_id);
373                }
374                loc.last_access = Instant::now();
375            })
376            .or_insert_with(|| DataLocation {
377                key: key.to_string(),
378                workers: vec![worker_id],
379                size_bytes,
380                last_access: Instant::now(),
381            });
382    }
383
384    /// Register that a worker has brick weights
385    pub fn track_weights(&self, brick_name: &str, worker_id: WorkerId) {
386        let key = format!("{}_weights", brick_name);
387        self.track_data(&key, worker_id, 0);
388    }
389
390    /// Remove data location from a worker
391    pub fn remove_data(&self, key: &str, worker_id: WorkerId) {
392        let mut locations = self.locations.write().expect("lock poisoned");
393        if let Some(loc) = locations.get_mut(key) {
394            loc.workers.retain(|w| *w != worker_id);
395        }
396    }
397
398    /// Get workers that have specific data
399    #[must_use]
400    pub fn get_workers_for_data(&self, key: &str) -> Vec<WorkerId> {
401        let locations = self.locations.read().expect("lock poisoned");
402        locations
403            .get(key)
404            .map_or(Vec::new(), |loc| loc.workers.clone())
405    }
406
407    /// Calculate affinity scores for workers based on data dependencies
408    pub fn calculate_affinity(&self, dependencies: &[String]) -> HashMap<WorkerId, f64> {
409        let locations = self.locations.read().expect("lock poisoned");
410        let mut affinity: HashMap<WorkerId, f64> = HashMap::new();
411
412        for dep in dependencies {
413            if let Some(loc) = locations.get(dep) {
414                let score_per_worker = 1.0 / loc.workers.len() as f64;
415                for worker in &loc.workers {
416                    *affinity.entry(*worker).or_insert(0.0) += score_per_worker;
417                }
418            }
419        }
420
421        // Normalize scores
422        if !affinity.is_empty() {
423            let max_score = affinity.values().cloned().fold(0.0_f64, f64::max);
424            if max_score > 0.0 {
425                for score in affinity.values_mut() {
426                    *score /= max_score;
427                }
428            }
429        }
430
431        affinity
432    }
433
434    /// Find the best worker for a brick based on data locality
435    #[must_use]
436    pub fn find_best_worker(&self, brick: &dyn Brick) -> Option<WorkerId> {
437        // Use brick name to find weights
438        let weights_key = format!("{}_weights", brick.brick_name());
439        let workers = self.get_workers_for_data(&weights_key);
440        workers.first().copied()
441    }
442
443    /// Find best worker for distributed brick with dependencies
444    #[must_use]
445    pub fn find_best_worker_for_distributed<B: Brick>(
446        &self,
447        brick: &DistributedBrick<B>,
448    ) -> Option<WorkerId> {
449        // Check preferred worker first
450        if let Some(preferred) = brick.preferred_worker() {
451            return Some(preferred);
452        }
453
454        // Calculate affinity based on dependencies
455        let affinity = self.calculate_affinity(brick.data_dependencies());
456        affinity
457            .into_iter()
458            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
459            .map(|(worker, _)| worker)
460    }
461
462    /// Get total data size tracked
463    #[must_use]
464    pub fn total_data_size(&self) -> usize {
465        let locations = self.locations.read().expect("lock poisoned");
466        locations.values().map(|loc| loc.size_bytes).sum()
467    }
468}
469
470/// Backend selector for choosing optimal execution backend
471#[derive(Debug)]
472pub struct BackendSelector {
473    /// Minimum element count for GPU execution
474    gpu_threshold: usize,
475    /// Minimum element count for SIMD execution
476    simd_threshold: usize,
477    /// Maximum element count for CPU (else remote)
478    cpu_max_threshold: usize,
479}
480
481impl Default for BackendSelector {
482    fn default() -> Self {
483        Self::new()
484    }
485}
486
487impl BackendSelector {
488    /// Create a new backend selector with default thresholds
489    #[must_use]
490    pub fn new() -> Self {
491        Self {
492            gpu_threshold: 1_000_000,       // 1M elements for GPU
493            simd_threshold: 10_000,         // 10K elements for SIMD
494            cpu_max_threshold: 100_000_000, // 100M elements max for local
495        }
496    }
497
498    /// Configure GPU threshold
499    #[must_use]
500    pub fn with_gpu_threshold(mut self, threshold: usize) -> Self {
501        self.gpu_threshold = threshold;
502        self
503    }
504
505    /// Configure SIMD threshold
506    #[must_use]
507    pub fn with_simd_threshold(mut self, threshold: usize) -> Self {
508        self.simd_threshold = threshold;
509        self
510    }
511
512    /// Configure CPU max threshold
513    #[must_use]
514    pub fn with_cpu_max_threshold(mut self, threshold: usize) -> Self {
515        self.cpu_max_threshold = threshold;
516        self
517    }
518
519    /// Select best backend based on input characteristics
520    #[must_use]
521    pub fn select(&self, element_count: usize, gpu_available: bool) -> Backend {
522        // Too large for local - use remote if available (not yet implemented)
523        if element_count > self.cpu_max_threshold && Backend::Remote.is_available() {
524            return Backend::Remote;
525        }
526
527        // Large enough for GPU
528        if element_count >= self.gpu_threshold && gpu_available {
529            return Backend::Gpu;
530        }
531
532        // Medium size - use SIMD
533        if element_count >= self.simd_threshold {
534            return Backend::Simd;
535        }
536
537        // Small - use plain CPU
538        Backend::Cpu
539    }
540
541    /// Select backend for a brick with input
542    #[must_use]
543    pub fn select_for_brick(
544        &self,
545        _brick_complexity: u32,
546        input_size: usize,
547        gpu_available: bool,
548    ) -> Backend {
549        // Future: factor in brick_complexity
550        self.select(input_size, gpu_available)
551    }
552}
553
554/// Multi-backend executor for brick operations
555///
556/// Dispatches brick execution to the best available backend.
557#[derive(Debug)]
558pub struct MultiBrickExecutor {
559    selector: BackendSelector,
560    gpu_available: bool,
561    data_tracker: Arc<BrickDataTracker>,
562}
563
564impl MultiBrickExecutor {
565    /// Create a new multi-backend executor
566    #[must_use]
567    pub fn new(data_tracker: Arc<BrickDataTracker>) -> Self {
568        Self {
569            selector: BackendSelector::new(),
570            gpu_available: cfg!(feature = "gpu"),
571            data_tracker,
572        }
573    }
574
575    /// Create with custom backend selector
576    #[must_use]
577    pub fn with_selector(mut self, selector: BackendSelector) -> Self {
578        self.selector = selector;
579        self
580    }
581
582    /// Set GPU availability
583    #[must_use]
584    pub fn with_gpu_available(mut self, available: bool) -> Self {
585        self.gpu_available = available;
586        self
587    }
588
589    /// Execute a brick on the best backend
590    pub fn execute(&self, brick: &dyn Brick, input: BrickInput) -> BrickResult<BrickOutput> {
591        let start = Instant::now();
592
593        // Select backend
594        let backend = self
595            .selector
596            .select(input.element_count(), self.gpu_available);
597
598        // Execute on selected backend
599        let (output_data, output_shape) = match backend {
600            Backend::Cpu => self.execute_cpu(brick, &input)?,
601            Backend::Simd => self.execute_simd(brick, &input)?,
602            Backend::Gpu => self.execute_gpu(brick, &input)?,
603            Backend::Remote => self.execute_remote(brick, &input)?,
604        };
605
606        let execution_time = start.elapsed();
607
608        // Build output with metrics
609        let mut output = BrickOutput::new(output_data, output_shape);
610        output.metrics = ExecutionMetrics::new(execution_time, backend);
611
612        Ok(output)
613    }
614
615    /// Execute distributed brick
616    pub fn execute_distributed<B: Brick>(
617        &self,
618        brick: &DistributedBrick<B>,
619        input: BrickInput,
620    ) -> BrickResult<BrickOutput> {
621        let start = Instant::now();
622
623        // Use brick's preferred backend or select automatically
624        let backend = brick.backend();
625
626        // Find best worker for locality
627        let worker_id = self.data_tracker.find_best_worker_for_distributed(brick);
628
629        // Execute
630        let (output_data, output_shape) = match backend {
631            Backend::Cpu => self.execute_cpu(brick.inner(), &input)?,
632            Backend::Simd => self.execute_simd(brick.inner(), &input)?,
633            Backend::Gpu => self.execute_gpu(brick.inner(), &input)?,
634            Backend::Remote => self.execute_remote(brick.inner(), &input)?,
635        };
636
637        let execution_time = start.elapsed();
638
639        // Build output with metrics
640        let mut output = BrickOutput::new(output_data, output_shape);
641        output.metrics = ExecutionMetrics {
642            execution_time,
643            backend,
644            worker_id,
645            transfer_time: None,
646        };
647
648        Ok(output)
649    }
650
651    fn execute_cpu(
652        &self,
653        _brick: &dyn Brick,
654        input: &BrickInput,
655    ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
656        // Simple passthrough for now - real implementation would execute brick
657        Ok((input.data.clone(), input.shape.clone()))
658    }
659
660    fn execute_simd(
661        &self,
662        _brick: &dyn Brick,
663        input: &BrickInput,
664    ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
665        // SIMD path - would use actual SIMD instructions
666        Ok((input.data.clone(), input.shape.clone()))
667    }
668
669    fn execute_gpu(
670        &self,
671        _brick: &dyn Brick,
672        input: &BrickInput,
673    ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
674        // GPU path - would use WebGPU/wgpu
675        if !self.gpu_available {
676            return Err(BrickError::HtmlGenerationFailed {
677                reason: "GPU not available".into(),
678            });
679        }
680        Ok((input.data.clone(), input.shape.clone()))
681    }
682
683    fn execute_remote(
684        &self,
685        _brick: &dyn Brick,
686        input: &BrickInput,
687    ) -> BrickResult<(Vec<f32>, Vec<usize>)> {
688        // Remote path - would serialize and send to remote worker
689        if !Backend::Remote.is_available() {
690            return Err(BrickError::HtmlGenerationFailed {
691                reason: "Distributed execution not available".into(),
692            });
693        }
694        Ok((input.data.clone(), input.shape.clone()))
695    }
696
697    /// Get the data tracker
698    #[must_use]
699    pub fn data_tracker(&self) -> &Arc<BrickDataTracker> {
700        &self.data_tracker
701    }
702}
703
704/// Message for PUB/SUB coordination
705#[derive(Debug, Clone)]
706pub enum BrickMessage {
707    /// Weight update message
708    WeightUpdate {
709        /// Name of the brick whose weights are being updated
710        brick_name: String,
711        /// Serialized weight data
712        weights: Vec<u8>,
713        /// Weight version number
714        version: u64,
715    },
716    /// State change notification
717    StateChange {
718        /// Name of the brick that changed state
719        brick_name: String,
720        /// Event description
721        event: String,
722    },
723    /// Request brick execution
724    ExecutionRequest {
725        /// Name of brick to execute
726        brick_name: String,
727        /// Key to input data
728        input_key: String,
729        /// Unique request ID for correlation
730        request_id: u64,
731    },
732    /// Execution result
733    ExecutionResult {
734        /// Request ID this result corresponds to
735        request_id: u64,
736        /// Key to output data
737        output_key: String,
738        /// Whether execution succeeded
739        success: bool,
740    },
741}
742
743/// Subscription to brick events
744#[derive(Debug)]
745pub struct Subscription {
746    topic: String,
747    messages: Arc<RwLock<Vec<BrickMessage>>>,
748}
749
750impl Subscription {
751    /// Get all pending messages
752    #[must_use]
753    pub fn drain(&self) -> Vec<BrickMessage> {
754        let mut messages = self.messages.write().expect("lock poisoned");
755        std::mem::take(&mut *messages)
756    }
757
758    /// Check if there are pending messages
759    #[must_use]
760    pub fn has_messages(&self) -> bool {
761        let messages = self.messages.read().expect("lock poisoned");
762        !messages.is_empty()
763    }
764
765    /// Get subscription topic
766    #[must_use]
767    pub fn topic(&self) -> &str {
768        &self.topic
769    }
770}
771
772// ============================================================================
773// Work-Stealing Scheduler (Phase 10e)
774// ============================================================================
775
776/// A task that can be executed by workers and potentially stolen
777#[derive(Debug, Clone)]
778pub struct WorkStealingTask {
779    /// Unique task ID
780    pub id: u64,
781    /// Task specification
782    pub spec: TaskSpec,
783    /// Input data key
784    pub input_key: String,
785    /// Priority (higher = more urgent)
786    pub priority: u32,
787    /// Creation time
788    pub created_at: Instant,
789}
790
791impl WorkStealingTask {
792    /// Create a new work-stealing task
793    #[must_use]
794    pub fn new(id: u64, spec: TaskSpec, input_key: String) -> Self {
795        Self {
796            id,
797            spec,
798            input_key,
799            priority: 0,
800            created_at: Instant::now(),
801        }
802    }
803
804    /// Set task priority
805    #[must_use]
806    pub fn with_priority(mut self, priority: u32) -> Self {
807        self.priority = priority;
808        self
809    }
810
811    /// Get task age
812    #[must_use]
813    pub fn age(&self) -> Duration {
814        self.created_at.elapsed()
815    }
816}
817
818/// Per-worker task queue supporting work-stealing
819#[derive(Debug)]
820pub struct WorkerQueue {
821    /// Worker ID
822    worker_id: WorkerId,
823    /// Local task queue (owned tasks)
824    local_queue: RwLock<Vec<WorkStealingTask>>,
825    /// Number of tasks completed
826    completed_count: AtomicU64,
827    /// Number of tasks stolen from this queue
828    stolen_count: AtomicU64,
829}
830
831impl WorkerQueue {
832    /// Create a new worker queue
833    #[must_use]
834    pub fn new(worker_id: WorkerId) -> Self {
835        Self {
836            worker_id,
837            local_queue: RwLock::new(Vec::new()),
838            completed_count: AtomicU64::new(0),
839            stolen_count: AtomicU64::new(0),
840        }
841    }
842
843    /// Push a task to the local queue
844    pub fn push(&self, task: WorkStealingTask) {
845        let mut queue = self.local_queue.write().expect("lock poisoned");
846        queue.push(task);
847        // Sort by priority (higher first)
848        queue.sort_by(|a, b| b.priority.cmp(&a.priority));
849    }
850
851    /// Pop a task from the local queue (highest priority first)
852    pub fn pop(&self) -> Option<WorkStealingTask> {
853        let mut queue = self.local_queue.write().expect("lock poisoned");
854        if queue.is_empty() {
855            return None;
856        }
857        Some(queue.remove(0)) // Get highest priority (front after sort)
858    }
859
860    /// Steal a task from this queue (lowest priority - be nice to owner)
861    pub fn steal(&self) -> Option<WorkStealingTask> {
862        let mut queue = self.local_queue.write().expect("lock poisoned");
863        if queue.is_empty() {
864            return None;
865        }
866        self.stolen_count.fetch_add(1, Ordering::Relaxed);
867        queue.pop() // Steal lowest priority (back after sort)
868    }
869
870    /// Check if queue is empty
871    #[must_use]
872    pub fn is_empty(&self) -> bool {
873        let queue = self.local_queue.read().expect("lock poisoned");
874        queue.is_empty()
875    }
876
877    /// Get queue length
878    #[must_use]
879    pub fn len(&self) -> usize {
880        let queue = self.local_queue.read().expect("lock poisoned");
881        queue.len()
882    }
883
884    /// Mark a task as completed
885    pub fn mark_completed(&self) {
886        self.completed_count.fetch_add(1, Ordering::Relaxed);
887    }
888
889    /// Get worker ID
890    #[must_use]
891    pub fn worker_id(&self) -> WorkerId {
892        self.worker_id
893    }
894
895    /// Get completed count
896    #[must_use]
897    pub fn completed_count(&self) -> u64 {
898        self.completed_count.load(Ordering::Relaxed)
899    }
900
901    /// Get stolen count
902    #[must_use]
903    pub fn stolen_count(&self) -> u64 {
904        self.stolen_count.load(Ordering::Relaxed)
905    }
906}
907
908/// Work-stealing scheduler for distributed brick execution
909///
910/// Implements work-stealing algorithm where idle workers steal tasks
911/// from busy workers' queues. This provides automatic load balancing.
912///
913/// # Algorithm
914///
915/// 1. Each worker has a local deque (double-ended queue)
916/// 2. Workers push/pop from their own queue (LIFO - good for cache locality)
917/// 3. When idle, workers steal from other queues (FIFO - steal oldest tasks)
918/// 4. Stealing considers data locality via `BrickDataTracker`
919#[derive(Debug)]
920pub struct WorkStealingScheduler {
921    /// Worker queues indexed by worker ID
922    queues: RwLock<HashMap<WorkerId, Arc<WorkerQueue>>>,
923    /// Data tracker for locality-aware stealing
924    data_tracker: Arc<BrickDataTracker>,
925    /// Task ID counter
926    task_counter: AtomicU64,
927    /// Total tasks submitted
928    submitted_count: AtomicU64,
929}
930
931impl WorkStealingScheduler {
932    /// Create a new work-stealing scheduler
933    #[must_use]
934    pub fn new(data_tracker: Arc<BrickDataTracker>) -> Self {
935        Self {
936            queues: RwLock::new(HashMap::new()),
937            data_tracker,
938            task_counter: AtomicU64::new(0),
939            submitted_count: AtomicU64::new(0),
940        }
941    }
942
943    /// Register a worker with the scheduler
944    pub fn register_worker(&self, worker_id: WorkerId) -> Arc<WorkerQueue> {
945        let queue = Arc::new(WorkerQueue::new(worker_id));
946        let mut queues = self.queues.write().expect("lock poisoned");
947        queues.insert(worker_id, Arc::clone(&queue));
948        queue
949    }
950
951    /// Unregister a worker
952    pub fn unregister_worker(&self, worker_id: WorkerId) {
953        let mut queues = self.queues.write().expect("lock poisoned");
954        queues.remove(&worker_id);
955    }
956
957    /// Submit a task to the best worker based on locality
958    pub fn submit(&self, spec: TaskSpec, input_key: String) -> u64 {
959        let task_id = self.task_counter.fetch_add(1, Ordering::SeqCst);
960        let task = WorkStealingTask::new(task_id, spec.clone(), input_key);
961
962        // Find best worker based on data locality
963        let target_worker = self.find_best_worker_for_task(&spec);
964
965        let queues = self.queues.read().expect("lock poisoned");
966        if let Some(queue) = target_worker.and_then(|w| queues.get(&w)) {
967            queue.push(task);
968        } else if let Some((_, queue)) = queues.iter().next() {
969            // Fallback to first available worker
970            queue.push(task);
971        }
972
973        self.submitted_count.fetch_add(1, Ordering::Relaxed);
974        task_id
975    }
976
977    /// Submit with explicit priority
978    pub fn submit_priority(&self, spec: TaskSpec, input_key: String, priority: u32) -> u64 {
979        let task_id = self.task_counter.fetch_add(1, Ordering::SeqCst);
980        let task = WorkStealingTask::new(task_id, spec.clone(), input_key).with_priority(priority);
981
982        let target_worker = self.find_best_worker_for_task(&spec);
983
984        let queues = self.queues.read().expect("lock poisoned");
985        if let Some(queue) = target_worker.and_then(|w| queues.get(&w)) {
986            queue.push(task);
987        } else if let Some((_, queue)) = queues.iter().next() {
988            queue.push(task);
989        }
990
991        self.submitted_count.fetch_add(1, Ordering::Relaxed);
992        task_id
993    }
994
995    /// Try to get work for a worker (local pop or steal)
996    pub fn get_work(&self, worker_id: WorkerId) -> Option<WorkStealingTask> {
997        let queues = self.queues.read().expect("lock poisoned");
998
999        // First try local queue
1000        if let Some(queue) = queues.get(&worker_id) {
1001            if let Some(task) = queue.pop() {
1002                return Some(task);
1003            }
1004        }
1005
1006        // Try to steal from other workers
1007        self.try_steal(worker_id, &queues)
1008    }
1009
1010    /// Try to steal work from another worker's queue
1011    fn try_steal(
1012        &self,
1013        stealer_id: WorkerId,
1014        queues: &HashMap<WorkerId, Arc<WorkerQueue>>,
1015    ) -> Option<WorkStealingTask> {
1016        // Find queues with work, preferring those with data locality
1017        let mut candidates: Vec<_> = queues
1018            .iter()
1019            .filter(|(id, q)| **id != stealer_id && !q.is_empty())
1020            .collect();
1021
1022        if candidates.is_empty() {
1023            return None;
1024        }
1025
1026        // Sort by queue length (steal from busiest)
1027        candidates.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
1028
1029        // Try to steal from the busiest queue
1030        for (_, queue) in candidates {
1031            if let Some(task) = queue.steal() {
1032                return Some(task);
1033            }
1034        }
1035
1036        None
1037    }
1038
1039    /// Find best worker for a task based on data locality
1040    fn find_best_worker_for_task(&self, spec: &TaskSpec) -> Option<WorkerId> {
1041        // Check preferred worker
1042        if let Some(preferred) = spec.preferred_worker {
1043            return Some(preferred);
1044        }
1045
1046        // Calculate affinity based on data dependencies
1047        let affinity = self
1048            .data_tracker
1049            .calculate_affinity(&spec.data_dependencies);
1050        affinity
1051            .into_iter()
1052            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1053            .map(|(worker, _)| worker)
1054    }
1055
1056    /// Get scheduler statistics
1057    #[must_use]
1058    pub fn stats(&self) -> SchedulerStats {
1059        let queues = self.queues.read().expect("lock poisoned");
1060
1061        let worker_stats: Vec<_> = queues
1062            .values()
1063            .map(|q| WorkerStats {
1064                worker_id: q.worker_id(),
1065                queue_length: q.len(),
1066                completed: q.completed_count(),
1067                stolen_from: q.stolen_count(),
1068            })
1069            .collect();
1070
1071        let total_pending: usize = worker_stats.iter().map(|s| s.queue_length).sum();
1072        let total_completed: u64 = worker_stats.iter().map(|s| s.completed).sum();
1073        let total_stolen: u64 = worker_stats.iter().map(|s| s.stolen_from).sum();
1074
1075        SchedulerStats {
1076            worker_count: queues.len(),
1077            total_submitted: self.submitted_count.load(Ordering::Relaxed),
1078            total_pending,
1079            total_completed,
1080            total_stolen,
1081            workers: worker_stats,
1082        }
1083    }
1084
1085    /// Get the data tracker
1086    #[must_use]
1087    pub fn data_tracker(&self) -> &Arc<BrickDataTracker> {
1088        &self.data_tracker
1089    }
1090}
1091
1092/// Statistics for a single worker
1093#[derive(Debug, Clone)]
1094pub struct WorkerStats {
1095    /// Worker ID
1096    pub worker_id: WorkerId,
1097    /// Current queue length
1098    pub queue_length: usize,
1099    /// Tasks completed
1100    pub completed: u64,
1101    /// Tasks stolen from this worker
1102    pub stolen_from: u64,
1103}
1104
1105/// Scheduler-wide statistics
1106#[derive(Debug, Clone)]
1107pub struct SchedulerStats {
1108    /// Number of registered workers
1109    pub worker_count: usize,
1110    /// Total tasks submitted
1111    pub total_submitted: u64,
1112    /// Total tasks pending across all queues
1113    pub total_pending: usize,
1114    /// Total tasks completed
1115    pub total_completed: u64,
1116    /// Total tasks stolen (indicates load balancing activity)
1117    pub total_stolen: u64,
1118    /// Per-worker statistics
1119    pub workers: Vec<WorkerStats>,
1120}
1121
1122// ============================================================================
1123// PUB/SUB Coordinator
1124// ============================================================================
1125
1126/// PUB/SUB coordinator for brick communication
1127///
1128/// Enables distributed coordination via publish/subscribe messaging.
1129#[derive(Debug)]
1130pub struct BrickCoordinator {
1131    /// Active subscriptions by topic
1132    subscriptions: RwLock<HashMap<String, Vec<Arc<RwLock<Vec<BrickMessage>>>>>>,
1133    /// Message counter for request IDs
1134    message_counter: AtomicU64,
1135}
1136
1137impl Default for BrickCoordinator {
1138    fn default() -> Self {
1139        Self::new()
1140    }
1141}
1142
1143impl BrickCoordinator {
1144    /// Create a new coordinator
1145    #[must_use]
1146    pub fn new() -> Self {
1147        Self {
1148            subscriptions: RwLock::new(HashMap::new()),
1149            message_counter: AtomicU64::new(0),
1150        }
1151    }
1152
1153    /// Subscribe to a topic
1154    #[must_use]
1155    pub fn subscribe(&self, topic: &str) -> Subscription {
1156        let messages = Arc::new(RwLock::new(Vec::new()));
1157        {
1158            let mut subs = self.subscriptions.write().expect("lock poisoned");
1159            subs.entry(topic.to_string())
1160                .or_default()
1161                .push(Arc::clone(&messages));
1162        }
1163        Subscription {
1164            topic: topic.to_string(),
1165            messages,
1166        }
1167    }
1168
1169    /// Subscribe to brick events
1170    #[must_use]
1171    pub fn subscribe_brick(&self, brick_name: &str) -> Subscription {
1172        let topic = format!("brick/{}/events", brick_name);
1173        self.subscribe(&topic)
1174    }
1175
1176    /// Publish a message to a topic
1177    pub fn publish(&self, topic: &str, message: BrickMessage) {
1178        let subs = self.subscriptions.read().expect("lock poisoned");
1179        if let Some(subscribers) = subs.get(topic) {
1180            for sub in subscribers {
1181                let mut messages = sub.write().expect("lock poisoned");
1182                messages.push(message.clone());
1183            }
1184        }
1185    }
1186
1187    /// Broadcast weight updates for a brick
1188    pub fn broadcast_weights(&self, brick_name: &str, weights: Vec<u8>) {
1189        let topic = format!("brick/{}/weights", brick_name);
1190        let version = self.message_counter.fetch_add(1, Ordering::SeqCst);
1191        self.publish(
1192            &topic,
1193            BrickMessage::WeightUpdate {
1194                brick_name: brick_name.to_string(),
1195                weights,
1196                version,
1197            },
1198        );
1199    }
1200
1201    /// Broadcast state change for a brick
1202    pub fn broadcast_state_change(&self, brick_name: &str, event: &str) {
1203        let topic = format!("brick/{}/events", brick_name);
1204        self.publish(
1205            &topic,
1206            BrickMessage::StateChange {
1207                brick_name: brick_name.to_string(),
1208                event: event.to_string(),
1209            },
1210        );
1211    }
1212
1213    /// Generate a unique request ID
1214    #[must_use]
1215    pub fn next_request_id(&self) -> u64 {
1216        self.message_counter.fetch_add(1, Ordering::SeqCst)
1217    }
1218}
1219
1220#[cfg(test)]
1221mod tests {
1222    use super::*;
1223
1224    struct TestBrick {
1225        name: &'static str,
1226    }
1227
1228    impl Brick for TestBrick {
1229        fn brick_name(&self) -> &'static str {
1230            self.name
1231        }
1232
1233        fn assertions(&self) -> &[BrickAssertion] {
1234            &[BrickAssertion::TextVisible]
1235        }
1236
1237        fn budget(&self) -> BrickBudget {
1238            BrickBudget::uniform(16)
1239        }
1240
1241        fn verify(&self) -> BrickVerification {
1242            BrickVerification {
1243                passed: vec![BrickAssertion::TextVisible],
1244                failed: vec![],
1245                verification_time: Duration::from_micros(100),
1246            }
1247        }
1248
1249        fn to_html(&self) -> String {
1250            format!("<div>{}</div>", self.name)
1251        }
1252
1253        fn to_css(&self) -> String {
1254            ".test { }".into()
1255        }
1256    }
1257
1258    #[test]
1259    fn test_worker_id() {
1260        let id = WorkerId::new(42);
1261        assert_eq!(id.value(), 42);
1262        assert_eq!(format!("{id}"), "worker-42");
1263    }
1264
1265    #[test]
1266    fn test_backend_availability() {
1267        assert!(Backend::Cpu.is_available());
1268        assert!(Backend::Simd.is_available());
1269        // GPU/Remote depend on feature flags
1270    }
1271
1272    #[test]
1273    fn test_backend_performance() {
1274        assert!(Backend::Gpu.performance_estimate() > Backend::Simd.performance_estimate());
1275        assert!(Backend::Simd.performance_estimate() > Backend::Cpu.performance_estimate());
1276    }
1277
1278    #[test]
1279    fn test_distributed_brick_creation() {
1280        let inner = TestBrick { name: "Test" };
1281        let distributed = DistributedBrick::new(inner)
1282            .with_backend(Backend::Gpu)
1283            .with_data_dependencies(vec!["weights".into(), "biases".into()])
1284            .with_preferred_worker(WorkerId::new(1));
1285
1286        assert_eq!(distributed.backend(), Backend::Gpu);
1287        assert_eq!(distributed.data_dependencies().len(), 2);
1288        assert_eq!(distributed.preferred_worker(), Some(WorkerId::new(1)));
1289        assert_eq!(distributed.brick_name(), "Test");
1290    }
1291
1292    #[test]
1293    fn test_distributed_brick_implements_brick() {
1294        let inner = TestBrick { name: "Test" };
1295        let distributed = DistributedBrick::new(inner);
1296
1297        // Verify it implements Brick trait
1298        assert!(distributed.verify().is_valid());
1299        assert_eq!(distributed.budget().total_ms, 16);
1300    }
1301
1302    #[test]
1303    fn test_task_spec() {
1304        let inner = TestBrick { name: "TestTask" };
1305        let distributed = DistributedBrick::new(inner)
1306            .with_backend(Backend::Simd)
1307            .with_data_dependencies(vec!["model".into()]);
1308
1309        let spec = distributed.to_task_spec();
1310        assert_eq!(spec.brick_name, "TestTask");
1311        assert_eq!(spec.backend, Backend::Simd);
1312        assert_eq!(spec.data_dependencies, vec!["model"]);
1313    }
1314
1315    #[test]
1316    fn test_brick_input_output() {
1317        let input = BrickInput::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
1318        assert_eq!(input.element_count(), 4);
1319        assert_eq!(input.size_bytes(), 16);
1320
1321        let output = BrickOutput::new(vec![5.0, 6.0], vec![2]);
1322        assert_eq!(output.size_bytes(), 8);
1323    }
1324
1325    #[test]
1326    fn test_data_tracker() {
1327        let tracker = BrickDataTracker::new();
1328
1329        // Track some data
1330        tracker.track_data("model_weights", WorkerId::new(1), 1024);
1331        tracker.track_data("model_weights", WorkerId::new(2), 1024);
1332        tracker.track_data("biases", WorkerId::new(1), 256);
1333
1334        // Check workers
1335        let workers = tracker.get_workers_for_data("model_weights");
1336        assert_eq!(workers.len(), 2);
1337
1338        // Calculate affinity
1339        let affinity = tracker.calculate_affinity(&["model_weights".into(), "biases".into()]);
1340        assert!(affinity.get(&WorkerId::new(1)).unwrap_or(&0.0) > &0.0);
1341    }
1342
1343    #[test]
1344    fn test_data_tracker_find_best_worker() {
1345        let tracker = BrickDataTracker::new();
1346
1347        let brick = TestBrick { name: "MelBrick" };
1348        tracker.track_weights("MelBrick", WorkerId::new(5));
1349
1350        let best = tracker.find_best_worker(&brick);
1351        assert_eq!(best, Some(WorkerId::new(5)));
1352    }
1353
1354    #[test]
1355    fn test_backend_selector() {
1356        let selector = BackendSelector::new()
1357            .with_gpu_threshold(1000)
1358            .with_simd_threshold(100);
1359
1360        // Small input -> CPU
1361        assert_eq!(selector.select(50, true), Backend::Cpu);
1362
1363        // Medium input -> SIMD
1364        assert_eq!(selector.select(500, true), Backend::Simd);
1365
1366        // Large input with GPU -> GPU
1367        assert_eq!(selector.select(5000, true), Backend::Gpu);
1368
1369        // Large input without GPU -> SIMD
1370        assert_eq!(selector.select(5000, false), Backend::Simd);
1371    }
1372
1373    #[test]
1374    fn test_multi_executor() {
1375        let tracker = Arc::new(BrickDataTracker::new());
1376        let executor = MultiBrickExecutor::new(tracker);
1377
1378        let brick = TestBrick { name: "Test" };
1379        let input = BrickInput::new(vec![1.0, 2.0, 3.0], vec![3]);
1380
1381        let result = executor.execute(&brick, input);
1382        assert!(result.is_ok());
1383
1384        let output = result.expect("execution should succeed");
1385        assert_eq!(output.data.len(), 3);
1386        assert!(output.metrics.execution_time >= Duration::ZERO);
1387    }
1388
1389    #[test]
1390    fn test_brick_coordinator() {
1391        let coordinator = BrickCoordinator::new();
1392
1393        // Subscribe to events
1394        let sub = coordinator.subscribe_brick("MyBrick");
1395
1396        // Broadcast event
1397        coordinator.broadcast_state_change("MyBrick", "loaded");
1398
1399        // Check subscription received message
1400        assert!(sub.has_messages());
1401        let messages = sub.drain();
1402        assert_eq!(messages.len(), 1);
1403        matches!(&messages[0], BrickMessage::StateChange { brick_name, .. } if brick_name == "MyBrick");
1404    }
1405
1406    #[test]
1407    fn test_coordinator_weight_broadcast() {
1408        let coordinator = BrickCoordinator::new();
1409
1410        let sub = coordinator.subscribe("brick/Encoder/weights");
1411        coordinator.broadcast_weights("Encoder", vec![1, 2, 3, 4]);
1412
1413        let messages = sub.drain();
1414        assert_eq!(messages.len(), 1);
1415        match &messages[0] {
1416            BrickMessage::WeightUpdate {
1417                brick_name,
1418                weights,
1419                version,
1420            } => {
1421                assert_eq!(brick_name, "Encoder");
1422                assert_eq!(weights, &vec![1, 2, 3, 4]);
1423                assert_eq!(*version, 0);
1424            }
1425            _ => panic!("Expected WeightUpdate message"),
1426        }
1427    }
1428
1429    #[test]
1430    fn test_subscription_topic() {
1431        let coordinator = BrickCoordinator::new();
1432        let sub = coordinator.subscribe("my/topic");
1433        assert_eq!(sub.topic(), "my/topic");
1434    }
1435
1436    #[test]
1437    fn test_execution_metrics() {
1438        let metrics = ExecutionMetrics::new(Duration::from_millis(50), Backend::Gpu);
1439        assert_eq!(metrics.execution_time, Duration::from_millis(50));
1440        assert_eq!(metrics.backend, Backend::Gpu);
1441        assert!(metrics.worker_id.is_none());
1442    }
1443
1444    // ========================================================================
1445    // Work-Stealing Scheduler Tests (Phase 10e)
1446    // ========================================================================
1447
1448    #[test]
1449    fn test_work_stealing_task() {
1450        let spec = TaskSpec {
1451            brick_name: "TestBrick".into(),
1452            backend: Backend::Cpu,
1453            data_dependencies: vec![],
1454            preferred_worker: None,
1455        };
1456        let task = WorkStealingTask::new(1, spec, "input_key".into()).with_priority(10);
1457
1458        assert_eq!(task.id, 1);
1459        assert_eq!(task.priority, 10);
1460        assert_eq!(task.input_key, "input_key");
1461        assert!(task.age() >= Duration::ZERO);
1462    }
1463
1464    #[test]
1465    fn test_worker_queue_basic() {
1466        let queue = WorkerQueue::new(WorkerId::new(1));
1467
1468        assert!(queue.is_empty());
1469        assert_eq!(queue.len(), 0);
1470
1471        let spec = TaskSpec {
1472            brick_name: "Test".into(),
1473            backend: Backend::Cpu,
1474            data_dependencies: vec![],
1475            preferred_worker: None,
1476        };
1477        let task = WorkStealingTask::new(1, spec, "key".into());
1478        queue.push(task);
1479
1480        assert!(!queue.is_empty());
1481        assert_eq!(queue.len(), 1);
1482
1483        let popped = queue.pop();
1484        assert!(popped.is_some());
1485        assert!(queue.is_empty());
1486    }
1487
1488    #[test]
1489    fn test_worker_queue_priority_ordering() {
1490        let queue = WorkerQueue::new(WorkerId::new(1));
1491
1492        // Push tasks with different priorities
1493        for i in 0..5 {
1494            let spec = TaskSpec {
1495                brick_name: format!("Task{}", i),
1496                backend: Backend::Cpu,
1497                data_dependencies: vec![],
1498                preferred_worker: None,
1499            };
1500            let task = WorkStealingTask::new(i as u64, spec, "key".into()).with_priority(i);
1501            queue.push(task);
1502        }
1503
1504        // Pop should return highest priority first
1505        let task = queue.pop().unwrap();
1506        assert_eq!(task.priority, 4);
1507
1508        let task = queue.pop().unwrap();
1509        assert_eq!(task.priority, 3);
1510    }
1511
1512    #[test]
1513    fn test_worker_queue_steal() {
1514        let queue = WorkerQueue::new(WorkerId::new(1));
1515
1516        // Push 3 tasks with priorities 0, 1, 2
1517        for i in 0..3 {
1518            let spec = TaskSpec {
1519                brick_name: format!("Task{}", i),
1520                backend: Backend::Cpu,
1521                data_dependencies: vec![],
1522                preferred_worker: None,
1523            };
1524            let task = WorkStealingTask::new(i as u64, spec, "key".into()).with_priority(i);
1525            queue.push(task);
1526        }
1527
1528        // Steal takes from front (lowest priority after sort)
1529        let stolen = queue.steal().unwrap();
1530        assert_eq!(stolen.priority, 0);
1531        assert_eq!(queue.stolen_count(), 1);
1532
1533        // Queue still has 2 tasks
1534        assert_eq!(queue.len(), 2);
1535    }
1536
1537    #[test]
1538    fn test_work_stealing_scheduler_basic() {
1539        let tracker = Arc::new(BrickDataTracker::new());
1540        let scheduler = WorkStealingScheduler::new(tracker);
1541
1542        // Register workers
1543        let _q1 = scheduler.register_worker(WorkerId::new(1));
1544        let _q2 = scheduler.register_worker(WorkerId::new(2));
1545
1546        let stats = scheduler.stats();
1547        assert_eq!(stats.worker_count, 2);
1548        assert_eq!(stats.total_submitted, 0);
1549    }
1550
1551    #[test]
1552    fn test_work_stealing_scheduler_submit() {
1553        let tracker = Arc::new(BrickDataTracker::new());
1554        let scheduler = WorkStealingScheduler::new(tracker);
1555
1556        scheduler.register_worker(WorkerId::new(1));
1557
1558        let spec = TaskSpec {
1559            brick_name: "Test".into(),
1560            backend: Backend::Cpu,
1561            data_dependencies: vec![],
1562            preferred_worker: None,
1563        };
1564
1565        let task_id = scheduler.submit(spec, "input".into());
1566        assert_eq!(task_id, 0);
1567
1568        let stats = scheduler.stats();
1569        assert_eq!(stats.total_submitted, 1);
1570        assert_eq!(stats.total_pending, 1);
1571    }
1572
1573    #[test]
1574    fn test_work_stealing_scheduler_get_work() {
1575        let tracker = Arc::new(BrickDataTracker::new());
1576        let scheduler = WorkStealingScheduler::new(tracker);
1577
1578        scheduler.register_worker(WorkerId::new(1));
1579        scheduler.register_worker(WorkerId::new(2));
1580
1581        // Submit task preferring worker 1
1582        let spec = TaskSpec {
1583            brick_name: "Test".into(),
1584            backend: Backend::Cpu,
1585            data_dependencies: vec![],
1586            preferred_worker: Some(WorkerId::new(1)),
1587        };
1588        scheduler.submit(spec, "input".into());
1589
1590        // Worker 1 should get the task
1591        let task = scheduler.get_work(WorkerId::new(1));
1592        assert!(task.is_some());
1593
1594        // Worker 2 has nothing to get (or steal since queue is now empty)
1595        let task = scheduler.get_work(WorkerId::new(2));
1596        assert!(task.is_none());
1597    }
1598
1599    #[test]
1600    fn test_work_stealing_scheduler_steal() {
1601        let tracker = Arc::new(BrickDataTracker::new());
1602        let scheduler = WorkStealingScheduler::new(tracker);
1603
1604        scheduler.register_worker(WorkerId::new(1));
1605        scheduler.register_worker(WorkerId::new(2));
1606
1607        // Submit 3 tasks to worker 1
1608        for i in 0..3 {
1609            let spec = TaskSpec {
1610                brick_name: format!("Task{}", i),
1611                backend: Backend::Cpu,
1612                data_dependencies: vec![],
1613                preferred_worker: Some(WorkerId::new(1)),
1614            };
1615            scheduler.submit(spec, format!("input{}", i));
1616        }
1617
1618        // Worker 2 should be able to steal a task
1619        let stolen = scheduler.get_work(WorkerId::new(2));
1620        assert!(stolen.is_some());
1621
1622        let stats = scheduler.stats();
1623        assert_eq!(stats.total_stolen, 1);
1624        assert_eq!(stats.total_pending, 2); // 3 submitted - 1 stolen
1625    }
1626
1627    #[test]
1628    fn test_work_stealing_scheduler_locality() {
1629        let tracker = Arc::new(BrickDataTracker::new());
1630
1631        // Track data on worker 1
1632        tracker.track_data("model_weights", WorkerId::new(1), 1024);
1633
1634        let scheduler = WorkStealingScheduler::new(Arc::clone(&tracker));
1635        scheduler.register_worker(WorkerId::new(1));
1636        scheduler.register_worker(WorkerId::new(2));
1637
1638        // Submit task with data dependency - should go to worker 1
1639        let spec = TaskSpec {
1640            brick_name: "MelBrick".into(),
1641            backend: Backend::Cpu,
1642            data_dependencies: vec!["model_weights".into()],
1643            preferred_worker: None,
1644        };
1645        scheduler.submit(spec, "audio_input".into());
1646
1647        // Worker 1 should have the task
1648        let task = scheduler.get_work(WorkerId::new(1));
1649        assert!(task.is_some());
1650        assert_eq!(task.unwrap().spec.brick_name, "MelBrick");
1651    }
1652
1653    #[test]
1654    fn test_scheduler_stats() {
1655        let tracker = Arc::new(BrickDataTracker::new());
1656        let scheduler = WorkStealingScheduler::new(tracker);
1657
1658        scheduler.register_worker(WorkerId::new(1));
1659        scheduler.register_worker(WorkerId::new(2));
1660
1661        // Submit some tasks
1662        for i in 0..5 {
1663            let spec = TaskSpec {
1664                brick_name: format!("Task{}", i),
1665                backend: Backend::Cpu,
1666                data_dependencies: vec![],
1667                preferred_worker: if i % 2 == 0 {
1668                    Some(WorkerId::new(1))
1669                } else {
1670                    Some(WorkerId::new(2))
1671                },
1672            };
1673            scheduler.submit(spec, format!("input{}", i));
1674        }
1675
1676        let stats = scheduler.stats();
1677        assert_eq!(stats.worker_count, 2);
1678        assert_eq!(stats.total_submitted, 5);
1679        assert_eq!(stats.total_pending, 5);
1680        assert_eq!(stats.workers.len(), 2);
1681    }
1682}