ghostflow_core/
tensor.rs

1//! Core Tensor type - the foundation of GhostFlow
2
3use std::sync::Arc;
4use parking_lot::RwLock;
5use rand_distr::{Distribution, Normal, Uniform};
6
7use crate::dtype::{DType, TensorElement};
8use crate::shape::{Shape, Strides};
9use crate::storage::Storage;
10use crate::error::{GhostError, Result};
11
12/// The core Tensor type
13/// 
14/// Tensors are multi-dimensional arrays with:
15/// - Shared storage (enables zero-copy views)
16/// - Shape and strides (enables non-contiguous layouts)
17/// - Optional gradient tracking for autograd
18#[derive(Debug)]
19pub struct Tensor {
20    /// Underlying data storage (shared for views)
21    storage: Storage,
22    /// Shape of the tensor
23    shape: Shape,
24    /// Memory strides
25    strides: Strides,
26    /// Offset into storage (for views)
27    offset: usize,
28    /// Whether to track gradients
29    requires_grad: bool,
30    /// Accumulated gradient
31    grad: Option<Arc<RwLock<Tensor>>>,
32}
33
34impl Tensor {
35    // ==================== Creation ====================
36
37    /// Create a new tensor from a flat slice and shape
38    pub fn from_slice<T: TensorElement>(data: &[T], shape: &[usize]) -> Result<Self> {
39        let shape = Shape::new(shape);
40        if data.len() != shape.numel() {
41            return Err(GhostError::InvalidShape(format!(
42                "Data length {} doesn't match shape {:?} (numel={})",
43                data.len(),
44                shape.dims(),
45                shape.numel()
46            )));
47        }
48
49        let strides = shape.default_strides();
50        let storage = Storage::from_slice(data);
51
52        Ok(Tensor {
53            storage,
54            shape,
55            strides,
56            offset: 0,
57            requires_grad: false,
58            grad: None,
59        })
60    }
61
62    /// Create a tensor filled with zeros
63    pub fn zeros(shape: &[usize]) -> Self {
64        Self::full(shape, 0.0f32)
65    }
66
67    /// Create a tensor filled with ones
68    pub fn ones(shape: &[usize]) -> Self {
69        Self::full(shape, 1.0f32)
70    }
71
72    /// Create a tensor filled with a constant value
73    pub fn full<T: TensorElement>(shape: &[usize], value: T) -> Self {
74        let shape = Shape::new(shape);
75        let numel = shape.numel();
76        let data: Vec<T> = vec![value; numel];
77        let strides = shape.default_strides();
78        let storage = Storage::from_slice(&data);
79
80        Tensor {
81            storage,
82            shape,
83            strides,
84            offset: 0,
85            requires_grad: false,
86            grad: None,
87        }
88    }
89
90    /// Create a tensor with random values from uniform distribution [0, 1)
91    pub fn rand(shape: &[usize]) -> Self {
92        let shape_obj = Shape::new(shape);
93        let numel = shape_obj.numel();
94        let mut rng = rand::thread_rng();
95        let dist = Uniform::new(0.0f32, 1.0);
96        let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
97        
98        Tensor::from_slice(&data, shape).unwrap()
99    }
100
101    /// Create a tensor with random values from standard normal distribution
102    pub fn randn(shape: &[usize]) -> Self {
103        let shape_obj = Shape::new(shape);
104        let numel = shape_obj.numel();
105        let mut rng = rand::thread_rng();
106        let dist = Normal::new(0.0f32, 1.0).unwrap();
107        let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
108        
109        Tensor::from_slice(&data, shape).unwrap()
110    }
111
112    /// Create an identity matrix
113    pub fn eye(n: usize) -> Self {
114        let mut data = vec![0.0f32; n * n];
115        for i in 0..n {
116            data[i * n + i] = 1.0;
117        }
118        Tensor::from_slice(&data, &[n, n]).unwrap()
119    }
120
121    /// Create a 1D tensor with evenly spaced values
122    pub fn arange(start: f32, end: f32, step: f32) -> Self {
123        let mut data = Vec::new();
124        let mut val = start;
125        while val < end {
126            data.push(val);
127            val += step;
128        }
129        let len = data.len();
130        Tensor::from_slice(&data, &[len]).unwrap()
131    }
132
133    /// Create a 1D tensor with n evenly spaced values between start and end
134    pub fn linspace(start: f32, end: f32, n: usize) -> Self {
135        if n == 0 {
136            return Tensor::from_slice::<f32>(&[], &[0]).unwrap();
137        }
138        if n == 1 {
139            return Tensor::from_slice(&[start], &[1]).unwrap();
140        }
141        
142        let step = (end - start) / (n - 1) as f32;
143        let data: Vec<f32> = (0..n).map(|i| start + i as f32 * step).collect();
144        Tensor::from_slice(&data, &[n]).unwrap()
145    }
146
147    // ==================== Properties ====================
148
149    /// Get the shape of the tensor
150    pub fn shape(&self) -> &Shape {
151        &self.shape
152    }
153
154    /// Get dimensions as slice
155    pub fn dims(&self) -> &[usize] {
156        self.shape.dims()
157    }
158
159    /// Number of dimensions
160    pub fn ndim(&self) -> usize {
161        self.shape.ndim()
162    }
163
164    /// Total number of elements
165    pub fn numel(&self) -> usize {
166        self.shape.numel()
167    }
168
169    /// Get the data type
170    pub fn dtype(&self) -> DType {
171        self.storage.dtype()
172    }
173
174    /// Get strides
175    pub fn strides(&self) -> &Strides {
176        &self.strides
177    }
178
179    /// Check if tensor is contiguous in memory
180    pub fn is_contiguous(&self) -> bool {
181        self.strides.is_contiguous(&self.shape)
182    }
183
184    /// Check if gradient tracking is enabled
185    pub fn requires_grad(&self) -> bool {
186        self.requires_grad
187    }
188
189    // ==================== Gradient ====================
190
191    /// Enable gradient tracking
192    pub fn set_requires_grad(&mut self, requires_grad: bool) {
193        self.requires_grad = requires_grad;
194    }
195
196    /// Get gradient if available
197    pub fn grad(&self) -> Option<Tensor> {
198        self.grad.as_ref().map(|g| g.read().clone())
199    }
200
201    /// Set gradient
202    pub fn set_grad(&mut self, grad: Tensor) {
203        self.grad = Some(Arc::new(RwLock::new(grad)));
204    }
205
206    /// Zero out gradient
207    pub fn zero_grad(&mut self) {
208        if let Some(ref grad) = self.grad {
209            let mut g = grad.write();
210            let zeros = Tensor::zeros(g.dims());
211            *g = zeros;
212        }
213    }
214
215    // ==================== Data Access ====================
216
217    /// Get data as f32 slice (for f32 tensors)
218    pub fn data_f32(&self) -> Vec<f32> {
219        let guard = self.storage.as_slice::<f32>();
220        if self.is_contiguous() && self.offset == 0 {
221            guard.to_vec()
222        } else {
223            // Handle non-contiguous case
224            self.to_contiguous_data::<f32>()
225        }
226    }
227
228    /// Convert to contiguous data (handles views and non-contiguous layouts)
229    fn to_contiguous_data<T: TensorElement>(&self) -> Vec<T> {
230        let numel = self.numel();
231        let mut result = Vec::with_capacity(numel);
232        let guard = self.storage.as_slice::<T>();
233        
234        // Iterate through all indices
235        self.for_each_index(|indices| {
236            let offset = self.compute_offset(indices);
237            result.push(guard[offset]);
238        });
239        
240        result
241    }
242
243    /// Compute linear offset from indices
244    fn compute_offset(&self, indices: &[usize]) -> usize {
245        self.offset + self.strides.offset(indices)
246    }
247
248    /// Iterate through all valid indices
249    fn for_each_index<F: FnMut(&[usize])>(&self, mut f: F) {
250        let dims = self.dims();
251        if dims.is_empty() {
252            f(&[]);
253            return;
254        }
255
256        let mut indices = vec![0usize; dims.len()];
257        loop {
258            f(&indices);
259            
260            // Increment indices
261            let mut i = dims.len() - 1;
262            loop {
263                indices[i] += 1;
264                if indices[i] < dims[i] {
265                    break;
266                }
267                indices[i] = 0;
268                if i == 0 {
269                    return;
270                }
271                i -= 1;
272            }
273        }
274    }
275
276    // ==================== Shape Operations ====================
277
278    /// Reshape tensor to new shape (must have same numel)
279    pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor> {
280        let new_shape = Shape::new(new_shape);
281        if new_shape.numel() != self.numel() {
282            return Err(GhostError::InvalidShape(format!(
283                "Cannot reshape tensor of {} elements to shape {:?}",
284                self.numel(),
285                new_shape.dims()
286            )));
287        }
288
289        // If contiguous, can just change shape/strides
290        if self.is_contiguous() {
291            let new_strides = new_shape.default_strides();
292            return Ok(Tensor {
293                storage: self.storage.clone(),
294                shape: new_shape,
295                strides: new_strides,
296                offset: self.offset,
297                requires_grad: self.requires_grad,
298                grad: None,
299            });
300        }
301
302        // Non-contiguous: need to copy data
303        let data = self.to_contiguous_data::<f32>();
304        Tensor::from_slice(&data, new_shape.dims())
305    }
306
307    /// Flatten tensor to 1D
308    pub fn flatten(&self) -> Result<Tensor> {
309        self.reshape(&[self.numel()])
310    }
311
312    /// Transpose dimensions
313    pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Tensor> {
314        if dim0 >= self.ndim() || dim1 >= self.ndim() {
315            return Err(GhostError::DimOutOfBounds {
316                dim: dim0.max(dim1),
317                ndim: self.ndim(),
318            });
319        }
320
321        let mut new_shape = self.shape.dims().to_vec();
322        let mut new_strides = self.strides.as_slice().to_vec();
323        
324        new_shape.swap(dim0, dim1);
325        new_strides.swap(dim0, dim1);
326
327        Ok(Tensor {
328            storage: self.storage.clone(),
329            shape: Shape::from(new_shape),
330            strides: Strides::from(new_strides.as_slice()),
331            offset: self.offset,
332            requires_grad: self.requires_grad,
333            grad: None,
334        })
335    }
336
337    /// Transpose for 2D tensors (matrix transpose)
338    pub fn t(&self) -> Result<Tensor> {
339        if self.ndim() != 2 {
340            return Err(GhostError::InvalidOperation(
341                "t() only works on 2D tensors".to_string()
342            ));
343        }
344        self.transpose(0, 1)
345    }
346
347    /// Squeeze: remove dimensions of size 1
348    pub fn squeeze(&self) -> Tensor {
349        let new_dims: Vec<usize> = self.dims().iter()
350            .filter(|&&d| d != 1)
351            .copied()
352            .collect();
353        
354        if new_dims.is_empty() {
355            // Scalar case
356            let data = self.data_f32();
357            Tensor::from_slice(&data, &[]).unwrap()
358        } else {
359            self.reshape(&new_dims).unwrap()
360        }
361    }
362
363    /// Unsqueeze: add dimension of size 1 at position
364    pub fn unsqueeze(&self, dim: usize) -> Result<Tensor> {
365        if dim > self.ndim() {
366            return Err(GhostError::DimOutOfBounds {
367                dim,
368                ndim: self.ndim() + 1,
369            });
370        }
371
372        let mut new_dims = self.dims().to_vec();
373        new_dims.insert(dim, 1);
374        self.reshape(&new_dims)
375    }
376
377    // ==================== Clone ====================
378
379    /// Deep clone (copies data)
380    pub fn deep_clone(&self) -> Self {
381        let data = self.data_f32();
382        Tensor::from_slice(&data, self.dims()).unwrap()
383    }
384}
385
386impl Clone for Tensor {
387    /// Shallow clone (shares storage)
388    fn clone(&self) -> Self {
389        Tensor {
390            storage: self.storage.clone(),
391            shape: self.shape.clone(),
392            strides: self.strides.clone(),
393            offset: self.offset,
394            requires_grad: self.requires_grad,
395            grad: self.grad.clone(),
396        }
397    }
398}
399
400impl std::fmt::Display for Tensor {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        write!(f, "Tensor(shape={}, dtype={})", self.shape, self.dtype())
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_tensor_creation() {
412        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
413        assert_eq!(t.dims(), &[2, 2]);
414        assert_eq!(t.numel(), 4);
415    }
416
417    #[test]
418    fn test_zeros_ones() {
419        let zeros = Tensor::zeros(&[3, 3]);
420        let ones = Tensor::ones(&[3, 3]);
421        
422        assert!(zeros.data_f32().iter().all(|&x| x == 0.0));
423        assert!(ones.data_f32().iter().all(|&x| x == 1.0));
424    }
425
426    #[test]
427    fn test_reshape() {
428        let t = Tensor::arange(0.0, 12.0, 1.0);
429        let reshaped = t.reshape(&[3, 4]).unwrap();
430        assert_eq!(reshaped.dims(), &[3, 4]);
431    }
432
433    #[test]
434    fn test_transpose() {
435        let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
436        let transposed = t.t().unwrap();
437        assert_eq!(transposed.dims(), &[3, 2]);
438    }
439}