use std::sync::Arc;
use xlog_core::MemoryBudget;
use xlog_cuda::device_runtime::{
AsyncCudaResource, DeviceMemoryResource, GlobalDeviceBudget, InMemorySink, LoggingResource,
LoggingSink, StreamPool, XlogDeviceRuntime,
};
use xlog_cuda::{CudaDevice, CudaKernelProvider, GpuMemoryManager};
#[allow(dead_code)] pub fn setup_provider() -> Option<Arc<CudaKernelProvider>> {
let device = match CudaDevice::new(0) {
Ok(d) => Arc::new(d),
Err(e) => {
eprintln!("Skipping: CUDA runtime unavailable: {}", e);
return None;
}
};
let memory = Arc::new(GpuMemoryManager::new(
device.clone(),
MemoryBudget::with_limit(1024 * 1024 * 1024),
));
CudaKernelProvider::new(device, memory).ok().map(Arc::new)
}
#[allow(dead_code)] pub struct RuntimeProviderHandles {
pub provider: Arc<CudaKernelProvider>,
pub memory: Arc<GpuMemoryManager>,
pub runtime: Arc<XlogDeviceRuntime>,
pub sink: Arc<InMemorySink>,
}
#[allow(dead_code)] pub fn setup_provider_with_runtime() -> Option<RuntimeProviderHandles> {
let device = match CudaDevice::new(0) {
Ok(d) => Arc::new(d),
Err(e) => {
eprintln!("Skipping: CUDA runtime unavailable: {}", e);
return None;
}
};
let pool = Arc::new(StreamPool::with_defaults(Arc::clone(&device)));
let sink: Arc<InMemorySink> = Arc::new(InMemorySink::new());
let async_resource: Box<dyn DeviceMemoryResource + Send + Sync> = Box::new(
AsyncCudaResource::new(Arc::clone(&device), 0, Arc::clone(&pool)),
);
let logging: Box<dyn DeviceMemoryResource + Send + Sync> = Box::new(LoggingResource::new(
async_resource,
sink.clone() as Arc<dyn LoggingSink>,
));
let budget: Box<dyn DeviceMemoryResource + Send + Sync> =
Box::new(GlobalDeviceBudget::new(logging, 1024 * 1024 * 1024));
let runtime = Arc::new(XlogDeviceRuntime::with_resource(
Arc::clone(&device),
0,
Arc::clone(&pool),
budget,
));
let memory = Arc::new(GpuMemoryManager::with_runtime(
Arc::clone(&device),
MemoryBudget::with_limit(1024 * 1024 * 1024),
Arc::clone(&runtime),
));
let provider = match CudaKernelProvider::with_runtime(Arc::clone(&device), Arc::clone(&memory))
{
Ok(p) => Arc::new(p),
Err(e) => {
eprintln!("Skipping: provider with_runtime construction failed: {}", e);
return None;
}
};
Some(RuntimeProviderHandles {
provider,
memory,
runtime,
sink,
})
}