use alloc::string::String;
use crate::tensor::Element;
use crate::TensorMetadata;
use crate::{ops::*, quantization::QTensorPrimitive};
use super::DeviceOps;
pub trait Backend:
FloatTensorOps<Self>
+ BoolTensorOps<Self>
+ IntTensorOps<Self>
+ ModuleOps<Self>
+ ActivationOps<Self>
+ QTensorOps<Self>
+ TransactionOps<Self>
+ Clone
+ Default
+ Sized
+ Send
+ Sync
+ core::fmt::Debug
+ 'static
{
type Device: DeviceOps;
type FloatTensorPrimitive: TensorMetadata + 'static;
type FloatElem: Element;
type IntTensorPrimitive: TensorMetadata + 'static;
type IntElem: Element;
type BoolTensorPrimitive: TensorMetadata + 'static;
type BoolElem: Element;
type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
type QuantizedEncoding: Element;
fn ad_enabled() -> bool {
false
}
fn name() -> String;
fn seed(seed: u64);
fn sync(_device: &Self::Device) {}
}
pub trait AutodiffBackend: Backend {
type InnerBackend: Backend<
Device = Self::Device,
FloatElem = Self::FloatElem,
IntElem = Self::IntElem,
>;
type Gradients: Send;
fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
fn grad(
tensor: &FloatTensor<Self>,
grads: &Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend>>;
fn grad_remove(
tensor: &FloatTensor<Self>,
grads: &mut Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend>>;
fn grad_replace(
tensor: &FloatTensor<Self>,
grads: &mut Self::Gradients,
grad: FloatTensor<Self::InnerBackend>,
);
fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
}