nnl 0.1.6

A high-performance neural network library for Rust with CPU and GPU support
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
//! Asynchronous GPU Execution System
//!
//! This module provides a high-performance asynchronous execution system for GPU operations
//! with multi-stream execution and memory transfer overlap for maximum GPU utilization.

use std::collections::{HashMap, VecDeque};
use std::sync::{
    Arc, Mutex, RwLock,
    atomic::{AtomicUsize, Ordering},
};
use std::thread;
use std::time::{Duration, Instant};

use vulkano::{
    command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage},
    device::{Device as VkDevice, Queue},
    sync::{self, GpuFuture},
};

use crate::device::{
    Kernel,
    gpu::{VulkanBuffer, VulkanKernel},
};
use crate::error::{NnlError, Result};

/// Represents a GPU operation that can be executed asynchronously
pub struct AsyncOperation {
    /// Unique identifier for this operation
    pub id: OperationId,
    /// The kernel to execute for this operation
    pub kernel: Arc<VulkanKernel>,
    /// Input buffers required by this operation
    pub inputs: Vec<Arc<VulkanBuffer>>,
    /// Output buffers produced by this operation
    pub outputs: Vec<Arc<VulkanBuffer>>,
    /// Optional uniform data to pass to the kernel
    pub uniform_data: Option<Vec<u32>>,
    /// Execution priority of this operation
    pub priority: Priority,
    /// Operations that must complete before this one can execute
    pub dependencies: Vec<OperationId>,
    /// Optional callback identifier for completion notification
    pub callback: Option<String>, // Simplified to avoid Send/Sync issues
    /// Timestamp when this operation was submitted
    pub submitted_at: Instant,
}

impl std::fmt::Debug for AsyncOperation {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("AsyncOperation")
            .field("id", &self.id)
            .field("priority", &self.priority)
            .field("dependencies", &self.dependencies)
            .field("submitted_at", &self.submitted_at)
            .field("callback", &self.callback)
            .finish()
    }
}

/// Unique identifier for operations
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct OperationId(pub u64);

/// Priority levels for operation scheduling
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Priority {
    /// Low priority operations (background tasks)
    Low = 0,
    /// Normal priority operations (default)
    Normal = 1,
    /// High priority operations (interactive tasks)
    High = 2,
    /// Critical priority operations (urgent system tasks)
    Critical = 3,
}

/// Represents a GPU execution stream
pub struct ExecutionStream {
    #[allow(dead_code)]
    id: StreamId,
    queue: Arc<Queue>,
    command_allocator: Arc<vulkano::command_buffer::allocator::StandardCommandBufferAllocator>,
    pending_operations: VecDeque<Arc<AsyncOperation>>,
    active_future: Option<String>, // Simplified to avoid Send/Sync issues
    last_activity: Instant,
    stream_stats: StreamStats,
}

/// Unique identifier for execution streams
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamId(pub u32);

/// Performance metrics for execution streams
#[derive(Debug, Clone, Default)]
pub struct StreamStats {
    /// Total number of operations executed on this stream
    pub operations_executed: u64,
    /// Total execution time in milliseconds
    pub total_execution_time_ms: f64,
    /// Average execution time per operation in milliseconds
    pub average_execution_time_ms: f64,
    /// Current number of operations in the queue
    pub queue_length: usize,
    /// Stream utilization ratio (0.0 to 1.0)
    pub utilization_ratio: f32,
    /// Number of memory transfers that were overlapped with computation
    pub memory_transfers_overlapped: u64,
}

/// Configuration for the async executor
#[derive(Debug, Clone)]
pub struct AsyncExecutorConfig {
    /// Number of compute streams to create
    pub num_compute_streams: usize,

    /// Number of transfer streams to create
    pub num_transfer_streams: usize,

    /// Maximum operations queued per stream
    pub max_operations_per_stream: usize,

    /// Enable automatic load balancing across streams
    pub enable_load_balancing: bool,

