1use crate::{Result, Tensor, WithDType};
4
5pub enum TensorScalar {
6 Tensor(Tensor),
7 Scalar(Tensor),
8}
9
10pub trait TensorOrScalar {
11 fn to_tensor_scalar(self) -> Result<TensorScalar>;
12}
13
14impl TensorOrScalar for &Tensor {
15 fn to_tensor_scalar(self) -> Result<TensorScalar> {
16 Ok(TensorScalar::Tensor(self.clone()))
17 }
18}
19
20impl<T: WithDType> TensorOrScalar for T {
21 fn to_tensor_scalar(self) -> Result<TensorScalar> {
22 let scalar = Tensor::new(self, &crate::Device::Cpu)?;
23 Ok(TensorScalar::Scalar(scalar))
24 }
25}