Skip to main content

etensor_core/
tensor.rs

1//! The central Tensor representation and global identity tracking.
2
3use std::sync::atomic::{AtomicU64, Ordering};
4use crate::buffer::Buffer;
5use crate::device::Device;
6use crate::dtypes::DType;
7use crate::shape::Shape;
8
9// =====================================================================
10// GLOBAL IDENTITY TRACKER
11// =====================================================================
12
13/// A static atomic counter ensuring every tensor globally receives a unique ID.
14/// This allows the Autograd Tape to track mathematical histories safely across 
15/// thousands of simultaneous multi-threaded operations.
16static NEXT_TENSOR_ID: AtomicU64 = AtomicU64::new(1);
17
18/// A unique identifier for a Tensor within the computation graph.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TensorId(pub u64);
21
22impl TensorId {
23    /// Generates the next globally unique Tensor ID.
24    #[allow(clippy::new_without_default)]
25    pub fn new() -> Self {
26        // Ordering::Relaxed is safe here because we only care about uniqueness,
27        // not synchronizing memory access around the ID generation itself.
28        TensorId(NEXT_TENSOR_ID.fetch_add(1, Ordering::Relaxed))
29    }
30}
31
32// =====================================================================
33// THE ATOM
34// =====================================================================
35
36/// The core Tensor struct. 
37/// 
38/// Strictly acts as a metadata wrapper. It points to a physical memory `Buffer` 
39/// and maps it using a geometric `Shape`. It contains NO recursive graph pointers.
40#[derive(Debug, Clone)]
41pub struct Tensor {
42    /// The globally unique token used by the Tape to track gradients.
43    pub id: TensorId,
44    /// An optional string for explainability (e.g., "layer_1_weights").
45    pub name: Option<String>,
46    /// The physical memory container (wrapped in Arc for zero-copy sharing).
47    pub data: Buffer,
48    /// The geometric layout and memory strides.
49    pub shape: Shape,
50    /// The hardware context where the physical memory resides.
51    pub device: Device,
52    /// The precision format of the memory buffer.
53    pub dtype: DType,
54    /// Flag indicating whether the Autograd engine should track this tensor.
55    pub requires_grad: bool,
56}
57
58impl Tensor {
59    /// Constructs a new Tensor from raw components.
60    /// Automatically assigns a new unique `TensorId`.
61    pub fn new(
62        data: Buffer,
63        shape: Shape,
64        device: Device,
65        dtype: DType,
66        requires_grad: bool,
67    ) -> Self {
68        Self {
69            id: TensorId::new(),
70            name: None,
71            data,
72            shape,
73            device,
74            dtype,
75            requires_grad,
76        }
77    }
78
79    /// Builder pattern helper to attach an explainability name to the tensor.
80    pub fn with_name(mut self, name: &str) -> Self {
81        self.name = Some(name.to_string());
82        self
83    }
84
85    /// Performs an O(1) mathematical transpose.
86    /// 
87    /// This generates a NEW tensor (with a new ID for the graph) but explicitly 
88    /// clones only the `Arc` buffer pointer, never the underlying memory.
89    pub fn transpose(&self) -> Self {
90        let new_shape = self.shape.transpose();
91        
92        Self {
93            id: TensorId::new(), // New node in the computation graph!
94            name: self.name.as_ref().map(|n| format!("{}_T", n)), // Explainability tracker
95            data: self.data.clone(), // Zero-copy Arc increment
96            shape: new_shape,
97            device: self.device,
98            dtype: self.dtype,
99            requires_grad: self.requires_grad,
100        }
101    }
102}
103
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use std::thread;
109
110    #[test]
111    fn test_tensor_id_uniqueness() {
112        let id1 = TensorId::new();
113        let id2 = TensorId::new();
114        assert_ne!(id1.0, id2.0, "Sequential IDs must not overlap!");
115    }
116
117    #[test]
118    fn test_atomic_id_thread_safety() {
119        // Prove that if 10 threads create tensors at the exact same time, 
120        // the atomic counter never drops or duplicates an ID.
121        let mut handles = vec![];
122        
123        for _ in 0..10 {
124            handles.push(thread::spawn(|| {
125                let mut local_ids = vec![];
126                for _ in 0..100 {
127                    local_ids.push(TensorId::new().0);
128                }
129                local_ids
130            }));
131        }
132
133        let mut all_ids = vec![];
134        for handle in handles {
135            all_ids.extend(handle.join().unwrap());
136        }
137
138        // Sort and check for duplicates
139        all_ids.sort_unstable();
140        all_ids.dedup();
141        
142        assert_eq!(all_ids.len(), 1000, "Race condition detected! Duplicate IDs generated.");
143    }
144
145    #[test]
146    fn test_tensor_explainability_name() {
147        let shape = Shape::new(vec![2, 2]);
148        let data = Buffer::new_cpu_zeros(4, DType::F32);
149        
150        let t = Tensor::new(data, shape, Device::Cpu, DType::F32, true)
151            .with_name("attention_weights");
152            
153        assert_eq!(t.name.unwrap(), "attention_weights");
154    }
155
156    #[test]
157    fn test_zero_copy_transpose_view() {
158        let shape = Shape::new(vec![3, 4]);
159        let data = Buffer::new_cpu_zeros(12, DType::F32);
160        
161        let t1 = Tensor::new(data, shape, Device::Cpu, DType::F32, true).with_name("matrix");
162        let initial_arc_count = t1.data.strong_count().unwrap();
163        
164        // Transpose the tensor
165        let t2 = t1.transpose();
166        
167        // 1. Must be a different mathematical node (Different ID)
168        assert_ne!(t1.id, t2.id);
169        
170        // 2. Geometry must be updated
171        assert_eq!(t2.shape.dims, vec![4, 3]);
172        
173        // 3. Explainability tracking must carry over
174        assert_eq!(t2.name.unwrap(), "matrix_T");
175        
176        // 4. Memory MUST NOT move (Arc count increments, physical RAM remains untouched)
177        assert_eq!(t2.data.strong_count().unwrap(), initial_arc_count + 1);
178    }
179}