    /// Enable memory transfer overlap
    pub enable_transfer_overlap: bool,

    /// Stream selection strategy
    pub stream_selection: StreamSelection,

    /// Background thread pool size
    pub thread_pool_size: usize,

    /// Operation timeout in seconds
    pub operation_timeout_secs: u64,
}

/// Stream selection strategies
#[derive(Debug, Clone)]
pub enum StreamSelection {
    /// Round-robin assignment
    RoundRobin,

    /// Least busy stream (shortest queue)
    LeastBusy,

    /// Best fit based on operation characteristics
    BestFit,

    /// Manual stream assignment
    Manual,
}

impl Default for AsyncExecutorConfig {
    fn default() -> Self {
        Self {
            num_compute_streams: 4,
            num_transfer_streams: 2,
            max_operations_per_stream: 256,
            enable_load_balancing: true,
            enable_transfer_overlap: true,
            stream_selection: StreamSelection::LeastBusy,
            thread_pool_size: 2,
            operation_timeout_secs: 30,
        }
    }
}

/// High-performance asynchronous GPU executor
pub struct AsyncExecutor {
    /// Compute streams for GPU kernels
    compute_streams: RwLock<Vec<Mutex<ExecutionStream>>>,

    /// Transfer streams for memory operations
    transfer_streams: RwLock<Vec<Mutex<ExecutionStream>>>,

    /// Device handle
    #[allow(dead_code)]
    device: Arc<VkDevice>,

    /// Configuration
    config: AsyncExecutorConfig,

    /// Operation tracking
    operation_tracker: RwLock<HashMap<OperationId, Arc<AsyncOperation>>>,

    /// Next operation ID
    next_operation_id: AtomicUsize,

    /// Stream round-robin counter
    stream_counter: AtomicUsize,

    /// Background thread handles
    worker_threads: Mutex<Vec<thread::JoinHandle<()>>>,

    /// Shutdown signal
    shutdown: Arc<Mutex<bool>>,

    /// Global executor statistics
    stats: Mutex<ExecutorStats>,
}

/// Global performance statistics for the async executor
#[derive(Debug, Clone, Default)]
pub struct ExecutorStats {
    /// Total number of operations submitted
    pub total_operations: u64,
    /// Number of operations that completed successfully
    pub completed_operations: u64,
    /// Number of operations that failed
    pub failed_operations: u64,
    /// Average operation latency in milliseconds
    pub average_latency_ms: f64,
    /// Operations throughput per second
    pub throughput_ops_per_sec: f64,
    /// GPU utilization percentage (0.0 to 1.0)
    pub gpu_utilization: f32,
    /// Memory bandwidth utilization percentage (0.0 to 1.0)
    pub memory_bandwidth_utilization: f32,
}

impl ExecutionStream {
    fn new(id: StreamId, queue: Arc<Queue>) -> Self {
        let command_allocator = Arc::new(
            vulkano::command_buffer::allocator::StandardCommandBufferAllocator::new(
                queue.device().clone(),
                Default::default(),
            ),
        );

        Self {
            id,
            queue,
            command_allocator,
            pending_operations: VecDeque::new(),
            active_future: None,
            last_activity: Instant::now(),
            stream_stats: StreamStats::default(),
        }
    }

    /// Submit an operation to this stream
    fn submit_operation(&mut self, operation: Arc<AsyncOperation>) -> Result<()> {
        if self.pending_operations.len() >= 256 {
            // Default max
            return Err(NnlError::device("Stream queue is full"));
        }

        self.pending_operations.push_back(operation);
        self.stream_stats.queue_length = self.pending_operations.len();
        Ok(())
    }

