1use num_complex::{Complex32, Complex64};
4
5use super::traits::Scalar;
6
7#[cfg(feature = "f16")]
8use half::f16;
9
10#[cfg(feature = "f128")]
11use super::extended::QuadFloat;
12
13pub trait HasFastFma: Scalar {}
23
24impl HasFastFma for f32 {}
25impl HasFastFma for f64 {}
26impl HasFastFma for Complex32 {}
27impl HasFastFma for Complex64 {}
28
29pub trait SimdCompatible: Scalar {
34 const SIMD_WIDTH: usize;
36
37 #[inline]
39 fn use_simd_for(len: usize) -> bool {
40 len >= Self::SIMD_WIDTH * 2
41 }
42}
43
44impl SimdCompatible for f32 {
45 #[cfg(target_arch = "x86_64")]
46 const SIMD_WIDTH: usize = 8; #[cfg(target_arch = "aarch64")]
49 const SIMD_WIDTH: usize = 4; #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
52 const SIMD_WIDTH: usize = 4;
53}
54
55impl SimdCompatible for f64 {
56 #[cfg(target_arch = "x86_64")]
57 const SIMD_WIDTH: usize = 4; #[cfg(target_arch = "aarch64")]
60 const SIMD_WIDTH: usize = 2; #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
63 const SIMD_WIDTH: usize = 2;
64}
65
66impl SimdCompatible for Complex32 {
67 #[cfg(target_arch = "x86_64")]
69 const SIMD_WIDTH: usize = 4;
70
71 #[cfg(target_arch = "aarch64")]
72 const SIMD_WIDTH: usize = 2;
73
74 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
75 const SIMD_WIDTH: usize = 2;
76}
77
78impl SimdCompatible for Complex64 {
79 #[cfg(target_arch = "x86_64")]
80 const SIMD_WIDTH: usize = 2;
81
82 #[cfg(target_arch = "aarch64")]
83 const SIMD_WIDTH: usize = 1;
84
85 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
86 const SIMD_WIDTH: usize = 1;
87}
88
89pub trait ScalarBatch: Scalar + SimdCompatible {
94 fn dot_batch(x: &[Self], y: &[Self]) -> Self;
99
100 fn sum_batch(x: &[Self]) -> Self;
102
103 fn asum_batch(x: &[Self]) -> Self::Real;
105
106 fn iamax_batch(x: &[Self]) -> usize;
108
109 fn scale_batch(alpha: Self, x: &mut [Self]);
111
112 fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]);
114
115 fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]);
117}
118
119impl ScalarBatch for f32 {
120 #[inline]
121 fn dot_batch(x: &[Self], y: &[Self]) -> Self {
122 debug_assert_eq!(x.len(), y.len());
123 let mut sum = 0.0f32;
124 for i in 0..x.len() {
125 sum = x[i].mul_add(y[i], sum);
126 }
127 sum
128 }
129
130 #[inline]
131 fn sum_batch(x: &[Self]) -> Self {
132 x.iter().copied().sum()
133 }
134
135 #[inline]
136 fn asum_batch(x: &[Self]) -> Self::Real {
137 x.iter().map(|&v| v.abs()).sum()
138 }
139
140 #[inline]
141 fn iamax_batch(x: &[Self]) -> usize {
142 x.iter()
143 .enumerate()
144 .max_by(|(_, a), (_, b)| {
145 a.abs()
146 .partial_cmp(&b.abs())
147 .unwrap_or(core::cmp::Ordering::Equal)
148 })
149 .map(|(i, _)| i)
150 .unwrap_or(0)
151 }
152
153 #[inline]
154 fn scale_batch(alpha: Self, x: &mut [Self]) {
155 for xi in x.iter_mut() {
156 *xi *= alpha;
157 }
158 }
159
160 #[inline]
161 fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
162 debug_assert_eq!(x.len(), y.len());
163 for i in 0..x.len() {
164 y[i] = alpha.mul_add(x[i], y[i]);
165 }
166 }
167
168 #[inline]
169 fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
170 debug_assert_eq!(a.len(), b.len());
171 debug_assert_eq!(a.len(), c.len());
172 debug_assert_eq!(a.len(), out.len());
173 for i in 0..a.len() {
174 out[i] = a[i].mul_add(b[i], c[i]);
175 }
176 }
177}
178
179impl ScalarBatch for f64 {
180 #[inline]
181 fn dot_batch(x: &[Self], y: &[Self]) -> Self {
182 debug_assert_eq!(x.len(), y.len());
183 let mut sum = 0.0f64;
184 for i in 0..x.len() {
185 sum = x[i].mul_add(y[i], sum);
186 }
187 sum
188 }
189
190 #[inline]
191 fn sum_batch(x: &[Self]) -> Self {
192 x.iter().copied().sum()
193 }
194
195 #[inline]
196 fn asum_batch(x: &[Self]) -> Self::Real {
197 x.iter().map(|&v| v.abs()).sum()
198 }
199
200 #[inline]
201 fn iamax_batch(x: &[Self]) -> usize {
202 x.iter()
203 .enumerate()
204 .max_by(|(_, a), (_, b)| {
205 a.abs()
206 .partial_cmp(&b.abs())
207 .unwrap_or(core::cmp::Ordering::Equal)
208 })
209 .map(|(i, _)| i)
210 .unwrap_or(0)
211 }
212
213 #[inline]
214 fn scale_batch(alpha: Self, x: &mut [Self]) {
215 for xi in x.iter_mut() {
216 *xi *= alpha;
217 }
218 }
219
220 #[inline]
221 fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
222 debug_assert_eq!(x.len(), y.len());
223 for i in 0..x.len() {
224 y[i] = alpha.mul_add(x[i], y[i]);
225 }
226 }
227
228 #[inline]
229 fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
230 debug_assert_eq!(a.len(), b.len());
231 debug_assert_eq!(a.len(), c.len());
232 debug_assert_eq!(a.len(), out.len());
233 for i in 0..a.len() {
234 out[i] = a[i].mul_add(b[i], c[i]);
235 }
236 }
237}
238
239impl ScalarBatch for Complex32 {
240 #[inline]
241 fn dot_batch(x: &[Self], y: &[Self]) -> Self {
242 debug_assert_eq!(x.len(), y.len());
243 let mut sum = Complex32::new(0.0, 0.0);
244 for i in 0..x.len() {
245 sum += x[i] * y[i];
246 }
247 sum
248 }
249
250 #[inline]
251 fn sum_batch(x: &[Self]) -> Self {
252 x.iter().copied().sum()
253 }
254
255 #[inline]
256 fn asum_batch(x: &[Self]) -> Self::Real {
257 x.iter().map(|z| z.re.abs() + z.im.abs()).sum()
258 }
259
260 #[inline]
261 fn iamax_batch(x: &[Self]) -> usize {
262 x.iter()
263 .enumerate()
264 .max_by(|(_, a), (_, b)| {
265 (a.re.abs() + a.im.abs())
266 .partial_cmp(&(b.re.abs() + b.im.abs()))
267 .unwrap_or(core::cmp::Ordering::Equal)
268 })
269 .map(|(i, _)| i)
270 .unwrap_or(0)
271 }
272
273 #[inline]
274 fn scale_batch(alpha: Self, x: &mut [Self]) {
275 for xi in x.iter_mut() {
276 *xi *= alpha;
277 }
278 }
279
280 #[inline]
281 fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
282 debug_assert_eq!(x.len(), y.len());
283 for i in 0..x.len() {
284 y[i] += alpha * x[i];
285 }
286 }
287
288 #[inline]
289 fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
290 debug_assert_eq!(a.len(), b.len());
291 debug_assert_eq!(a.len(), c.len());
292 debug_assert_eq!(a.len(), out.len());
293 for i in 0..a.len() {
294 out[i] = a[i] * b[i] + c[i];
295 }
296 }
297}
298
299impl ScalarBatch for Complex64 {
300 #[inline]
301 fn dot_batch(x: &[Self], y: &[Self]) -> Self {
302 debug_assert_eq!(x.len(), y.len());
303 let mut sum = Complex64::new(0.0, 0.0);
304 for i in 0..x.len() {
305 sum += x[i] * y[i];
306 }
307 sum
308 }
309
310 #[inline]
311 fn sum_batch(x: &[Self]) -> Self {
312 x.iter().copied().sum()
313 }
314
315 #[inline]
316 fn asum_batch(x: &[Self]) -> Self::Real {
317 x.iter().map(|z| z.re.abs() + z.im.abs()).sum()
318 }
319
320 #[inline]
321 fn iamax_batch(x: &[Self]) -> usize {
322 x.iter()
323 .enumerate()
324 .max_by(|(_, a), (_, b)| {
325 (a.re.abs() + a.im.abs())
326 .partial_cmp(&(b.re.abs() + b.im.abs()))
327 .unwrap_or(core::cmp::Ordering::Equal)
328 })
329 .map(|(i, _)| i)
330 .unwrap_or(0)
331 }
332
333 #[inline]
334 fn scale_batch(alpha: Self, x: &mut [Self]) {
335 for xi in x.iter_mut() {
336 *xi *= alpha;
337 }
338 }
339
340 #[inline]
341 fn axpy_batch(alpha: Self, x: &[Self], y: &mut [Self]) {
342 debug_assert_eq!(x.len(), y.len());
343 for i in 0..x.len() {
344 y[i] += alpha * x[i];
345 }
346 }
347
348 #[inline]
349 fn fma_batch(a: &[Self], b: &[Self], c: &[Self], out: &mut [Self]) {
350 debug_assert_eq!(a.len(), b.len());
351 debug_assert_eq!(a.len(), c.len());
352 debug_assert_eq!(a.len(), out.len());
353 for i in 0..a.len() {
354 out[i] = a[i] * b[i] + c[i];
355 }
356 }
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq)]
364pub enum ScalarClass {
365 RealF32,
367 RealF64,
369 ComplexF32,
371 ComplexF64,
373 RealF16,
375 RealF128,
377 Other,
379}
380
381pub trait ScalarClassify: Scalar {
383 const CLASS: ScalarClass;
385
386 const PRECISION_LEVEL: u8;
388
389 const STORAGE_BYTES: usize = core::mem::size_of::<Self>();
391}
392
393impl ScalarClassify for f32 {
394 const CLASS: ScalarClass = ScalarClass::RealF32;
395 const PRECISION_LEVEL: u8 = 2;
396}
397
398impl ScalarClassify for f64 {
399 const CLASS: ScalarClass = ScalarClass::RealF64;
400 const PRECISION_LEVEL: u8 = 3;
401}
402
403impl ScalarClassify for Complex32 {
404 const CLASS: ScalarClass = ScalarClass::ComplexF32;
405 const PRECISION_LEVEL: u8 = 2;
406}
407
408impl ScalarClassify for Complex64 {
409 const CLASS: ScalarClass = ScalarClass::ComplexF64;
410 const PRECISION_LEVEL: u8 = 3;
411}
412
413#[cfg(feature = "f16")]
414impl ScalarClassify for f16 {
415 const CLASS: ScalarClass = ScalarClass::RealF16;
416 const PRECISION_LEVEL: u8 = 1;
417}
418
419#[cfg(feature = "f128")]
420impl ScalarClassify for QuadFloat {
421 const CLASS: ScalarClass = ScalarClass::RealF128;
422 const PRECISION_LEVEL: u8 = 4;
423}
424
425pub trait UnrollHints: Scalar {
430 const UNROLL_FACTOR: usize;
432
433 const BLOCK_SIZE: usize;
435
436 const PREFER_STREAMING: bool;
438}
439
440impl UnrollHints for f32 {
441 const UNROLL_FACTOR: usize = 8;
442 const BLOCK_SIZE: usize = 64;
443 const PREFER_STREAMING: bool = true;
444}
445
446impl UnrollHints for f64 {
447 const UNROLL_FACTOR: usize = 4;
448 const BLOCK_SIZE: usize = 32;
449 const PREFER_STREAMING: bool = true;
450}
451
452impl UnrollHints for Complex32 {
453 const UNROLL_FACTOR: usize = 4;
454 const BLOCK_SIZE: usize = 32;
455 const PREFER_STREAMING: bool = true;
456}
457
458impl UnrollHints for Complex64 {
459 const UNROLL_FACTOR: usize = 2;
460 const BLOCK_SIZE: usize = 16;
461 const PREFER_STREAMING: bool = true;
462}
463
464pub trait ExtendedPrecision: Scalar {
469 type Accumulator: Scalar;
471
472 fn to_accumulator(self) -> Self::Accumulator;
474
475 fn from_accumulator(acc: Self::Accumulator) -> Self;
477}
478
479impl ExtendedPrecision for f32 {
480 type Accumulator = f64;
481
482 #[inline]
483 fn to_accumulator(self) -> f64 {
484 self as f64
485 }
486
487 #[inline]
488 fn from_accumulator(acc: f64) -> f32 {
489 acc as f32
490 }
491}
492
493impl ExtendedPrecision for f64 {
494 type Accumulator = f64;
496
497 #[inline]
498 fn to_accumulator(self) -> f64 {
499 self
500 }
501
502 #[inline]
503 fn from_accumulator(acc: f64) -> f64 {
504 acc
505 }
506}
507
508impl ExtendedPrecision for Complex32 {
509 type Accumulator = Complex64;
510
511 #[inline]
512 fn to_accumulator(self) -> Complex64 {
513 Complex64::new(self.re as f64, self.im as f64)
514 }
515
516 #[inline]
517 fn from_accumulator(acc: Complex64) -> Complex32 {
518 Complex32::new(acc.re as f32, acc.im as f32)
519 }
520}
521
522impl ExtendedPrecision for Complex64 {
523 type Accumulator = Complex64;
524
525 #[inline]
526 fn to_accumulator(self) -> Complex64 {
527 self
528 }
529
530 #[inline]
531 fn from_accumulator(acc: Complex64) -> Complex64 {
532 acc
533 }
534}
535
536#[derive(Debug, Clone, Copy)]
544pub struct KahanSum<T: Scalar> {
545 sum: T,
546 compensation: T,
547}
548
549impl<T: Scalar> Default for KahanSum<T> {
550 fn default() -> Self {
551 Self::new()
552 }
553}
554
555impl<T: Scalar> KahanSum<T> {
556 #[inline]
558 pub fn new() -> Self {
559 Self {
560 sum: T::zero(),
561 compensation: T::zero(),
562 }
563 }
564
565 #[inline]
567 pub fn add(&mut self, value: T) {
568 let y = value - self.compensation;
569 let t = self.sum + y;
570 self.compensation = (t - self.sum) - y;
571 self.sum = t;
572 }
573
574 #[inline]
576 pub fn sum(self) -> T {
577 self.sum
578 }
579}
580
581#[inline]
585pub fn pairwise_sum<T: Scalar>(values: &[T]) -> T {
586 const THRESHOLD: usize = 32;
587
588 if values.is_empty() {
589 return T::zero();
590 }
591 if values.len() <= THRESHOLD {
592 return values.iter().copied().fold(T::zero(), |acc, x| acc + x);
593 }
594
595 let mid = values.len() / 2;
596 pairwise_sum(&values[..mid]) + pairwise_sum(&values[mid..])
597}
598
599#[derive(Debug, Clone, Copy)]
603pub struct KBKSum<T: Scalar> {
604 sum: T,
605 cs: T,
606 ccs: T,
607}
608
609impl<T: Scalar> Default for KBKSum<T> {
610 fn default() -> Self {
611 Self::new()
612 }
613}
614
615impl<T: Scalar> KBKSum<T> {
616 #[inline]
618 pub fn new() -> Self {
619 Self {
620 sum: T::zero(),
621 cs: T::zero(),
622 ccs: T::zero(),
623 }
624 }
625
626 #[inline]
628 pub fn add(&mut self, value: T) {
629 let t = self.sum + value;
630 let c = if Scalar::abs(self.sum) >= Scalar::abs(value) {
631 (self.sum - t) + value
632 } else {
633 (value - t) + self.sum
634 };
635 self.sum = t;
636
637 let t2 = self.cs + c;
638 let cc = if Scalar::abs(self.cs) >= Scalar::abs(c) {
639 (self.cs - t2) + c
640 } else {
641 (c - t2) + self.cs
642 };
643 self.cs = t2;
644 self.ccs += cc;
645 }
646
647 #[inline]
649 pub fn sum(self) -> T {
650 self.sum + self.cs + self.ccs
651 }
652}