1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use burn_tensor::Element;
use half::{bf16, f16};

/// The element type for the tch backend.
pub trait TchElement: Element + tch::kind::Element {}

impl TchElement for f64 {}
impl TchElement for f32 {}
impl TchElement for f16 {}
impl TchElement for bf16 {}

impl TchElement for i64 {}
impl TchElement for i32 {}
impl TchElement for i16 {}

impl TchElement for u8 {}