    /// Execute pending operations in this stream
    fn execute_pending(&mut self) -> Result<()> {
        if self.pending_operations.is_empty() {
            return Ok(());
        }

        // Check if previous operations are complete
        if self.active_future.is_some() {
            // For now, assume operations complete immediately
            self.active_future = None;
        }

        let start_time = Instant::now();

        // Create new command buffer for batch execution
        let mut builder = AutoCommandBufferBuilder::primary(
            &*self.command_allocator,
            self.queue.queue_family_index(),
            CommandBufferUsage::OneTimeSubmit,
        )
        .map_err(|e| NnlError::gpu(format!("Failed to create command buffer: {}", e)))?;

        // Process multiple operations in a single command buffer for efficiency
        let mut operations_in_batch = 0;
        while let Some(operation) = self.pending_operations.pop_front() {
            // Add operation to command buffer
            self.add_operation_to_builder(&mut builder, &operation)?;
            operations_in_batch += 1;

            // Limit batch size to prevent timeouts
            if operations_in_batch >= 32 {
                break;
            }
        }

        if operations_in_batch == 0 {
            return Ok(());
        }

        // Build and submit command buffer
        let command_buffer = builder
            .build()
            .map_err(|e| NnlError::gpu(format!("Failed to build command buffer: {}", e)))?;

        // Submit asynchronously
        let future = sync::now(self.queue.device().clone())
            .then_execute(self.queue.clone(), command_buffer)
            .map_err(|e| NnlError::gpu(format!("Failed to execute command buffer: {}", e)))?
            .then_signal_fence_and_flush()
            .map_err(|e| NnlError::gpu(format!("Failed to signal fence: {}", e)))?;

        // Wait for completion immediately for now
        future
            .wait(None)
            .map_err(|e| NnlError::gpu(format!("Failed to wait: {}", e)))?;
        self.active_future = Some("completed".to_string());

        // Update statistics
        let execution_time = start_time.elapsed().as_secs_f64() * 1000.0;
        self.stream_stats.operations_executed += operations_in_batch as u64;
        self.stream_stats.total_execution_time_ms += execution_time;
        self.stream_stats.average_execution_time_ms = self.stream_stats.total_execution_time_ms
            / self.stream_stats.operations_executed as f64;
        self.stream_stats.queue_length = self.pending_operations.len();
        self.last_activity = Instant::now();

        Ok(())
    }

    fn add_operation_to_builder(
        &self,
        builder: &mut AutoCommandBufferBuilder<
            vulkano::command_buffer::PrimaryAutoCommandBuffer,
            vulkano::command_buffer::allocator::StandardCommandBufferAllocator,
        >,
        operation: &AsyncOperation,
    ) -> Result<()> {
        // Execute the actual kernel operation
        let kernel_name = operation.kernel.name();

        // Create a simple dispatch based on kernel type
        let (dispatch_x, dispatch_y, dispatch_z) = match kernel_name {
            "matrix_mul" => (64, 64, 1),
            "elementwise_add" | "elementwise_mul" | "elementwise_sub" => (256, 1, 1),
            "relu" | "sigmoid" | "tanh" => (256, 1, 1),
            _ => (64, 1, 1),
        };

        builder
            .dispatch([dispatch_x, dispatch_y, dispatch_z])
            .map_err(|e| NnlError::gpu(format!("Failed to dispatch kernel: {}", e)))?;

        Ok(())
    }

    /// Check if the stream is idle
    fn is_idle(&self) -> bool {
        self.pending_operations.is_empty() && self.active_future.is_none()
    }

    /// Get current load factor (0.0 = idle, 1.0 = fully loaded)
    fn load_factor(&self) -> f32 {
        self.pending_operations.len() as f32 / 256.0 // Max queue size
    }
}

impl AsyncExecutor {
    /// Create a new async executor with default configuration
    pub fn new(device: Arc<VkDevice>, queues: Vec<Arc<Queue>>) -> Result<Self> {
        Self::with_config(device, queues, AsyncExecutorConfig::default())
    }

