1use core::fmt;
6
7#[cfg(feature = "tensor")]
8use ndarray::Array2;
9
10use crate::tensor::error::TensorError;
11use crate::tensor::traits::{DType, Device, TensorBase, TensorOps};
12
13#[derive(Clone, PartialEq)]
17pub struct DenseTensor {
18 data: Vec<f64>,
20 shape: Vec<usize>,
22 strides: Vec<usize>,
24 dtype: DType,
26 device: Device,
28}
29
30#[cfg(feature = "tensor")]
31impl DenseTensor {
32 pub fn nbytes(&self) -> usize {
34 self.data.len() * self.dtype.size_bytes()
35 }
36
37 pub fn is_contiguous(&self) -> bool {
39 self.is_c_contiguous()
40 }
41
42 pub fn alignment(&self) -> usize {
44 64 }
46
47 pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Self {
59 let expected_len = shape.iter().product::<usize>();
60 assert_eq!(
61 data.len(),
62 expected_len,
63 "Data length {} does not match shape product {}",
64 data.len(),
65 expected_len
66 );
67
68 let strides = compute_strides(&shape);
69 Self {
70 data,
71 shape,
72 strides,
73 dtype: DType::F64,
74 device: Device::Cpu,
75 }
76 }
77
78 pub fn from_vec(data: Vec<f64>, shape: Vec<usize>) -> Self {
80 Self::new(data, shape)
81 }
82
83 pub fn zeros(shape: Vec<usize>) -> Self {
85 let data = vec![0.0; shape.iter().product()];
86 Self::new(data, shape)
87 }
88
89 pub fn ones(shape: Vec<usize>) -> Self {
91 let data = vec![1.0; shape.iter().product()];
92 Self::new(data, shape)
93 }
94
95 pub fn scalar(value: f64) -> Self {
97 Self {
98 data: vec![value],
99 shape: vec![],
100 strides: vec![],
101 dtype: DType::F64,
102 device: Device::Cpu,
103 }
104 }
105
106 pub fn matrix(rows: usize, cols: usize, data: Vec<f64>) -> Self {
108 Self::new(data, vec![rows, cols])
109 }
110
111 pub fn eye(size: usize) -> Self {
113 let mut data = vec![0.0; size * size];
114 for i in 0..size {
115 data[i * size + i] = 1.0;
116 }
117 Self::new(data, vec![size, size])
118 }
119
120 #[cfg(feature = "tensor")]
122 pub fn from_ndarray(arr: &Array2<f64>) -> Self {
123 let shape = vec![arr.nrows(), arr.ncols()];
124 let data = arr.as_slice().unwrap().to_vec();
125 Self::new(data, shape)
126 }
127
128 #[cfg(feature = "tensor")]
130 pub fn to_ndarray(&self) -> Result<Array2<f64>, TensorError> {
131 if self.ndim() != 2 {
132 return Err(TensorError::DimensionMismatch {
133 expected: 2,
134 got: self.ndim(),
135 });
136 }
137 Ok(Array2::from_shape_vec((self.shape[0], self.shape[1]), self.data.clone()).unwrap())
138 }
139
140 pub fn data(&self) -> &[f64] {
142 &self.data
143 }
144
145 pub fn data_mut(&mut self) -> &mut [f64] {
147 &mut self.data
148 }
149
150 pub fn strides(&self) -> &[usize] {
152 &self.strides
153 }
154
155 pub fn is_c_contiguous(&self) -> bool {
157 if self.ndim() <= 1 {
158 return true;
159 }
160 for i in 0..self.ndim() - 1 {
161 if self.strides[i] != self.strides[i + 1] * self.shape[i + 1] {
162 return false;
163 }
164 }
165 true
166 }
167
168 pub fn get(&self, indices: &[usize]) -> Result<f64, TensorError> {
170 if indices.len() != self.ndim() {
171 return Err(TensorError::DimensionMismatch {
172 expected: self.ndim(),
173 got: indices.len(),
174 });
175 }
176
177 let mut offset = 0;
178 for (i, &idx) in indices.iter().enumerate() {
179 if idx >= self.shape[i] {
180 return Err(TensorError::IndexOutOfBounds {
181 index: idx,
182 dim: i,
183 size: self.shape[i],
184 });
185 }
186 offset += idx * self.strides[i];
187 }
188
189 Ok(self.data[offset])
190 }
191
192 pub fn set(&mut self, indices: &[usize], value: f64) -> Result<(), TensorError> {
194 if indices.len() != self.ndim() {
195 return Err(TensorError::DimensionMismatch {
196 expected: self.ndim(),
197 got: indices.len(),
198 });
199 }
200
201 let mut offset = 0;
202 for (i, &idx) in indices.iter().enumerate() {
203 if idx >= self.shape[i] {
204 return Err(TensorError::IndexOutOfBounds {
205 index: idx,
206 dim: i,
207 size: self.shape[i],
208 });
209 }
210 offset += idx * self.strides[i];
211 }
212
213 self.data[offset] = value;
214 Ok(())
215 }
216
217 pub fn row(&self, row: usize) -> Result<Vec<f64>, TensorError> {
219 if self.ndim() != 2 {
220 return Err(TensorError::DimensionMismatch {
221 expected: 2,
222 got: self.ndim(),
223 });
224 }
225 if row >= self.shape[0] {
226 return Err(TensorError::IndexOutOfBounds {
227 index: row,
228 dim: 0,
229 size: self.shape[0],
230 });
231 }
232
233 let start = row * self.strides[0];
234 let cols = self.shape[1];
235 Ok(self.data[start..start + cols].to_vec())
236 }
237
238 pub fn col(&self, col: usize) -> Result<Vec<f64>, TensorError> {
240 if self.ndim() != 2 {
241 return Err(TensorError::DimensionMismatch {
242 expected: 2,
243 got: self.ndim(),
244 });
245 }
246 if col >= self.shape[1] {
247 return Err(TensorError::IndexOutOfBounds {
248 index: col,
249 dim: 1,
250 size: self.shape[1],
251 });
252 }
253
254 let mut result = Vec::with_capacity(self.shape[0]);
255 for row in 0..self.shape[0] {
256 let idx = row * self.strides[0] + col;
257 result.push(self.data[idx]);
258 }
259 Ok(result)
260 }
261}
262
263fn compute_strides(shape: &[usize]) -> Vec<usize> {
265 let ndim = shape.len();
266 if ndim == 0 {
267 return vec![];
268 }
269
270 let mut strides = vec![1; ndim];
271 for i in (0..ndim - 1).rev() {
272 strides[i] = strides[i + 1] * shape[i + 1];
273 }
274 strides
275}
276
277#[cfg(feature = "tensor")]
278impl TensorBase for DenseTensor {
279 fn shape(&self) -> &[usize] {
280 &self.shape
281 }
282
283 fn dtype(&self) -> DType {
284 self.dtype
285 }
286
287 fn device(&self) -> Device {
288 self.device
289 }
290
291 fn to_dense(&self) -> DenseTensor {
292 self.clone()
293 }
294
295 #[cfg(feature = "tensor")]
296 fn to_sparse(&self) -> Option<crate::tensor::sparse::SparseTensor> {
297 let mut row_offsets = vec![0];
299 let mut col_indices = Vec::new();
300 let mut values = Vec::new();
301
302 if self.ndim() == 2 {
303 let rows = self.shape[0];
304 let cols = self.shape[1];
305
306 for row in 0..rows {
307 for col in 0..cols {
308 let val = self.get(&[row, col]).unwrap();
309 if val.abs() > 1e-10 {
310 col_indices.push(col);
311 values.push(val);
312 }
313 }
314 row_offsets.push(col_indices.len());
315 }
316
317 let values_tensor = DenseTensor::new(values.clone(), vec![values.len()]);
318 let csr = crate::tensor::sparse::CSRTensor::new(
319 row_offsets,
320 col_indices,
321 values_tensor,
322 [self.shape[0], self.shape[1]],
323 );
324 Some(crate::tensor::sparse::SparseTensor::CSR(csr))
325 } else {
326 None
327 }
328 }
329}
330
331#[cfg(feature = "tensor")]
332impl TensorOps for DenseTensor {
333 fn add(&self, other: &Self) -> Self {
334 assert_eq!(
335 self.shape, other.shape,
336 "Shape mismatch for addition: {:?} vs {:?}",
337 self.shape, other.shape
338 );
339
340 let data: Vec<f64> = self
341 .data
342 .iter()
343 .zip(other.data.iter())
344 .map(|(&a, &b)| a + b)
345 .collect();
346
347 Self::new(data, self.shape.clone())
348 }
349
350 fn sub(&self, other: &Self) -> Self {
351 assert_eq!(
352 self.shape, other.shape,
353 "Shape mismatch for subtraction: {:?} vs {:?}",
354 self.shape, other.shape
355 );
356
357 let data: Vec<f64> = self
358 .data
359 .iter()
360 .zip(other.data.iter())
361 .map(|(&a, &b)| a - b)
362 .collect();
363
364 Self::new(data, self.shape.clone())
365 }
366
367 fn mul(&self, other: &Self) -> Self {
368 assert_eq!(
369 self.shape, other.shape,
370 "Shape mismatch for element-wise multiplication: {:?} vs {:?}",
371 self.shape, other.shape
372 );
373
374 let data: Vec<f64> = self
375 .data
376 .iter()
377 .zip(other.data.iter())
378 .map(|(&a, &b)| a * b)
379 .collect();
380
381 Self::new(data, self.shape.clone())
382 }
383
384 fn div(&self, other: &Self) -> Self {
385 assert_eq!(
386 self.shape, other.shape,
387 "Shape mismatch for division: {:?} vs {:?}",
388 self.shape, other.shape
389 );
390
391 let data: Vec<f64> = self
392 .data
393 .iter()
394 .zip(other.data.iter())
395 .map(|(&a, &b)| a / b)
396 .collect();
397
398 Self::new(data, self.shape.clone())
399 }
400
401 fn matmul(&self, other: &Self) -> Self {
402 assert_eq!(
403 self.ndim(),
404 2,
405 "matmul requires 2D tensors, got {}D",
406 self.ndim()
407 );
408 assert_eq!(
409 other.ndim(),
410 2,
411 "matmul requires 2D tensors, got {}D",
412 other.ndim()
413 );
414 assert_eq!(
415 self.shape[1], other.shape[0],
416 "Shape mismatch for matmul: {:?} x {:?}",
417 self.shape, other.shape
418 );
419
420 let m = self.shape[0];
421 let k = self.shape[1];
422 let n = other.shape[1];
423
424 let mut result = vec![0.0; m * n];
425
426 for i in 0..m {
428 for j in 0..n {
429 let mut sum = 0.0;
430 for p in 0..k {
431 sum += self.data[i * k + p] * other.data[p * n + j];
432 }
433 result[i * n + j] = sum;
434 }
435 }
436
437 Self::new(result, vec![m, n])
438 }
439
440 fn transpose(&self, axes: Option<&[usize]>) -> Self {
441 if self.ndim() == 0 {
442 return self.clone();
443 }
444
445 if self.ndim() == 2 {
446 let rows = self.shape[0];
448 let cols = self.shape[1];
449 let mut result = vec![0.0; cols * rows];
450
451 for i in 0..rows {
452 for j in 0..cols {
453 result[j * rows + i] = self.data[i * cols + j];
454 }
455 }
456
457 Self::new(result, vec![cols, rows])
458 } else {
459 let default_axes: Vec<usize> = (0..self.ndim()).rev().collect();
461 let axes = axes.unwrap_or(&default_axes);
462
463 assert_eq!(axes.len(), self.ndim(), "Axes length must match ndim");
464
465 let new_shape: Vec<usize> = axes.iter().map(|&a| self.shape[a]).collect();
466 let mut result = vec![0.0; self.numel()];
467
468 for (i, &val) in self.data.iter().enumerate() {
470 let mut idx = i;
471 let mut new_idx = 0;
472 let mut stride = 1;
473
474 for &a in axes.iter().rev() {
475 let dim_size = self.shape[a];
476 let dim_idx = idx % dim_size;
477 idx /= dim_size;
478 new_idx += dim_idx * stride;
479 stride *= new_shape[new_shape.len() - 1 - a];
480 }
481
482 result[new_idx] = val;
483 }
484
485 Self::new(result, new_shape)
486 }
487 }
488
489 fn sum(&self, axes: Option<&[usize]>) -> Self {
490 if let Some(axes) = axes {
491 if axes.is_empty() {
492 return self.clone();
493 }
494
495 if axes.len() == 1 {
497 let axis = axes[0];
498 if self.ndim() == 2 && axis == 0 {
499 let cols = self.shape[1];
501 let mut result = vec![0.0; cols];
502 for row in self.data.chunks(cols) {
503 for (j, val) in row.iter().enumerate() {
504 result[j] += val;
505 }
506 }
507 return Self::new(result, vec![cols]);
508 } else if self.ndim() == 2 && axis == 1 {
509 let rows = self.shape[0];
511 let cols = self.shape[1];
512 let mut result = vec![0.0; rows];
513 for (i, row_sum) in result.iter_mut().enumerate().take(rows) {
514 let row_start = i * cols;
515 *row_sum = self.data[row_start..row_start + cols].iter().sum();
516 }
517 return Self::new(result, vec![rows]);
518 }
519 }
520
521 let sum: f64 = self.data.iter().sum();
523 Self::scalar(sum)
524 } else {
525 let sum: f64 = self.data.iter().sum();
527 Self::scalar(sum)
528 }
529 }
530
531 fn mean(&self, axes: Option<&[usize]>) -> Self {
532 let sum = self.sum(axes);
533 let count = if let Some(axes) = axes {
534 if axes.is_empty() {
535 1
536 } else {
537 axes.iter().map(|&a| self.shape[a]).product::<usize>()
538 }
539 } else {
540 self.numel()
541 };
542
543 sum.mul_scalar(1.0 / count as f64)
544 }
545
546 fn mul_scalar(&self, scalar: f64) -> Self {
547 let data: Vec<f64> = self.data.iter().map(|&x| x * scalar).collect();
548 Self::new(data, self.shape.clone())
549 }
550
551 fn add_scalar(&self, scalar: f64) -> Self {
552 let data: Vec<f64> = self.data.iter().map(|&x| x + scalar).collect();
553 Self::new(data, self.shape.clone())
554 }
555
556 fn map<F>(&self, f: F) -> Self
557 where
558 F: Fn(f64) -> f64 + Send + Sync,
559 {
560 let data: Vec<f64> = self.data.iter().copied().map(f).collect();
561 Self::new(data, self.shape.clone())
562 }
563
564 fn reshape(&self, new_shape: &[usize]) -> Self {
565 let new_size: usize = new_shape.iter().product();
566 assert_eq!(
567 new_size,
568 self.numel(),
569 "Reshape size mismatch: {} vs {}",
570 new_size,
571 self.numel()
572 );
573
574 Self::new(self.data.clone(), new_shape.to_vec())
575 }
576
577 fn slice(&self, axes: &[usize], ranges: &[core::ops::Range<usize>]) -> Self {
578 assert_eq!(axes.len(), ranges.len(), "Axes and ranges length mismatch");
579
580 if self.ndim() == 2 && axes.len() == 2 {
582 let row_range = if axes[0] == 0 {
583 ranges[0].clone()
584 } else {
585 ranges[1].clone()
586 };
587 let col_range = if axes[1] == 1 {
588 ranges[1].clone()
589 } else {
590 ranges[0].clone()
591 };
592
593 let new_rows = row_range.len();
594 let new_cols = col_range.len();
595 let mut result = Vec::with_capacity(new_rows * new_cols);
596
597 for i in row_range {
598 for j in col_range.clone() {
599 result.push(self.data[i * self.shape[1] + j]);
600 }
601 }
602
603 return Self::new(result, vec![new_rows, new_cols]);
604 }
605
606 self.clone()
608 }
609
610 fn concat(&self, other: &Self, axis: usize) -> Self {
611 assert_eq!(
612 self.ndim(),
613 other.ndim(),
614 "Concat ndim mismatch: {} vs {}",
615 self.ndim(),
616 other.ndim()
617 );
618 assert!(
619 axis < self.ndim(),
620 "Axis {} out of range for {}D tensor",
621 axis,
622 self.ndim()
623 );
624
625 for (i, (&s, &o)) in self.shape.iter().zip(other.shape.iter()).enumerate() {
627 if i != axis {
628 assert_eq!(s, o, "Shape mismatch at dim {}", i);
629 }
630 }
631
632 if self.ndim() == 2 && axis == 0 {
634 assert_eq!(
635 self.shape[1], other.shape[1],
636 "Column count mismatch for concat"
637 );
638
639 let new_rows = self.shape[0] + other.shape[0];
640 let cols = self.shape[1];
641 let mut result = Vec::with_capacity(new_rows * cols);
642
643 result.extend_from_slice(&self.data);
645 result.extend_from_slice(&other.data);
647
648 return Self::new(result, vec![new_rows, cols]);
649 }
650
651 unimplemented!("Concat for this case is not implemented")
653 }
654
655 fn max(&self) -> f64 {
656 self.data.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
657 }
658
659 fn min(&self) -> f64 {
660 self.data.iter().cloned().fold(f64::INFINITY, f64::min)
661 }
662
663 fn norm(&self) -> f64 {
664 self.data.iter().map(|&x| x * x).sum::<f64>().sqrt()
665 }
666
667 fn normalize(&self) -> Self {
668 let norm = self.norm();
669 if norm > 1e-10 {
670 self.mul_scalar(1.0 / norm)
671 } else {
672 self.clone()
673 }
674 }
675}
676
677#[cfg(feature = "tensor")]
679impl DenseTensor {
680 pub fn silu(&self) -> Self {
682 self.map(|x| x / (1.0 + (-x).exp()))
683 }
684
685 pub fn gelu_derivative(&self) -> Self {
687 const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
688 const COEF: f64 = 0.044715;
689 self.map(|x| {
690 let x2 = x * x;
691 let x3 = x * x2;
692 let tanh_arg = SQRT_2_OVER_PI * (x + COEF * x3);
693 let tanh_val = tanh_arg.tanh();
694 0.5 * (1.0 + tanh_val) + x * 0.5 * (1.0 - tanh_val * tanh_val) * SQRT_2_OVER_PI * (1.0 + 3.0 * COEF * x2)
695 })
696 }
697
698 pub fn mean_dim(&self, dim: isize) -> Self {
700 let ndim = self.ndim();
701 let axis = if dim < 0 { (ndim as isize + dim) as usize } else { dim as usize };
702
703 if ndim == 2 && axis == 0 {
704 let cols = self.shape[1];
706 let rows = self.shape[0];
707 let mut result = vec![0.0; cols];
708 #[allow(clippy::needless_range_loop)]
709 for col in 0..cols {
710 for row in 0..rows {
711 result[col] += self.data[row * cols + col];
712 }
713 result[col] /= rows as f64;
714 }
715 Self::new(result, vec![1, cols])
716 } else if ndim == 2 && axis == 1 {
717 let rows = self.shape[0];
719 let cols = self.shape[1];
720 let mut result = vec![0.0; rows];
721 #[allow(clippy::needless_range_loop)]
722 for row in 0..rows {
723 let row_start = row * cols;
724 result[row] = self.data[row_start..row_start + cols].iter().sum::<f64>() / cols as f64;
725 }
726 Self::new(result, vec![rows, 1])
727 } else if ndim == 3 && axis == 2 {
728 let batch = self.shape[0];
730 let seq = self.shape[1];
731 let dim = self.shape[2];
732 let mut result = vec![0.0; batch * seq];
733 for b in 0..batch {
734 for s in 0..seq {
735 let start = (b * seq + s) * dim;
736 let sum: f64 = self.data[start..start + dim].iter().sum();
737 result[b * seq + s] = sum / dim as f64;
738 }
739 }
740 Self::new(result, vec![batch, seq, 1])
741 } else {
742 let sum: f64 = self.data.iter().sum();
744 Self::scalar(sum / self.numel() as f64)
745 }
746 }
747
748 pub fn var_dim(&self, dim: isize) -> Self {
750 let mean = self.mean_dim(dim);
751 let ndim = self.ndim();
752 let axis = if dim < 0 { (ndim as isize + dim) as usize } else { dim as usize };
753
754 if ndim == 2 && axis == 0 {
755 let cols = self.shape[1];
756 let rows = self.shape[0];
757 let mut result = vec![0.0; cols];
758 #[allow(clippy::needless_range_loop)]
759 for col in 0..cols {
760 for row in 0..rows {
761 let diff = self.data[row * cols + col] - mean.data()[col];
762 result[col] += diff * diff;
763 }
764 result[col] /= rows as f64;
765 }
766 Self::new(result, vec![1, cols])
767 } else if ndim == 2 && axis == 1 {
768 let rows = self.shape[0];
769 let cols = self.shape[1];
770 let mut result = vec![0.0; rows];
771 #[allow(clippy::needless_range_loop)]
772 for row in 0..rows {
773 let row_start = row * cols;
774 let m = mean.data()[row];
775 let var: f64 = self.data[row_start..row_start + cols]
776 .iter()
777 .map(|&x| (x - m) * (x - m))
778 .sum::<f64>() / cols as f64;
779 result[row] = var;
780 }
781 Self::new(result, vec![rows, 1])
782 } else if ndim == 3 && axis == 2 {
783 let batch = self.shape[0];
784 let seq = self.shape[1];
785 let dim = self.shape[2];
786 let mut result = vec![0.0; batch * seq];
787 for b in 0..batch {
788 for s in 0..seq {
789 let start = (b * seq + s) * dim;
790 let m = mean.data()[b * seq + s];
791 let var: f64 = self.data[start..start + dim]
792 .iter()
793 .map(|&x| (x - m) * (x - m))
794 .sum::<f64>() / dim as f64;
795 result[b * seq + s] = var;
796 }
797 }
798 Self::new(result, vec![batch, seq, 1])
799 } else {
800 let mean_val = self.data.iter().sum::<f64>() / self.numel() as f64;
802 let var: f64 = self.data.iter().map(|&x| (x - mean_val) * (x - mean_val)).sum::<f64>() / self.numel() as f64;
803 Self::scalar(var)
804 }
805 }
806
807 pub fn sqrt(&self) -> Self {
809 self.map(|x| x.sqrt())
810 }
811
812 pub fn neg(&self) -> Self {
814 self.mul_scalar(-1.0)
815 }
816
817 pub fn gt(&self, value: f64) -> Self {
819 self.map(|x| if x > value { 1.0 } else { 0.0 })
820 }
821
822 pub fn mask_fill(&self, mask: &Self, value: f64) -> Self {
824 assert_eq!(self.shape, mask.shape, "Shape mismatch for mask_fill");
825 let data: Vec<f64> = self.data.iter()
826 .zip(mask.data.iter())
827 .map(|(&v, &m)| if m > 0.5 { value } else { v })
828 .collect();
829 Self::new(data, self.shape.clone())
830 }
831
832 pub fn transpose_2d(&self) -> Self {
834 self.transpose(None)
835 }
836
837 pub fn get_row(&self, row: usize) -> Self {
839 if self.ndim() == 2 {
840 let cols = self.shape[1];
841 let start = row * cols;
842 Self::new(self.data[start..start + cols].to_vec(), vec![1, cols])
843 } else if self.ndim() == 3 {
844 let batch = self.shape[0];
847 let dim = self.shape[2];
848 let mut result_data = Vec::with_capacity(batch * dim);
849
850 for b in 0..batch {
851 let offset = (b * self.shape[1] + row) * dim;
852 result_data.extend_from_slice(&self.data[offset..offset + dim]);
853 }
854
855 Self::new(result_data, vec![batch, dim])
856 } else {
857 Self::scalar(self.data[0])
859 }
860 }
861
862 pub fn set_row(&mut self, row: usize, data: &Self) {
864 if self.ndim() == 2 && data.ndim() == 2 {
865 let cols = self.shape[1];
866 let start = row * cols;
867 self.data[start..start + cols].copy_from_slice(data.data());
868 }
869 }
870
871 pub fn full(shape: &[usize], value: f64) -> Self {
873 let size: usize = shape.iter().product();
874 let data = vec![value; size];
875 Self::new(data, shape.to_vec())
876 }
877
878 pub fn scale(&self, scalar: f64) -> Self {
880 self.mul_scalar(scalar)
881 }
882
883 pub fn softmax(&self, dim: isize) -> Self {
885 crate::tensor::ops::activations::softmax(self, dim)
886 }
887
888 pub fn relu(&self) -> Self {
890 crate::tensor::ops::activations::relu(self)
891 }
892
893 pub fn gelu(&self) -> Self {
895 crate::tensor::ops::activations::gelu(self)
896 }
897
898 pub fn cos(&self) -> Self {
900 self.map(|x| x.cos())
901 }
902
903 pub fn sin(&self) -> Self {
905 self.map(|x| x.sin())
906 }
907
908 pub fn ln(&self) -> Self {
910 self.map(|x| x.ln())
911 }
912
913 pub fn bmm_broadcast_weight(&self, weight: &DenseTensor) -> Self {
917 assert_eq!(self.ndim(), 3, "bmm_broadcast_weight requires 3D tensor, got {}D", self.ndim());
918 assert_eq!(weight.ndim(), 2, "weight must be 2D, got {}D", weight.ndim());
919 assert_eq!(
920 self.shape[2], weight.shape[0],
921 "Shape mismatch for bmm: {:?} x {:?}",
922 self.shape, weight.shape
923 );
924
925 let batch = self.shape[0];
926 let seq = self.shape[1];
927 let hidden = self.shape[2];
928 let out = weight.shape[1];
929
930 let mut result = vec![0.0; batch * seq * out];
931
932 for b in 0..batch {
934 for s in 0..seq {
935 let input_start = (b * seq + s) * hidden;
936 let output_start = (b * seq + s) * out;
937
938 for o in 0..out {
939 let mut sum = 0.0;
940 for h in 0..hidden {
941 sum += self.data[input_start + h] * weight.data[h * out + o];
942 }
943 result[output_start + o] = sum;
944 }
945 }
946 }
947
948 Self::new(result, vec![batch, seq, out])
949 }
950
951 pub fn expand_last_dim(&self, target_dim: usize) -> Self {
954 assert!(
955 self.ndim() >= 1 && self.shape()[self.ndim() - 1] == 1,
956 "Last dimension must be 1 for expansion"
957 );
958
959 let mut new_shape = self.shape.to_vec();
960 new_shape[self.ndim() - 1] = target_dim;
961
962 let mut data = Vec::with_capacity(self.numel() * target_dim);
963 for &val in self.data.iter() {
964 for _ in 0..target_dim {
965 data.push(val);
966 }
967 }
968
969 Self::new(data, new_shape)
970 }
971
972 pub fn expand_to_3d(&self, batch: usize, seq: usize) -> Self {
974 assert_eq!(self.ndim(), 1, "Must be 1D tensor for 3D expansion");
975 let hidden = self.shape[0];
976
977 let mut data = Vec::with_capacity(batch * seq * hidden);
978 for _ in 0..batch * seq {
979 data.extend_from_slice(&self.data);
980 }
981
982 Self::new(data, vec![batch, seq, hidden])
983 }
984
985 pub fn expand_last_dim_2d(&self, target_dim: usize) -> Self {
988 assert!(
989 self.ndim() == 2 && self.shape()[1] == 1,
990 "Must be 2D tensor with last dim 1 for expansion"
991 );
992
993 let seq = self.shape[0];
994 let mut data = Vec::with_capacity(seq * target_dim);
995 for &val in self.data.iter() {
996 for _ in 0..target_dim {
997 data.push(val);
998 }
999 }
1000
1001 Self::new(data, vec![seq, target_dim])
1002 }
1003
1004 pub fn expand_to_2d(&self, seq: usize) -> Self {
1006 assert_eq!(self.ndim(), 1, "Must be 1D tensor for 2D expansion");
1007 let hidden = self.shape[0];
1008
1009 let mut data = Vec::with_capacity(seq * hidden);
1010 for _ in 0..seq {
1011 data.extend_from_slice(&self.data);
1012 }
1013
1014 Self::new(data, vec![seq, hidden])
1015 }
1016}
1017
1018#[cfg(feature = "tensor")]
1019impl fmt::Debug for DenseTensor {
1020 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1021 f.debug_struct("DenseTensor")
1022 .field("shape", &self.shape)
1023 .field("dtype", &self.dtype)
1024 .field("device", &self.device)
1025 .field("numel", &self.numel())
1026 .finish()
1027 }
1028}
1029
1030#[cfg(feature = "tensor")]
1031impl Default for DenseTensor {
1032 fn default() -> Self {
1033 Self::zeros(vec![1])
1034 }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039 use super::*;
1040
1041 #[test]
1042 fn test_dense_tensor_creation() {
1043 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1044 let tensor = DenseTensor::new(data.clone(), vec![2, 3]);
1045
1046 assert_eq!(tensor.shape(), &[2, 3]);
1047 assert_eq!(tensor.data(), &data);
1048 assert_eq!(tensor.numel(), 6);
1049 assert_eq!(tensor.ndim(), 2);
1050 }
1051
1052 #[test]
1053 fn test_zeros_and_ones() {
1054 let zeros = DenseTensor::zeros(vec![2, 3]);
1055 assert!(zeros.data().iter().all(|&x| x == 0.0));
1056
1057 let ones = DenseTensor::ones(vec![2, 3]);
1058 assert!(ones.data().iter().all(|&x| x == 1.0));
1059 }
1060
1061 #[test]
1062 fn test_matrix_operations() {
1063 let a = DenseTensor::matrix(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
1064 let b = DenseTensor::matrix(2, 2, vec![5.0, 6.0, 7.0, 8.0]);
1065
1066 let sum = a.add(&b);
1067 assert_eq!(sum.data(), &[6.0, 8.0, 10.0, 12.0]);
1068
1069 let diff = a.sub(&b);
1070 assert_eq!(diff.data(), &[-4.0, -4.0, -4.0, -4.0]);
1071 }
1072
1073 #[test]
1074 fn test_matmul() {
1075 let a = DenseTensor::matrix(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1076 let b = DenseTensor::matrix(3, 2, vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
1077
1078 let result = a.matmul(&b);
1079 assert_eq!(result.shape(), &[2, 2]);
1080 assert_eq!(result.data(), &[58.0, 64.0, 139.0, 154.0]);
1085 }
1086
1087 #[test]
1088 fn test_transpose() {
1089 let a = DenseTensor::matrix(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1090 let t = a.transpose(None);
1091
1092 assert_eq!(t.shape(), &[3, 2]);
1093 assert_eq!(t.data(), &[1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
1094 }
1095
1096 #[test]
1097 fn test_scalar_operations() {
1098 let a = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
1099
1100 let mul = a.mul_scalar(2.0);
1101 assert_eq!(mul.data(), &[2.0, 4.0, 6.0]);
1102
1103 let add = a.add_scalar(1.0);
1104 assert_eq!(add.data(), &[2.0, 3.0, 4.0]);
1105 }
1106
1107 #[test]
1108 fn test_norm_and_normalize() {
1109 let a = DenseTensor::new(vec![3.0, 4.0], vec![2]);
1110
1111 assert!((a.norm() - 5.0).abs() < 1e-10);
1112
1113 let normalized = a.normalize();
1114 assert!((normalized.norm() - 1.0).abs() < 1e-10);
1115 }
1116
1117 #[test]
1118 #[should_panic]
1119 fn test_shape_mismatch_panic() {
1120 let a = DenseTensor::new(vec![1.0, 2.0], vec![2]);
1121 let b = DenseTensor::new(vec![1.0, 2.0, 3.0], vec![3]);
1122 let _ = a.add(&b);
1123 }
1124}