ghostflow_core/
tensor.rs

1//! Core Tensor type - the foundation of GhostFlow
2
3use std::sync::Arc;
4use parking_lot::RwLock;
5use rand_distr::{Distribution, Normal, Uniform};
6
7use crate::dtype::{DType, TensorElement};
8use crate::shape::{Shape, Strides};
9use crate::storage::Storage;
10use crate::error::{GhostError, Result};
11
12/// The core Tensor type
13/// 
14/// Tensors are multi-dimensional arrays with:
15/// - Shared storage (enables zero-copy views)
16/// - Shape and strides (enables non-contiguous layouts)
17/// - Optional gradient tracking for autograd
18#[derive(Debug)]
19pub struct Tensor {
20    /// Underlying data storage (shared for views)
21    storage: Storage,
22    /// Shape of the tensor
23    shape: Shape,
24    /// Memory strides
25    strides: Strides,
26    /// Offset into storage (for views)
27    offset: usize,
28    /// Whether to track gradients
29    requires_grad: bool,
30    /// Accumulated gradient
31    grad: Option<Arc<RwLock<Tensor>>>,
32}
33
34impl Tensor {
35    // ==================== Creation ====================
36
37    /// Create a new tensor from a flat slice and shape
38    pub fn from_slice<T: TensorElement>(data: &[T], shape: &[usize]) -> Result<Self> {
39        let shape = Shape::new(shape);
40        if data.len() != shape.numel() {
41            return Err(GhostError::InvalidShape(format!(
42                "Data length {} doesn't match shape {:?} (numel={})",
43                data.len(),
44                shape.dims(),
45                shape.numel()
46            )));
47        }
48
49        let strides = shape.default_strides();
50        let storage = Storage::from_slice(data);
51
52        Ok(Tensor {
53            storage,
54            shape,
55            strides,
56            offset: 0,
57            requires_grad: false,
58            grad: None,
59        })
60    }
61
62    /// Create a tensor filled with zeros
63    pub fn zeros(shape: &[usize]) -> Self {
64        Self::full(shape, 0.0f32)
65    }
66
67    /// Create a tensor filled with ones
68    pub fn ones(shape: &[usize]) -> Self {
69        Self::full(shape, 1.0f32)
70    }
71
72    /// Create a tensor filled with a constant value
73    pub fn full<T: TensorElement>(shape: &[usize], value: T) -> Self {
74        let shape = Shape::new(shape);
75        let numel = shape.numel();
76        let data: Vec<T> = vec![value; numel];
77        let strides = shape.default_strides();
78        let storage = Storage::from_slice(&data);
79
80        Tensor {
81            storage,
82            shape,
83            strides,
84            offset: 0,
85            requires_grad: false,
86            grad: None,
87        }
88    }
89
90    /// Create a tensor with random values from uniform distribution [0, 1)
91    pub fn rand(shape: &[usize]) -> Self {
92        let shape_obj = Shape::new(shape);
93        let numel = shape_obj.numel();
94        let mut rng = rand::thread_rng();
95        let dist = Uniform::new(0.0f32, 1.0);
96        let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
97        
98        Tensor::from_slice(&data, shape).unwrap()
99    }
100
101    /// Create a tensor with random values from standard normal distribution
102    pub fn randn(shape: &[usize]) -> Self {
103        let shape_obj = Shape::new(shape);
104        let numel = shape_obj.numel();
105        let mut rng = rand::thread_rng();
106        let dist = Normal::new(0.0f32, 1.0).unwrap();
107        let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
108        
109        Tensor::from_slice(&data, shape).unwrap()
110    }
111
112    /// Create an identity matrix
113    pub fn eye(n: usize) -> Self {
114        let mut data = vec![0.0f32; n * n];
115        for i in 0..n {
116            data[i * n + i] = 1.0;
117        }
118        Tensor::from_slice(&data, &[n, n]).unwrap()
119    }
120
121    /// Create a 1D tensor with evenly spaced values
122    pub fn arange(start: f32, end: f32, step: f32) -> Self {
123        let mut data = Vec::new();
124        let mut val = start;
125        while val < end {
126            data.push(val);
127            val += step;
128        }
129        let len = data.len();
130        Tensor::from_slice(&data, &[len]).unwrap()
131    }
132
133    /// Create a 1D tensor with n evenly spaced values between start and end
134    pub fn linspace(start: f32, end: f32, n: usize) -> Self {
135        if n == 0 {
136            return Tensor::from_slice::<f32>(&[], &[0]).unwrap();
137        }
138        if n == 1 {
139            return Tensor::from_slice(&[start], &[1]).unwrap();
140        }
141        
142        let step = (end - start) / (n - 1) as f32;
143        let data: Vec<f32> = (0..n).map(|i| start + i as f32 * step).collect();
144        Tensor::from_slice(&data, &[n]).unwrap()
145    }
146
147    // ==================== Properties ====================
148
149    /// Get the shape of the tensor
150    pub fn shape(&self) -> &Shape {
151        &self.shape
152    }
153
154    /// Get dimensions as slice
155    pub fn dims(&self) -> &[usize] {
156        self.shape.dims()
157    }
158
159    /// Number of dimensions
160    pub fn ndim(&self) -> usize {
161        self.shape.ndim()
162    }
163
164    /// Total number of elements
165    pub fn numel(&self) -> usize {
166        self.shape.numel()
167    }
168
169    /// Get the data type
170    pub fn dtype(&self) -> DType {
171        self.storage.dtype()
172    }
173
174    /// Get strides
175    pub fn strides(&self) -> &Strides {
176        &self.strides
177    }
178
179    /// Check if tensor is contiguous in memory
180    pub fn is_contiguous(&self) -> bool {
181        self.strides.is_contiguous(&self.shape)
182    }
183
184    /// Check if gradient tracking is enabled
185    pub fn requires_grad(&self) -> bool {
186        self.requires_grad
187    }
188
189    // ==================== Gradient ====================
190
191    /// Enable gradient tracking
192    pub fn set_requires_grad(&mut self, requires_grad: bool) {
193        self.requires_grad = requires_grad;
194    }
195
196    /// Get gradient if available
197    pub fn grad(&self) -> Option<Tensor> {
198        self.grad.as_ref().map(|g| g.read().clone())
199    }
200
201    /// Get reference to underlying storage
202    pub fn storage(&self) -> &Storage {
203        &self.storage
204    }
205
206    /// Set gradient
207    pub fn set_grad(&mut self, grad: Tensor) {
208        self.grad = Some(Arc::new(RwLock::new(grad)));
209    }
210
211    /// Zero out gradient
212    pub fn zero_grad(&mut self) {
213        if let Some(ref grad) = self.grad {
214            let mut g = grad.write();
215            let zeros = Tensor::zeros(g.dims());
216            *g = zeros;
217        }
218    }
219
220    // ==================== Data Access ====================
221
222    /// Get data as f32 slice (for f32 tensors)
223    pub fn data_f32(&self) -> Vec<f32> {
224        let guard = self.storage.as_slice::<f32>();
225        if self.is_contiguous() && self.offset == 0 {
226            guard.to_vec()
227        } else {
228            // Handle non-contiguous case
229            self.to_contiguous_data::<f32>()
230        }
231    }
232
233    /// Convert to contiguous data (handles views and non-contiguous layouts)
234    fn to_contiguous_data<T: TensorElement>(&self) -> Vec<T> {
235        let numel = self.numel();
236        let mut result = Vec::with_capacity(numel);
237        let guard = self.storage.as_slice::<T>();
238        
239        // Iterate through all indices
240        self.for_each_index(|indices| {
241            let offset = self.compute_offset(indices);
242            result.push(guard[offset]);
243        });
244        
245        result
246    }
247
248    /// Compute linear offset from indices
249    fn compute_offset(&self, indices: &[usize]) -> usize {
250        self.offset + self.strides.offset(indices)
251    }
252
253    /// Iterate through all valid indices
254    fn for_each_index<F: FnMut(&[usize])>(&self, mut f: F) {
255        let dims = self.dims();
256        if dims.is_empty() {
257            f(&[]);
258            return;
259        }
260
261        let mut indices = vec![0usize; dims.len()];
262        loop {
263            f(&indices);
264            
265            // Increment indices
266            let mut i = dims.len() - 1;
267            loop {
268                indices[i] += 1;
269                if indices[i] < dims[i] {
270                    break;
271                }
272                indices[i] = 0;
273                if i == 0 {
274                    return;
275                }
276                i -= 1;
277            }
278        }
279    }
280
281    // ==================== Shape Operations ====================
282
283    /// Reshape tensor to new shape (must have same numel)
284    pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor> {
285        let new_shape = Shape::new(new_shape);
286        if new_shape.numel() != self.numel() {
287            return Err(GhostError::InvalidShape(format!(
288                "Cannot reshape tensor of {} elements to shape {:?}",
289                self.numel(),
290                new_shape.dims()
291            )));
292        }
293
294        // If contiguous, can just change shape/strides
295        if self.is_contiguous() {
296            let new_strides = new_shape.default_strides();
297            return Ok(Tensor {
298                storage: self.storage.clone(),
299                shape: new_shape,
300                strides: new_strides,
301                offset: self.offset,
302                requires_grad: self.requires_grad,
303                grad: None,
304            });
305        }
306
307        // Non-contiguous: need to copy data
308        let data = self.to_contiguous_data::<f32>();
309        Tensor::from_slice(&data, new_shape.dims())
310    }
311
312    /// Flatten tensor to 1D
313    pub fn flatten(&self) -> Result<Tensor> {
314        self.reshape(&[self.numel()])
315    }
316
317    /// Transpose dimensions
318    pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Tensor> {
319        if dim0 >= self.ndim() || dim1 >= self.ndim() {
320            return Err(GhostError::DimOutOfBounds {
321                dim: dim0.max(dim1),
322                ndim: self.ndim(),
323            });
324        }
325
326        let mut new_shape = self.shape.dims().to_vec();
327        let mut new_strides = self.strides.as_slice().to_vec();
328        
329        new_shape.swap(dim0, dim1);
330        new_strides.swap(dim0, dim1);
331
332        Ok(Tensor {
333            storage: self.storage.clone(),
334            shape: Shape::from(new_shape),
335            strides: Strides::from(new_strides.as_slice()),
336            offset: self.offset,
337            requires_grad: self.requires_grad,
338            grad: None,
339        })
340    }
341
342    /// Transpose for 2D tensors (matrix transpose)
343    pub fn t(&self) -> Result<Tensor> {
344        if self.ndim() != 2 {
345            return Err(GhostError::InvalidOperation(
346                "t() only works on 2D tensors".to_string()
347            ));
348        }
349        self.transpose(0, 1)
350    }
351
352    /// Squeeze: remove dimensions of size 1
353    pub fn squeeze(&self) -> Tensor {
354        let new_dims: Vec<usize> = self.dims().iter()
355            .filter(|&&d| d != 1)
356            .copied()
357            .collect();
358        
359        if new_dims.is_empty() {
360            // Scalar case
361            let data = self.data_f32();
362            Tensor::from_slice(&data, &[]).unwrap()
363        } else {
364            self.reshape(&new_dims).unwrap()
365        }
366    }
367
368    /// Unsqueeze: add dimension of size 1 at position
369    pub fn unsqueeze(&self, dim: usize) -> Result<Tensor> {
370        if dim > self.ndim() {
371            return Err(GhostError::DimOutOfBounds {
372                dim,
373                ndim: self.ndim() + 1,
374            });
375        }
376
377        let mut new_dims = self.dims().to_vec();
378        new_dims.insert(dim, 1);
379        self.reshape(&new_dims)
380    }
381
382    // ==================== Clone ====================
383
384    /// Deep clone (copies data)
385    pub fn deep_clone(&self) -> Self {
386        let data = self.data_f32();
387        Tensor::from_slice(&data, self.dims()).unwrap()
388    }
389}
390
391impl Clone for Tensor {
392    /// Shallow clone (shares storage)
393    fn clone(&self) -> Self {
394        Tensor {
395            storage: self.storage.clone(),
396            shape: self.shape.clone(),
397            strides: self.strides.clone(),
398            offset: self.offset,
399            requires_grad: self.requires_grad,
400            grad: self.grad.clone(),
401        }
402    }
403}
404
405impl std::fmt::Display for Tensor {
406    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407        write!(f, "Tensor(shape={}, dtype={})", self.shape, self.dtype())
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_tensor_creation() {
417        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
418        assert_eq!(t.dims(), &[2, 2]);
419        assert_eq!(t.numel(), 4);
420    }
421
422    #[test]
423    fn test_zeros_ones() {
424        let zeros = Tensor::zeros(&[3, 3]);
425        let ones = Tensor::ones(&[3, 3]);
426        
427        assert!(zeros.data_f32().iter().all(|&x| x == 0.0));
428        assert!(ones.data_f32().iter().all(|&x| x == 1.0));
429    }
430
431    #[test]
432    fn test_reshape() {
433        let t = Tensor::arange(0.0, 12.0, 1.0);
434        let reshaped = t.reshape(&[3, 4]).unwrap();
435        assert_eq!(reshaped.dims(), &[3, 4]);
436    }
437
438    #[test]
439    fn test_transpose() {
440        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
441        let transposed = t.t().unwrap();
442        assert_eq!(transposed.dims(), &[3, 2]);
443    }
444}