etensor_core/tensor.rs
1//! The central Tensor representation and global identity tracking.
2
3use std::sync::atomic::{AtomicU64, Ordering};
4use crate::buffer::Buffer;
5use crate::device::Device;
6use crate::dtypes::DType;
7use crate::shape::Shape;
8
9// =====================================================================
10// GLOBAL IDENTITY TRACKER
11// =====================================================================
12
13/// A static atomic counter ensuring every tensor globally receives a unique ID.
14/// This allows the Autograd Tape to track mathematical histories safely across
15/// thousands of simultaneous multi-threaded operations.
16static NEXT_TENSOR_ID: AtomicU64 = AtomicU64::new(1);
17
18/// A unique identifier for a Tensor within the computation graph.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TensorId(pub u64);
21
22impl TensorId {
23 /// Generates the next globally unique Tensor ID.
24 #[allow(clippy::new_without_default)]
25 pub fn new() -> Self {
26 // Ordering::Relaxed is safe here because we only care about uniqueness,
27 // not synchronizing memory access around the ID generation itself.
28 TensorId(NEXT_TENSOR_ID.fetch_add(1, Ordering::Relaxed))
29 }
30}
31
32// =====================================================================
33// THE ATOM
34// =====================================================================
35
36/// The core Tensor struct.
37///
38/// Strictly acts as a metadata wrapper. It points to a physical memory `Buffer`
39/// and maps it using a geometric `Shape`. It contains NO recursive graph pointers.
40#[derive(Debug, Clone)]
41pub struct Tensor {
42 /// The globally unique token used by the Tape to track gradients.
43 pub id: TensorId,
44 /// An optional string for explainability (e.g., "layer_1_weights").
45 pub name: Option<String>,
46 /// The physical memory container (wrapped in Arc for zero-copy sharing).
47 pub data: Buffer,
48 /// The geometric layout and memory strides.
49 pub shape: Shape,
50 /// The hardware context where the physical memory resides.
51 pub device: Device,
52 /// The precision format of the memory buffer.
53 pub dtype: DType,
54 /// Flag indicating whether the Autograd engine should track this tensor.
55 pub requires_grad: bool,
56}
57
58impl Tensor {
59 /// Constructs a new Tensor from raw components.
60 /// Automatically assigns a new unique `TensorId`.
61 pub fn new(
62 data: Buffer,
63 shape: Shape,
64 device: Device,
65 dtype: DType,
66 requires_grad: bool,
67 ) -> Self {
68 Self {
69 id: TensorId::new(),
70 name: None,
71 data,
72 shape,
73 device,
74 dtype,
75 requires_grad,
76 }
77 }
78
79 /// Builder pattern helper to attach an explainability name to the tensor.
80 pub fn with_name(mut self, name: &str) -> Self {
81 self.name = Some(name.to_string());
82 self
83 }
84
85 /// Performs an O(1) mathematical transpose.
86 ///
87 /// This generates a NEW tensor (with a new ID for the graph) but explicitly
88 /// clones only the `Arc` buffer pointer, never the underlying memory.
89 pub fn transpose(&self) -> Self {
90 let new_shape = self.shape.transpose();
91
92 Self {
93 id: TensorId::new(), // New node in the computation graph!
94 name: self.name.as_ref().map(|n| format!("{}_T", n)), // Explainability tracker
95 data: self.data.clone(), // Zero-copy Arc increment
96 shape: new_shape,
97 device: self.device,
98 dtype: self.dtype,
99 requires_grad: self.requires_grad,
100 }
101 }
102}
103
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use std::thread;
109
110 #[test]
111 fn test_tensor_id_uniqueness() {
112 let id1 = TensorId::new();
113 let id2 = TensorId::new();
114 assert_ne!(id1.0, id2.0, "Sequential IDs must not overlap!");
115 }
116
117 #[test]
118 fn test_atomic_id_thread_safety() {
119 // Prove that if 10 threads create tensors at the exact same time,
120 // the atomic counter never drops or duplicates an ID.
121 let mut handles = vec![];
122
123 for _ in 0..10 {
124 handles.push(thread::spawn(|| {
125 let mut local_ids = vec![];
126 for _ in 0..100 {
127 local_ids.push(TensorId::new().0);
128 }
129 local_ids
130 }));
131 }
132
133 let mut all_ids = vec![];
134 for handle in handles {
135 all_ids.extend(handle.join().unwrap());
136 }
137
138 // Sort and check for duplicates
139 all_ids.sort_unstable();
140 all_ids.dedup();
141
142 assert_eq!(all_ids.len(), 1000, "Race condition detected! Duplicate IDs generated.");
143 }
144
145 #[test]
146 fn test_tensor_explainability_name() {
147 let shape = Shape::new(vec![2, 2]);
148 let data = Buffer::new_cpu_zeros(4, DType::F32);
149
150 let t = Tensor::new(data, shape, Device::Cpu, DType::F32, true)
151 .with_name("attention_weights");
152
153 assert_eq!(t.name.unwrap(), "attention_weights");
154 }
155
156 #[test]
157 fn test_zero_copy_transpose_view() {
158 let shape = Shape::new(vec![3, 4]);
159 let data = Buffer::new_cpu_zeros(12, DType::F32);
160
161 let t1 = Tensor::new(data, shape, Device::Cpu, DType::F32, true).with_name("matrix");
162 let initial_arc_count = t1.data.strong_count().unwrap();
163
164 // Transpose the tensor
165 let t2 = t1.transpose();
166
167 // 1. Must be a different mathematical node (Different ID)
168 assert_ne!(t1.id, t2.id);
169
170 // 2. Geometry must be updated
171 assert_eq!(t2.shape.dims, vec![4, 3]);
172
173 // 3. Explainability tracking must carry over
174 assert_eq!(t2.name.unwrap(), "matrix_T");
175
176 // 4. Memory MUST NOT move (Arc count increments, physical RAM remains untouched)
177 assert_eq!(t2.data.strong_count().unwrap(), initial_arc_count + 1);
178 }
179}