Skip to main content

burn_tch/
element.rs

1use burn_backend::Element;
2use burn_backend::{bf16, f16};
3
4/// The element type for the tch backend.
5pub trait TchElement: Element + tch::kind::Element {
6    /// Returns the associated tensor kind for [`tch::kind::Element`].
7    fn kind() -> tch::Kind {
8        Self::KIND
9    }
10}
11
12impl TchElement for f64 {}
13impl TchElement for f32 {}
14impl TchElement for f16 {}
15impl TchElement for bf16 {
16    fn kind() -> tch::Kind {
17        let mut kind = <Self as tch::kind::Element>::KIND;
18        // Incorrect kind mapping in tch definitions, force bfloat16
19        if matches!(Self::dtype(), burn_backend::DType::BF16) && kind == tch::Kind::Half {
20            kind = tch::Kind::BFloat16
21        }
22        kind
23    }
24}
25
26impl TchElement for i64 {}
27impl TchElement for i32 {}
28impl TchElement for i16 {}
29impl TchElement for i8 {}
30
31impl TchElement for u8 {}
32
33impl TchElement for bool {}
34
35#[cfg(test)]
36mod tests {
37    use super::*;
38
39    #[test]
40    fn test_elem_kinds() {
41        assert_eq!(f64::kind(), tch::Kind::Double);
42        assert_eq!(f32::kind(), tch::Kind::Float);
43        assert_eq!(f16::kind(), tch::Kind::Half);
44        assert_eq!(bf16::kind(), tch::Kind::BFloat16);
45        assert_eq!(i64::kind(), tch::Kind::Int64);
46        assert_eq!(i32::kind(), tch::Kind::Int);
47        assert_eq!(i16::kind(), tch::Kind::Int16);
48        assert_eq!(i8::kind(), tch::Kind::Int8);
49        assert_eq!(bool::kind(), tch::Kind::Bool);
50    }
51}