use std::collections::{HashMap, VecDeque};
use std::sync::{
Arc, Mutex, RwLock,
atomic::{AtomicUsize, Ordering},
};
use std::thread;
use std::time::{Duration, Instant};
use vulkano::{
command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage},
device::{Device as VkDevice, Queue},
sync::{self, GpuFuture},
};
use crate::device::{
Kernel,
gpu::{VulkanBuffer, VulkanKernel},
};
use crate::error::{NnlError, Result};
pub struct AsyncOperation {
pub id: OperationId,
pub kernel: Arc<VulkanKernel>,
pub inputs: Vec<Arc<VulkanBuffer>>,
pub outputs: Vec<Arc<VulkanBuffer>>,
pub uniform_data: Option<Vec<u32>>,
pub priority: Priority,
pub dependencies: Vec<OperationId>,
pub callback: Option<String>, pub submitted_at: Instant,
}
impl std::fmt::Debug for AsyncOperation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncOperation")
.field("id", &self.id)
.field("priority", &self.priority)
.field("dependencies", &self.dependencies)
.field("submitted_at", &self.submitted_at)
.field("callback", &self.callback)
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct OperationId(pub u64);
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
pub struct ExecutionStream {
#[allow(dead_code)]
id: StreamId,
queue: Arc<Queue>,
command_allocator: Arc<vulkano::command_buffer::allocator::StandardCommandBufferAllocator>,
pending_operations: VecDeque<Arc<AsyncOperation>>,
active_future: Option<String>, last_activity: Instant,
stream_stats: StreamStats,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamId(pub u32);
#[derive(Debug, Clone, Default)]
pub struct StreamStats {
pub operations_executed: u64,
pub total_execution_time_ms: f64,
pub average_execution_time_ms: f64,
pub queue_length: usize,
pub utilization_ratio: f32,
pub memory_transfers_overlapped: u64,
}
#[derive(Debug, Clone)]
pub struct AsyncExecutorConfig {
pub num_compute_streams: usize,
pub num_transfer_streams: usize,
pub max_operations_per_stream: usize,
pub enable_load_balancing: bool,
pub enable_transfer_overlap: bool,
pub stream_selection: StreamSelection,
pub thread_pool_size: usize,
pub operation_timeout_secs: u64,
}
#[derive(Debug, Clone)]
pub enum StreamSelection {
RoundRobin,
LeastBusy,
BestFit,
Manual,
}
impl Default for AsyncExecutorConfig {
fn default() -> Self {
Self {
num_compute_streams: 4,
num_transfer_streams: 2,
max_operations_per_stream: 256,
enable_load_balancing: true,
enable_transfer_overlap: true,
stream_selection: StreamSelection::LeastBusy,
thread_pool_size: 2,
operation_timeout_secs: 30,
}
}
}
pub struct AsyncExecutor {
compute_streams: RwLock<Vec<Mutex<ExecutionStream>>>,
transfer_streams: RwLock<Vec<Mutex<ExecutionStream>>>,
#[allow(dead_code)]
device: Arc<VkDevice>,
config: AsyncExecutorConfig,
operation_tracker: RwLock<HashMap<OperationId, Arc<AsyncOperation>>>,
next_operation_id: AtomicUsize,
stream_counter: AtomicUsize,
worker_threads: Mutex<Vec<thread::JoinHandle<()>>>,
shutdown: Arc<Mutex<bool>>,
stats: Mutex<ExecutorStats>,
}
#[derive(Debug, Clone, Default)]
pub struct ExecutorStats {
pub total_operations: u64,
pub completed_operations: u64,
pub failed_operations: u64,
pub average_latency_ms: f64,
pub throughput_ops_per_sec: f64,
pub gpu_utilization: f32,
pub memory_bandwidth_utilization: f32,
}
impl ExecutionStream {
fn new(id: StreamId, queue: Arc<Queue>) -> Self {
let command_allocator = Arc::new(
vulkano::command_buffer::allocator::StandardCommandBufferAllocator::new(
queue.device().clone(),
Default::default(),
),
);
Self {
id,
queue,
command_allocator,
pending_operations: VecDeque::new(),
active_future: None,
last_activity: Instant::now(),
stream_stats: StreamStats::default(),
}
}
fn submit_operation(&mut self, operation: Arc<AsyncOperation>) -> Result<()> {
if self.pending_operations.len() >= 256 {
return Err(NnlError::device("Stream queue is full"));
}
self.pending_operations.push_back(operation);
self.stream_stats.queue_length = self.pending_operations.len();
Ok(())
}
fn execute_pending(&mut self) -> Result<()> {
if self.pending_operations.is_empty() {
return Ok(());
}
if self.active_future.is_some() {
self.active_future = None;
}
let start_time = Instant::now();
let mut builder = AutoCommandBufferBuilder::primary(
&*self.command_allocator,
self.queue.queue_family_index(),
CommandBufferUsage::OneTimeSubmit,
)
.map_err(|e| NnlError::gpu(format!("Failed to create command buffer: {}", e)))?;
let mut operations_in_batch = 0;
while let Some(operation) = self.pending_operations.pop_front() {
self.add_operation_to_builder(&mut builder, &operation)?;
operations_in_batch += 1;
if operations_in_batch >= 32 {
break;
}
}
if operations_in_batch == 0 {
return Ok(());
}
let command_buffer = builder
.build()
.map_err(|e| NnlError::gpu(format!("Failed to build command buffer: {}", e)))?;
let future = sync::now(self.queue.device().clone())
.then_execute(self.queue.clone(), command_buffer)
.map_err(|e| NnlError::gpu(format!("Failed to execute command buffer: {}", e)))?
.then_signal_fence_and_flush()
.map_err(|e| NnlError::gpu(format!("Failed to signal fence: {}", e)))?;
future
.wait(None)
.map_err(|e| NnlError::gpu(format!("Failed to wait: {}", e)))?;
self.active_future = Some("completed".to_string());
let execution_time = start_time.elapsed().as_secs_f64() * 1000.0;
self.stream_stats.operations_executed += operations_in_batch as u64;
self.stream_stats.total_execution_time_ms += execution_time;
self.stream_stats.average_execution_time_ms = self.stream_stats.total_execution_time_ms
/ self.stream_stats.operations_executed as f64;
self.stream_stats.queue_length = self.pending_operations.len();
self.last_activity = Instant::now();
Ok(())
}
fn add_operation_to_builder(
&self,
builder: &mut AutoCommandBufferBuilder<
vulkano::command_buffer::PrimaryAutoCommandBuffer,
vulkano::command_buffer::allocator::StandardCommandBufferAllocator,
>,
operation: &AsyncOperation,
) -> Result<()> {
let kernel_name = operation.kernel.name();
let (dispatch_x, dispatch_y, dispatch_z) = match kernel_name {
"matrix_mul" => (64, 64, 1),
"elementwise_add" | "elementwise_mul" | "elementwise_sub" => (256, 1, 1),
"relu" | "sigmoid" | "tanh" => (256, 1, 1),
_ => (64, 1, 1),
};
builder
.dispatch([dispatch_x, dispatch_y, dispatch_z])
.map_err(|e| NnlError::gpu(format!("Failed to dispatch kernel: {}", e)))?;
Ok(())
}
fn is_idle(&self) -> bool {
self.pending_operations.is_empty() && self.active_future.is_none()
}
fn load_factor(&self) -> f32 {
self.pending_operations.len() as f32 / 256.0 }
}
impl AsyncExecutor {
pub fn new(device: Arc<VkDevice>, queues: Vec<Arc<Queue>>) -> Result<Self> {
Self::with_config(device, queues, AsyncExecutorConfig::default())
}
pub fn with_config(
device: Arc<VkDevice>,
queues: Vec<Arc<Queue>>,
config: AsyncExecutorConfig,
) -> Result<Self> {
if queues.len() < config.num_compute_streams + config.num_transfer_streams {
return Err(NnlError::device(
"Not enough queues for requested configuration",
));
}
let mut compute_streams = Vec::new();
for i in 0..config.num_compute_streams {
let stream = ExecutionStream::new(StreamId(i as u32), queues[i].clone());
compute_streams.push(Mutex::new(stream));
}
let mut transfer_streams = Vec::new();
for i in 0..config.num_transfer_streams {
let queue_idx = config.num_compute_streams + i;
let stream =
ExecutionStream::new(StreamId((queue_idx) as u32), queues[queue_idx].clone());
transfer_streams.push(Mutex::new(stream));
}
let executor = Self {
compute_streams: RwLock::new(compute_streams),
transfer_streams: RwLock::new(transfer_streams),
device,
config,
operation_tracker: RwLock::new(HashMap::new()),
next_operation_id: AtomicUsize::new(0),
stream_counter: AtomicUsize::new(0),
worker_threads: Mutex::new(Vec::new()),
shutdown: Arc::new(Mutex::new(false)),
stats: Mutex::new(ExecutorStats::default()),
};
executor.start_worker_threads()?;
Ok(executor)
}
pub fn submit_operation(
&self,
kernel: Arc<VulkanKernel>,
inputs: Vec<Arc<VulkanBuffer>>,
outputs: Vec<Arc<VulkanBuffer>>,
uniform_data: Option<Vec<u32>>,
) -> Result<OperationId> {
self.submit_operation_with_options(
kernel,
inputs,
outputs,
uniform_data,
Priority::Normal,
Vec::new(),
None,
)
}
pub fn submit_operation_with_options(
&self,
kernel: Arc<VulkanKernel>,
inputs: Vec<Arc<VulkanBuffer>>,
outputs: Vec<Arc<VulkanBuffer>>,
uniform_data: Option<Vec<u32>>,
priority: Priority,
dependencies: Vec<OperationId>,
callback: Option<String>,
) -> Result<OperationId> {
let id = OperationId(self.next_operation_id.fetch_add(1, Ordering::Relaxed) as u64);
let operation = Arc::new(AsyncOperation {
id,
kernel,
inputs,
outputs,
uniform_data,
priority,
dependencies,
callback,
submitted_at: Instant::now(),
});
{
let mut tracker = self.operation_tracker.write().unwrap();
tracker.insert(id, operation.clone());
}
let stream_id = self.select_stream(&operation)?;
{
let streams = self.compute_streams.read().unwrap();
let mut stream = streams[stream_id].lock().unwrap();
stream.submit_operation(operation)?;
}
{
let mut stats = self.stats.lock().unwrap();
stats.total_operations += 1;
}
Ok(id)
}
pub fn wait_for_operation(&self, id: OperationId) -> Result<()> {
let timeout = Duration::from_secs(self.config.operation_timeout_secs);
let start = Instant::now();
while start.elapsed() < timeout {
{
let tracker = self.operation_tracker.read().unwrap();
if !tracker.contains_key(&id) {
return Ok(()); }
}
thread::sleep(Duration::from_millis(1));
}
Err(NnlError::device("Operation timed out"))
}
pub fn synchronize(&self) -> Result<()> {
{
let streams = self.compute_streams.read().unwrap();
for stream_mutex in streams.iter() {
let mut stream = stream_mutex.lock().unwrap();
while !stream.is_idle() {
stream.execute_pending()?;
}
}
}
{
let streams = self.transfer_streams.read().unwrap();
for stream_mutex in streams.iter() {
let mut stream = stream_mutex.lock().unwrap();
while !stream.is_idle() {
stream.execute_pending()?;
}
}
}
Ok(())
}
pub fn get_stats(&self) -> ExecutorStats {
let mut stats = self.stats.lock().unwrap();
let compute_streams = self.compute_streams.read().unwrap();
let total_utilization: f32 = compute_streams
.iter()
.map(|s| s.lock().unwrap().load_factor())
.sum::<f32>()
/ compute_streams.len() as f32;
stats.gpu_utilization = total_utilization;
if stats.total_operations > 0 {
stats.throughput_ops_per_sec = stats.completed_operations as f64 / 1.0;
}
stats.clone()
}
fn select_stream(&self, operation: &AsyncOperation) -> Result<usize> {
let streams = self.compute_streams.read().unwrap();
match self.config.stream_selection {
StreamSelection::RoundRobin => {
let idx = self.stream_counter.fetch_add(1, Ordering::Relaxed) % streams.len();
Ok(idx)
}
StreamSelection::LeastBusy => {
let mut best_idx = 0;
let mut lowest_load = f32::MAX;
for (i, stream_mutex) in streams.iter().enumerate() {
let stream = stream_mutex.lock().unwrap();
let load = stream.load_factor();
if load < lowest_load {
lowest_load = load;
best_idx = i;
}
}
Ok(best_idx)
}
StreamSelection::BestFit => {
self.select_stream(operation)
}
StreamSelection::Manual => {
Ok(0)
}
}
}
fn start_worker_threads(&self) -> Result<()> {
let mut threads = self.worker_threads.lock().unwrap();
let shutdown = self.shutdown.clone();
for _i in 0..self.config.thread_pool_size {
let shutdown_clone = shutdown.clone();
let handle = thread::spawn(move || {
while !*shutdown_clone.lock().unwrap() {
thread::sleep(Duration::from_millis(10));
}
});
threads.push(handle);
}
Ok(())
}
}
impl Drop for AsyncExecutor {
fn drop(&mut self) {
{
let mut shutdown = self.shutdown.lock().unwrap();
*shutdown = true;
}
let mut threads = self.worker_threads.lock().unwrap();
while let Some(handle) = threads.pop() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_operation_id_generation() {
let device = create_test_device(); let queues = create_test_queues(); let executor = AsyncExecutor::new(device, queues).unwrap();
}
#[test]
fn test_stream_selection_strategies() {
}
#[test]
fn test_load_balancing() {
}
fn create_test_device() -> Arc<VkDevice> {
todo!("Create test device")
}
fn create_test_queues() -> Vec<Arc<Queue>> {
todo!("Create test queues")
}
}