use burn_std::DType;
pub use burn_std::backtrace::BackTrace;
use alloc::string::String;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::element::Element;
use crate::ops::*;
use crate::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use crate::{QTensorPrimitive, TensorData, TensorMetadata};
use super::DeviceOps;
#[cfg_attr(doc, doc = crate::doc_tensor!())]
#[cfg_attr(not(doc), doc = "`Tensor`")]
pub trait Backend:
FloatTensorOps<Self>
+ BoolTensorOps<Self>
+ IntTensorOps<Self>
+ ModuleOps<Self>
+ ActivationOps<Self>
+ QTensorOps<Self>
+ TransactionOps<Self>
+ Clone
+ Default
+ Sized
+ Send
+ Sync
+ core::fmt::Debug
+ 'static
{
type Device: DeviceOps;
type FloatTensorPrimitive: TensorMetadata + 'static;
type FloatElem: Element;
type IntTensorPrimitive: TensorMetadata + 'static;
type IntElem: Element;
type BoolTensorPrimitive: TensorMetadata + 'static;
type BoolElem: Element;
type QuantizedTensorPrimitive: TensorMetadata + QTensorPrimitive + 'static;
fn ad_enabled() -> bool {
false
}
#[allow(unused_variables)]
fn memory_persistent_allocations<Output, Input, Func: Fn(Input) -> Output>(
device: &Self::Device,
input: Input,
func: Func,
) -> Output {
func(input)
}
#[allow(unused_variables)]
fn memory_cleanup(device: &Self::Device) {}
fn name(device: &Self::Device) -> String;
fn seed(device: &Self::Device, seed: u64);
fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
Ok(())
}
fn staging<'a, Iter>(_data: Iter, _device: &Self::Device)
where
Iter: Iterator<Item = &'a mut TensorData>,
{
}
fn supports_dtype(device: &Self::Device, dtype: DType) -> bool;
}
#[derive(Error, Serialize, Deserialize)]
pub enum ExecutionError {
#[error("An error happened during execution\nCaused by:\n {reason}")]
WithContext {
reason: String,
},
#[error("An error happened during execution\nCaused by:\n {reason}")]
Generic {
reason: String,
#[serde(skip)]
backtrace: BackTrace,
},
}
impl core::fmt::Debug for ExecutionError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{self}"))
}
}
pub trait AutodiffBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
type Gradients: Send;
fn backward(tensor: FloatTensor<Self>) -> Self::Gradients;
fn grad(
tensor: &FloatTensor<Self>,
grads: &Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend>>;
fn grad_remove(
tensor: &FloatTensor<Self>,
grads: &mut Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend>>;
fn grad_replace(
tensor: &FloatTensor<Self>,
grads: &mut Self::Gradients,
grad: FloatTensor<Self::InnerBackend>,
);
fn inner(tensor: FloatTensor<Self>) -> FloatTensor<Self::InnerBackend>;
fn int_inner(tensor: IntTensor<Self>) -> IntTensor<Self::InnerBackend>;
fn bool_inner(tensor: BoolTensor<Self>) -> BoolTensor<Self::InnerBackend>;
fn q_inner(tensor: QuantizedTensor<Self>) -> QuantizedTensor<Self::InnerBackend>;
fn from_inner(tensor: FloatTensor<Self::InnerBackend>) -> FloatTensor<Self>;
fn int_from_inner(tensor: IntTensor<Self::InnerBackend>) -> IntTensor<Self>;
fn bool_from_inner(tensor: BoolTensor<Self::InnerBackend>) -> BoolTensor<Self>;
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;
}