candle_core/
scalar.rs

1//! TensorScalar Enum and Trait
2//!
3use crate::{DType, Result, Tensor, WithDType};
4use float8::F8E4M3 as f8e4m3;
5use half::{bf16, f16};
6
7#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum Scalar {
9    U8(u8),
10    U32(u32),
11    I16(i16),
12    I32(i32),
13    I64(i64),
14    BF16(bf16),
15    F16(f16),
16    F32(f32),
17    F64(f64),
18    F8E4M3(f8e4m3),
19}
20
21impl<T: WithDType> From<T> for Scalar {
22    fn from(value: T) -> Self {
23        value.to_scalar()
24    }
25}
26
27impl Scalar {
28    pub fn zero(dtype: DType) -> Self {
29        match dtype {
30            DType::U8 => Scalar::U8(0),
31            DType::U32 => Scalar::U32(0),
32            DType::I16 => Scalar::I16(0),
33            DType::I32 => Scalar::I32(0),
34            DType::I64 => Scalar::I64(0),
35            DType::BF16 => Scalar::BF16(bf16::ZERO),
36            DType::F16 => Scalar::F16(f16::ZERO),
37            DType::F32 => Scalar::F32(0.0),
38            DType::F64 => Scalar::F64(0.0),
39            DType::F8E4M3 => Scalar::F8E4M3(f8e4m3::ZERO),
40            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
41                panic!("Cannot create zero scalar for dummy type {dtype:?}")
42            }
43        }
44    }
45
46    pub fn one(dtype: DType) -> Self {
47        match dtype {
48            DType::U8 => Scalar::U8(1),
49            DType::U32 => Scalar::U32(1),
50            DType::I16 => Scalar::I16(1),
51            DType::I32 => Scalar::I32(1),
52            DType::I64 => Scalar::I64(1),
53            DType::BF16 => Scalar::BF16(bf16::ONE),
54            DType::F16 => Scalar::F16(f16::ONE),
55            DType::F32 => Scalar::F32(1.0),
56            DType::F64 => Scalar::F64(1.0),
57            DType::F8E4M3 => Scalar::F8E4M3(f8e4m3::ONE),
58            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
59                panic!("Cannot create one scalar for dummy type {dtype:?}")
60            }
61        }
62    }
63
64    pub fn dtype(&self) -> DType {
65        match self {
66            Scalar::U8(_) => DType::U8,
67            Scalar::U32(_) => DType::U32,
68            Scalar::I16(_) => DType::I16,
69            Scalar::I32(_) => DType::I32,
70            Scalar::I64(_) => DType::I64,
71            Scalar::BF16(_) => DType::BF16,
72            Scalar::F16(_) => DType::F16,
73            Scalar::F32(_) => DType::F32,
74            Scalar::F64(_) => DType::F64,
75            Scalar::F8E4M3(_) => DType::F8E4M3,
76        }
77    }
78
79    pub fn to_f64(&self) -> f64 {
80        match self {
81            Scalar::U8(v) => *v as f64,
82            Scalar::U32(v) => *v as f64,
83            Scalar::I16(v) => *v as f64,
84            Scalar::I32(v) => *v as f64,
85            Scalar::I64(v) => *v as f64,
86            Scalar::BF16(v) => v.to_f64(),
87            Scalar::F16(v) => v.to_f64(),
88            Scalar::F32(v) => *v as f64,
89            Scalar::F64(v) => *v,
90            Scalar::F8E4M3(v) => v.to_f64(),
91        }
92    }
93}
94
95pub enum TensorScalar {
96    Tensor(Tensor),
97    Scalar(Tensor),
98}
99
100pub trait TensorOrScalar {
101    fn to_tensor_scalar(self) -> Result<TensorScalar>;
102}
103
104impl TensorOrScalar for &Tensor {
105    fn to_tensor_scalar(self) -> Result<TensorScalar> {
106        Ok(TensorScalar::Tensor(self.clone()))
107    }
108}
109
110impl<T: WithDType> TensorOrScalar for T {
111    fn to_tensor_scalar(self) -> Result<TensorScalar> {
112        let scalar = Tensor::new(self, &crate::Device::Cpu)?;
113        Ok(TensorScalar::Scalar(scalar))
114    }
115}