use crate::error::{FerrotorchError, FerrotorchResult};
#[inline]
pub fn cast<T, U>(v: T) -> FerrotorchResult<U>
where
T: num_traits::ToPrimitive + std::fmt::Debug + Copy,
U: num_traits::NumCast,
{
let result: U =
<U as num_traits::NumCast>::from(v).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!(
"cast from {} to {} failed: value {:?} not representable",
std::any::type_name::<T>(),
std::any::type_name::<U>(),
v,
),
})?;
let src_finite = v.to_f64().is_some_and(f64::is_finite);
if src_finite {
if let Some(r) = result.to_f64() {
if !r.is_finite() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"cast from {} to {} failed: value {:?} saturates to non-finite ({}) and is not representable",
std::any::type_name::<T>(),
std::any::type_name::<U>(),
v,
r,
),
});
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cast_f64_to_f32_succeeds_for_finite() {
let x: f32 = cast(3.5_f64).unwrap();
assert!((x - 3.5_f32).abs() < f32::EPSILON);
}
#[test]
fn cast_f64_inf_to_i32_fails() {
let r: FerrotorchResult<i32> = cast(f64::INFINITY);
assert!(r.is_err());
let msg = format!("{}", r.unwrap_err());
assert!(msg.contains("not representable"), "got: {msg}");
}
#[test]
#[allow(clippy::float_cmp)]
fn cast_usize_to_f32_succeeds() {
let x: f32 = cast(42_usize).unwrap();
assert_eq!(x, 42.0);
}
#[test]
fn cast_to_bf16_round_trip() {
let x: half::bf16 = cast(1.5_f64).unwrap();
assert!((x.to_f32() - 1.5).abs() < 0.01);
}
#[test]
fn cast_negative_to_unsigned_fails() {
let r: FerrotorchResult<u32> = cast(-1_i32);
assert!(r.is_err());
}
#[test]
fn cast_huge_f64_to_bf16_returns_err() {
let r: FerrotorchResult<half::bf16> = cast(1e300_f64);
assert!(r.is_err(), "expected Err for finite-source saturation");
let msg = format!("{}", r.unwrap_err());
assert!(
msg.contains("saturates to non-finite") || msg.contains("not representable"),
"got: {msg}"
);
}
#[test]
fn cast_huge_f32_to_bf16_returns_err() {
let r: FerrotorchResult<half::bf16> = cast(f32::MAX);
assert!(r.is_err(), "expected Err for finite-source saturation");
}
#[test]
fn cast_f64_inf_to_bf16_passes_through() {
let v: half::bf16 = cast(f64::INFINITY).expect("Inf passthrough");
assert!(v.is_infinite(), "expected bf16::INFINITY, got {v}");
}
#[test]
fn cast_f64_neg_inf_to_bf16_passes_through() {
let v: half::bf16 = cast(f64::NEG_INFINITY).expect("-Inf passthrough");
assert!(v.is_infinite(), "expected bf16::NEG_INFINITY, got {v}");
}
#[test]
fn cast_f64_nan_to_bf16_passes_through() {
let v: half::bf16 = cast(f64::NAN).expect("NaN passthrough");
assert!(v.is_nan(), "expected bf16::NAN");
}
#[test]
fn cast_f64_in_range_to_bf16_succeeds() {
let v: half::bf16 = cast(1.5_f64).expect("in-range");
assert!((v.to_f32() - 1.5).abs() < 0.01);
}
#[test]
fn cast_huge_f64_to_f16_returns_err() {
let r: FerrotorchResult<half::f16> = cast(1e30_f64);
assert!(r.is_err(), "expected Err for finite-source saturation");
}
#[test]
fn cast_f64_inf_to_f16_passes_through() {
let v: half::f16 = cast(f64::INFINITY).expect("Inf passthrough");
assert!(v.is_infinite(), "expected f16::INFINITY, got {v}");
}
#[test]
fn cast_f64_nan_to_f16_passes_through() {
let v: half::f16 = cast(f64::NAN).expect("NaN passthrough");
assert!(v.is_nan(), "expected f16::NAN");
}
#[test]
fn cast_f64_in_range_to_f16_succeeds() {
let v: half::f16 = cast(1.5_f64).expect("in-range");
assert!((v.to_f32() - 1.5).abs() < 0.01);
}
#[test]
fn cast_huge_f64_to_f32_returns_err() {
let r: FerrotorchResult<f32> = cast(f64::MAX);
assert!(r.is_err(), "expected Err for finite-source saturation");
}
#[test]
fn cast_f64_inf_to_f32_passes_through() {
let v: f32 = cast(f64::INFINITY).expect("Inf passthrough");
assert!(v.is_infinite());
}
}