use rustorch::tensor::{complex::Complex, Tensor};
fn main() {
println!("=== RusTorch Complex Tensor Operations Demo ===");
println!("=== RusTorch複素テンソル操作デモ ===\n");
demo_complex_numbers();
demo_complex_tensor_creation();
demo_complex_mathematical_functions();
demo_complex_matrix_operations();
demo_complex_fft_operations();
println!("Demo completed successfully! すべてのデモが正常に完了しました!");
}
fn demo_complex_numbers() {
println!("📊 Complex Number Operations / 複素数演算");
println!("----------------------------------------");
let z1 = Complex::new(3.0, 4.0); let z2 = Complex::new(1.0, -2.0);
println!("z1 = {}", z1);
println!("z2 = {}", z2);
println!("z1 + z2 = {}", z1 + z2);
println!("z1 * z2 = {}", z1 * z2);
println!("z1 / z2 = {}", z1 / z2);
println!("Magnitude of z1: |z1| = {}", z1.abs());
println!("Phase of z1: arg(z1) = {:.4} radians", z1.arg());
println!("Complex conjugate: conj(z1) = {}", z1.conj());
let (r, theta) = z1.to_polar();
println!("Polar form: z1 = {:.3} * exp(i * {:.3})", r, theta);
let z1_from_polar = Complex::from_polar(r, theta);
println!("Reconstructed from polar: {}", z1_from_polar);
println!();
}
fn demo_complex_tensor_creation() {
println!("🎯 Complex Tensor Creation / 複素テンソル作成");
println!("------------------------------------------");
let real_part = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let imag_part = Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]);
let complex_tensor = Complex::from_tensors(&real_part, &imag_part).unwrap();
println!("Real part: {:?}", real_part.data.as_slice().unwrap());
println!("Imaginary part: {:?}", imag_part.data.as_slice().unwrap());
println!("Complex tensor:");
for (i, z) in complex_tensor.data.iter().enumerate() {
println!(" [{:?}] = {}", i, z);
}
let extracted_real = Complex::tensor_real_part(&complex_tensor);
let extracted_imag = Complex::tensor_imag_part(&complex_tensor);
let magnitude = Complex::tensor_abs(&complex_tensor);
let phase = Complex::tensor_arg(&complex_tensor);
println!(
"Extracted real: {:?}",
extracted_real.data.as_slice().unwrap()
);
println!(
"Extracted imag: {:?}",
extracted_imag.data.as_slice().unwrap()
);
println!("Magnitude: {:?}", magnitude.data.as_slice().unwrap());
println!("Phase: {:?}", phase.data.as_slice().unwrap());
let zeros = Tensor::<Complex<f64>>::complex_zeros(&[2, 3]);
let ones = Tensor::<Complex<f64>>::complex_ones(&[2, 3]);
let i_tensor = Tensor::<Complex<f64>>::complex_i(&[3]);
println!("Complex zeros shape: {:?}", zeros.shape());
println!("Complex ones shape: {:?}", ones.shape());
println!(
"Imaginary unit tensor: {:?}",
i_tensor.data.as_slice().unwrap()
);
println!();
}
fn demo_complex_mathematical_functions() {
println!("🧮 Complex Mathematical Functions / 複素数学関数");
println!("-----------------------------------------------");
let complex_data = vec![
Complex::new(1.0, 0.0), Complex::new(0.0, 1.0), Complex::new(1.0, 1.0), Complex::new(2.0, -1.0), ];
let z = Tensor::from_vec(complex_data.clone(), vec![4]);
println!("Input tensor z:");
for (i, val) in complex_data.iter().enumerate() {
println!(" z[{}] = {}", i, val);
}
let exp_z = z.exp();
println!("\nExponential e^z:");
for (i, val) in exp_z.data.iter().enumerate() {
println!(" e^z[{}] = {}", i, val);
}
let ln_z = z.ln();
println!("\nNatural logarithm ln(z):");
for (i, val) in ln_z.data.iter().enumerate() {
println!(" ln(z[{}]) = {}", i, val);
}
let sqrt_z = z.sqrt();
println!("\nSquare root √z:");
for (i, val) in sqrt_z.data.iter().enumerate() {
println!(" √z[{}] = {}", i, val);
}
let sin_z = z.sin();
let cos_z = z.cos();
println!("\nTrigonometric functions:");
for i in 0..4 {
println!(
" sin(z[{}]) = {}, cos(z[{}]) = {}",
i, sin_z.data[i], i, cos_z.data[i]
);
}
println!("\nVerifying sin²(z) + cos²(z) = 1:");
for i in 0..4 {
let sin_sq = sin_z.data[i] * sin_z.data[i];
let cos_sq = cos_z.data[i] * cos_z.data[i];
let identity = sin_sq + cos_sq;
println!(" z[{}]: sin²+cos² = {} (should be 1+0i)", i, identity);
}
println!();
}
fn demo_complex_matrix_operations() {
println!("🔢 Complex Matrix Operations / 複素行列演算");
println!("--------------------------------------------");
let a_data = vec![
Complex::new(1.0, 1.0),
Complex::new(2.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(1.0, -1.0),
];
let a = Tensor::from_vec(a_data, vec![2, 2]);
let b_data = vec![
Complex::new(1.0, 0.0),
Complex::new(0.0, 1.0),
Complex::new(1.0, 1.0),
Complex::new(1.0, 0.0),
];
let b = Tensor::from_vec(b_data, vec![2, 2]);
println!("Matrix A:");
print_complex_matrix(&a);
println!("\nMatrix B:");
print_complex_matrix(&b);
let ab = a.matmul(&b).unwrap();
println!("\nMatrix multiplication A * B:");
print_complex_matrix(&ab);
let a_t = a.transpose().unwrap();
println!("\nTranspose A^T:");
print_complex_matrix(&a_t);
let a_h = a.conj_transpose().unwrap();
println!("\nConjugate transpose A^H:");
print_complex_matrix(&a_h);
let trace_a = a.trace().unwrap();
println!("\nTrace of A: tr(A) = {}", trace_a);
let det_a = a.determinant().unwrap();
println!("Determinant of A: det(A) = {}", det_a);
println!();
}
fn demo_complex_fft_operations() {
println!("🌊 Complex FFT Operations / 複素FFT演算");
println!("---------------------------------------");
let n = 8;
let mut signal_data = Vec::new();
for i in 0..n {
let t = i as f64 / n as f64;
let real_part =
(2.0 * std::f64::consts::PI * t).cos() + 0.5 * (4.0 * std::f64::consts::PI * t).cos();
let imag_part =
(2.0 * std::f64::consts::PI * t).sin() + 0.5 * (4.0 * std::f64::consts::PI * t).sin();
signal_data.push(Complex::new(real_part, imag_part));
}
let signal = Tensor::from_vec(signal_data.clone(), vec![n]);
println!("Input signal:");
for (i, z) in signal_data.iter().enumerate() {
println!(" x[{}] = {}", i, z);
}
let fft_result = signal.fft(None, None, None).unwrap();
println!("\nFFT result:");
for (i, z) in fft_result.data.iter().enumerate() {
let magnitude = z.abs();
let phase = z.arg();
println!(
" X[{}] = {} (mag: {:.3}, phase: {:.3})",
i, z, magnitude, phase
);
}
let ifft_result = fft_result.ifft(None, None, None).unwrap();
println!("\nIFFT result (reconstructed signal):");
for (i, z) in ifft_result.data.iter().enumerate() {
println!(" x'[{}] = {} (original: {})", i, z, signal_data[i]);
}
let fft_shifted = fft_result.fftshift(None).unwrap();
println!("\nFFT with fftshift (DC in center):");
for (i, z) in fft_shifted.data.iter().enumerate() {
println!(" X_shifted[{}] = {}", i, z);
}
println!("\nPower spectrum:");
for (i, z) in fft_result.data.iter().enumerate() {
let power = z.abs_sq();
println!(" |X[{}]|² = {:.3}", i, power);
}
println!();
}
fn print_complex_matrix(matrix: &Tensor<Complex<f64>>) {
let shape = matrix.shape();
if shape.len() != 2 {
println!("Not a 2D matrix");
return;
}
let rows = shape[0];
let cols = shape[1];
for i in 0..rows {
print!(" [");
for j in 0..cols {
print!("{:>12}", format!("{}", matrix.data[[i, j]]));
if j < cols - 1 {
print!(" ");
}
}
println!("]");
}
}