use crate::error::{FerrotorchError, FerrotorchResult};
#[inline]
pub fn cast<T, U>(v: T) -> FerrotorchResult<U>
where
T: num_traits::ToPrimitive + std::fmt::Debug + Copy,
U: num_traits::NumCast,
{
<U as num_traits::NumCast>::from(v).ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!(
"cast from {} to {} failed: value {:?} not representable",
std::any::type_name::<T>(),
std::any::type_name::<U>(),
v,
),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cast_f64_to_f32_succeeds_for_finite() {
let x: f32 = cast(3.5_f64).unwrap();
assert!((x - 3.5_f32).abs() < f32::EPSILON);
}
#[test]
fn cast_f64_inf_to_i32_fails() {
let r: FerrotorchResult<i32> = cast(f64::INFINITY);
assert!(r.is_err());
let msg = format!("{}", r.unwrap_err());
assert!(msg.contains("not representable"), "got: {msg}");
}
#[test]
fn cast_usize_to_f32_succeeds() {
let x: f32 = cast(42_usize).unwrap();
assert_eq!(x, 42.0);
}
#[test]
fn cast_to_bf16_round_trip() {
let x: half::bf16 = cast(1.5_f64).unwrap();
assert!((x.to_f32() - 1.5).abs() < 0.01);
}
#[test]
fn cast_negative_to_unsigned_fails() {
let r: FerrotorchResult<u32> = cast(-1_i32);
assert!(r.is_err());
}
}