use super::kernel::{GpuKernel, CompiledKernel};
use super::buffer::BufferPool;
use crate::error::{Error, Result};
use std::sync::Arc;
use parking_lot::RwLock;
use wgpu;
pub struct GpuRuntime {
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
adapter_info: wgpu::AdapterInfo,
buffer_pool: BufferPool,
}
impl GpuRuntime {
pub async fn new() -> Result<Self> {
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
});
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.ok_or_else(|| Error::gpu("No GPU adapter found"))?;
let adapter_info = adapter.get_info();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: Some("veda-gpu-device"),
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
},
None,
)
.await
.map_err(|e| Error::gpu(format!("Failed to request device: {}", e)))?;
let device = Arc::new(device);
let queue = Arc::new(queue);
let buffer_pool = BufferPool::new(Arc::clone(&device));
Ok(Self {
device,
queue,
adapter_info,
buffer_pool,
})
}
pub async fn get_or_init() -> Result<Arc<Self>> {
static RUNTIME: RwLock<Option<Arc<GpuRuntime>>> = RwLock::new(None);
{
let runtime = RUNTIME.read();
if let Some(rt) = runtime.as_ref() {
return Ok(Arc::clone(rt));
}
}
let mut runtime = RUNTIME.write();
if let Some(rt) = runtime.as_ref() {
return Ok(Arc::clone(rt));
}
let rt = Arc::new(Self::new().await?);
*runtime = Some(Arc::clone(&rt));
Ok(rt)
}
pub async fn execute_kernel<K: GpuKernel>(&self, kernel: K) -> Result<Vec<u8>> {
let compiled = kernel.compile(&self.device)?;
let input_size = kernel.input_size();
let output_size = kernel.output_size();
let input_buffer = self.buffer_pool.acquire(input_size);
let output_buffer = self.buffer_pool.acquire(output_size);
compiled.execute(&self.queue, &input_buffer, &output_buffer).await?;
let result = output_buffer.read_data(&self.queue).await?;
Ok(result)
}
pub fn device(&self) -> &wgpu::Device {
&self.device
}
pub fn queue(&self) -> &wgpu::Queue {
&self.queue
}
pub fn adapter_info(&self) -> &wgpu::AdapterInfo {
&self.adapter_info
}
pub fn buffer_pool(&self) -> &BufferPool {
&self.buffer_pool
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_gpu_runtime_init() {
if let Ok(runtime) = GpuRuntime::new().await {
println!("GPU: {:?}", runtime.adapter_info().name);
assert!(runtime.device().limits().max_compute_workgroup_size_x > 0);
}
}
}