candle_core/
scalar.rs

1//! TensorScalar Enum and Trait
2//!
3use crate::{DType, Result, Tensor, WithDType};
4use half::{bf16, f16};
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7pub enum Scalar {
8    U8(u8),
9    U32(u32),
10    I64(i64),
11    BF16(bf16),
12    F16(f16),
13    F32(f32),
14    F64(f64),
15}
16
17impl<T: WithDType> From<T> for Scalar {
18    fn from(value: T) -> Self {
19        value.to_scalar()
20    }
21}
22
23impl Scalar {
24    pub fn zero(dtype: DType) -> Self {
25        match dtype {
26            DType::U8 => Scalar::U8(0),
27            DType::U32 => Scalar::U32(0),
28            DType::I64 => Scalar::I64(0),
29            DType::BF16 => Scalar::BF16(bf16::ZERO),
30            DType::F16 => Scalar::F16(f16::ZERO),
31            DType::F32 => Scalar::F32(0.0),
32            DType::F64 => Scalar::F64(0.0),
33        }
34    }
35
36    pub fn one(dtype: DType) -> Self {
37        match dtype {
38            DType::U8 => Scalar::U8(1),
39            DType::U32 => Scalar::U32(1),
40            DType::I64 => Scalar::I64(1),
41            DType::BF16 => Scalar::BF16(bf16::ONE),
42            DType::F16 => Scalar::F16(f16::ONE),
43            DType::F32 => Scalar::F32(1.0),
44            DType::F64 => Scalar::F64(1.0),
45        }
46    }
47
48    pub fn dtype(&self) -> DType {
49        match self {
50            Scalar::U8(_) => DType::U8,
51            Scalar::U32(_) => DType::U32,
52            Scalar::I64(_) => DType::I64,
53            Scalar::BF16(_) => DType::BF16,
54            Scalar::F16(_) => DType::F16,
55            Scalar::F32(_) => DType::F32,
56            Scalar::F64(_) => DType::F64,
57        }
58    }
59
60    pub fn to_f64(&self) -> f64 {
61        match self {
62            Scalar::U8(v) => *v as f64,
63            Scalar::U32(v) => *v as f64,
64            Scalar::I64(v) => *v as f64,
65            Scalar::BF16(v) => v.to_f64(),
66            Scalar::F16(v) => v.to_f64(),
67            Scalar::F32(v) => *v as f64,
68            Scalar::F64(v) => *v,
69        }
70    }
71}
72
73pub enum TensorScalar {
74    Tensor(Tensor),
75    Scalar(Tensor),
76}
77
78pub trait TensorOrScalar {
79    fn to_tensor_scalar(self) -> Result<TensorScalar>;
80}
81
82impl TensorOrScalar for &Tensor {
83    fn to_tensor_scalar(self) -> Result<TensorScalar> {
84        Ok(TensorScalar::Tensor(self.clone()))
85    }
86}
87
88impl<T: WithDType> TensorOrScalar for T {
89    fn to_tensor_scalar(self) -> Result<TensorScalar> {
90        let scalar = Tensor::new(self, &crate::Device::Cpu)?;
91        Ok(TensorScalar::Scalar(scalar))
92    }
93}