1use std::sync::Arc;
4use parking_lot::RwLock;
5use rand_distr::{Distribution, Normal, Uniform};
6
7use crate::dtype::{DType, TensorElement};
8use crate::shape::{Shape, Strides};
9use crate::storage::Storage;
10use crate::error::{GhostError, Result};
11
12#[derive(Debug)]
19pub struct Tensor {
20 storage: Storage,
22 shape: Shape,
24 strides: Strides,
26 offset: usize,
28 requires_grad: bool,
30 grad: Option<Arc<RwLock<Tensor>>>,
32}
33
34impl Tensor {
35 pub fn from_slice<T: TensorElement>(data: &[T], shape: &[usize]) -> Result<Self> {
39 let shape = Shape::new(shape);
40 if data.len() != shape.numel() {
41 return Err(GhostError::InvalidShape(format!(
42 "Data length {} doesn't match shape {:?} (numel={})",
43 data.len(),
44 shape.dims(),
45 shape.numel()
46 )));
47 }
48
49 let strides = shape.default_strides();
50 let storage = Storage::from_slice(data);
51
52 Ok(Tensor {
53 storage,
54 shape,
55 strides,
56 offset: 0,
57 requires_grad: false,
58 grad: None,
59 })
60 }
61
62 pub fn zeros(shape: &[usize]) -> Self {
64 Self::full(shape, 0.0f32)
65 }
66
67 pub fn ones(shape: &[usize]) -> Self {
69 Self::full(shape, 1.0f32)
70 }
71
72 pub fn full<T: TensorElement>(shape: &[usize], value: T) -> Self {
74 let shape = Shape::new(shape);
75 let numel = shape.numel();
76 let data: Vec<T> = vec![value; numel];
77 let strides = shape.default_strides();
78 let storage = Storage::from_slice(&data);
79
80 Tensor {
81 storage,
82 shape,
83 strides,
84 offset: 0,
85 requires_grad: false,
86 grad: None,
87 }
88 }
89
90 pub fn rand(shape: &[usize]) -> Self {
92 let shape_obj = Shape::new(shape);
93 let numel = shape_obj.numel();
94 let mut rng = rand::thread_rng();
95 let dist = Uniform::new(0.0f32, 1.0);
96 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
97
98 Tensor::from_slice(&data, shape).unwrap()
99 }
100
101 pub fn randn(shape: &[usize]) -> Self {
103 let shape_obj = Shape::new(shape);
104 let numel = shape_obj.numel();
105 let mut rng = rand::thread_rng();
106 let dist = Normal::new(0.0f32, 1.0).unwrap();
107 let data: Vec<f32> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
108
109 Tensor::from_slice(&data, shape).unwrap()
110 }
111
112 pub fn eye(n: usize) -> Self {
114 let mut data = vec![0.0f32; n * n];
115 for i in 0..n {
116 data[i * n + i] = 1.0;
117 }
118 Tensor::from_slice(&data, &[n, n]).unwrap()
119 }
120
121 pub fn arange(start: f32, end: f32, step: f32) -> Self {
123 let mut data = Vec::new();
124 let mut val = start;
125 while val < end {
126 data.push(val);
127 val += step;
128 }
129 let len = data.len();
130 Tensor::from_slice(&data, &[len]).unwrap()
131 }
132
133 pub fn linspace(start: f32, end: f32, n: usize) -> Self {
135 if n == 0 {
136 return Tensor::from_slice::<f32>(&[], &[0]).unwrap();
137 }
138 if n == 1 {
139 return Tensor::from_slice(&[start], &[1]).unwrap();
140 }
141
142 let step = (end - start) / (n - 1) as f32;
143 let data: Vec<f32> = (0..n).map(|i| start + i as f32 * step).collect();
144 Tensor::from_slice(&data, &[n]).unwrap()
145 }
146
147 pub fn shape(&self) -> &Shape {
151 &self.shape
152 }
153
154 pub fn dims(&self) -> &[usize] {
156 self.shape.dims()
157 }
158
159 pub fn ndim(&self) -> usize {
161 self.shape.ndim()
162 }
163
164 pub fn numel(&self) -> usize {
166 self.shape.numel()
167 }
168
169 pub fn dtype(&self) -> DType {
171 self.storage.dtype()
172 }
173
174 pub fn strides(&self) -> &Strides {
176 &self.strides
177 }
178
179 pub fn is_contiguous(&self) -> bool {
181 self.strides.is_contiguous(&self.shape)
182 }
183
184 pub fn requires_grad(&self) -> bool {
186 self.requires_grad
187 }
188
189 pub fn set_requires_grad(&mut self, requires_grad: bool) {
193 self.requires_grad = requires_grad;
194 }
195
196 pub fn grad(&self) -> Option<Tensor> {
198 self.grad.as_ref().map(|g| g.read().clone())
199 }
200
201 pub fn set_grad(&mut self, grad: Tensor) {
203 self.grad = Some(Arc::new(RwLock::new(grad)));
204 }
205
206 pub fn zero_grad(&mut self) {
208 if let Some(ref grad) = self.grad {
209 let mut g = grad.write();
210 let zeros = Tensor::zeros(g.dims());
211 *g = zeros;
212 }
213 }
214
215 pub fn data_f32(&self) -> Vec<f32> {
219 let guard = self.storage.as_slice::<f32>();
220 if self.is_contiguous() && self.offset == 0 {
221 guard.to_vec()
222 } else {
223 self.to_contiguous_data::<f32>()
225 }
226 }
227
228 fn to_contiguous_data<T: TensorElement>(&self) -> Vec<T> {
230 let numel = self.numel();
231 let mut result = Vec::with_capacity(numel);
232 let guard = self.storage.as_slice::<T>();
233
234 self.for_each_index(|indices| {
236 let offset = self.compute_offset(indices);
237 result.push(guard[offset]);
238 });
239
240 result
241 }
242
243 fn compute_offset(&self, indices: &[usize]) -> usize {
245 self.offset + self.strides.offset(indices)
246 }
247
248 fn for_each_index<F: FnMut(&[usize])>(&self, mut f: F) {
250 let dims = self.dims();
251 if dims.is_empty() {
252 f(&[]);
253 return;
254 }
255
256 let mut indices = vec![0usize; dims.len()];
257 loop {
258 f(&indices);
259
260 let mut i = dims.len() - 1;
262 loop {
263 indices[i] += 1;
264 if indices[i] < dims[i] {
265 break;
266 }
267 indices[i] = 0;
268 if i == 0 {
269 return;
270 }
271 i -= 1;
272 }
273 }
274 }
275
276 pub fn reshape(&self, new_shape: &[usize]) -> Result<Tensor> {
280 let new_shape = Shape::new(new_shape);
281 if new_shape.numel() != self.numel() {
282 return Err(GhostError::InvalidShape(format!(
283 "Cannot reshape tensor of {} elements to shape {:?}",
284 self.numel(),
285 new_shape.dims()
286 )));
287 }
288
289 if self.is_contiguous() {
291 let new_strides = new_shape.default_strides();
292 return Ok(Tensor {
293 storage: self.storage.clone(),
294 shape: new_shape,
295 strides: new_strides,
296 offset: self.offset,
297 requires_grad: self.requires_grad,
298 grad: None,
299 });
300 }
301
302 let data = self.to_contiguous_data::<f32>();
304 Tensor::from_slice(&data, new_shape.dims())
305 }
306
307 pub fn flatten(&self) -> Result<Tensor> {
309 self.reshape(&[self.numel()])
310 }
311
312 pub fn transpose(&self, dim0: usize, dim1: usize) -> Result<Tensor> {
314 if dim0 >= self.ndim() || dim1 >= self.ndim() {
315 return Err(GhostError::DimOutOfBounds {
316 dim: dim0.max(dim1),
317 ndim: self.ndim(),
318 });
319 }
320
321 let mut new_shape = self.shape.dims().to_vec();
322 let mut new_strides = self.strides.as_slice().to_vec();
323
324 new_shape.swap(dim0, dim1);
325 new_strides.swap(dim0, dim1);
326
327 Ok(Tensor {
328 storage: self.storage.clone(),
329 shape: Shape::from(new_shape),
330 strides: Strides::from(new_strides.as_slice()),
331 offset: self.offset,
332 requires_grad: self.requires_grad,
333 grad: None,
334 })
335 }
336
337 pub fn t(&self) -> Result<Tensor> {
339 if self.ndim() != 2 {
340 return Err(GhostError::InvalidOperation(
341 "t() only works on 2D tensors".to_string()
342 ));
343 }
344 self.transpose(0, 1)
345 }
346
347 pub fn squeeze(&self) -> Tensor {
349 let new_dims: Vec<usize> = self.dims().iter()
350 .filter(|&&d| d != 1)
351 .copied()
352 .collect();
353
354 if new_dims.is_empty() {
355 let data = self.data_f32();
357 Tensor::from_slice(&data, &[]).unwrap()
358 } else {
359 self.reshape(&new_dims).unwrap()
360 }
361 }
362
363 pub fn unsqueeze(&self, dim: usize) -> Result<Tensor> {
365 if dim > self.ndim() {
366 return Err(GhostError::DimOutOfBounds {
367 dim,
368 ndim: self.ndim() + 1,
369 });
370 }
371
372 let mut new_dims = self.dims().to_vec();
373 new_dims.insert(dim, 1);
374 self.reshape(&new_dims)
375 }
376
377 pub fn deep_clone(&self) -> Self {
381 let data = self.data_f32();
382 Tensor::from_slice(&data, self.dims()).unwrap()
383 }
384}
385
386impl Clone for Tensor {
387 fn clone(&self) -> Self {
389 Tensor {
390 storage: self.storage.clone(),
391 shape: self.shape.clone(),
392 strides: self.strides.clone(),
393 offset: self.offset,
394 requires_grad: self.requires_grad,
395 grad: self.grad.clone(),
396 }
397 }
398}
399
400impl std::fmt::Display for Tensor {
401 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402 write!(f, "Tensor(shape={}, dtype={})", self.shape, self.dtype())
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_tensor_creation() {
412 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
413 assert_eq!(t.dims(), &[2, 2]);
414 assert_eq!(t.numel(), 4);
415 }
416
417 #[test]
418 fn test_zeros_ones() {
419 let zeros = Tensor::zeros(&[3, 3]);
420 let ones = Tensor::ones(&[3, 3]);
421
422 assert!(zeros.data_f32().iter().all(|&x| x == 0.0));
423 assert!(ones.data_f32().iter().all(|&x| x == 1.0));
424 }
425
426 #[test]
427 fn test_reshape() {
428 let t = Tensor::arange(0.0, 12.0, 1.0);
429 let reshaped = t.reshape(&[3, 4]).unwrap();
430 assert_eq!(reshaped.dims(), &[3, 4]);
431 }
432
433 #[test]
434 fn test_transpose() {
435 let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
436 let transposed = t.t().unwrap();
437 assert_eq!(transposed.dims(), &[3, 2]);
438 }
439}