use super::super::types::{QuantizableInteger, QuantizedTensor};
use crate::error::{RusTorchError, RusTorchResult};
#[cfg(feature = "cuda")]
use crate::gpu::DeviceType;
pub fn qmatmul_cuda<Q: QuantizableInteger>(
a: &QuantizedTensor<Q>,
b: &QuantizedTensor<Q>,
) -> RusTorchResult<QuantizedTensor<Q>> {
#[cfg(feature = "cuda")]
{
super::cpu::qmatmul_cpu(a, b)
}
#[cfg(not(feature = "cuda"))]
{
super::cpu::qmatmul_cpu(a, b)
}
}
pub fn qconv2d_cuda<Q: QuantizableInteger>(
input: &QuantizedTensor<Q>,
weight: &QuantizedTensor<Q>,
bias: Option<&QuantizedTensor<Q>>,
stride: (usize, usize),
padding: (usize, usize),
) -> RusTorchResult<QuantizedTensor<Q>> {
#[cfg(feature = "cuda")]
{
super::cpu::qconv2d_cpu(input, weight, bias, stride, padding)
}
#[cfg(not(feature = "cuda"))]
{
super::cpu::qconv2d_cpu(input, weight, bias, stride, padding)
}
}
pub struct CudaOps;
impl CudaOps {
pub fn is_available() -> bool {
#[cfg(feature = "cuda")]
{
use crate::gpu::DeviceType;
DeviceType::Cuda(0).is_available()
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
pub fn device_count() -> usize {
#[cfg(feature = "cuda")]
{
use crate::gpu::DeviceType;
(0..8)
.take_while(|&i| DeviceType::Cuda(i).is_available())
.count()
}
#[cfg(not(feature = "cuda"))]
{
0
}
}
}
pub use self::CudaOps as cuda_ops;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cuda_availability() {
let _available = CudaOps::is_available();
let _count = CudaOps::device_count();
}
}