use super::client::{CpuAllocator, CpuClient};
use super::device::CpuDevice;
use crate::runtime::{NoOpGraph, Runtime};
use std::alloc::{Layout as AllocLayout, alloc, dealloc};
#[derive(Clone, Debug, Default)]
pub struct CpuRuntime;
impl Runtime for CpuRuntime {
type Device = CpuDevice;
type Client = CpuClient;
type Allocator = CpuAllocator;
type Graph = NoOpGraph;
type RawHandle = ();
type DType = crate::dtype::DType;
fn name() -> &'static str {
"cpu"
}
fn capture_graph<F, T>(client: &Self::Client, f: F) -> crate::error::Result<(Self::Graph, T)>
where
F: FnOnce(&Self::Client) -> crate::error::Result<T>,
{
let result = f(client)?;
Ok((NoOpGraph, result))
}
fn allocate(size_bytes: usize, _device: &Self::Device) -> crate::error::Result<u64> {
if size_bytes == 0 {
return Ok(0);
}
let align = 64; let layout = AllocLayout::from_size_align(size_bytes, align)
.map_err(|_| crate::error::Error::OutOfMemory { size: size_bytes })?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
return Err(crate::error::Error::OutOfMemory { size: size_bytes });
}
Ok(ptr as u64)
}
fn deallocate(ptr: u64, size_bytes: usize, _device: &Self::Device) {
if ptr == 0 || size_bytes == 0 {
return;
}
let align = 64;
let layout =
AllocLayout::from_size_align(size_bytes, align).expect("Invalid allocation layout");
unsafe {
dealloc(ptr as *mut u8, layout);
}
}
fn copy_to_device(src: &[u8], dst: u64, _device: &Self::Device) -> crate::error::Result<()> {
if src.is_empty() || dst == 0 {
return Ok(());
}
unsafe {
std::ptr::copy_nonoverlapping(src.as_ptr(), dst as *mut u8, src.len());
}
Ok(())
}
fn copy_from_device(
src: u64,
dst: &mut [u8],
_device: &Self::Device,
) -> crate::error::Result<()> {
if dst.is_empty() || src == 0 {
return Ok(());
}
unsafe {
std::ptr::copy_nonoverlapping(src as *const u8, dst.as_mut_ptr(), dst.len());
}
Ok(())
}
fn copy_within_device(
src: u64,
dst: u64,
size_bytes: usize,
_device: &Self::Device,
) -> crate::error::Result<()> {
if size_bytes == 0 || src == 0 || dst == 0 {
return Ok(());
}
unsafe {
std::ptr::copy(src as *const u8, dst as *mut u8, size_bytes);
}
Ok(())
}
fn copy_strided(
src_handle: u64,
src_byte_offset: usize,
dst_handle: u64,
shape: &[usize],
strides: &[isize],
elem_size: usize,
_device: &Self::Device,
) -> crate::error::Result<()> {
if src_handle == 0 || dst_handle == 0 || shape.is_empty() {
return Ok(());
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Ok(());
}
let src_base = (src_handle as usize + src_byte_offset) as *const u8;
let dst_base = dst_handle as *mut u8;
let mut indices = vec![0usize; shape.len()];
for dst_offset in 0..numel {
let mut src_elem_offset: isize = 0;
for (i, &idx) in indices.iter().enumerate() {
src_elem_offset += (idx as isize) * strides[i];
}
unsafe {
std::ptr::copy_nonoverlapping(
src_base.offset(src_elem_offset * elem_size as isize),
dst_base.add(dst_offset * elem_size),
elem_size,
);
}
for dim in (0..shape.len()).rev() {
indices[dim] += 1;
if indices[dim] < shape[dim] {
break;
}
indices[dim] = 0;
}
}
Ok(())
}
fn default_device() -> Self::Device {
CpuDevice::new()
}
fn default_client(device: &Self::Device) -> Self::Client {
CpuClient::new(device.clone())
}
fn raw_handle(_client: &Self::Client) -> &Self::RawHandle {
&()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Graph;
#[test]
fn test_cpu_supports_graph_capture() {
assert!(!CpuRuntime::supports_graph_capture());
}
#[test]
fn test_cpu_capture_graph_executes_eagerly() {
let device = CpuRuntime::default_device();
let client = CpuRuntime::default_client(&device);
let mut executed = false;
let (graph, result) = CpuRuntime::capture_graph(&client, |_c| {
executed = true;
Ok(42)
})
.unwrap();
assert!(executed);
assert_eq!(result, 42);
assert!(!graph.is_replay_capable());
assert!(graph.launch().is_ok());
}
#[test]
fn test_cpu_capture_graph_propagates_error() {
let device = CpuRuntime::default_device();
let client = CpuRuntime::default_client(&device);
let result: crate::error::Result<(NoOpGraph, ())> =
CpuRuntime::capture_graph(&client, |_c| {
Err(crate::error::Error::Internal("test error".into()))
});
assert!(result.is_err());
}
}