#![warn(missing_debug_implementations)]
use std::fmt::Debug;
#[cfg(feature = "cuda")]
pub use kn_cuda_eval::executor::CudaExecutor;
#[cfg(feature = "cuda")]
pub use kn_cuda_sys::wrapper::handle::CudaDevice;
use kn_graph::cpu::cpu_eval_graph;
use kn_graph::dtype::DTensor;
use kn_graph::graph::Graph;
pub fn compiled_with_cuda_support() -> bool {
#[cfg(feature = "cuda")]
return true;
#[cfg(not(feature = "cuda"))]
return false;
}
#[derive(Debug)]
pub enum Device {
Cpu,
#[cfg(feature = "cuda")]
Cuda(CudaDevice),
}
#[derive(Debug)]
pub enum PreparedGraph {
CPU { graph: Graph, batch_size: usize },
#[cfg(feature = "cuda")]
Cuda { executor: CudaExecutor },
}
impl Device {
pub fn prepare(&self, graph: Graph, batch_size: usize) -> PreparedGraph {
match *self {
Device::Cpu => PreparedGraph::CPU { graph, batch_size },
#[cfg(feature = "cuda")]
Device::Cuda(device) => PreparedGraph::Cuda {
executor: CudaExecutor::new(device, &graph, batch_size),
},
}
}
pub fn best() -> Device {
if let Some(device) = Device::first_cuda() {
return device;
}
Device::Cpu
}
pub fn first_cuda() -> Option<Device> {
#[cfg(feature = "cuda")]
if let Some(device) = CudaDevice::all().next() {
return Some(Device::Cuda(device));
}
None
}
}
impl PreparedGraph {
pub fn eval(&mut self, inputs: &[DTensor]) -> Vec<DTensor> {
match self {
PreparedGraph::CPU { graph, batch_size } => cpu_eval_graph(graph, *batch_size, inputs),
#[cfg(feature = "cuda")]
PreparedGraph::Cuda { executor } => executor.evaluate(inputs).to_owned(),
}
}
}