use torsh_core::dtype::{DType, TensorElement};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum TypeCategory {
Boolean = 0,
Integer = 1,
FloatingPoint = 2,
Complex = 3,
}
pub fn get_type_category(dtype: DType) -> TypeCategory {
match dtype {
DType::Bool => TypeCategory::Boolean,
DType::U8 | DType::I8 | DType::I16 | DType::I32 | DType::I64 | DType::U32 | DType::U64 => {
TypeCategory::Integer
}
DType::F16 | DType::F32 | DType::F64 => TypeCategory::FloatingPoint,
DType::C64 | DType::C128 => TypeCategory::Complex,
DType::BF16 => TypeCategory::FloatingPoint, DType::QInt8 | DType::QUInt8 | DType::QInt32 => TypeCategory::Integer, }
}
pub fn get_type_precision(dtype: DType) -> u8 {
match dtype {
DType::Bool => 1,
DType::U8 | DType::I8 => 8,
DType::I16 => 16,
DType::F16 | DType::BF16 => 16,
DType::I32 | DType::F32 | DType::U32 => 32,
DType::I64 | DType::F64 | DType::C64 | DType::U64 => 64,
DType::C128 => 128,
DType::QInt8 | DType::QUInt8 => 8, DType::QInt32 => 32, }
}
pub fn promote_types(lhs: DType, rhs: DType) -> Result<DType> {
if lhs == rhs {
return Ok(lhs);
}
let lhs_category = get_type_category(lhs);
let rhs_category = get_type_category(rhs);
let result_category = std::cmp::max(lhs_category, rhs_category);
match result_category {
TypeCategory::Boolean => {
Ok(DType::Bool)
}
TypeCategory::Integer => {
let lhs_precision = get_type_precision(lhs);
let rhs_precision = get_type_precision(rhs);
if lhs_precision >= rhs_precision {
Ok(lhs)
} else {
Ok(rhs)
}
}
TypeCategory::FloatingPoint => {
let target_precision = std::cmp::max(get_type_precision(lhs), get_type_precision(rhs));
match target_precision {
16 => Ok(DType::F16), 32 => Ok(DType::F32),
64 => Ok(DType::F64),
_ => Ok(DType::F32), }
}
TypeCategory::Complex => {
let target_precision = std::cmp::max(get_type_precision(lhs), get_type_precision(rhs));
if target_precision <= 64 {
Ok(DType::C64)
} else {
Ok(DType::C128)
}
}
}
}
pub fn promote_multiple_types(types: &[DType]) -> Result<DType> {
if types.is_empty() {
return Err(TorshError::InvalidArgument(
"Cannot promote empty type list".to_string(),
));
}
if types.len() == 1 {
return Ok(types[0]);
}
let mut result = types[0];
for &dtype in &types[1..] {
result = promote_types(result, dtype)?;
}
Ok(result)
}
pub fn can_cast_safely(from: DType, to: DType) -> bool {
if from == to {
return true;
}
let from_category = get_type_category(from);
let to_category = get_type_category(to);
if to_category > from_category {
return true;
}
if from_category == to_category {
return get_type_precision(to) >= get_type_precision(from);
}
false
}
pub fn promote_tensors<T, U>(lhs: &Tensor<T>, rhs: &Tensor<U>) -> Result<(Tensor<f32>, Tensor<f32>)>
where
T: TensorElement,
U: TensorElement,
{
let lhs_f32 = Tensor::zeros(&lhs.shape().dims(), lhs.device())?;
let rhs_f32 = Tensor::zeros(&rhs.shape().dims(), rhs.device())?;
Ok((lhs_f32, rhs_f32))
}
pub fn promote_tensor_list<T>(tensors: &[&Tensor<T>]) -> Result<Vec<Tensor<f32>>>
where
T: TensorElement,
{
let mut result = Vec::new();
for tensor in tensors {
let promoted = Tensor::zeros(&tensor.shape().dims(), tensor.device())?;
result.push(promoted);
}
Ok(result)
}
pub fn result_type<T, U>(lhs: &Tensor<T>, rhs: &Tensor<U>) -> Result<DType>
where
T: TensorElement,
U: TensorElement,
{
let lhs_dtype = lhs.dtype();
let rhs_dtype = rhs.dtype();
promote_types(lhs_dtype, rhs_dtype)
}
pub fn promote_scalar_type<T>(tensor_dtype: DType, _scalar: T) -> Result<DType>
where
T: TensorElement,
{
let scalar_dtype = T::dtype();
promote_types(tensor_dtype, scalar_dtype)
}
pub fn ensure_compatible_types<T, U>(lhs: &Tensor<T>, rhs: &Tensor<U>) -> Result<DType>
where
T: TensorElement,
U: TensorElement,
{
result_type(lhs, rhs)
}
pub fn reduction_result_type(input_dtype: DType, operation: &str) -> Result<DType> {
match operation {
"sum" | "prod" => {
match input_dtype {
DType::Bool | DType::U8 | DType::I8 | DType::I16 => Ok(DType::I64),
DType::I32 | DType::U32 => Ok(DType::I64),
DType::I64 | DType::U64 => Ok(DType::I64),
DType::F16 | DType::BF16 => Ok(DType::F32),
DType::F32 => Ok(DType::F32),
DType::F64 => Ok(DType::F64),
DType::C64 => Ok(DType::C64),
DType::C128 => Ok(DType::C128),
DType::QInt8 | DType::QUInt8 | DType::QInt32 => Ok(DType::I64), }
}
"mean" => {
match input_dtype {
DType::Bool
| DType::U8
| DType::I8
| DType::I16
| DType::I32
| DType::I64
| DType::U32
| DType::U64 => Ok(DType::F32),
DType::F16 | DType::BF16 => Ok(DType::F32),
DType::F32 => Ok(DType::F32),
DType::F64 => Ok(DType::F64),
DType::C64 => Ok(DType::C64),
DType::C128 => Ok(DType::C128),
DType::QInt8 | DType::QUInt8 | DType::QInt32 => Ok(DType::F32), }
}
"max" | "min" | "argmax" | "argmin" => {
if operation.starts_with("arg") {
Ok(DType::I64)
} else {
Ok(input_dtype)
}
}
_ => {
Ok(input_dtype)
}
}
}
pub fn common_dtype_for_operation(dtypes: &[DType], operation: &str) -> Result<DType> {
if dtypes.is_empty() {
return Err(TorshError::InvalidArgument(
"No dtypes provided".to_string(),
));
}
let common_type = promote_multiple_types(dtypes)?;
match operation {
"div" | "true_div" => {
match common_type {
DType::Bool | DType::U8 | DType::I8 | DType::I16 | DType::I32 | DType::I64 => {
Ok(DType::F32)
}
_ => Ok(common_type),
}
}
"floor_div" => {
match common_type {
DType::Bool | DType::U8 | DType::I8 => Ok(DType::I32),
_ => Ok(common_type),
}
}
_ => Ok(common_type),
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::{ones, zeros};
#[test]
fn test_type_categories() {
assert_eq!(get_type_category(DType::Bool), TypeCategory::Boolean);
assert_eq!(get_type_category(DType::I32), TypeCategory::Integer);
assert_eq!(get_type_category(DType::F32), TypeCategory::FloatingPoint);
assert_eq!(get_type_category(DType::C64), TypeCategory::Complex);
}
#[test]
fn test_type_precision() {
assert!(get_type_precision(DType::F64) > get_type_precision(DType::F32));
assert!(get_type_precision(DType::I64) > get_type_precision(DType::I32));
assert!(get_type_precision(DType::C128) > get_type_precision(DType::C64));
}
#[test]
fn test_basic_type_promotion() {
assert_eq!(promote_types(DType::I32, DType::I32).unwrap(), DType::I32);
assert_eq!(promote_types(DType::I32, DType::F32).unwrap(), DType::F32);
assert_eq!(promote_types(DType::F32, DType::F64).unwrap(), DType::F64);
assert_eq!(promote_types(DType::F32, DType::C64).unwrap(), DType::C64);
}
#[test]
fn test_multiple_type_promotion() {
let types = vec![DType::I32, DType::F32, DType::F64];
assert_eq!(promote_multiple_types(&types).unwrap(), DType::F64);
let types = vec![DType::Bool, DType::I16, DType::I32];
assert_eq!(promote_multiple_types(&types).unwrap(), DType::I32);
}
#[test]
fn test_safe_casting() {
assert!(can_cast_safely(DType::I32, DType::I64));
assert!(can_cast_safely(DType::F32, DType::F64));
assert!(can_cast_safely(DType::I32, DType::F32));
assert!(!can_cast_safely(DType::F64, DType::F32));
assert!(!can_cast_safely(DType::I64, DType::I32));
}
#[test]
fn test_reduction_result_types() {
assert_eq!(
reduction_result_type(DType::I32, "sum").unwrap(),
DType::I64
);
assert_eq!(
reduction_result_type(DType::F32, "mean").unwrap(),
DType::F32
);
assert_eq!(
reduction_result_type(DType::I32, "argmax").unwrap(),
DType::I64
);
assert_eq!(
reduction_result_type(DType::F32, "max").unwrap(),
DType::F32
);
}
#[test]
fn test_operation_dtypes() {
let dtypes = vec![DType::I32, DType::F32];
assert_eq!(
common_dtype_for_operation(&dtypes, "add").unwrap(),
DType::F32
);
assert_eq!(
common_dtype_for_operation(&dtypes, "div").unwrap(),
DType::F32
);
let int_types = vec![DType::I16, DType::I32];
assert_eq!(
common_dtype_for_operation(&int_types, "div").unwrap(),
DType::F32
);
}
#[test]
fn test_tensor_promotion() -> Result<()> {
let t1: Tensor<f32> = ones(&[2, 3])?; let t2: Tensor<f32> = zeros(&[2, 3])?;
let result_dtype = result_type(&t1, &t2)?;
assert_eq!(result_dtype, DType::F32);
Ok(())
}
}