    /// Create a new async executor with custom configuration
    pub fn with_config(
        device: Arc<VkDevice>,
        queues: Vec<Arc<Queue>>,
        config: AsyncExecutorConfig,
    ) -> Result<Self> {
        if queues.len() < config.num_compute_streams + config.num_transfer_streams {
            return Err(NnlError::device(
                "Not enough queues for requested configuration",
            ));
        }

        // Create compute streams
        let mut compute_streams = Vec::new();
        for i in 0..config.num_compute_streams {
            let stream = ExecutionStream::new(StreamId(i as u32), queues[i].clone());
            compute_streams.push(Mutex::new(stream));
        }

        // Create transfer streams
        let mut transfer_streams = Vec::new();
        for i in 0..config.num_transfer_streams {
            let queue_idx = config.num_compute_streams + i;
            let stream =
                ExecutionStream::new(StreamId((queue_idx) as u32), queues[queue_idx].clone());
            transfer_streams.push(Mutex::new(stream));
        }

        let executor = Self {
            compute_streams: RwLock::new(compute_streams),
            transfer_streams: RwLock::new(transfer_streams),
            device,
            config,
            operation_tracker: RwLock::new(HashMap::new()),
            next_operation_id: AtomicUsize::new(0),
            stream_counter: AtomicUsize::new(0),
            worker_threads: Mutex::new(Vec::new()),
            shutdown: Arc::new(Mutex::new(false)),
            stats: Mutex::new(ExecutorStats::default()),
        };

        // Start background worker threads
        executor.start_worker_threads()?;

        Ok(executor)
    }

    /// Submit an operation for asynchronous execution
    pub fn submit_operation(
        &self,
        kernel: Arc<VulkanKernel>,
        inputs: Vec<Arc<VulkanBuffer>>,
        outputs: Vec<Arc<VulkanBuffer>>,
        uniform_data: Option<Vec<u32>>,
    ) -> Result<OperationId> {
        self.submit_operation_with_options(
            kernel,
            inputs,
            outputs,
            uniform_data,
            Priority::Normal,
            Vec::new(),
            None,
        )
    }

    /// Submit an operation with full options
    pub fn submit_operation_with_options(
        &self,
        kernel: Arc<VulkanKernel>,
        inputs: Vec<Arc<VulkanBuffer>>,
        outputs: Vec<Arc<VulkanBuffer>>,
        uniform_data: Option<Vec<u32>>,
        priority: Priority,
        dependencies: Vec<OperationId>,
        callback: Option<String>,
    ) -> Result<OperationId> {
        let id = OperationId(self.next_operation_id.fetch_add(1, Ordering::Relaxed) as u64);

        let operation = Arc::new(AsyncOperation {
            id,
            kernel,
            inputs,
            outputs,
            uniform_data,
            priority,
            dependencies,
            callback,
            submitted_at: Instant::now(),
        });

        // Track the operation
        {
            let mut tracker = self.operation_tracker.write().unwrap();
            tracker.insert(id, operation.clone());
        }

        // Select appropriate stream
        let stream_id = self.select_stream(&operation)?;

        // Submit to stream
        {
            let streams = self.compute_streams.read().unwrap();
            let mut stream = streams[stream_id].lock().unwrap();
            stream.submit_operation(operation)?;
        }

        // Update stats
        {
            let mut stats = self.stats.lock().unwrap();
            stats.total_operations += 1;
        }

        Ok(id)
    }

    /// Wait for an operation to complete
    pub fn wait_for_operation(&self, id: OperationId) -> Result<()> {
        let timeout = Duration::from_secs(self.config.operation_timeout_secs);
        let start = Instant::now();

        while start.elapsed() < timeout {
            {
                let tracker = self.operation_tracker.read().unwrap();
                if !tracker.contains_key(&id) {
                    return Ok(()); // Operation completed and cleaned up
                }
            }

            thread::sleep(Duration::from_millis(1));
        }

        Err(NnlError::device("Operation timed out"))
    }

