1use std::sync::Arc;
4use parking_lot::RwLock;
5use rand_distr::{Distribution, Normal, Uniform};
6
7use crate::dtype::{DType, TensorElement};
8use crate::shape::{Shape, Strides};
9use crate::storage::Storage;
10use crate::error::{GhostError, Result};
11
12#[derive(Debug)]
19pub struct Tensor {
20 storage: Storage,
22 shape: Shape,
24 strides: Strides,
26 offset: usize,
28 requires_grad: bool,
30 grad: Option<Arc<RwLock<Tensor>>>,
32}
33
34impl Tensor {
35 pub fn from_slice<T: TensorElement>(data: &[T], shape: &[usize]) -> Result<Self> {
39 let shape = Shape::new(shape);
40 if data.len() != shape.numel() {
41 return Err(GhostError::InvalidShape(format!(
42 "Data length {} doesn't match shape {:?} (numel={})",
43 data.len(),
44 shape.dims(),
45 shape.numel()
46 )));
47 }
48
49 let strides = shape.default_strides();
50 let storage = Storage::from_slice(data);
51
52 Ok(Tensor {
53 storage,
54 shape,
55 strides,
56 offset: 0,
57 requires_grad: false,
58 grad: None,
59 })
60 }
61
62 pub fn zeros(shape: &[usize]) -> Self {
64 Self::full(shape, 0.0f32)
65 }
66
67 pub fn ones(shape: &[usize]) -> Self {
69 Self::full(shape, 1.0f32)
70 }
71
72 pub fn full<T: TensorElement>(shape: &[usize], value: T) -> Self {
74 let shape = Shape::new(shape);
75 let numel = shape.numel();
76 let data: Vec<T> = vec![value; numel];
77 let strides = shape.default_strides();
78 let storage = Storage::from_slice(&data);
79
80 Tensor {
81 storage,
82 shape,
83 strides,
84 offset: 0,
85 requires_grad: false,
86 grad: None,
87 }
88 }
89
90 pub fn rand(shape: &[usize]) -> Self {
92 let shape_obj = Shape::new(shape);
93 let numel = shape_obj.numel();
94 let mut rng = rand::thread_rng();
95 let dist = Uniform::new(0.0f32, 1.0);
96 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
97
98 Tensor::from_slice(&data, shape).unwrap()
99 }
100
101 pub fn randn(shape: &[usize]) -> Self {
103 let shape_obj = Shape::new(shape);
104 let numel = shape_obj.numel();
105 let mut rng = rand::thread_rng();
106 let dist = Normal::new(0.0f32, 1.0).unwrap();
107 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
108
109 Tensor::from_slice(&data, shape).unwrap()
110 }
111
112 pub fn eye(n: usize) -> Self {
114 let mut data = vec![0.0f32; n * n];
115 for i in 0..n {
116 data[i * n + i] = 1.0;
117 }
118 Tensor::from_slice(&data, &[n, n]).unwrap()
119 }
120
121 pub fn arange(start: f32, end: f32, step: f32) -> Self {
123 let mut data = Vec::new();
124 let mut val = start;
125 while val < end {
126 data.push(val);
127 val += step;
128 }
129 let len = data.len();
130 Tensor::from_slice(&data, &[len]).unwrap()
131 }
132
133 pub fn linspace(start: f32, end: f32, n: usize) -> Self {
135 if n == 0 {
136 return Tensor::from_slice::<f32>(&[], &[0]).unwrap();
137 }
138 if n == 1 {
139 return Tensor::from_slice(&[start], &[1]).unwrap();
140 }
141
142 let step = (end - start) / (n - 1) as f32;
143 let data: Vec<f32> = (0..n).map(|i| start + i as f32 * step).collect();
144 Tensor::from_slice(&data, &[n]).unwrap()
145 }
146
147 pub fn shape(&self) -> &Shape {
151 &self.shape
152 }
153
154 pub fn dims(&self) -> &[usize] {
156 self.shape.dims()
157 }
158
159 pub fn ndim(&self) -> usize {
161 self.shape.ndim()
162 }
163
164 pub fn numel(&self) -> usize {
166 self.shape.numel()
167 }
168
169 pub fn dtype(&self) -> DType {
171 self.storage.dtype()
172 }
173
174 pub fn strides(&self) -> &Strides {
176 &self.strides
177 }
178
179 pub fn is_contiguous(&self) -> bool {
181 self.strides.is_contiguous(&self.shape)
182 }
183
184 pub fn requires_grad(&self) -> bool {
186 self.requires_grad
187 }
188
189 pub fn set_requires_grad(&mut self, requires_grad: bool) {
193 self.requires_grad = requires_grad;
194 }
195
196 pub fn grad(&self) -> Option<Tensor> {
198 self.grad.as_ref().map(|g| g.read().clone())
199 }
200
201 pub fn storage(&self) -> &Storage {
203 &self.storage
204 }
205
206 pub fn set_grad(&mut self, grad: Tensor) {
208 self.grad = Some(Arc::new(RwLock::new(grad)));
209 }
210
211 pub fn zero_grad(&mut self) {
213 if let Some(ref grad) = self.grad {
214 let mut g = grad.write();
215 let zeros = Tensor::zeros(g.dims());
216 *g = zeros;
217 }
218 }
219
220 pub fn data_f32(&self) -> Vec<f32> {
224 let guard = self.storage.as_slice::<f32>();
225 if self.is_contiguous() && self.offset == 0 {
226 guard.to_vec()
227 } else {
228 self.to_contiguous_data::<f32>()
230 }
231 }
232
233 fn to_contiguous_data<T: TensorElement>(&self) -> Vec<T> {
235 let numel = self.numel();
236 let mut result = Vec::with_capacity(numel);
237 let guard = self.storage.as_slice::<T>();
238
239 self.for_each_index(|indices| {
241 let offset = self.compute_offset(indices);
242 result.push(guard[offset]);
243 });
244
245 result
246 }
247
248 fn compute_offset(&self, indices: &[usize]) -> usize {
250 self.offset + self.strides.offset(indices)
251 }
252
253 fn for_each_index<F: FnMut(&[usize])>(&self, mut f: F) {
255 let dims = self.dims();
256 if dims.is_empty() {
257 f(&[]);
258 return;
259 }
260
261 let mut indices = vec![0usize; dims.len()];
262 loop {
263 f(&indices);
264
265 let mut i = dims.len() - 1;
267 loop {
268 indices[i] += 1;
269 if indices[i] < dims[i] {
270 break;
271 }
272 indices[i] = 0;
273 if i == 0 {
274 return;
275 }
276 i -= 1;
277 }
278 }
279 }
280
281 pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor> {
285 let new_shape = Shape::new(new_shape);
286 if new_shape.numel() != self.numel() {
287 return Err(GhostError::InvalidShape(format!(
288 "Cannot reshape tensor of {} elements to shape {:?}",
289 self.numel(),
290 new_shape.dims()
291 )));
292 }
293
294 if self.is_contiguous() {
296 let new_strides = new_shape.default_strides();
297 return Ok(Tensor {
298 storage: self.storage.clone(),
299 shape: new_shape,
300 strides: new_strides,
301 offset: self.offset,
302 requires_grad: self.requires_grad,
303 grad: None,
304 });
305 }
306
307 let data = self.to_contiguous_data::<f32>();
309 Tensor::from_slice(&data, new_shape.dims())
310 }
311
312 pub fn flatten(&self) -> Result<Tensor> {
314 self.reshape(&[self.numel()])
315 }
316
317 pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Tensor> {
319 if dim0 >= self.ndim() || dim1 >= self.ndim() {
320 return Err(GhostError::DimOutOfBounds {
321 dim: dim0.max(dim1),
322 ndim: self.ndim(),
323 });
324 }
325
326 let mut new_shape = self.shape.dims().to_vec();
327 let mut new_strides = self.strides.as_slice().to_vec();
328
329 new_shape.swap(dim0, dim1);
330 new_strides.swap(dim0, dim1);
331
332 Ok(Tensor {
333 storage: self.storage.clone(),
334 shape: Shape::from(new_shape),
335 strides: Strides::from(new_strides.as_slice()),
336 offset: self.offset,
337 requires_grad: self.requires_grad,
338 grad: None,
339 })
340 }
341
342 pub fn t(&self) -> Result<Tensor> {
344 if self.ndim() != 2 {
345 return Err(GhostError::InvalidOperation(
346 "t() only works on 2D tensors".to_string()
347 ));
348 }
349 self.transpose(0, 1)
350 }
351
352 pub fn squeeze(&self) -> Tensor {
354 let new_dims: Vec<usize> = self.dims().iter()
355 .filter(|&&d| d != 1)
356 .copied()
357 .collect();
358
359 if new_dims.is_empty() {
360 let data = self.data_f32();
362 Tensor::from_slice(&data, &[]).unwrap()
363 } else {
364 self.reshape(&new_dims).unwrap()
365 }
366 }
367
368 pub fn unsqueeze(&self, dim: usize) -> Result<Tensor> {
370 if dim > self.ndim() {
371 return Err(GhostError::DimOutOfBounds {
372 dim,
373 ndim: self.ndim() + 1,
374 });
375 }
376
377 let mut new_dims = self.dims().to_vec();
378 new_dims.insert(dim, 1);
379 self.reshape(&new_dims)
380 }
381
382 pub fn deep_clone(&self) -> Self {
386 let data = self.data_f32();
387 Tensor::from_slice(&data, self.dims()).unwrap()
388 }
389}
390
391impl Clone for Tensor {
392 fn clone(&self) -> Self {
394 Tensor {
395 storage: self.storage.clone(),
396 shape: self.shape.clone(),
397 strides: self.strides.clone(),
398 offset: self.offset,
399 requires_grad: self.requires_grad,
400 grad: self.grad.clone(),
401 }
402 }
403}
404
405impl std::fmt::Display for Tensor {
406 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
407 write!(f, "Tensor(shape={}, dtype={})", self.shape, self.dtype())
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_tensor_creation() {
417 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
418 assert_eq!(t.dims(), &[2, 2]);
419 assert_eq!(t.numel(), 4);
420 }
421
422 #[test]
423 fn test_zeros_ones() {
424 let zeros = Tensor::zeros(&[3, 3]);
425 let ones = Tensor::ones(&[3, 3]);
426
427 assert!(zeros.data_f32().iter().all(|&x| x == 0.0));
428 assert!(ones.data_f32().iter().all(|&x| x == 1.0));
429 }
430
431 #[test]
432 fn test_reshape() {
433 let t = Tensor::arange(0.0, 12.0, 1.0);
434 let reshaped = t.reshape(&[3, 4]).unwrap();
435 assert_eq!(reshaped.dims(), &[3, 4]);
436 }
437
438 #[test]
439 fn test_transpose() {
440 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
441 let transposed = t.t().unwrap();
442 assert_eq!(transposed.dims(), &[3, 2]);
443 }
444}