#![allow(unreachable_code, dead_code)]
use crate::dispatch_dtype;
use crate::dtype::{DType, Element};
use crate::error::{Error, Result};
use crate::ops::{
ActivationOps, BinaryOp, BinaryOps, CompareOp, CompareOps, ConditionalOps, MatmulOps, ReduceOp,
ReduceOps, ScalarOps, UnaryOp, UnaryOps, broadcast_shape, reduce_output_shape,
};
use crate::runtime::{Device, Runtime, cpu};
use crate::tensor::Tensor;
pub struct CpuFallbackContext {
pub device: cpu::CpuDevice,
pub client: cpu::CpuClient,
}
impl CpuFallbackContext {
#[inline]
pub fn new() -> Self {
let device = cpu::CpuDevice::new();
let client = cpu::CpuRuntime::default_client(&device);
Self { device, client }
}
#[inline]
pub fn tensor_from_gpu<T: Element, R: Runtime<DType = DType>>(
&self,
tensor: &Tensor<R>,
) -> Tensor<cpu::CpuRuntime> {
let data: Vec<T> = tensor.to_vec();
Tensor::<cpu::CpuRuntime>::from_slice(&data, tensor.shape(), &self.device)
}
}
impl Default for CpuFallbackContext {
fn default() -> Self {
Self::new()
}
}
#[inline]
pub fn validate_binary_dtypes<R: Runtime<DType = DType>>(
a: &Tensor<R>,
b: &Tensor<R>,
) -> Result<DType> {
if a.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: a.dtype(),
rhs: b.dtype(),
});
}
Ok(a.dtype())
}
#[inline]
pub fn compute_broadcast_shape<R: Runtime<DType = DType>>(
a: &Tensor<R>,
b: &Tensor<R>,
) -> Result<Vec<usize>> {
broadcast_shape(a.shape(), b.shape()).ok_or_else(|| Error::BroadcastError {
lhs: a.shape().to_vec(),
rhs: b.shape().to_vec(),
})
}
pub fn binary_op_fallback<R, D>(
a: &Tensor<R>,
b: &Tensor<R>,
op: BinaryOp,
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = validate_binary_dtypes(a, b)?;
let out_shape = compute_broadcast_shape(a, b)?;
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let b_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(b);
let result_cpu = match op {
BinaryOp::Add => cpu.client.add(&a_cpu, &b_cpu)?,
BinaryOp::Sub => cpu.client.sub(&a_cpu, &b_cpu)?,
BinaryOp::Mul => cpu.client.mul(&a_cpu, &b_cpu)?,
BinaryOp::Div => cpu.client.div(&a_cpu, &b_cpu)?,
BinaryOp::Pow => cpu.client.pow(&a_cpu, &b_cpu)?,
BinaryOp::Max => cpu.client.maximum(&a_cpu, &b_cpu)?,
BinaryOp::Min => cpu.client.minimum(&a_cpu, &b_cpu)?,
BinaryOp::Atan2 => cpu.client.atan2(&a_cpu, &b_cpu)?,
};
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, &out_shape, device));
}, op_name);
unreachable!()
}
pub fn unary_op_fallback<R, D>(
a: &Tensor<R>,
op: UnaryOp,
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = a.dtype();
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let result_cpu = match op {
UnaryOp::Neg => cpu.client.neg(&a_cpu)?,
UnaryOp::Abs => cpu.client.abs(&a_cpu)?,
UnaryOp::Sqrt => cpu.client.sqrt(&a_cpu)?,
UnaryOp::Rsqrt => cpu.client.rsqrt(&a_cpu)?,
UnaryOp::Cbrt => cpu.client.cbrt(&a_cpu)?,
UnaryOp::Exp => cpu.client.exp(&a_cpu)?,
UnaryOp::Exp2 => cpu.client.exp2(&a_cpu)?,
UnaryOp::Expm1 => cpu.client.expm1(&a_cpu)?,
UnaryOp::Log => cpu.client.log(&a_cpu)?,
UnaryOp::Log2 => cpu.client.log2(&a_cpu)?,
UnaryOp::Log10 => cpu.client.log10(&a_cpu)?,
UnaryOp::Log1p => cpu.client.log1p(&a_cpu)?,
UnaryOp::Sin => cpu.client.sin(&a_cpu)?,
UnaryOp::Cos => cpu.client.cos(&a_cpu)?,
UnaryOp::Tan => cpu.client.tan(&a_cpu)?,
UnaryOp::Asin => cpu.client.asin(&a_cpu)?,
UnaryOp::Acos => cpu.client.acos(&a_cpu)?,
UnaryOp::Atan => cpu.client.atan(&a_cpu)?,
UnaryOp::Sinh => cpu.client.sinh(&a_cpu)?,
UnaryOp::Cosh => cpu.client.cosh(&a_cpu)?,
UnaryOp::Tanh => cpu.client.tanh(&a_cpu)?,
UnaryOp::Asinh => cpu.client.asinh(&a_cpu)?,
UnaryOp::Acosh => cpu.client.acosh(&a_cpu)?,
UnaryOp::Atanh => cpu.client.atanh(&a_cpu)?,
UnaryOp::Recip => cpu.client.recip(&a_cpu)?,
UnaryOp::Square => cpu.client.square(&a_cpu)?,
UnaryOp::Floor => cpu.client.floor(&a_cpu)?,
UnaryOp::Ceil => cpu.client.ceil(&a_cpu)?,
UnaryOp::Round => cpu.client.round(&a_cpu)?,
UnaryOp::Trunc => cpu.client.trunc(&a_cpu)?,
UnaryOp::Sign => cpu.client.sign(&a_cpu)?,
};
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, a.shape(), device));
}, op_name);
unreachable!()
}
pub fn scalar_op_fallback<R, D>(
a: &Tensor<R>,
op: BinaryOp,
scalar: f64,
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = a.dtype();
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let result_cpu = match op {
BinaryOp::Add => cpu.client.add_scalar(&a_cpu, scalar)?,
BinaryOp::Sub => cpu.client.sub_scalar(&a_cpu, scalar)?,
BinaryOp::Mul => cpu.client.mul_scalar(&a_cpu, scalar)?,
BinaryOp::Div => cpu.client.div_scalar(&a_cpu, scalar)?,
BinaryOp::Pow => cpu.client.pow_scalar(&a_cpu, scalar)?,
_ => return Err(Error::UnsupportedDType { dtype, op: op_name }),
};
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, a.shape(), device));
}, op_name);
unreachable!()
}
pub fn reduce_op_fallback<R, D>(
a: &Tensor<R>,
op: ReduceOp,
dims: &[usize],
keepdim: bool,
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = a.dtype();
let out_shape = reduce_output_shape(a.shape(), dims, keepdim);
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let result_cpu = match op {
ReduceOp::Sum => cpu.client.sum(&a_cpu, dims, keepdim)?,
ReduceOp::Mean => cpu.client.mean(&a_cpu, dims, keepdim)?,
ReduceOp::Max => cpu.client.max(&a_cpu, dims, keepdim)?,
ReduceOp::Min => cpu.client.min(&a_cpu, dims, keepdim)?,
_ => return Err(Error::UnsupportedDType { dtype, op: op_name }),
};
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, &out_shape, device));
}, op_name);
unreachable!()
}
pub fn activation_fallback<R, D, F>(
a: &Tensor<R>,
device: &D,
op_name: &'static str,
op_fn: F,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
F: Fn(&cpu::CpuClient, &Tensor<cpu::CpuRuntime>) -> Result<Tensor<cpu::CpuRuntime>>,
{
let dtype = a.dtype();
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let result_cpu = op_fn(&cpu.client, &a_cpu)?;
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, a.shape(), device));
}, op_name);
unreachable!()
}
pub fn softmax_fallback<R, D>(
a: &Tensor<R>,
dim: isize,
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = a.dtype();
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let result_cpu = cpu.client.softmax(&a_cpu, dim)?;
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, a.shape(), device));
}, op_name);
unreachable!()
}
pub fn matmul_fallback<R, D>(
a: &Tensor<R>,
b: &Tensor<R>,
out_shape: &[usize],
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = validate_binary_dtypes(a, b)?;
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let b_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(b);
let result_cpu = cpu.client.matmul(&a_cpu, &b_cpu)?;
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, out_shape, device));
}, op_name);
unreachable!()
}
pub fn compare_op_fallback<R, D>(
a: &Tensor<R>,
b: &Tensor<R>,
op: CompareOp,
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = validate_binary_dtypes(a, b)?;
let out_shape = compute_broadcast_shape(a, b)?;
let cpu = CpuFallbackContext::new();
dispatch_dtype!(dtype, T => {
let a_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a);
let b_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(b);
let result_cpu = match op {
CompareOp::Eq => cpu.client.eq(&a_cpu, &b_cpu)?,
CompareOp::Ne => cpu.client.ne(&a_cpu, &b_cpu)?,
CompareOp::Lt => cpu.client.lt(&a_cpu, &b_cpu)?,
CompareOp::Le => cpu.client.le(&a_cpu, &b_cpu)?,
CompareOp::Gt => cpu.client.gt(&a_cpu, &b_cpu)?,
CompareOp::Ge => cpu.client.ge(&a_cpu, &b_cpu)?,
};
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, &out_shape, device));
}, op_name);
unreachable!()
}
#[inline]
pub fn compute_ternary_broadcast_shape<R: Runtime<DType = DType>>(
cond: &Tensor<R>,
x: &Tensor<R>,
y: &Tensor<R>,
) -> Result<Vec<usize>> {
let xy_shape = broadcast_shape(x.shape(), y.shape()).ok_or_else(|| Error::BroadcastError {
lhs: x.shape().to_vec(),
rhs: y.shape().to_vec(),
})?;
broadcast_shape(cond.shape(), &xy_shape).ok_or_else(|| Error::BroadcastError {
lhs: cond.shape().to_vec(),
rhs: xy_shape,
})
}
pub fn where_cond_fallback<R, D>(
cond: &Tensor<R>,
x: &Tensor<R>,
y: &Tensor<R>,
device: &D,
op_name: &'static str,
) -> Result<Tensor<R>>
where
R: Runtime<Device = D, DType = DType>,
D: Device + Clone,
{
let dtype = validate_binary_dtypes(x, y)?;
let cond_dtype = cond.dtype();
let out_shape = compute_ternary_broadcast_shape(cond, x, y)?;
let cpu = CpuFallbackContext::new();
dispatch_dtype!(cond_dtype, C => {
dispatch_dtype!(dtype, T => {
let cond_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<C, R>(cond);
let x_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(x);
let y_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(y);
let result_cpu = cpu.client.where_cond(&cond_cpu, &x_cpu, &y_cpu)?;
let result_data: Vec<T> = result_cpu.to_vec();
return Ok(Tensor::<R>::from_slice(&result_data, &out_shape, device));
}, op_name);
}, op_name);
unreachable!()
}
#[cfg(feature = "sparse")]
#[allow(private_interfaces)]
pub fn csc_elementwise_fallback<T: Element, R: Runtime<DType = DType>, F, FA, FB>(
a_col_ptrs: &Tensor<R>,
a_row_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_col_ptrs: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_values: &Tensor<R>,
shape: [usize; 2],
strategy: super::cpu::sparse::MergeStrategy,
semantics: super::cpu::sparse::OperationSemantics,
op: F,
only_a_op: FA,
only_b_op: FB,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>
where
R::Device: Device + Clone,
F: Fn(T, T) -> T + Copy,
FA: Fn(T) -> T + Copy,
FB: Fn(T) -> T + Copy,
{
let device = a_values.device();
let cpu = CpuFallbackContext::new();
let a_col_ptrs_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(a_col_ptrs);
let a_row_indices_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(a_row_indices);
let a_values_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a_values);
let b_col_ptrs_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(b_col_ptrs);
let b_row_indices_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(b_row_indices);
let b_values_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(b_values);
let (result_col_ptrs_cpu, result_row_indices_cpu, result_values_cpu) =
super::cpu::sparse::merge_csc_impl(
&a_col_ptrs_cpu,
&a_row_indices_cpu,
&a_values_cpu,
&b_col_ptrs_cpu,
&b_row_indices_cpu,
&b_values_cpu,
shape,
strategy,
semantics,
op,
only_a_op,
only_b_op,
)?;
let col_ptrs_data: Vec<i64> = result_col_ptrs_cpu.to_vec();
let row_indices_data: Vec<i64> = result_row_indices_cpu.to_vec();
let values_data: Vec<T> = result_values_cpu.to_vec();
let result_col_ptrs =
Tensor::<R>::from_slice(&col_ptrs_data, result_col_ptrs_cpu.shape(), device);
let result_row_indices =
Tensor::<R>::from_slice(&row_indices_data, result_row_indices_cpu.shape(), device);
let result_values = Tensor::<R>::from_slice(&values_data, result_values_cpu.shape(), device);
Ok((result_col_ptrs, result_row_indices, result_values))
}
#[cfg(feature = "sparse")]
#[allow(private_interfaces)]
pub fn coo_elementwise_fallback<T: Element, R: Runtime<DType = DType>, F, FA, FB>(
a_row_indices: &Tensor<R>,
a_col_indices: &Tensor<R>,
a_values: &Tensor<R>,
b_row_indices: &Tensor<R>,
b_col_indices: &Tensor<R>,
b_values: &Tensor<R>,
semantics: super::cpu::sparse::OperationSemantics,
op: F,
only_a_op: FA,
only_b_op: FB,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>
where
R::Device: Device + Clone,
F: Fn(T, T) -> T + Copy,
FA: Fn(T) -> T + Copy,
FB: Fn(T) -> T + Copy,
{
let device = a_values.device();
let cpu = CpuFallbackContext::new();
let a_row_indices_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(a_row_indices);
let a_col_indices_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(a_col_indices);
let a_values_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(a_values);
let b_row_indices_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(b_row_indices);
let b_col_indices_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<i64, R>(b_col_indices);
let b_values_cpu: Tensor<cpu::CpuRuntime> = cpu.tensor_from_gpu::<T, R>(b_values);
let (result_row_indices_cpu, result_col_indices_cpu, result_values_cpu) =
super::cpu::sparse::merge_coo_impl(
&a_row_indices_cpu,
&a_col_indices_cpu,
&a_values_cpu,
&b_row_indices_cpu,
&b_col_indices_cpu,
&b_values_cpu,
semantics,
op,
only_a_op,
only_b_op,
)?;
let row_indices_data: Vec<i64> = result_row_indices_cpu.to_vec();
let col_indices_data: Vec<i64> = result_col_indices_cpu.to_vec();
let values_data: Vec<T> = result_values_cpu.to_vec();
let result_row_indices =
Tensor::<R>::from_slice(&row_indices_data, result_row_indices_cpu.shape(), device);
let result_col_indices =
Tensor::<R>::from_slice(&col_indices_data, result_col_indices_cpu.shape(), device);
let result_values = Tensor::<R>::from_slice(&values_data, result_values_cpu.shape(), device);
Ok((result_row_indices, result_col_indices, result_values))
}