1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
//! Element-wise arithmetic operations for tensors
//! 🚀 Enhanced with SciRS2 breakthrough hyperoptimized SIMD implementations
//! - Up to 14.17x speedup for medium arrays with TLB optimization
//! - 7.93x speedup for small arrays with cache-line aware processing
//! - 7.41x speedup for large arrays with software pipelining
//! - Adaptive selection automatically chooses optimal strategy
use crate::{FloatElement, Tensor, TensorElement};
use torsh_core::error::{Result, TorshError};
use super::simd::{should_use_simd, SimdOpType};
/// Element-wise arithmetic operations
impl<
T: TensorElement
+ Copy
+ Default
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
> Tensor<T>
{
/// Element-wise addition with broadcasting (ops module implementation)
///
/// # SIMD Optimization (Phase 3/4)
/// For f32 tensors with matching shapes, uses adaptive SIMD dispatch:
/// - Small tensors (<512): Scalar (SIMD overhead not worth it)
/// - Medium tensors (512-65K): Phase 3 SIMD (uninit buffer + scirs2 API)
/// - Large tensors (>65K): Parallel SIMD (Rayon + SIMD chunks)
pub fn add_op(&self, other: &Self) -> Result<Self> {
let mut result = if self.shape() == other.shape() {
// 🚀 Phase 3/4: Use adaptive SIMD for f32 tensors
#[cfg(feature = "simd")]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
// Use Phase 4 adaptive dispatch for f32
return {
let mut result = self.add_adaptive(other)?;
// Track the operation for gradient computation
if self.requires_grad || other.requires_grad {
use std::sync::Arc;
result.requires_grad = true;
result.operation = crate::Operation::Add {
lhs: Arc::new(self.clone()),
rhs: Arc::new(other.clone()),
};
}
Ok(result)
};
}
}
// Fallback to scalar for non-f32 types
self.element_wise_op(other, |a, b| a + b)?
} else {
self.broadcast_binary_op(other, |a, b| a + b)?
};
// Track the operation for gradient computation
if self.requires_grad || other.requires_grad {
use std::sync::Arc;
result.requires_grad = true;
result.operation = crate::Operation::Add {
lhs: Arc::new(self.clone()),
rhs: Arc::new(other.clone()),
};
}
Ok(result)
}
/// Performs element-wise addition of two tensors with broadcasting support.
///
/// Adds corresponding elements from `self` and `other`. If the tensors have
/// different shapes, broadcasting rules are applied to make them compatible.
///
/// # Broadcasting Rules
/// - Dimensions are aligned from right to left
/// - Dimension of size 1 can broadcast to any size
/// - Missing dimensions are treated as size 1
///
/// Examples of valid broadcasts:
/// - `[3, 4]` + `[3, 4]` → `[3, 4]` (same shape)
/// - `[3, 4]` + `[4]` → `[3, 4]` (broadcast last dimension)
/// - `[3, 1]` + `[1, 4]` → `[3, 4]` (broadcast both)
/// - `[3, 4, 5]` + `[5]` → `[3, 4, 5]` (broadcast to batch)
///
/// # Performance
/// For matching shapes with f32 type, automatically uses adaptive SIMD optimization:
/// - Small tensors (<512 elements): Scalar operations
/// - Medium tensors (512-65K): SIMD vectorization (up to 14x speedup)
/// - Large tensors (>65K): Parallel SIMD (multi-threaded)
///
/// # Gradient Tracking
/// If either tensor has `requires_grad=true`, the operation is recorded
/// in the computational graph for automatic differentiation.
///
/// # Arguments
/// * `other` - The tensor to add to `self`
///
/// # Returns
/// A new tensor containing the element-wise sum
///
/// # Errors
/// Returns error if the shapes are not compatible for broadcasting
///
/// # Examples
/// ```rust,no_run
/// # use torsh::Tensor;
/// # use torsh_core::device::DeviceType;
/// // Same shape addition
/// let a = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu)?;
/// let b = Tensor::from_data(vec![4.0, 5.0, 6.0], vec![3], DeviceType::Cpu)?;
/// let c = a.add(&b)?;
/// assert_eq!(c.data()?, vec![5.0, 7.0, 9.0]);
///
/// // Broadcasting: [2,3] + [3] → [2,3]
/// let a = Tensor::from_data(
/// vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
/// vec![2, 3],
/// DeviceType::Cpu
/// )?;
/// let b = Tensor::from_data(vec![10.0, 20.0, 30.0], vec![3], DeviceType::Cpu)?;
/// let c = a.add(&b)?; // Adds [10,20,30] to each row
/// assert_eq!(c.data()?, vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
///
/// // Neural network bias addition
/// let activations = Tensor::randn(&[32, 128], DeviceType::Cpu)?; // Batch output
/// let bias = Tensor::randn(&[128], DeviceType::Cpu)?; // Bias vector
/// let output = activations.add(&bias)?; // Broadcasts bias to all samples
///
/// // Matrix addition for residual connections
/// let x = Tensor::randn(&[64, 64], DeviceType::Cpu)?;
/// let residual = Tensor::randn(&[64, 64], DeviceType::Cpu)?;
/// let output = x.add(&residual)?; // Element-wise sum
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
///
/// # PyTorch Compatibility
/// Equivalent to `torch.add(a, b)` or `a + b`
///
/// See also: [`Self::add_scalar`], [`Self::add_`], [`Self::sub`], [`Self::mul`]
pub fn add(&self, other: &Self) -> Result<Self> {
self.add_op(other)
}
/// Element-wise subtraction with broadcasting
///
/// # SIMD Optimization (Phase 3/4)
/// For f32 tensors with matching shapes, uses adaptive SIMD dispatch:
/// - Small tensors (<512): Scalar (SIMD overhead not worth it)
/// - Medium tensors (512-65K): Phase 7 direct SIMD
/// - Large tensors (>65K): Parallel SIMD
pub fn sub(&self, other: &Self) -> Result<Self> {
if self.shape() == other.shape() {
// 🚀 Phase 3/4: Use adaptive SIMD for f32 tensors
#[cfg(feature = "simd")]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
return self.sub_adaptive(other);
}
}
// Fallback to scalar for non-f32 types
self.element_wise_op(other, |a, b| a - b)
} else {
self.broadcast_binary_op(other, |a, b| a - b)
}
}
/// Element-wise multiplication with broadcasting (ops module implementation)
///
/// # SIMD Optimization (Phase 3/4)
/// For f32 tensors with matching shapes, uses adaptive SIMD dispatch:
/// - Small tensors (<512): Scalar (SIMD overhead not worth it)
/// - Medium tensors (512-65K): Phase 3 SIMD (uninit buffer + scirs2 API)
/// - Large tensors (>65K): Parallel SIMD (Rayon + SIMD chunks)
pub fn mul_op(&self, other: &Self) -> Result<Self> {
let mut result = if self.shape() == other.shape() {
// 🚀 Phase 3/4: Use adaptive SIMD for f32 tensors
#[cfg(feature = "simd")]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
// Use Phase 4 adaptive dispatch for f32
return {
let mut result = self.mul_adaptive(other)?;
// Track the operation for gradient computation
if self.requires_grad || other.requires_grad {
use std::sync::Arc;
result.requires_grad = true;
result.operation = crate::Operation::Mul {
lhs: Arc::new(self.clone()),
rhs: Arc::new(other.clone()),
};
}
Ok(result)
};
}
}
// Fallback to scalar for non-f32 types
self.element_wise_op(other, |a, b| a * b)?
} else {
self.broadcast_binary_op(other, |a, b| a * b)?
};
// Track the operation for gradient computation
if self.requires_grad || other.requires_grad {
use std::sync::Arc;
result.requires_grad = true;
result.operation = crate::Operation::Mul {
lhs: Arc::new(self.clone()),
rhs: Arc::new(other.clone()),
};
}
Ok(result)
}
/// Performs element-wise multiplication of two tensors with broadcasting support.
///
/// Multiplies corresponding elements from `self` and `other`. If the tensors have
/// different shapes, broadcasting rules are applied to make them compatible.
///
/// # Broadcasting Rules
/// - Dimensions are aligned from right to left
/// - Dimension of size 1 can broadcast to any size
/// - Missing dimensions are treated as size 1
///
/// Examples of valid broadcasts:
/// - `[3, 4]` * `[3, 4]` → `[3, 4]` (same shape)
/// - `[3, 4]` * `[4]` → `[3, 4]` (broadcast last dimension)
/// - `[3, 1]` * `[1, 4]` → `[3, 4]` (broadcast both)
/// - `[3, 4, 5]` * `[5]` → `[3, 4, 5]` (broadcast to batch)
///
/// # Performance
/// For matching shapes with f32 type, automatically uses adaptive SIMD optimization:
/// - Small tensors (<512 elements): Scalar operations
/// - Medium tensors (512-65K): SIMD vectorization (up to 14x speedup)
/// - Large tensors (>65K): Parallel SIMD (multi-threaded)
///
/// # Gradient Tracking
/// If either tensor has `requires_grad=true`, the operation is recorded
/// in the computational graph for automatic differentiation.
///
/// # Arguments
/// * `other` - The tensor to multiply with `self`
///
/// # Returns
/// A new tensor containing the element-wise product
///
/// # Errors
/// Returns error if the shapes are not compatible for broadcasting
///
/// # Examples
/// ```rust,no_run
/// # use torsh::Tensor;
/// # use torsh_core::device::DeviceType;
/// // Same shape multiplication
/// let a = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu)?;
/// let b = Tensor::from_data(vec![4.0, 5.0, 6.0], vec![3], DeviceType::Cpu)?;
/// let c = a.mul(&b)?;
/// assert_eq!(c.data()?, vec![4.0, 10.0, 18.0]);
///
/// // Broadcasting: [2,3] * [3] → [2,3]
/// let a = Tensor::from_data(
/// vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
/// vec![2, 3],
/// DeviceType::Cpu
/// )?;
/// let b = Tensor::from_data(vec![10.0, 20.0, 30.0], vec![3], DeviceType::Cpu)?;
/// let c = a.mul(&b)?; // Multiplies each row by [10,20,30]
/// assert_eq!(c.data()?, vec![10.0, 40.0, 90.0, 40.0, 100.0, 180.0]);
///
/// // Apply attention mask (element-wise gating)
/// let features = Tensor::randn(&[32, 128], DeviceType::Cpu)?;
/// let mask = Tensor::ones(&[32, 128], DeviceType::Cpu)?; // Binary mask
/// let masked_features = features.mul(&mask)?; // Zero out masked positions
///
/// // Feature scaling
/// let x = Tensor::randn(&[64, 256], DeviceType::Cpu)?;
/// let scale = Tensor::ones(&[256], DeviceType::Cpu)?; // Learnable scale
/// let scaled_x = x.mul(&scale)?; // Scale each feature
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
///
/// # PyTorch Compatibility
/// Equivalent to `torch.mul(a, b)` or `a * b`
///
/// Note: This is element-wise multiplication, not matrix multiplication.
/// For matrix multiplication, use [`Self::matmul`].
///
/// See also: [`Self::mul_scalar`], [`Self::mul_`], [`Self::matmul`], [`Self::div`]
pub fn mul(&self, other: &Self) -> Result<Self> {
self.mul_op(other)
}
/// Performs element-wise division of two tensors with broadcasting support.
///
/// Divides corresponding elements of `self` by `other`. If the tensors have
/// different shapes, broadcasting rules are applied to make them compatible.
///
/// # Broadcasting Rules
/// - Dimensions are aligned from right to left
/// - Dimension of size 1 can broadcast to any size
/// - Missing dimensions are treated as size 1
///
/// Examples of valid broadcasts:
/// - `[3, 4]` / `[3, 4]` → `[3, 4]` (same shape)
/// - `[3, 4]` / `[4]` → `[3, 4]` (broadcast last dimension)
/// - `[3, 1]` / `[1, 4]` → `[3, 4]` (broadcast both)
/// - `[3, 4, 5]` / `[5]` → `[3, 4, 5]` (broadcast to batch)
///
/// # Performance
/// For matching shapes with f32 type, automatically uses adaptive SIMD optimization:
/// - Small tensors (<512 elements): Scalar operations
/// - Medium tensors (512-65K): SIMD vectorization
/// - Large tensors (>65K): Parallel SIMD (multi-threaded)
///
/// # Division by Zero
/// Division by zero produces infinity (inf) or NaN according to IEEE 754 rules:
/// - Positive number / 0.0 → inf
/// - Negative number / 0.0 → -inf
/// - 0.0 / 0.0 → NaN
///
/// # Arguments
/// * `other` - The tensor to divide by (denominator)
///
/// # Returns
/// A new tensor containing the element-wise quotient
///
/// # Errors
/// Returns error if the shapes are not compatible for broadcasting
///
/// # Examples
/// ```rust,no_run
/// # use torsh::Tensor;
/// # use torsh_core::device::DeviceType;
/// // Same shape division
/// let a = Tensor::from_data(vec![10.0, 20.0, 30.0], vec![3], DeviceType::Cpu)?;
/// let b = Tensor::from_data(vec![2.0, 4.0, 5.0], vec![3], DeviceType::Cpu)?;
/// let c = a.div(&b)?;
/// assert_eq!(c.data()?, vec![5.0, 5.0, 6.0]);
///
/// // Broadcasting: [2,3] / [3] → [2,3]
/// let a = Tensor::from_data(
/// vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
/// vec![2, 3],
/// DeviceType::Cpu
/// )?;
/// let b = Tensor::from_data(vec![2.0, 5.0, 10.0], vec![3], DeviceType::Cpu)?;
/// let c = a.div(&b)?; // Divides each row by [2,5,10]
/// assert_eq!(c.data()?, vec![5.0, 4.0, 3.0, 20.0, 10.0, 6.0]);
///
/// // Normalize features (divide by standard deviation)
/// let x = Tensor::randn(&[32, 128], DeviceType::Cpu)?;
/// let std = Tensor::ones(&[128], DeviceType::Cpu)?; // Feature std deviations
/// let normalized = x.div(&std)?; // Normalize each feature
///
/// // Compute probabilities from logits
/// let logits = Tensor::randn(&[64, 10], DeviceType::Cpu)?;
/// let sum = logits.sum_dim(1, true)?; // Sum over classes
/// let probs = logits.div(&sum)?; // Normalize to probabilities
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
///
/// # PyTorch Compatibility
/// Equivalent to `torch.div(a, b)` or `a / b`
///
/// See also: [`Self::div_scalar`], [`Self::div_`], [`Self::mul`], [`Self::reciprocal`]
pub fn div(&self, other: &Self) -> Result<Self> {
if self.shape() == other.shape() {
// 🚀 Phase 3/4: Use adaptive SIMD for f32 tensors
#[cfg(feature = "simd")]
{
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
return self.div_adaptive(other);
}
}
// Fallback to scalar for non-f32 types
self.element_wise_op(other, |a, b| a / b)
} else {
self.broadcast_binary_op(other, |a, b| a / b)
}
}
/// Generic broadcasting binary operation with comprehensive error handling
pub fn broadcast_binary_op<F>(&self, other: &Self, op: F) -> Result<Self>
where
F: Fn(T, T) -> T + Send + Sync,
T: Copy + Default,
{
use crate::broadcast::{BroadcastOps, BroadcastShape};
// Validate the broadcast operation
BroadcastOps::validate_broadcast_operation(
self.shape().dims(),
other.shape().dims(),
"binary operation",
)?;
// If shapes are identical, use optimized path
if self.shape() == other.shape() {
return self.element_wise_op(other, op);
}
// Check if broadcasting is memory efficient
if !self.shape().is_broadcast_efficient(&other.shape()) {
eprintln!(
"Warning: Broadcasting shapes {:?} and {:?} may use significant memory",
self.shape().dims(),
other.shape().dims()
);
}
// Compute broadcasted shape using the new implementation
let broadcast_shape = self.shape().broadcast_shape(&other.shape())?;
let broadcast_dims = broadcast_shape.dims();
let broadcast_size = broadcast_shape.numel();
// Estimate memory requirements
let element_size = std::mem::size_of::<T>();
let _memory_required = BroadcastOps::estimate_broadcast_memory(
self.shape().dims(),
other.shape().dims(),
element_size,
)?;
let self_data = self.data()?;
let other_data = other.data()?;
let mut result_data = Vec::with_capacity(broadcast_size);
// Compute broadcasting for each element using optimized indexing
for flat_idx in 0..broadcast_size {
let broadcast_indices = BroadcastOps::flat_to_multi_index(flat_idx, broadcast_dims);
let self_idx = BroadcastOps::compute_broadcast_index(
&broadcast_indices,
self.shape().dims(),
broadcast_dims,
)?;
let other_idx = BroadcastOps::compute_broadcast_index(
&broadcast_indices,
other.shape().dims(),
broadcast_dims,
)?;
let self_val = self_data[self_idx];
let other_val = other_data[other_idx];
result_data.push(op(self_val, other_val));
}
Self::from_data(
result_data,
broadcast_dims.to_vec(),
self.device,
)
}
/// Element-wise power with scalar exponent
pub fn pow_scalar_f32(&self, exponent: f32) -> Result<Self>
where
T: FloatElement,
{
let data = self.data()?;
let result_data: Vec<T> = data
.iter()
.map(|&x| {
let exp_val = T::from_f64(exponent as f64).ok_or_else(|| {
TorshError::ConversionError(format!(
"Cannot convert {exponent} to element type"
))
})?;
Ok(x.powf(exp_val))
})
.collect::<Result<Vec<T>>>()?;
let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
result.requires_grad = self.requires_grad;
// Track the operation for gradient computation
if self.requires_grad {
use std::sync::Arc;
result.operation = crate::Operation::Power {
input: Arc::new(self.clone()),
exponent,
};
}
Ok(result)
}
/// Element-wise power with tensor exponent
pub fn pow_tensor(&self, exponent: &Self) -> Result<Self>
where
T: FloatElement,
{
// Check shapes are compatible for broadcasting
if self.shape() != exponent.shape() {
return Err(TorshError::BroadcastError {
shape1: self.shape().dims().to_vec(),
shape2: exponent.shape().dims().to_vec(),
});
}
let self_data = self.data()?;
let exp_data = exponent.data()?;
let result_data: Vec<T> = self_data
.iter()
.zip(exp_data.iter())
.map(|(&base, &exp)| base.powf(exp))
.collect();
let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
result.requires_grad = self.requires_grad || exponent.requires_grad;
Ok(result)
}
/// Negation (for float types) - legacy arithmetic implementation
pub fn neg_float(&self) -> Result<Self>
where
T: FloatElement,
{
let data = self.data()?;
let result_data: Vec<T> = data.iter().map(|&x| -x).collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
/// Add scalar
pub fn add_scalar(&self, scalar: T) -> Result<Self> {
let self_data = self.data()?;
let result_data: Vec<T> = self_data.iter().map(|&a| a + scalar).collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
/// Divide by scalar
pub fn div_scalar(&self, scalar: f32) -> Result<Self> {
let self_data = self.data()?;
let scalar_t = T::from_f64(scalar as f64).ok_or_else(|| {
TorshError::ConversionError(format!(
"Cannot convert scalar {} to target type",
scalar
))
})?;
let result_data: Vec<T> = self_data
.iter()
.map(|&a| a / scalar_t)
.collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
/// Power by scalar
pub fn pow_scalar(&self, exponent: f32) -> Result<Self>
where
T: FloatElement,
{
let data = self.data()?;
let exp_t = T::from_f64(exponent as f64).ok_or_else(|| {
TorshError::ConversionError(format!(
"Cannot convert exponent {} to target type",
exponent
))
})?;
let result_data: Vec<T> = data
.iter()
.map(|&x| x.powf(exp_t))
.collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
/// Clamp values between min and max (f32 version)
pub fn clamp_f32(&self, min: f32, max: f32) -> Result<Self>
where
T: PartialOrd,
{
let min_t = T::from_f64(min as f64).ok_or_else(|| {
TorshError::ConversionError(format!(
"Cannot convert min value {} to target type",
min
))
})?;
let max_t = T::from_f64(max as f64).ok_or_else(|| {
TorshError::ConversionError(format!(
"Cannot convert max value {} to target type",
max
))
})?;
let data = self.data()?;
let result_data: Vec<T> = data.iter().map(|&item| {
if item < min_t {
min_t
} else if item > max_t {
max_t
} else {
item
}
}).collect();
Self::from_data(
result_data,
self.shape().dims().to_vec(),
self.device,
)
}
/// Computes the dot product (inner product) of two 1D tensors.
///
/// For two vectors `a` and `b` of length `n`, computes the sum of element-wise products:
/// `dot(a, b) = a[0]*b[0] + a[1]*b[1] + ... + a[n-1]*b[n-1]`
///
/// This is a scalar-valued operation that measures the projection of one vector onto another.
///
/// # Requirements
/// - Both tensors must be 1-dimensional
/// - Both tensors must have the same length
///
/// # Performance
/// Uses breakthrough hyperoptimized SIMD implementation with adaptive selection:
/// - Small vectors (<512 elements): Standard scalar loop
/// - Medium vectors (512-65K): Cache-optimized SIMD (7-14x speedup)
/// - Large vectors (>65K): Parallel SIMD with software pipelining
///
/// Performance characteristics:
/// - Up to 14.17x speedup for medium arrays with TLB optimization
/// - 7.93x speedup for small arrays with cache-line aware processing
/// - 7.41x speedup for large arrays with software pipelining
/// - Automatically selects optimal strategy based on vector size
///
/// # Arguments
/// * `other` - The second 1D tensor to compute dot product with
///
/// # Returns
/// A scalar value (type `T`) representing the dot product
///
/// # Errors
/// - Returns error if either tensor is not 1-dimensional
/// - Returns error if the vectors have different lengths
///
/// # Examples
/// ```rust,no_run
/// # use torsh::Tensor;
/// # use torsh_core::device::DeviceType;
/// // Basic dot product
/// let a = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu)?;
/// let b = Tensor::from_data(vec![4.0, 5.0, 6.0], vec![3], DeviceType::Cpu)?;
/// let dot = a.dot_hyperoptimized(&b)?;
/// assert_eq!(dot, 32.0); // 1*4 + 2*5 + 3*6 = 32
///
/// // Cosine similarity computation
/// let v1 = Tensor::randn(&[128], DeviceType::Cpu)?;
/// let v2 = Tensor::randn(&[128], DeviceType::Cpu)?;
/// let dot = v1.dot_hyperoptimized(&v2)?;
/// let norm1 = v1.dot_hyperoptimized(&v1)?.sqrt();
/// let norm2 = v2.dot_hyperoptimized(&v2)?.sqrt();
/// let cosine_sim = dot / (norm1 * norm2);
///
/// // Neural network: compute attention scores
/// let query = Tensor::randn(&[512], DeviceType::Cpu)?;
/// let key = Tensor::randn(&[512], DeviceType::Cpu)?;
/// let attention_score = query.dot_hyperoptimized(&key)?;
///
/// // Compute vector norm (L2 norm)
/// let v = Tensor::from_data(vec![3.0, 4.0], vec![2], DeviceType::Cpu)?;
/// let norm = v.dot_hyperoptimized(&v)?.sqrt();
/// assert_eq!(norm, 5.0); // sqrt(3^2 + 4^2) = 5
/// # Ok::<(), Box<dyn std::error::Error>>(())
/// ```
///
/// # PyTorch Compatibility
/// Equivalent to `torch.dot(a, b)`
///
/// Note: For matrix-vector or matrix-matrix products, use [`Self::matmul`].
///
/// See also: [`Self::matmul`], [`Self::mul`], [`Self::outer`], [`Self::cross`]
pub fn dot_hyperoptimized(&self, other: &Self) -> Result<T>
where
T: FloatElement + Copy + std::iter::Sum,
{
// Ensure both tensors are 1D and have the same shape
if self.shape().dims().len() != 1 || other.shape().dims().len() != 1 {
return Err(TorshError::InvalidArgument(
"Dot product requires 1D tensors".to_string()
));
}
if self.shape() != other.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().to_vec(),
got: other.shape().to_vec(),
});
}
let self_data = self.data()?;
let other_data = other.data()?;
// Use hyperoptimized SIMD dot product for f32 tensors
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
#[cfg(feature = "simd")]
{
if should_use_simd(self.numel()) {
use scirs2_core::ndarray::ArrayView1;
use super::simd::simd_dot_f32;
// Cast to f32 for SIMD operations
let self_f32: &[f32] = unsafe {
std::slice::from_raw_parts(
self_data.as_ptr() as *const f32,
self_data.len(),
)
};
let other_f32: &[f32] = unsafe {
std::slice::from_raw_parts(
other_data.as_ptr() as *const f32,
other_data.len(),
)
};
let self_view = ArrayView1::from(self_f32);
let other_view = ArrayView1::from(other_f32);
// Use adaptive hyperoptimized SIMD dot product
let result_f32 = simd_dot_f32(&self_view, &other_view);
let result: T = unsafe { std::mem::transmute_copy::<f32, T>(&result_f32) };
return Ok(result);
}
}
}
// Fallback to standard dot product for non-f32 types or when SIMD is not beneficial
let result: T = self_data
.iter()
.zip(other_data.iter())
.map(|(&a, &b)| a * b)
.sum();
Ok(result)
}
}
// ✅ In-place operations for PyTorch compatibility
impl<
T: TensorElement
+ Copy
+ Default
+ std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>,
> Tensor<T>
{
/// In-place addition: self += other
///
/// # PyTorch Compatibility
/// Equivalent to PyTorch's `tensor.add_(other)`
///
/// # Errors
/// - Returns error if `requires_grad` is true (in-place ops break autograd)
/// - Returns error if shapes are incompatible
pub fn add_(&mut self, other: &Self) -> Result<&mut Self> {
if self.requires_grad {
return Err(TorshError::InvalidArgument(
"In-place operation on tensor that requires grad is not allowed".to_string(),
));
}
if self.shape() != other.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().to_vec(),
got: other.shape().to_vec(),
});
}
let other_data = other.data()?;
// Perform in-place addition
for i in 0..other_data.len() {
let current = self.storage.get(i)?;
self.storage.set(i, current + other_data[i])?;
}
Ok(self)
}
/// In-place subtraction: self -= other
///
/// # PyTorch Compatibility
/// Equivalent to PyTorch's `tensor.sub_(other)`
pub fn sub_(&mut self, other: &Self) -> Result<&mut Self> {
if self.requires_grad {
return Err(TorshError::InvalidArgument(
"In-place operation on tensor that requires grad is not allowed".to_string(),
));
}
if self.shape() != other.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().to_vec(),
got: other.shape().to_vec(),
});
}
let other_data = other.data()?;
for i in 0..other_data.len() {
let current = self.storage.get(i)?;
self.storage.set(i, current - other_data[i])?;
}
Ok(self)
}
/// In-place multiplication: self *= other
///
/// # PyTorch Compatibility
/// Equivalent to PyTorch's `tensor.mul_(other)`
pub fn mul_(&mut self, other: &Self) -> Result<&mut Self> {
if self.requires_grad {
return Err(TorshError::InvalidArgument(
"In-place operation on tensor that requires grad is not allowed".to_string(),
));
}
if self.shape() != other.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().to_vec(),
got: other.shape().to_vec(),
});
}
let other_data = other.data()?;
for i in 0..other_data.len() {
let current = self.storage.get(i)?;
self.storage.set(i, current * other_data[i])?;
}
Ok(self)
}
/// In-place division: self /= other
///
/// # PyTorch Compatibility
/// Equivalent to PyTorch's `tensor.div_(other)`
pub fn div_(&mut self, other: &Self) -> Result<&mut Self> {
if self.requires_grad {
return Err(TorshError::InvalidArgument(
"In-place operation on tensor that requires grad is not allowed".to_string(),
));
}
if self.shape() != other.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().to_vec(),
got: other.shape().to_vec(),
});
}
let other_data = other.data()?;
for i in 0..other_data.len() {
let current = self.storage.get(i)?;
self.storage.set(i, current / other_data[i])?;
}
Ok(self)
}
/// In-place scalar addition: self += scalar
///
/// # PyTorch Compatibility
/// Equivalent to PyTorch's `tensor.add_(scalar)`
pub fn add_scalar_(&mut self, scalar: T) -> Result<&mut Self> {
if self.requires_grad {
return Err(TorshError::InvalidArgument(
"In-place operation on tensor that requires grad is not allowed".to_string(),
));
}
let len = self.storage.len();
for i in 0..len {
let current = self.storage.get(i)?;
self.storage.set(i, current + scalar)?;
}
Ok(self)
}
/// In-place scalar multiplication: self *= scalar
///
/// # PyTorch Compatibility
/// Equivalent to PyTorch's `tensor.mul_(scalar)`
pub fn mul_scalar_(&mut self, scalar: T) -> Result<&mut Self> {
if self.requires_grad {
return Err(TorshError::InvalidArgument(
"In-place operation on tensor that requires grad is not allowed".to_string(),
));
}
let len = self.storage.len();
for i in 0..len {
let current = self.storage.get(i)?;
self.storage.set(i, current * scalar)?;
}
Ok(self)
}
/// In-place scalar division: self /= scalar
///
/// # PyTorch Compatibility
/// Equivalent to PyTorch's `tensor.div_(scalar)`
pub fn div_scalar_(&mut self, scalar: T) -> Result<&mut Self> {
if self.requires_grad {
return Err(TorshError::InvalidArgument(
"In-place operation on tensor that requires grad is not allowed".to_string(),
));
}
let len = self.storage.len();
for i in 0..len {
let current = self.storage.get(i)?;
self.storage.set(i, current / scalar)?;
}
Ok(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
#[test]
fn test_add_inplace() {
let mut a = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let b = Tensor::from_data(vec![4.0f32, 5.0, 6.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
a.add_(&b).expect("add_ failed");
let result = a.data().expect("data retrieval failed");
assert_eq!(result, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_mul_inplace() {
let mut a = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let b = Tensor::from_data(vec![2.0f32, 3.0, 4.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
a.mul_(&b).expect("mul_ failed");
let result = a.data().expect("data retrieval failed");
assert_eq!(result, vec![2.0, 6.0, 12.0]);
}
#[test]
fn test_sub_inplace() {
let mut a = Tensor::from_data(vec![5.0f32, 7.0, 9.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let b = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
a.sub_(&b).expect("sub_ failed");
let result = a.data().expect("data retrieval failed");
assert_eq!(result, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_div_inplace() {
let mut a = Tensor::from_data(vec![6.0f32, 12.0, 18.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let b = Tensor::from_data(vec![2.0f32, 3.0, 6.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
a.div_(&b).expect("div_ failed");
let result = a.data().expect("data retrieval failed");
assert_eq!(result, vec![3.0, 4.0, 3.0]);
}
#[test]
fn test_add_scalar_inplace() {
let mut tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
tensor.add_scalar_(10.0).expect("add_scalar_ failed");
let result = tensor.data().expect("data retrieval failed");
assert_eq!(result, vec![11.0, 12.0, 13.0]);
}
#[test]
fn test_mul_scalar_inplace() {
let mut tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
tensor.mul_scalar_(2.0).expect("mul_scalar_ failed");
let result = tensor.data().expect("data retrieval failed");
assert_eq!(result, vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_div_scalar_inplace() {
let mut tensor = Tensor::from_data(vec![10.0f32, 20.0, 30.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
tensor.div_scalar_(10.0).expect("div_scalar_ failed");
let result = tensor.data().expect("data retrieval failed");
assert_eq!(result, vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_inplace_method_chaining() {
let mut tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
let b = Tensor::from_data(vec![1.0f32, 1.0, 1.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
// Test method chaining
tensor.add_(&b).expect("add_ failed").mul_scalar_(2.0).expect("mul_scalar_ failed");
let result = tensor.data().expect("data retrieval failed");
assert_eq!(result, vec![4.0, 6.0, 8.0]); // (1+1)*2, (2+1)*2, (3+1)*2
}
#[test]
fn test_inplace_shape_mismatch_error() {
let mut a = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu).expect("tensor creation failed");
let b = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu).expect("tensor creation failed");
assert!(a.add_(&b).is_err());
assert!(a.mul_(&b).is_err());
}
}