1use std::fmt;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9
10use crate::primitives::Vector;
11
12use super::grad_fn::GradFn;
13use super::with_graph;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
17pub struct TensorId(u64);
18
19impl TensorId {
20 pub fn new() -> Self {
22 static COUNTER: AtomicU64 = AtomicU64::new(0);
23 TensorId(COUNTER.fetch_add(1, Ordering::SeqCst))
24 }
25}
26
27impl Default for TensorId {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33#[derive(Clone)]
50pub struct Tensor {
51 data: Vector<f32>,
53
54 shape: Vec<usize>,
56
57 grad: Option<Box<Tensor>>,
59
60 requires_grad: bool,
62
63 is_leaf: bool,
65
66 grad_fn: Option<Arc<dyn GradFn>>,
68
69 id: TensorId,
71}
72
73impl Tensor {
74 pub fn new(data: &[f32], shape: &[usize]) -> Self {
82 let expected_len: usize = shape.iter().product();
83 assert_eq!(
84 data.len(),
85 expected_len,
86 "Data length {} doesn't match shape {:?} (expected {})",
87 data.len(),
88 shape,
89 expected_len
90 );
91
92 Self {
93 data: Vector::from_slice(data),
94 shape: shape.to_vec(),
95 grad: None,
96 requires_grad: false,
97 is_leaf: true,
98 grad_fn: None,
99 id: TensorId::new(),
100 }
101 }
102
103 pub fn from_slice(data: &[f32]) -> Self {
105 Self::new(data, &[data.len()])
106 }
107
108 pub fn zeros(shape: &[usize]) -> Self {
110 let len: usize = shape.iter().product();
111 Self::new(&vec![0.0; len], shape)
112 }
113
114 pub fn ones(shape: &[usize]) -> Self {
116 let len: usize = shape.iter().product();
117 Self::new(&vec![1.0; len], shape)
118 }
119
120 pub fn zeros_like(other: &Tensor) -> Self {
122 Self::zeros(&other.shape)
123 }
124
125 pub fn ones_like(other: &Tensor) -> Self {
127 Self::ones(&other.shape)
128 }
129
130 pub fn requires_grad(mut self) -> Self {
134 self.requires_grad = true;
135 self
136 }
137
138 pub fn requires_grad_(&mut self, requires: bool) -> &mut Self {
140 self.requires_grad = requires;
141 self
142 }
143
144 pub fn requires_grad_enabled(&self) -> bool {
146 self.requires_grad
147 }
148
149 pub fn is_leaf(&self) -> bool {
151 self.is_leaf
152 }
153
154 pub fn id(&self) -> TensorId {
156 self.id
157 }
158
159 pub fn shape(&self) -> &[usize] {
161 &self.shape
162 }
163
164 pub fn numel(&self) -> usize {
166 self.shape.iter().product()
167 }
168
169 pub fn ndim(&self) -> usize {
171 self.shape.len()
172 }
173
174 pub fn data(&self) -> &[f32] {
176 self.data.as_slice()
177 }
178
179 pub fn data_mut(&mut self) -> &mut [f32] {
185 self.data.as_mut_slice()
186 }
187
188 pub fn grad(&self) -> Option<&Tensor> {
190 self.grad.as_deref()
191 }
192
193 pub fn zero_grad_(&mut self) {
195 self.grad = None;
196 }
197
198 pub fn clear_grad(&mut self) {
200 self.grad = None;
201 }
202
203 pub(crate) fn accumulate_grad(&mut self, grad: Tensor) {
205 match &mut self.grad {
206 Some(existing) => {
207 let new_data: Vec<f32> = existing
209 .data()
210 .iter()
211 .zip(grad.data().iter())
212 .map(|(a, b)| a + b)
213 .collect();
214 **existing = Tensor::new(&new_data, &self.shape);
215 }
216 None => {
217 self.grad = Some(Box::new(grad));
218 }
219 }
220 }
221
222 pub(crate) fn set_grad_fn(&mut self, grad_fn: Arc<dyn GradFn>) {
224 self.grad_fn = Some(grad_fn);
225 self.is_leaf = false;
226 }
227
228 #[allow(dead_code)]
230 pub(crate) fn grad_fn(&self) -> Option<&Arc<dyn GradFn>> {
231 self.grad_fn.as_ref()
232 }
233
234 pub fn detach(&self) -> Tensor {
238 Tensor {
239 data: self.data.clone(),
240 shape: self.shape.clone(),
241 grad: None,
242 requires_grad: false,
243 is_leaf: true,
244 grad_fn: None,
245 id: TensorId::new(),
246 }
247 }
248
249 pub fn item(&self) -> f32 {
255 assert_eq!(
256 self.numel(),
257 1,
258 "item() only works on tensors with exactly 1 element, got {}",
259 self.numel()
260 );
261 self.data[0]
262 }
263
264 pub fn backward(&self) {
274 assert_eq!(
275 self.numel(),
276 1,
277 "backward() requires scalar output, got shape {:?}. Use backward_with_grad() instead.",
278 self.shape
279 );
280
281 self.backward_with_grad(Tensor::ones(&self.shape));
282 }
283
284 pub fn backward_with_grad(&self, grad_output: Tensor) {
290 with_graph(|graph| {
291 graph.backward(self.id, grad_output);
292 });
293 }
294}
295
296impl fmt::Debug for Tensor {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 f.debug_struct("Tensor")
299 .field("shape", &self.shape)
300 .field("requires_grad", &self.requires_grad)
301 .field("is_leaf", &self.is_leaf)
302 .field("has_grad", &self.grad.is_some())
303 .field("id", &self.id)
304 .finish_non_exhaustive()
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_tensor_creation() {
314 let t = Tensor::new(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
315 assert_eq!(t.shape(), &[2, 2]);
316 assert_eq!(t.numel(), 4);
317 assert_eq!(t.ndim(), 2);
318 }
319
320 #[test]
321 fn test_tensor_from_slice() {
322 let t = Tensor::from_slice(&[1.0, 2.0, 3.0]);
323 assert_eq!(t.shape(), &[3]);
324 assert_eq!(t.numel(), 3);
325 }
326
327 #[test]
328 fn test_tensor_zeros_ones() {
329 let z = Tensor::zeros(&[2, 3]);
330 assert!(z.data().iter().all(|&x| x == 0.0));
331
332 let o = Tensor::ones(&[2, 3]);
333 assert!(o.data().iter().all(|&x| x == 1.0));
334 }
335
336 #[test]
337 fn test_requires_grad() {
338 let t = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
339 assert!(t.requires_grad_enabled());
340
341 let t2 = Tensor::from_slice(&[1.0, 2.0]);
342 assert!(!t2.requires_grad_enabled());
343 }
344
345 #[test]
346 fn test_detach() {
347 let t = Tensor::from_slice(&[1.0, 2.0]).requires_grad();
348 let d = t.detach();
349
350 assert!(t.requires_grad_enabled());
351 assert!(!d.requires_grad_enabled());
352 assert!(d.is_leaf());
353 }
354
355 #[test]
356 fn test_item() {
357 let t = Tensor::new(&[42.0], &[1]);
358 assert_eq!(t.item(), 42.0);
359
360 let t2 = Tensor::new(&[42.0], &[]);
361 assert_eq!(t2.item(), 42.0);
362 }
363
364 #[test]
365 #[should_panic(expected = "item() only works on tensors with exactly 1 element")]
366 fn test_item_panics_multi_element() {
367 let t = Tensor::from_slice(&[1.0, 2.0]);
368 t.item();
369 }
370
371 #[test]
372 fn test_tensor_id_unique() {
373 let t1 = Tensor::from_slice(&[1.0]);
374 let t2 = Tensor::from_slice(&[1.0]);
375 assert_ne!(t1.id(), t2.id());
376 }
377
378 #[test]
379 fn test_gradient_accumulation() {
380 let mut t = Tensor::from_slice(&[1.0, 2.0, 3.0]).requires_grad();
381
382 t.accumulate_grad(Tensor::from_slice(&[0.1, 0.2, 0.3]));
383
384 let grad1 = t
385 .grad()
386 .expect("grad should exist after accumulate")
387 .data()
388 .to_vec();
389 assert_eq!(grad1, vec![0.1, 0.2, 0.3]);
390
391 t.accumulate_grad(Tensor::from_slice(&[0.1, 0.2, 0.3]));
392 let grad2 = t
393 .grad()
394 .expect("grad should exist after second accumulate")
395 .data()
396 .to_vec();
397 assert_eq!(grad2, vec![0.2, 0.4, 0.6]);
398 }
399}