candle_core/
scalar.rs

1//! TensorScalar Enum and Trait
2//!
3use crate::{Result, Tensor, WithDType};
4
5pub enum TensorScalar {
6    Tensor(Tensor),
7    Scalar(Tensor),
8}
9
10pub trait TensorOrScalar {
11    fn to_tensor_scalar(self) -> Result<TensorScalar>;
12}
13
14impl TensorOrScalar for &Tensor {
15    fn to_tensor_scalar(self) -> Result<TensorScalar> {
16        Ok(TensorScalar::Tensor(self.clone()))
17    }
18}
19
20impl<T: WithDType> TensorOrScalar for T {
21    fn to_tensor_scalar(self) -> Result<TensorScalar> {
22        let scalar = Tensor::new(self, &crate::Device::Cpu)?;
23        Ok(TensorScalar::Scalar(scalar))
24    }
25}