element_iteration/
element_iteration.rs

1//! # Element Iteration - Basic Tensor Element Processing
2//!
3//! ## Overview
4//!
5//! This example demonstrates the fundamental tensor iterator functionality in Train Station,
6//! showing how to iterate over tensor elements as individual view tensors. Each element
7//! becomes a proper Tensor of shape [1] that supports all existing tensor operations
8//! and gradient tracking.
9//!
10//! ## Learning Objectives
11//!
12//! - Understand basic tensor element iteration
13//! - Learn standard iterator trait methods
14//! - Master element-wise transformations
15//! - Explore gradient tracking through iterations
16//!
17//! ## Prerequisites
18//!
19//! - Basic Rust knowledge and iterator concepts
20//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
21//! - Familiarity with functional programming patterns
22//!
23//! ## Key Concepts Demonstrated
24//!
25//! - **Element Views**: Each element becomes a true tensor view of shape [1]
26//! - **Standard Library Integration**: Full compatibility with Rust's iterator traits
27//! - **Gradient Tracking**: Automatic gradient propagation through element operations
28//! - **Zero-Copy Semantics**: True views with shared memory allocation
29//!
30//! ## Example Code Structure
31//!
32//! 1. **Basic Iteration**: Simple element access and transformation
33//! 2. **Standard Methods**: Using Iterator trait methods (map, filter, collect)
34//! 3. **Gradient Tracking**: Demonstrating autograd through element operations
35//! 4. **Advanced Patterns**: Complex iterator chains and transformations
36//!
37//! ## Expected Output
38//!
39//! The example will demonstrate various iteration patterns, showing element-wise
40//! transformations, gradient tracking, and performance characteristics of the
41//! tensor iterator system.
42//!
43//! ## Performance Notes
44//!
45//! - View creation is O(1) per element with true zero-copy semantics
46//! - Memory overhead is ~64 bytes per view tensor (no data copying)
47//! - All operations leverage existing SIMD-optimized tensor implementations
48//!
49//! ## Next Steps
50//!
51//! - Explore advanced_patterns.rs for complex iterator chains
52//! - Study performance_optimization.rs for large-scale processing
53//! - Review tensor operations for element-wise mathematical functions
54
55use train_station::Tensor;
56
57/// Main example function demonstrating basic element iteration
58///
59/// This function serves as the primary educational entry point,
60/// with extensive inline comments explaining each step.
61fn main() -> Result<(), Box<dyn std::error::Error>> {
62    println!("Starting Element Iteration Example");
63
64    demonstrate_basic_iteration()?;
65    demonstrate_standard_methods()?;
66    demonstrate_gradient_tracking()?;
67    demonstrate_advanced_patterns()?;
68
69    println!("Element Iteration Example completed successfully!");
70    Ok(())
71}
72
73/// Demonstrate basic tensor element iteration
74///
75/// Shows how to create iterators over tensor elements and perform
76/// simple element-wise operations.
77fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
78    println!("\n--- Basic Element Iteration ---");
79
80    // Create a simple tensor for demonstration
81    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
82    println!("Original tensor: {:?}", tensor.data());
83
84    // Basic iteration with for loop
85    println!("\nBasic iteration with for loop:");
86    for (i, element) in tensor.iter().enumerate() {
87        println!(
88            "  Element {}: value = {:.1}, shape = {:?}",
89            i,
90            element.value(),
91            element.shape().dims()
92        );
93    }
94
95    // Element-wise transformation
96    println!("\nElement-wise transformation (2x + 1):");
97    let transformed: Tensor = tensor
98        .iter()
99        .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
100        .collect();
101    println!("  Result: {:?}", transformed.data());
102
103    // Filtering elements
104    println!("\nFiltering elements (values > 3.0):");
105    let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
106    println!("  Filtered: {:?}", filtered.data());
107
108    Ok(())
109}
110
111/// Demonstrate standard iterator trait methods
112///
113/// Shows compatibility with Rust's standard library iterator methods
114/// and demonstrates various functional programming patterns.
115fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
116    println!("\n--- Standard Iterator Methods ---");
117
118    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
119
120    // Using map for transformations
121    println!("\nMap transformation (square each element):");
122    let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
123    println!("  Squared: {:?}", squared.data());
124
125    // Using enumerate for indexed operations
126    println!("\nEnumerate with indexed operations:");
127    let indexed: Tensor = tensor
128        .iter()
129        .enumerate()
130        .map(|(i, elem)| elem.add_scalar(i as f32))
131        .collect();
132    println!("  Indexed: {:?}", indexed.data());
133
134    // Using fold for reduction
135    println!("\nFold for sum calculation:");
136    let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
137    println!("  Sum: {:.1}", sum);
138
139    // Using find for element search
140    println!("\nFind specific element:");
141    if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
142        println!("  Found element with value 3.0: {:.1}", found.value());
143    }
144
145    // Using any/all for condition checking
146    println!("\nCondition checking:");
147    let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
148    let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
149    println!("  All positive: {}", all_positive);
150    println!("  Any > 4.0: {}", any_large);
151
152    Ok(())
153}
154
155/// Demonstrate gradient tracking through element operations
156///
157/// Shows how gradient tracking works seamlessly through iterator
158/// operations, maintaining the computational graph for backpropagation.
159fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
160    println!("\n--- Gradient Tracking ---");
161
162    // Create a tensor with gradient tracking enabled
163    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
164    println!("Input tensor (requires_grad): {:?}", tensor.data());
165
166    // Perform element-wise operations through iteration
167    let result: Tensor = tensor
168        .iter()
169        .map(|elem| {
170            // Apply a complex transformation: (x^2 + 1) * 2
171            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
172        })
173        .collect();
174
175    println!("Result tensor: {:?}", result.data());
176    println!("Result requires_grad: {}", result.requires_grad());
177
178    // Compute gradients
179    let mut loss = result.sum();
180    loss.backward(None);
181
182    println!("Loss: {:.6}", loss.value());
183    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
184
185    Ok(())
186}
187
188/// Demonstrate advanced iterator patterns
189///
190/// Shows complex iterator chains and advanced functional programming
191/// patterns for sophisticated data processing workflows.
192fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
193    println!("\n--- Advanced Iterator Patterns ---");
194
195    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
196    println!("Input tensor: {:?}", tensor.data());
197
198    // Complex chain: enumerate -> filter -> map -> collect
199    println!("\nComplex chain (even indices only, add index to value):");
200    let result: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, _)| i % 2 == 0) // Take even indices
204        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
205        .collect();
206    println!("  Result: {:?}", result.data());
207
208    // Using take and skip for windowing
209    println!("\nWindowing with take and skip:");
210    let window1: Tensor = tensor.iter().take(3).collect();
211    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
212    println!("  Window 1 (first 3): {:?}", window1.data());
213    println!("  Window 2 (middle 3): {:?}", window2.data());
214
215    // Using rev() for reverse iteration
216    println!("\nReverse iteration:");
217    let reversed: Tensor = tensor.iter().rev().collect();
218    println!("  Reversed: {:?}", reversed.data());
219
220    // Chaining with mathematical operations
221    println!("\nMathematical operation chain:");
222    let math_result: Tensor = tensor
223        .iter()
224        .map(|elem| elem.exp()) // e^x
225        .filter(|elem| elem.value() < 50.0) // Filter large values
226        .map(|elem| elem.log()) // ln(x)
227        .collect();
228    println!("  Math chain result: {:?}", math_result.data());
229
230    // Using zip for element-wise combinations
231    println!("\nElement-wise combination with zip:");
232    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
233    let combined: Tensor = tensor
234        .iter()
235        .zip(tensor2.iter())
236        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
237        .collect();
238    println!("  Combined: {:?}", combined.data());
239
240    Ok(())
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    /// Test basic iteration functionality
248    #[test]
249    fn test_basic_iteration() {
250        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
251        let elements: Vec<Tensor> = tensor.iter().collect();
252
253        assert_eq!(elements.len(), 3);
254        assert_eq!(elements[0].value(), 1.0);
255        assert_eq!(elements[1].value(), 2.0);
256        assert_eq!(elements[2].value(), 3.0);
257    }
258
259    /// Test element-wise transformation
260    #[test]
261    fn test_element_transformation() {
262        let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
263        let doubled: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
264
265        assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
266    }
267
268    /// Test gradient tracking
269    #[test]
270    fn test_gradient_tracking() {
271        let tensor = Tensor::from_slice(&[1.0, 2.0], vec![2])
272            .unwrap()
273            .with_requires_grad();
274
275        let result: Tensor = tensor.iter().map(|elem| elem.mul_scalar(2.0)).collect();
276
277        assert!(result.requires_grad());
278        assert_eq!(result.data(), &[2.0, 4.0]);
279    }
280}