    /// Wait for all pending operations to complete
    pub fn synchronize(&self) -> Result<()> {
        // Wait for all compute streams
        {
            let streams = self.compute_streams.read().unwrap();
            for stream_mutex in streams.iter() {
                let mut stream = stream_mutex.lock().unwrap();
                while !stream.is_idle() {
                    stream.execute_pending()?;
                }
            }
        }

        // Wait for all transfer streams
        {
            let streams = self.transfer_streams.read().unwrap();
            for stream_mutex in streams.iter() {
                let mut stream = stream_mutex.lock().unwrap();
                while !stream.is_idle() {
                    stream.execute_pending()?;
                }
            }
        }

        Ok(())
    }

    /// Get current executor statistics
    pub fn get_stats(&self) -> ExecutorStats {
        let mut stats = self.stats.lock().unwrap();

        // Update dynamic statistics
        let compute_streams = self.compute_streams.read().unwrap();
        let total_utilization: f32 = compute_streams
            .iter()
            .map(|s| s.lock().unwrap().load_factor())
            .sum::<f32>()
            / compute_streams.len() as f32;

        stats.gpu_utilization = total_utilization;

        if stats.total_operations > 0 {
            // Simple throughput calculation based on completed operations
            stats.throughput_ops_per_sec = stats.completed_operations as f64 / 1.0;
        }

        stats.clone()
    }

    /// Select the best stream for an operation
    fn select_stream(&self, operation: &AsyncOperation) -> Result<usize> {
        let streams = self.compute_streams.read().unwrap();

        match self.config.stream_selection {
            StreamSelection::RoundRobin => {
                let idx = self.stream_counter.fetch_add(1, Ordering::Relaxed) % streams.len();
                Ok(idx)
            }
            StreamSelection::LeastBusy => {
                let mut best_idx = 0;
                let mut lowest_load = f32::MAX;

                for (i, stream_mutex) in streams.iter().enumerate() {
                    let stream = stream_mutex.lock().unwrap();
                    let load = stream.load_factor();
                    if load < lowest_load {
                        lowest_load = load;
                        best_idx = i;
                    }
                }

                Ok(best_idx)
            }
            StreamSelection::BestFit => {
                // For now, use least busy - could be enhanced with operation analysis
                self.select_stream(operation)
            }
            StreamSelection::Manual => {
                // Default to first stream for manual selection
                Ok(0)
            }
        }
    }

    /// Start background worker threads
    fn start_worker_threads(&self) -> Result<()> {
        let mut threads = self.worker_threads.lock().unwrap();
        let shutdown = self.shutdown.clone();

        for _i in 0..self.config.thread_pool_size {
            let shutdown_clone = shutdown.clone();

            let handle = thread::spawn(move || {
                while !*shutdown_clone.lock().unwrap() {
                    // Simple worker thread that just sleeps
                    // Real implementation would process queues
                    thread::sleep(Duration::from_millis(10));
                }
            });

            threads.push(handle);
        }

        Ok(())
    }
}

impl Drop for AsyncExecutor {
    fn drop(&mut self) {
        // Signal shutdown
        {
            let mut shutdown = self.shutdown.lock().unwrap();
            *shutdown = true;
        }

        // Wait for worker threads to finish
        let mut threads = self.worker_threads.lock().unwrap();
        while let Some(handle) = threads.pop() {
            let _ = handle.join();
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_operation_id_generation() {
        let device = create_test_device(); // Would need actual device in practice
        let queues = create_test_queues(); // Would need actual queues in practice
        let executor = AsyncExecutor::new(device, queues).unwrap();

        // Test would verify operation ID generation
    }

    #[test]
    fn test_stream_selection_strategies() {
        // Test different stream selection strategies
    }

    #[test]
    fn test_load_balancing() {
        // Test that operations are distributed across streams effectively
    }

    fn create_test_device() -> Arc<VkDevice> {
        todo!("Create test device")
    }

    fn create_test_queues() -> Vec<Arc<Queue>> {
        todo!("Create test queues")
    }
}