use crate::{FloatElement, Tensor, TensorElement};
use torsh_core::error::{Result, TorshError};
use torsh_core::device::DeviceType;
impl<T: TensorElement> Tensor<T> {
pub fn to_f32(&self) -> Result<Tensor<f32>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<f32>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.and_then(|f| if f.is_finite() && f >= f32::MIN as f64 && f <= f32::MAX as f64 {
Some(f as f32)
} else {
None
})
.ok_or_else(|| TorshError::InvalidArgument(
format!("Cannot convert value to f32: {}", f64::from_bits(<T as TensorElement>::to_f64(&x).unwrap_or(0.0) as u64))
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_f64(&self) -> Result<Tensor<f64>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<f64>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to f64".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_i32(&self) -> Result<Tensor<i32>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<i32>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.and_then(|f| {
if f.is_finite() && f >= i32::MIN as f64 && f <= i32::MAX as f64 {
Some(f.round() as i32)
} else {
None
}
})
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to i32: value out of range or not finite".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_i64(&self) -> Result<Tensor<i64>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<i64>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.and_then(|f| {
if f.is_finite() && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
Some(f.round() as i64)
} else {
None
}
})
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to i64: value out of range or not finite".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_bool(&self) -> Result<Tensor<bool>> {
let data = self.data()?;
let converted_data: Vec<bool> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.map(|f| f != 0.0)
.unwrap_or(false)
})
.collect();
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_f16(&self) -> Result<Tensor<torsh_core::dtype::f16>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<torsh_core::dtype::f16>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.map(|f| torsh_core::dtype::f16::from_f64(f))
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to f16".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_bf16(&self) -> Result<Tensor<torsh_core::dtype::bf16>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<torsh_core::dtype::bf16>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.map(|f| torsh_core::dtype::bf16::from_f64(f))
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to bf16".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_u8(&self) -> Result<Tensor<u8>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<u8>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.and_then(|f| {
if f.is_finite() && f >= 0.0 && f <= u8::MAX as f64 {
Some(f.round() as u8)
} else {
None
}
})
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to u8: value out of range or not finite".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_i8(&self) -> Result<Tensor<i8>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<i8>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.and_then(|f| {
if f.is_finite() && f >= i8::MIN as f64 && f <= i8::MAX as f64 {
Some(f.round() as i8)
} else {
None
}
})
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to i8: value out of range or not finite".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_i16(&self) -> Result<Tensor<i16>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<i16>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.and_then(|f| {
if f.is_finite() && f >= i16::MIN as f64 && f <= i16::MAX as f64 {
Some(f.round() as i16)
} else {
None
}
})
.ok_or_else(|| TorshError::InvalidArgument(
"Cannot convert value to i16: value out of range or not finite".to_string()
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_cpu(&self) -> Result<Self> {
self.to_device(DeviceType::Cpu)
}
pub fn to_cuda(&self, device_id: usize) -> Result<Self> {
self.to_device(DeviceType::Cuda(device_id))
}
pub fn to_tensor<U: TensorElement>(&self) -> Result<Tensor<U>> {
let data = self.data()?;
let converted_data: std::result::Result<Vec<U>, _> = data
.iter()
.map(|&x| {
<T as TensorElement>::to_f64(&x)
.and_then(|f| U::from_f64(f))
.ok_or_else(|| TorshError::InvalidArgument(
format!("Cannot convert value to target type")
))
})
.collect();
let converted_data = converted_data?;
Tensor::from_data(
converted_data,
self.shape().dims().to_vec(),
self.device,
)
}
}
impl<T: FloatElement> Tensor<T> {
pub fn to_complex32(&self) -> Result<Tensor<torsh_core::dtype::Complex32>> {
let data = self.data()?;
let complex_data: Vec<torsh_core::dtype::Complex32> = data
.iter()
.map(|&x| {
let real = <T as TensorElement>::to_f64(&x).unwrap_or(0.0) as f32;
torsh_core::dtype::Complex32::new(real, 0.0)
})
.collect();
Tensor::from_data(
complex_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn to_complex64(&self) -> Result<Tensor<torsh_core::dtype::Complex64>> {
let data = self.data()?;
let complex_data: Vec<torsh_core::dtype::Complex64> = data
.iter()
.map(|&x| {
let real = <T as TensorElement>::to_f64(&x).unwrap_or(0.0);
torsh_core::dtype::Complex64::new(real, 0.0)
})
.collect();
Tensor::from_data(
complex_data,
self.shape().dims().to_vec(),
self.device,
)
}
}
impl Tensor<torsh_core::dtype::Complex32> {
pub fn real_part(&self) -> Result<Tensor<f32>> {
let data = self.data()?;
let real_data: Vec<f32> = data.iter().map(|x| x.re).collect();
Tensor::from_data(
real_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn imag_part(&self) -> Result<Tensor<f32>> {
let data = self.data()?;
let imag_data: Vec<f32> = data.iter().map(|x| x.im).collect();
Tensor::from_data(
imag_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn magnitude(&self) -> Result<Tensor<f32>> {
let data = self.data()?;
let mag_data: Vec<f32> = data.iter().map(|x| x.norm()).collect();
Tensor::from_data(
mag_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn phase(&self) -> Result<Tensor<f32>> {
let data = self.data()?;
let phase_data: Vec<f32> = data.iter().map(|x| x.arg()).collect();
Tensor::from_data(
phase_data,
self.shape().dims().to_vec(),
self.device,
)
}
}
impl Tensor<torsh_core::dtype::Complex64> {
pub fn real_part(&self) -> Result<Tensor<f64>> {
let data = self.data()?;
let real_data: Vec<f64> = data.iter().map(|x| x.re).collect();
Tensor::from_data(
real_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn imag_part(&self) -> Result<Tensor<f64>> {
let data = self.data()?;
let imag_data: Vec<f64> = data.iter().map(|x| x.im).collect();
Tensor::from_data(
imag_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn magnitude(&self) -> Result<Tensor<f64>> {
let data = self.data()?;
let mag_data: Vec<f64> = data.iter().map(|x| x.norm()).collect();
Tensor::from_data(
mag_data,
self.shape().dims().to_vec(),
self.device,
)
}
pub fn phase(&self) -> Result<Tensor<f64>> {
let data = self.data()?;
let phase_data: Vec<f64> = data.iter().map(|x| x.arg()).collect();
Tensor::from_data(
phase_data,
self.shape().dims().to_vec(),
self.device,
)
}
}
impl Tensor<bool> {
}
pub fn promote_types<T1: TensorElement, T2: TensorElement>(
tensor1: &Tensor<T1>,
tensor2: &Tensor<T2>,
) -> Result<(Tensor<f64>, Tensor<f64>)> {
let promoted1 = tensor1.to_f64()?;
let promoted2 = tensor2.to_f64()?;
Ok((promoted1, promoted2))
}
pub fn complex_from_parts<T: FloatElement>(
real: &Tensor<T>,
imag: &Tensor<T>,
) -> Result<Tensor<torsh_core::dtype::Complex64>> {
if real.shape() != imag.shape() {
return Err(TorshError::ShapeMismatch {
expected: real.shape().dims().to_vec(),
got: imag.shape().dims().to_vec(),
});
}
let real_data = real.data()?;
let imag_data = imag.data()?;
let complex_data: Vec<torsh_core::dtype::Complex64> = real_data
.iter()
.zip(imag_data.iter())
.map(|(&r, &i)| {
let real_f64 = <T as TensorElement>::to_f64(&r).unwrap_or(0.0);
let imag_f64 = <T as TensorElement>::to_f64(&i).unwrap_or(0.0);
torsh_core::dtype::Complex64::new(real_f64, imag_f64)
})
.collect();
Tensor::from_data(
complex_data,
real.shape().dims().to_vec(),
real.device,
)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_to_f32() {
let tensor = Tensor::from_data(vec![1i32, 2, 3, 4], vec![4], DeviceType::Cpu).expect("operation should succeed");
let f32_tensor = tensor.to_f32().expect("f32 conversion should succeed");
let data = f32_tensor.data().expect("data retrieval should succeed");
assert_eq!(data.as_slice(), &[1.0f32, 2.0, 3.0, 4.0]);
}
#[test]
fn test_to_i32() {
let tensor = Tensor::from_data(vec![1.7f32, 2.3, -3.9, 4.0], vec![4], DeviceType::Cpu).expect("operation should succeed");
let i32_tensor = tensor.to_i32().expect("i32 conversion should succeed");
let data = i32_tensor.data().expect("data retrieval should succeed");
assert_eq!(data.as_slice(), &[2i32, 2, -4, 4]); }
#[test]
fn test_to_bool() {
let tensor = Tensor::from_data(vec![0.0f32, 1.0, -2.5, 0.0], vec![4], DeviceType::Cpu).expect("operation should succeed");
let bool_tensor = tensor.to_bool().expect("bool conversion should succeed");
let data = bool_tensor.data().expect("data retrieval should succeed");
assert_eq!(data.as_slice(), &[false, true, true, false]);
}
#[test]
fn test_bool_to_f32() {
let tensor = Tensor::from_data(vec![true, false, true, false], vec![4], DeviceType::Cpu).expect("operation should succeed");
let f32_tensor = tensor.to_f32().expect("f32 conversion should succeed");
let data = f32_tensor.data().expect("data retrieval should succeed");
assert_eq!(data.as_slice(), &[1.0f32, 0.0, 1.0, 0.0]);
}
#[test]
fn test_to_device_cpu() {
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("operation should succeed");
let cpu_tensor = tensor.to_cpu().expect("cpu transfer should succeed");
assert_eq!(cpu_tensor.device, DeviceType::Cpu);
}
#[test]
fn test_to_complex32() {
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("operation should succeed");
let complex_tensor = tensor.to_complex32().expect("complex32 conversion should succeed");
let data = complex_tensor.data().expect("data retrieval should succeed");
assert_eq!(data[0].re, 1.0);
assert_eq!(data[0].im, 0.0);
assert_eq!(data[1].re, 2.0);
assert_eq!(data[1].im, 0.0);
}
#[test]
fn test_complex_real_imag_parts() {
let real_data = vec![1.0f32, 2.0, 3.0];
let imag_data = vec![4.0f32, 5.0, 6.0];
let complex_data: Vec<torsh_core::dtype::Complex32> = real_data
.iter()
.zip(imag_data.iter())
.map(|(&r, &i)| torsh_core::dtype::Complex32::new(r, i))
.collect();
let complex_tensor = Tensor::from_data(complex_data, vec![3], DeviceType::Cpu).expect("operation should succeed");
let real_part = complex_tensor.real_part().expect("real_part extraction should succeed");
let imag_part = complex_tensor.imag_part().expect("imag_part extraction should succeed");
let real_data_result = real_part.data().expect("data retrieval should succeed");
let imag_data_result = imag_part.data().expect("data retrieval should succeed");
assert_eq!(real_data_result.as_slice(), &[1.0f32, 2.0, 3.0]);
assert_eq!(imag_data_result.as_slice(), &[4.0f32, 5.0, 6.0]);
}
#[test]
fn test_complex_magnitude() {
let complex_data = vec![
torsh_core::dtype::Complex32::new(3.0, 4.0), torsh_core::dtype::Complex32::new(0.0, 1.0), ];
let complex_tensor = Tensor::from_data(complex_data, vec![2], DeviceType::Cpu).expect("operation should succeed");
let magnitude = complex_tensor.magnitude().expect("magnitude computation should succeed");
let mag_data = magnitude.data().expect("data retrieval should succeed");
assert!((mag_data[0] - 5.0).abs() < 1e-6);
assert!((mag_data[1] - 1.0).abs() < 1e-6);
}
#[test]
fn test_complex_from_parts() {
let real = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu).expect("operation should succeed");
let imag = Tensor::from_data(vec![3.0f32, 4.0], vec![2], DeviceType::Cpu).expect("operation should succeed");
let complex_tensor = complex_from_parts(&real, &imag).expect("operation should succeed");
let data = complex_tensor.data().expect("data retrieval should succeed");
assert_eq!(data[0].re, 1.0);
assert_eq!(data[0].im, 3.0);
assert_eq!(data[1].re, 2.0);
assert_eq!(data[1].im, 4.0);
}
#[test]
fn test_promote_types() {
let tensor1 = Tensor::from_data(vec![1i32, 2], vec![2], DeviceType::Cpu).expect("operation should succeed");
let tensor2 = Tensor::from_data(vec![3.5f32, 4.5], vec![2], DeviceType::Cpu).expect("operation should succeed");
let (promoted1, promoted2) = promote_types(&tensor1, &tensor2).expect("operation should succeed");
let data1 = promoted1.data().expect("data retrieval should succeed");
let data2 = promoted2.data().expect("data retrieval should succeed");
assert_eq!(data1.as_slice(), &[1.0f64, 2.0]);
assert_eq!(data2.as_slice(), &[3.5f64, 4.5]);
}
#[test]
fn test_conversion_error_handling() {
let large_tensor = Tensor::from_data(
vec![f64::MAX, f64::MIN],
vec![2],
DeviceType::Cpu
).expect("operation should succeed");
assert!(large_tensor.to_f32().is_err());
}
#[test]
fn test_to_f16() {
let tensor = Tensor::from_data(vec![1.0f32, 2.5, -3.75, 4.0], vec![4], DeviceType::Cpu).expect("operation should succeed");
let f16_tensor = tensor.to_f16().expect("f16 conversion should succeed");
let data = f16_tensor.data().expect("data retrieval should succeed");
assert!((data[0].to_f32() - 1.0).abs() < 1e-3);
assert!((data[1].to_f32() - 2.5).abs() < 1e-3);
assert!((data[2].to_f32() + 3.75).abs() < 1e-3);
assert!((data[3].to_f32() - 4.0).abs() < 1e-3);
}
#[test]
fn test_to_bf16() {
let tensor = Tensor::from_data(vec![1.0f32, 2.5, -3.75, 4.0], vec![4], DeviceType::Cpu).expect("operation should succeed");
let bf16_tensor = tensor.to_bf16().expect("bf16 conversion should succeed");
let data = bf16_tensor.data().expect("data retrieval should succeed");
assert!((data[0].to_f32() - 1.0).abs() < 1e-2);
assert!((data[1].to_f32() - 2.5).abs() < 1e-2);
assert!((data[2].to_f32() + 3.75).abs() < 1e-2);
assert!((data[3].to_f32() - 4.0).abs() < 1e-2);
}
#[test]
fn test_to_u8() {
let tensor = Tensor::from_data(vec![0.0f32, 127.5, 255.0, 100.7], vec![4], DeviceType::Cpu).expect("operation should succeed");
let u8_tensor = tensor.to_u8().expect("u8 conversion should succeed");
let data = u8_tensor.data().expect("data retrieval should succeed");
assert_eq!(data.as_slice(), &[0u8, 128, 255, 101]);
}
#[test]
fn test_to_i8() {
let tensor = Tensor::from_data(vec![-128.0f32, -50.5, 0.0, 127.0], vec![4], DeviceType::Cpu).expect("operation should succeed");
let i8_tensor = tensor.to_i8().expect("i8 conversion should succeed");
let data = i8_tensor.data().expect("data retrieval should succeed");
assert_eq!(data.as_slice(), &[-128i8, -50, 0, 127]);
}
#[test]
fn test_to_i16() {
let tensor = Tensor::from_data(vec![-1000.0f32, -50.5, 0.0, 1000.7], vec![4], DeviceType::Cpu).expect("operation should succeed");
let i16_tensor = tensor.to_i16().expect("i16 conversion should succeed");
let data = i16_tensor.data().expect("data retrieval should succeed");
assert_eq!(data.as_slice(), &[-1000i16, -50, 0, 1001]);
}
#[test]
fn test_generic_to_tensor() {
let tensor = Tensor::from_data(vec![1.5f32, 2.5, 3.5], vec![3], DeviceType::Cpu).expect("operation should succeed");
let f64_tensor: Tensor<f64> = tensor.to_tensor().expect("tensor type conversion should succeed");
let data = f64_tensor.data().expect("data retrieval should succeed");
assert!((data[0] - 1.5).abs() < 1e-6);
assert!((data[1] - 2.5).abs() < 1e-6);
assert!((data[2] - 3.5).abs() < 1e-6);
}
#[test]
fn test_mixed_precision_workflow() {
let f32_tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu).expect("operation should succeed");
let f16_tensor = f32_tensor.to_f16().expect("f16 conversion should succeed");
let f32_result: Tensor<f32> = f16_tensor.to_tensor().expect("tensor type conversion should succeed");
let data = f32_result.data().expect("data retrieval should succeed");
assert!((data[0] - 1.0).abs() < 1e-3);
assert!((data[1] - 2.0).abs() < 1e-3);
assert!((data[2] - 3.0).abs() < 1e-3);
assert!((data[3] - 4.0).abs() < 1e-3);
}
}