use numrs2::expr::{CachedExpr, ExprCache, SharedExpr, SharedExprBuilder};
use numrs2::memory_optimize::access_patterns::{
cache_aware_binary_op, cache_aware_copy, cache_aware_transform, detect_layout, Block,
BlockedIterator, OptimizationHints, StrideOptimizer, Tile2D, TiledIterator2D,
};
use numrs2::prelude::*;
use numrs2::shared_array::SharedArray;
fn main() {
println!("=== NumRS2 Expression Templates Example ===\n");
shared_array_example();
shared_expr_example();
cse_example();
memory_patterns_example();
println!("\n=== All examples completed successfully! ===");
}
fn shared_array_example() {
println!("--- Part 1: SharedArray with Operator Overloading ---\n");
let a: SharedArray<f64> = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b: SharedArray<f64> = SharedArray::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
println!("Array a: {:?}", a.to_vec());
println!("Array b: {:?}", b.to_vec());
let a_clone = a.clone();
println!("Reference count after clone: {}", a.ref_count());
let sum = a.clone() + b.clone();
let diff = a.clone() - b.clone();
let product = a.clone() * b.clone();
let quotient = b.clone() / a.clone();
println!("\na + b = {:?}", sum.to_vec());
println!("a - b = {:?}", diff.to_vec());
println!("a * b = {:?}", product.to_vec());
println!("b / a = {:?}", quotient.to_vec());
let scaled = a.clone() * 2.0;
let shifted = a.clone() + 5.0;
println!("\na * 2 = {:?}", scaled.to_vec());
println!("a + 5 = {:?}", shifted.to_vec());
let result = (a.clone() + b.clone()) * 2.0 - 5.0;
println!("\n(a + b) * 2 - 5 = {:?}", result.to_vec());
let ref_sum = &a_clone + &b;
println!("&a + &b = {:?}", ref_sum.to_vec());
println!();
}
fn shared_expr_example() {
println!("--- Part 2: SharedExpr Lazy Evaluation ---\n");
let a: SharedArray<f64> = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b: SharedArray<f64> = SharedArray::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
let add_result = a.clone() + b.clone();
let scaled_expr = SharedExprBuilder::from_shared_array(add_result).mul_scalar(2.0);
println!("Expression built (mul_scalar is lazy)");
println!(
"Expression size: {} elements",
SharedExpr::size(&scaled_expr.clone().into_expr())
);
let result = scaled_expr.eval();
println!("Evaluated result: {:?}", result.to_vec());
let c: SharedArray<f64> = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let expr_c = SharedExprBuilder::from_shared_array(c);
let squared = expr_c.map(|x| x * x);
let squared_result = squared.eval();
println!("Squared: {:?}", squared_result.to_vec());
println!();
}
fn cse_example() {
println!("--- Part 3: Common Subexpression Elimination (CSE) ---\n");
let a: SharedArray<f64> = SharedArray::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let b: SharedArray<f64> = SharedArray::from_vec(vec![5.0, 6.0, 7.0, 8.0]);
let cache: ExprCache<f64> = ExprCache::new();
let sum = a.clone() + b.clone();
let sum_expr = SharedExprBuilder::from_shared_array(sum);
let cached_sum = CachedExpr::new(sum_expr.into_expr(), cache.clone());
println!("First evaluation (computes and caches):");
let result1 = cached_sum.eval();
println!("Result: {:?}", result1.to_vec());
println!("Second evaluation (uses cache):");
let result2 = cached_sum.eval();
println!("Result: {:?}", result2.to_vec());
println!("\nCache contains {} entries", cache.len());
cached_sum.invalidate();
println!("After invalidation, cache contains {} entries", cache.len());
println!();
}
fn memory_patterns_example() {
println!("--- Part 4: Memory Access Pattern Optimization ---\n");
let size = 1000;
let layout = detect_layout(&[size], &[1]);
println!("Memory layout for 1D contiguous array: {:?}", layout);
let layout_2d = detect_layout(&[100, 100], &[100, 1]);
println!("Memory layout for 2D C-contiguous: {:?}", layout_2d);
let layout_f = detect_layout(&[100, 100], &[1, 100]);
println!("Memory layout for 2D F-contiguous: {:?}", layout_f);
let hints = OptimizationHints::default_for::<f64>(size);
println!("\nOptimization hints for {} elements:", size);
println!(" Layout: {:?}", hints.layout);
println!(" Access pattern: {:?}", hints.access_pattern);
println!(" Block size: {}", hints.block_size);
println!(" Use parallel: {}", hints.use_parallel);
println!(" Cache efficiency: {:.2}", hints.cache_efficiency);
let hints_analyzed = OptimizationHints::analyze::<f64>(&[100, 100], &[100, 1]);
println!("\nAnalyzed hints for 100x100 array:");
println!(" Layout: {:?}", hints_analyzed.layout);
println!(" Tile size: {:?}", hints_analyzed.tile_size);
println!("\nBlocked iteration (block_size = 64):");
let block_iter = BlockedIterator::new(size, 64);
let blocks: Vec<Block> = block_iter.collect();
println!(" Total blocks: {}", blocks.len());
println!(" First block: {:?}", blocks.first());
println!(" Last block: {:?}", blocks.last());
println!("\nTiled 2D iteration (100x100 matrix, 16x16 tiles):");
let rows = 100;
let cols = 100;
let tile_iter = TiledIterator2D::new(rows, cols, 16, 16);
let tiles: Vec<Tile2D> = tile_iter.collect();
println!(" Total tiles: {}", tiles.len());
println!(" First tile: {:?}", tiles.first());
println!("\nStride optimization:");
let shape = [8usize, 8];
let strides = [8usize, 1]; let stride_opt = StrideOptimizer::new(&shape, &strides);
println!(
" Optimal iteration order: {:?}",
stride_opt.optimal_iteration_order()
);
println!(" Should copy: {}", stride_opt.should_copy());
println!(
" Bandwidth efficiency: {:.2}",
stride_opt.bandwidth_efficiency()
);
println!("\nCache-aware operations:");
let src: Vec<f64> = (0..1000).map(|i| i as f64).collect();
let mut dst = vec![0.0f64; 1000];
cache_aware_copy(&src, &mut dst);
println!(" cache_aware_copy: copied {} elements", dst.len());
assert_eq!(dst[0], 0.0);
assert_eq!(dst[999], 999.0);
let mut transformed = vec![0.0f64; 1000];
cache_aware_transform(&src, &mut transformed, |x| x * 2.0);
println!(
" cache_aware_transform: first 5 values = {:?}",
&transformed[0..5]
);
let a: Vec<f64> = (0..1000).map(|i| i as f64).collect();
let b: Vec<f64> = (0..1000).map(|i| (i * 2) as f64).collect();
let mut result = vec![0.0f64; 1000];
cache_aware_binary_op(&a, &b, &mut result, |x, y| x + y);
println!(
" cache_aware_binary_op: first 5 values = {:?}",
&result[0..5]
);
println!();
}