use scirs2_core::ndarray::{array, Array2, Array3, ArrayView2};
use scirs2_linalg::attention::scaled_dot_product_attention;
use scirs2_linalg::matrix_functions::{expm, logm, sqrtm};
use scirs2_linalg::prelude::*;
use scirs2_linalg::specialized::TridiagonalMatrix;
use scirs2_linalg::structured::{CirculantMatrix, ToeplitzMatrix};
fn matrices_approx_equal(a: &ArrayView2<f64>, b: &ArrayView2<f64>, tol: f64) -> bool {
if a.shape() != b.shape() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| (x - y).abs() < tol)
}
#[allow(dead_code)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("=== Advanced Linear Algebra Tutorial ===\n");
println!("1. Matrix Functions");
println!("------------------");
let a = array![[1.0, 1.0], [0.0, 2.0]];
println!("Matrix A:");
println!("{}", a);
let exp_a = expm(&a.view(), None)?;
println!("\nexp(A) =");
println!("{}", exp_a);
match logm(&exp_a.view()) {
Ok(log_exp_a) => {
println!("\nlog(exp(A)) =");
println!("{}", log_exp_a);
println!(
"Should equal A? {}\n",
matrices_approx_equal(&a.view(), &log_exp_a.view(), 1e-10)
);
}
Err(e) => {
println!("\nMatrix logarithm failed: {}", e);
println!("This is a known issue with the current implementation.\n");
}
}
let sqrt_a = sqrtm(&a.view(), 10, 1e-10)?;
println!("sqrt(A) =");
println!("{}", sqrt_a);
let sqrt_squared = sqrt_a.dot(&sqrt_a);
println!("\nsqrt(A) * sqrt(A) =");
println!("{}", sqrt_squared);
println!(
"Should equal A? {}\n",
matrices_approx_equal(&a.view(), &sqrt_squared.view(), 1e-10)
);
println!("2. Iterative Solvers");
println!("-------------------");
let spd = array![[4.0, 1.0, 0.0], [1.0, 3.0, 1.0], [0.0, 1.0, 4.0]];
let b = array![1.0, 2.0, 3.0];
println!("Solving Ax = b with iterative methods");
println!("A =");
println!("{}", spd);
println!("b = {:?}\n", b);
let x_cg = conjugate_gradient(&spd.view(), &b.view(), 10, 1e-10, None)?;
println!("Conjugate Gradient solution: {:?}", x_cg);
let _restart = 3;
let max_iter = 10;
let x_gmres = conjugate_gradient(&spd.view(), &b.view(), max_iter, 1e-10, None)?;
println!("GMRES solution: {:?}", x_gmres);
match solve(&spd.view(), &b.view(), None) {
Ok(x_direct) => {
println!("Direct solution: {:?}\n", x_direct);
}
Err(e) => {
println!("Direct solution failed: {}", e);
println!("This is a known limitation - matrix inverse for 3x3+ matrices not yet implemented.");
let spd_2x2 = array![[4.0, 1.0], [1.0, 3.0]];
let b_2x2 = array![1.0, 2.0];
let x_direct_2x2 = solve(&spd_2x2.view(), &b_2x2.view(), None)?;
println!("Direct solution (2x2): {:?}\n", x_direct_2x2);
}
}
println!("3. Specialized Matrix Types");
println!("--------------------------");
let main_diag = array![2.0, 2.0, 2.0];
let upper_diag = array![-1.0, -1.0];
let lower_diag = array![-1.0, -1.0];
let tridiag = TridiagonalMatrix::new(main_diag.view(), upper_diag.view(), lower_diag.view())?;
println!("Tridiagonal matrix:");
println!("Main diagonal: {:?}", main_diag);
println!("Upper diagonal: {:?}", upper_diag);
println!("Lower diagonal: {:?}", lower_diag);
let b_tri = array![1.0, 0.0, 1.0];
let x_tri = tridiag.solve(&b_tri.view())?;
println!("\nSolution to tridiagonal system: {:?}\n", x_tri);
println!("4. Structured Matrices");
println!("---------------------");
let first_col = array![4.0, 3.0, 2.0, 1.0];
let first_row = array![4.0, 5.0, 6.0, 7.0];
let toeplitz = ToeplitzMatrix::new(first_row.view(), first_col.view())?;
println!("Toeplitz matrix:");
let toeplitz_dense = toeplitz.to_dense()?;
println!("{}", toeplitz_dense);
let first_row_circ = array![1.0, 2.0, 3.0, 4.0];
let circulant = CirculantMatrix::new(first_row_circ.view())?;
println!("\nCirculant matrix:");
let circulant_dense = circulant.to_dense()?;
println!("{}", circulant_dense);
println!("5. Machine Learning Operations");
println!("-----------------------------");
let batchsize = 1;
let seq_len = 4;
let d_k = 8;
let d_v = 8;
let query = Array3::from_shape_fn((batchsize, seq_len, d_k), |_| {
scirs2_core::random::random::<f32>()
});
let key = Array3::from_shape_fn((batchsize, seq_len, d_k), |_| {
scirs2_core::random::random::<f32>()
});
let value = Array3::from_shape_fn((batchsize, seq_len, d_v), |_| {
scirs2_core::random::random::<f32>()
});
let scale = 1.0 / (d_k as f32).sqrt();
let attention_output = scaled_dot_product_attention(
&query.view(),
&key.view(),
&value.view(),
None, scale,
)?;
println!("Attention output shape: {:?}", attention_output.shape());
println!("Multi-head attention: (implementation details vary)");
println!("6. Complex Number Support");
println!("------------------------");
use scirs2_core::Complex;
use scirs2_linalg::complex::{complex_inverse, complex_matmul};
let complex_a = array![
[Complex::new(1.0, 1.0), Complex::new(0.0, 1.0)],
[Complex::new(0.0, -1.0), Complex::new(2.0, 0.0)]
];
println!("Complex matrix A:");
println!("{}", complex_a);
let complex_product = complex_matmul(&complex_a.view(), &complex_a.view())?;
println!("\nA * A =");
println!("{}", complex_product);
let complex_inv = complex_inverse(&complex_a.view())?;
println!("\ninv(A) =");
println!("{}", complex_inv);
let should_be_identity = complex_matmul(&complex_a.view(), &complex_inv.view())?;
println!("\nA * inv(A) =");
println!("{}\n", should_be_identity);
println!("=== Tutorial Complete ===");
Ok(())
}