use super::types::{QuantizableInteger, QuantizedTensor};
use crate::error::RusTorchResult;
pub mod cpu;
#[cfg(feature = "cuda")]
pub mod cuda;
pub struct OptimizedOps;
impl OptimizedOps {
pub fn qmatmul<Q: QuantizableInteger>(
a: &QuantizedTensor<Q>,
b: &QuantizedTensor<Q>,
) -> RusTorchResult<QuantizedTensor<Q>> {
#[cfg(feature = "cuda")]
{
if a.device.is_cuda() {
return cuda::qmatmul_cuda(a, b);
}
}
cpu::qmatmul_cpu(a, b)
}
pub fn qconv2d<Q: QuantizableInteger>(
input: &QuantizedTensor<Q>,
weight: &QuantizedTensor<Q>,
bias: Option<&QuantizedTensor<Q>>,
stride: (usize, usize),
padding: (usize, usize),
) -> RusTorchResult<QuantizedTensor<Q>> {
#[cfg(feature = "cuda")]
{
if input.device.is_cuda() {
return cuda::qconv2d_cuda(input, weight, bias, stride, padding);
}
}
cpu::qconv2d_cpu(input, weight, bias, stride, padding)
}
}
pub use cpu::optimized_ops;
#[cfg(feature = "cuda")]
pub use cuda::cuda_ops;