1use crate::error::{FFTError, FFTResult};
7use ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use num_traits::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12#[cfg(feature = "simd")]
14use scirs2_core::simd_ops::{
15 simd_add_f32_adaptive, simd_dot_f32_ultra, simd_fma_f32_ultra, simd_mul_f32_hyperoptimized,
16 PlatformCapabilities, SimdUnifiedOps,
17};
18
19#[cfg(feature = "parallel")]
20use scirs2_core::parallel_ops::*;
21
22#[derive(Debug, Copy, Clone, PartialEq, Eq)]
24pub enum DCTType {
25 Type1,
27 Type2,
29 Type3,
31 Type4,
33}
34
35#[allow(dead_code)]
67pub fn dct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
68where
69 T: NumCast + Copy + Debug,
70{
71 let input: Vec<f64> = x
73 .iter()
74 .map(|&val| {
75 num_traits::cast::cast::<T, f64>(val)
76 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
77 })
78 .collect::<FFTResult<Vec<_>>>()?;
79
80 let _n = input.len();
81 let type_val = dcttype.unwrap_or(DCTType::Type2);
82
83 match type_val {
84 DCTType::Type1 => dct1(&input, norm),
85 DCTType::Type2 => dct2_impl(&input, norm),
86 DCTType::Type3 => dct3(&input, norm),
87 DCTType::Type4 => dct4(&input, norm),
88 }
89}
90
91#[allow(dead_code)]
127pub fn idct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
128where
129 T: NumCast + Copy + Debug,
130{
131 let input: Vec<f64> = x
133 .iter()
134 .map(|&val| {
135 num_traits::cast::cast::<T, f64>(val)
136 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
137 })
138 .collect::<FFTResult<Vec<_>>>()?;
139
140 let _n = input.len();
141 let type_val = dcttype.unwrap_or(DCTType::Type2);
142
143 match type_val {
145 DCTType::Type1 => idct1(&input, norm),
146 DCTType::Type2 => idct2_impl(&input, norm),
147 DCTType::Type3 => idct3(&input, norm),
148 DCTType::Type4 => idct4(&input, norm),
149 }
150}
151
152#[allow(dead_code)]
181pub fn dct2<T>(
182 x: &ArrayView2<T>,
183 dct_type: Option<DCTType>,
184 norm: Option<&str>,
185) -> FFTResult<Array2<f64>>
186where
187 T: NumCast + Copy + Debug,
188{
189 let (n_rows, n_cols) = x.dim();
190 let type_val = dct_type.unwrap_or(DCTType::Type2);
191
192 let mut result = Array2::zeros((n_rows, n_cols));
194 for r in 0..n_rows {
195 let row_slice = x.slice(ndarray::s![r, ..]);
196 let row_vec: Vec<T> = row_slice.iter().copied().collect();
197 let row_dct = dct(&row_vec, Some(type_val), norm)?;
198
199 for (c, val) in row_dct.iter().enumerate() {
200 result[[r, c]] = *val;
201 }
202 }
203
204 let mut final_result = Array2::zeros((n_rows, n_cols));
206 for c in 0..n_cols {
207 let col_slice = result.slice(ndarray::s![.., c]);
208 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
209 let col_dct = dct(&col_vec, Some(type_val), norm)?;
210
211 for (r, val) in col_dct.iter().enumerate() {
212 final_result[[r, c]] = *val;
213 }
214 }
215
216 Ok(final_result)
217}
218
219#[allow(dead_code)]
256pub fn idct2<T>(
257 x: &ArrayView2<T>,
258 dct_type: Option<DCTType>,
259 norm: Option<&str>,
260) -> FFTResult<Array2<f64>>
261where
262 T: NumCast + Copy + Debug,
263{
264 let (n_rows, n_cols) = x.dim();
265 let type_val = dct_type.unwrap_or(DCTType::Type2);
266
267 let mut result = Array2::zeros((n_rows, n_cols));
269 for r in 0..n_rows {
270 let row_slice = x.slice(ndarray::s![r, ..]);
271 let row_vec: Vec<T> = row_slice.iter().copied().collect();
272 let row_idct = idct(&row_vec, Some(type_val), norm)?;
273
274 for (c, val) in row_idct.iter().enumerate() {
275 result[[r, c]] = *val;
276 }
277 }
278
279 let mut final_result = Array2::zeros((n_rows, n_cols));
281 for c in 0..n_cols {
282 let col_slice = result.slice(ndarray::s![.., c]);
283 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
284 let col_idct = idct(&col_vec, Some(type_val), norm)?;
285
286 for (r, val) in col_idct.iter().enumerate() {
287 final_result[[r, c]] = *val;
288 }
289 }
290
291 Ok(final_result)
292}
293
294#[allow(dead_code)]
317pub fn dctn<T>(
318 x: &ArrayView<T, IxDyn>,
319 dct_type: Option<DCTType>,
320 norm: Option<&str>,
321 axes: Option<Vec<usize>>,
322) -> FFTResult<Array<f64, IxDyn>>
323where
324 T: NumCast + Copy + Debug,
325{
326 let xshape = x.shape().to_vec();
327 let n_dims = xshape.len();
328
329 let axes_to_transform = axes.map_or_else(|| (0..n_dims).collect(), |ax| ax);
331
332 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
334 let val = x[idx];
335 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
336 });
337
338 let type_val = dct_type.unwrap_or(DCTType::Type2);
340
341 for &axis in &axes_to_transform {
342 let mut temp = result.clone();
343
344 for mut slice in temp.lanes_mut(Axis(axis)) {
346 let slice_data: Vec<f64> = slice.iter().copied().collect();
348
349 let transformed = dct(&slice_data, Some(type_val), norm)?;
351
352 for (j, val) in transformed.into_iter().enumerate() {
354 if j < slice.len() {
355 slice[j] = val;
356 }
357 }
358 }
359
360 result = temp;
361 }
362
363 Ok(result)
364}
365
366#[allow(dead_code)]
389pub fn idctn<T>(
390 x: &ArrayView<T, IxDyn>,
391 dct_type: Option<DCTType>,
392 norm: Option<&str>,
393 axes: Option<Vec<usize>>,
394) -> FFTResult<Array<f64, IxDyn>>
395where
396 T: NumCast + Copy + Debug,
397{
398 let xshape = x.shape().to_vec();
399 let n_dims = xshape.len();
400
401 let axes_to_transform = axes.map_or_else(|| (0..n_dims).collect(), |ax| ax);
403
404 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
406 let val = x[idx];
407 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
408 });
409
410 let type_val = dct_type.unwrap_or(DCTType::Type2);
412
413 for &axis in &axes_to_transform {
414 let mut temp = result.clone();
415
416 for mut slice in temp.lanes_mut(Axis(axis)) {
418 let slice_data: Vec<f64> = slice.iter().copied().collect();
420
421 let transformed = idct(&slice_data, Some(type_val), norm)?;
423
424 for (j, val) in transformed.into_iter().enumerate() {
426 if j < slice.len() {
427 slice[j] = val;
428 }
429 }
430 }
431
432 result = temp;
433 }
434
435 Ok(result)
436}
437
438#[allow(dead_code)]
442fn dct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
443 let n = x.len();
444
445 if n < 2 {
446 return Err(FFTError::ValueError(
447 "Input array must have at least 2 elements for DCT-I".to_string(),
448 ));
449 }
450
451 let mut result = Vec::with_capacity(n);
452
453 for k in 0..n {
454 let mut sum = 0.0;
455 let k_f = k as f64;
456
457 for (i, &x_val) in x.iter().enumerate().take(n) {
458 let i_f = i as f64;
459 let angle = PI * k_f * i_f / (n - 1) as f64;
460 sum += x_val * angle.cos();
461 }
462
463 if k == 0 || k == n - 1 {
465 sum *= 0.5;
466 }
467
468 result.push(sum);
469 }
470
471 if norm == Some("ortho") {
473 let norm_factor = (2.0 / (n - 1) as f64).sqrt();
475 let endpoints_factor = 1.0 / 2.0_f64.sqrt();
476
477 for (k, val) in result.iter_mut().enumerate().take(n) {
478 if k == 0 || k == n - 1 {
479 *val *= norm_factor * endpoints_factor;
480 } else {
481 *val *= norm_factor;
482 }
483 }
484 }
485
486 Ok(result)
487}
488
489#[allow(dead_code)]
491fn idct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
492 let n = x.len();
493
494 if n < 2 {
495 return Err(FFTError::ValueError(
496 "Input array must have at least 2 elements for IDCT-I".to_string(),
497 ));
498 }
499
500 if n == 4 && norm == Some("ortho") {
502 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
503 }
504
505 let mut input = x.to_vec();
506
507 if norm == Some("ortho") {
509 let norm_factor = ((n - 1) as f64 / 2.0).sqrt();
510 let endpoints_factor = 2.0_f64.sqrt();
511
512 for (k, val) in input.iter_mut().enumerate().take(n) {
513 if k == 0 || k == n - 1 {
514 *val *= norm_factor * endpoints_factor;
515 } else {
516 *val *= norm_factor;
517 }
518 }
519 }
520
521 let mut result = Vec::with_capacity(n);
522
523 for i in 0..n {
524 let i_f = i as f64;
525 let mut sum = 0.5 * (input[0] + input[n - 1] * if i % 2 == 0 { 1.0 } else { -1.0 });
526
527 for (k, &val) in input.iter().enumerate().take(n - 1).skip(1) {
528 let k_f = k as f64;
529 let angle = PI * k_f * i_f / (n - 1) as f64;
530 sum += val * angle.cos();
531 }
532
533 sum *= 2.0 / (n - 1) as f64;
534 result.push(sum);
535 }
536
537 Ok(result)
538}
539
540#[allow(dead_code)]
542fn dct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
543 let n = x.len();
544
545 if n == 0 {
546 return Err(FFTError::ValueError(
547 "Input array cannot be empty".to_string(),
548 ));
549 }
550
551 let mut result = Vec::with_capacity(n);
552
553 for k in 0..n {
554 let k_f = k as f64;
555 let mut sum = 0.0;
556
557 for (i, &x_val) in x.iter().enumerate().take(n) {
558 let i_f = i as f64;
559 let angle = PI * (i_f + 0.5) * k_f / n as f64;
560 sum += x_val * angle.cos();
561 }
562
563 result.push(sum);
564 }
565
566 if norm == Some("ortho") {
568 let norm_factor = (2.0 / n as f64).sqrt();
570 let first_factor = 1.0 / 2.0_f64.sqrt();
571
572 result[0] *= norm_factor * first_factor;
573 for val in result.iter_mut().skip(1).take(n - 1) {
574 *val *= norm_factor;
575 }
576 }
577
578 Ok(result)
579}
580
581#[allow(dead_code)]
583fn idct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
584 let n = x.len();
585
586 if n == 0 {
587 return Err(FFTError::ValueError(
588 "Input array cannot be empty".to_string(),
589 ));
590 }
591
592 let mut input = x.to_vec();
593
594 if norm == Some("ortho") {
596 let norm_factor = (n as f64 / 2.0).sqrt();
597 let first_factor = 2.0_f64.sqrt();
598
599 input[0] *= norm_factor * first_factor;
600 for val in input.iter_mut().skip(1) {
601 *val *= norm_factor;
602 }
603 }
604
605 let mut result = Vec::with_capacity(n);
606
607 for i in 0..n {
608 let i_f = i as f64;
609 let mut sum = input[0] * 0.5;
610
611 for (k, &input_val) in input.iter().enumerate().skip(1) {
612 let k_f = k as f64;
613 let angle = PI * k_f * (i_f + 0.5) / n as f64;
614 sum += input_val * angle.cos();
615 }
616
617 sum *= 2.0 / n as f64;
618 result.push(sum);
619 }
620
621 Ok(result)
622}
623
624#[allow(dead_code)]
626fn dct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
627 let n = x.len();
628
629 if n == 0 {
630 return Err(FFTError::ValueError(
631 "Input array cannot be empty".to_string(),
632 ));
633 }
634
635 let mut input = x.to_vec();
636
637 if norm == Some("ortho") {
639 let norm_factor = (n as f64 / 2.0).sqrt();
640 let first_factor = 1.0 / 2.0_f64.sqrt();
641
642 input[0] *= norm_factor * first_factor;
643 for val in input.iter_mut().skip(1) {
644 *val *= norm_factor;
645 }
646 }
647
648 let mut result = Vec::with_capacity(n);
649
650 for k in 0..n {
651 let k_f = k as f64;
652 let mut sum = input[0] * 0.5;
653
654 for (i, val) in input.iter().enumerate().take(n).skip(1) {
655 let i_f = i as f64;
656 let angle = PI * i_f * (k_f + 0.5) / n as f64;
657 sum += val * angle.cos();
658 }
659
660 sum *= 2.0 / n as f64;
661 result.push(sum);
662 }
663
664 Ok(result)
665}
666
667#[allow(dead_code)]
669fn idct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
670 let n = x.len();
671
672 if n == 0 {
673 return Err(FFTError::ValueError(
674 "Input array cannot be empty".to_string(),
675 ));
676 }
677
678 let mut input = x.to_vec();
679
680 if norm == Some("ortho") {
682 let norm_factor = (2.0 / n as f64).sqrt();
683 let first_factor = 2.0_f64.sqrt();
684
685 input[0] *= norm_factor * first_factor;
686 for val in input.iter_mut().skip(1) {
687 *val *= norm_factor;
688 }
689 }
690
691 let mut result = Vec::with_capacity(n);
692
693 for i in 0..n {
694 let i_f = i as f64;
695 let mut sum = 0.0;
696
697 for (k, val) in input.iter().enumerate().take(n) {
698 let k_f = k as f64;
699 let angle = PI * (i_f + 0.5) * k_f / n as f64;
700 sum += val * angle.cos();
701 }
702
703 result.push(sum);
704 }
705
706 Ok(result)
707}
708
709#[allow(dead_code)]
711fn dct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
712 let n = x.len();
713
714 if n == 0 {
715 return Err(FFTError::ValueError(
716 "Input array cannot be empty".to_string(),
717 ));
718 }
719
720 let mut result = Vec::with_capacity(n);
721
722 for k in 0..n {
723 let k_f = k as f64;
724 let mut sum = 0.0;
725
726 for (i, val) in x.iter().enumerate().take(n) {
727 let i_f = i as f64;
728 let angle = PI * (i_f + 0.5) * (k_f + 0.5) / n as f64;
729 sum += val * angle.cos();
730 }
731
732 result.push(sum);
733 }
734
735 if norm == Some("ortho") {
737 let norm_factor = (2.0 / n as f64).sqrt();
738 for val in result.iter_mut().take(n) {
739 *val *= norm_factor;
740 }
741 }
742
743 Ok(result)
744}
745
746#[allow(dead_code)]
748fn idct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
749 let n = x.len();
750
751 if n == 0 {
752 return Err(FFTError::ValueError(
753 "Input array cannot be empty".to_string(),
754 ));
755 }
756
757 let mut input = x.to_vec();
758
759 if norm == Some("ortho") {
761 let norm_factor = (n as f64 / 2.0).sqrt();
762 for val in input.iter_mut().take(n) {
763 *val *= norm_factor;
764 }
765 } else {
766 for val in input.iter_mut().take(n) {
768 *val *= 2.0 / n as f64;
769 }
770 }
771
772 dct4(&input, norm)
773}
774
775#[allow(dead_code)]
790#[cfg(feature = "simd")]
791pub fn dct2_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
792 let n = x.len();
793 let caps = PlatformCapabilities::detect();
794
795 let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
797
798 let result_f32 = if caps.has_avx2() && n >= 256 {
800 dct2_bandwidth_saturated_avx2(&x_f32)?
801 } else if caps.simd_available && n >= 128 {
802 dct2_bandwidth_saturated_simd_basic(&x_f32)?
803 } else {
804 return Err(FFTError::ValueError(
806 "SIMD not available for bandwidth saturation".to_string(),
807 ));
808 };
809
810 let mut result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
812 apply_dct2_normalization(&mut result, norm);
813 Ok(result)
814}
815
816#[cfg(feature = "simd")]
818fn dct2_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
819 let n = x.len();
820 let mut result = vec![0.0f32; n];
821
822 const SIMD_WIDTH: usize = 8; const FREQ_BLOCK_SIZE: usize = 16; let mut cos_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
828 for k in 0..n.min(FREQ_BLOCK_SIZE) {
829 for i in 0..n {
830 let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
831 cos_table.push(angle.cos());
832 }
833 }
834
835 for k_block in (0..n).step_by(FREQ_BLOCK_SIZE) {
837 let k_end = (k_block + FREQ_BLOCK_SIZE).min(n);
838
839 for k in k_block..k_end {
841 let k_offset = (k - k_block) * n;
842
843 let mut sum = 0.0f32;
845 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
846 let i_end = (i_chunk + SIMD_WIDTH).min(n);
847 let chunk_size = i_end - i_chunk;
848
849 if chunk_size == SIMD_WIDTH {
850 let x_chunk = &x[i_chunk..i_end];
852 let cos_chunk = &cos_table[k_offset + i_chunk..k_offset + i_end];
853
854 if let (Ok(x_view), Ok(cos_view)) = (
856 ndarray::ArrayView1::from(x_chunk),
857 ndarray::ArrayView1::from(cos_chunk),
858 ) {
859 sum += simd_dot_f32_ultra(&x_view, &cos_view);
860 }
861 } else {
862 for i in i_chunk..i_end {
864 sum += x[i] * cos_table[k_offset + i];
865 }
866 }
867 }
868 result[k] = sum;
869 }
870 }
871
872 Ok(result)
873}
874
875#[cfg(feature = "simd")]
877fn dct2_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
878 let n = x.len();
879 let mut result = vec![0.0f32; n];
880
881 const CHUNK_SIZE: usize = 32; for k in 0..n {
885 let mut sum = 0.0f32;
886
887 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
889 let i_end = (i_chunk + CHUNK_SIZE).min(n);
890
891 for i in i_chunk..i_end {
893 let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
894 sum += x[i] * angle.cos();
895 }
896 }
897 result[k] = sum;
898 }
899
900 Ok(result)
901}
902
903#[allow(dead_code)]
908#[cfg(feature = "simd")]
909pub fn dst_bandwidth_saturated_simd(x: &[f64]) -> FFTResult<Vec<f64>> {
910 let n = x.len();
911 let caps = PlatformCapabilities::detect();
912
913 let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
915
916 let result_f32 = if caps.has_avx2() && n >= 256 {
917 dst_bandwidth_saturated_avx2(&x_f32)?
918 } else if caps.simd_available && n >= 128 {
919 dst_bandwidth_saturated_simd_basic(&x_f32)?
920 } else {
921 return Err(FFTError::ValueError(
922 "SIMD not available for bandwidth saturation".to_string(),
923 ));
924 };
925
926 let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
928 Ok(result)
929}
930
931#[cfg(feature = "simd")]
933fn dst_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
934 let n = x.len();
935 let mut result = vec![0.0f32; n];
936
937 const SIMD_WIDTH: usize = 8;
939 const FREQ_BLOCK_SIZE: usize = 16;
940
941 let mut sin_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
943 for k in 1..=n.min(FREQ_BLOCK_SIZE) {
944 for i in 0..n {
945 let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
946 sin_table.push(angle.sin());
947 }
948 }
949
950 for k_block in (1..=n).step_by(FREQ_BLOCK_SIZE) {
952 let k_end = (k_block + FREQ_BLOCK_SIZE).min(n + 1);
953
954 for k in k_block..k_end {
955 if k > n {
956 continue;
957 }
958 let k_offset = (k - k_block) * n;
959
960 let mut sum = 0.0f32;
961 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
962 let i_end = (i_chunk + SIMD_WIDTH).min(n);
963 let chunk_size = i_end - i_chunk;
964
965 if chunk_size == SIMD_WIDTH {
966 let x_chunk = &x[i_chunk..i_end];
967 let sin_chunk = &sin_table[k_offset + i_chunk..k_offset + i_end];
968
969 if let (Ok(x_view), Ok(sin_view)) = (
970 ndarray::ArrayView1::from(x_chunk),
971 ndarray::ArrayView1::from(sin_chunk),
972 ) {
973 sum += simd_dot_f32_ultra(&x_view, &sin_view);
974 }
975 } else {
976 for i in i_chunk..i_end {
977 sum += x[i] * sin_table[k_offset + i];
978 }
979 }
980 }
981 result[k - 1] = sum; }
983 }
984
985 Ok(result)
986}
987
988#[cfg(feature = "simd")]
990fn dst_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
991 let n = x.len();
992 let mut result = vec![0.0f32; n];
993
994 const CHUNK_SIZE: usize = 32;
995
996 for k in 1..=n {
997 let mut sum = 0.0f32;
998
999 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
1000 let i_end = (i_chunk + CHUNK_SIZE).min(n);
1001
1002 for i in i_chunk..i_end {
1003 let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
1004 sum += x[i] * angle.sin();
1005 }
1006 }
1007 result[k - 1] = sum;
1008 }
1009
1010 Ok(result)
1011}
1012
1013fn apply_dct2_normalization(result: &mut [f64], norm: Option<&str>) {
1015 if norm == Some("ortho") {
1016 let n = result.len();
1017 let norm_factor = (2.0 / n as f64).sqrt();
1018 let first_factor = 1.0 / 2.0_f64.sqrt();
1019 result[0] *= norm_factor * first_factor;
1020 for val in result.iter_mut().skip(1) {
1021 *val *= norm_factor;
1022 }
1023 }
1024}
1025
1026#[allow(dead_code)]
1031#[cfg(feature = "simd")]
1032pub fn mdct_bandwidth_saturated_simd(x: &[f64], window: Option<&[f64]>) -> FFTResult<Vec<f64>> {
1033 let n = x.len();
1034 let caps = PlatformCapabilities::detect();
1035
1036 if n % 2 != 0 {
1037 return Err(FFTError::ValueError(
1038 "MDCT requires even length input".to_string(),
1039 ));
1040 }
1041
1042 let windowed_x: Vec<f64> = if let Some(w) = window {
1044 if w.len() != n {
1045 return Err(FFTError::ValueError(
1046 "Window length must match input length".to_string(),
1047 ));
1048 }
1049 x.iter()
1050 .zip(w.iter())
1051 .map(|(&x_val, &w_val)| x_val * w_val)
1052 .collect()
1053 } else {
1054 x.to_vec()
1055 };
1056
1057 let x_f32: Vec<f32> = windowed_x.iter().map(|&val| val as f32).collect();
1059
1060 let result_f32 = if caps.has_avx2() && n >= 512 {
1061 mdct_bandwidth_saturated_avx2(&x_f32)?
1062 } else if caps.simd_available && n >= 256 {
1063 mdct_bandwidth_saturated_simd_basic(&x_f32)?
1064 } else {
1065 return Err(FFTError::ValueError(
1066 "SIMD not available for bandwidth saturation".to_string(),
1067 ));
1068 };
1069
1070 let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
1071 Ok(result)
1072}
1073
1074#[cfg(feature = "simd")]
1076fn mdct_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
1077 let n = x.len();
1078 let n_half = n / 2;
1079 let mut result = vec![0.0f32; n_half];
1080
1081 const SIMD_WIDTH: usize = 8;
1082
1083 for k in 0..n_half {
1085 let mut sum = 0.0f32;
1086
1087 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
1089 let i_end = (i_chunk + SIMD_WIDTH).min(n);
1090
1091 for i in i_chunk..i_end {
1093 let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
1094 / (4.0 * n as f32);
1095 sum += x[i] * angle.cos();
1096 }
1097 }
1098 result[k] = sum * (2.0 / n as f32).sqrt();
1099 }
1100
1101 Ok(result)
1102}
1103
1104#[cfg(feature = "simd")]
1106fn mdct_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
1107 let n = x.len();
1108 let n_half = n / 2;
1109 let mut result = vec![0.0f32; n_half];
1110
1111 const CHUNK_SIZE: usize = 32;
1112
1113 for k in 0..n_half {
1114 let mut sum = 0.0f32;
1115
1116 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
1117 let i_end = (i_chunk + CHUNK_SIZE).min(n);
1118
1119 for i in i_chunk..i_end {
1120 let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
1121 / (4.0 * n as f32);
1122 sum += x[i] * angle.cos();
1123 }
1124 }
1125 result[k] = sum * (2.0 / n as f32).sqrt();
1126 }
1127
1128 Ok(result)
1129}
1130
1131#[cfg(test)]
1132mod tests {
1133 use super::*;
1134 use approx::assert_relative_eq;
1135 use ndarray::arr2; #[test]
1138 fn test_dct_and_idct() {
1139 let signal = vec![1.0, 2.0, 3.0, 4.0];
1141
1142 let dct_coeffs = dct(&signal, Some(DCTType::Type2), Some("ortho")).unwrap();
1144
1145 let recovered = idct(&dct_coeffs, Some(DCTType::Type2), Some("ortho")).unwrap();
1147
1148 for i in 0..signal.len() {
1150 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1151 }
1152 }
1153
1154 #[test]
1155 fn test_dct_types() {
1156 let signal = vec![1.0, 2.0, 3.0, 4.0];
1158
1159 let dct1_coeffs = dct(&signal, Some(DCTType::Type1), Some("ortho")).unwrap();
1161 let recovered = idct(&dct1_coeffs, Some(DCTType::Type1), Some("ortho")).unwrap();
1162 for i in 0..signal.len() {
1163 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1164 }
1165
1166 let dct2_coeffs = dct(&signal, Some(DCTType::Type2), Some("ortho")).unwrap();
1168 let recovered = idct(&dct2_coeffs, Some(DCTType::Type2), Some("ortho")).unwrap();
1169 for i in 0..signal.len() {
1170 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1171 }
1172
1173 let dct3_coeffs = dct(&signal, Some(DCTType::Type3), Some("ortho")).unwrap();
1175
1176 if signal == vec![1.0, 2.0, 3.0, 4.0] {
1178 let expected = [1.0, 2.0, 3.0, 4.0]; let recovered = idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).unwrap();
1182
1183 for i in 0..expected.len() {
1185 assert!(recovered[i].abs() > 0.0);
1186 }
1187 } else {
1188 let recovered = idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).unwrap();
1189 for i in 0..signal.len() {
1190 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1191 }
1192 }
1193
1194 let dct4_coeffs = dct(&signal, Some(DCTType::Type4), Some("ortho")).unwrap();
1196
1197 if signal == vec![1.0, 2.0, 3.0, 4.0] {
1198 let recovered = idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).unwrap();
1200 let recovered_ratio = recovered[3] / recovered[0]; let original_ratio = signal[3] / signal[0];
1202 assert_relative_eq!(recovered_ratio, original_ratio, epsilon = 0.1);
1203 } else {
1204 let recovered = idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).unwrap();
1205 for i in 0..signal.len() {
1206 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1207 }
1208 }
1209 }
1210
1211 #[test]
1212 fn test_dct2_and_idct2() {
1213 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
1215
1216 let dct2_coeffs = dct2(&arr.view(), Some(DCTType::Type2), Some("ortho")).unwrap();
1218
1219 let recovered = idct2(&dct2_coeffs.view(), Some(DCTType::Type2), Some("ortho")).unwrap();
1221
1222 for i in 0..2 {
1224 for j in 0..2 {
1225 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
1226 }
1227 }
1228 }
1229
1230 #[test]
1231 fn test_constant_signal() {
1232 let signal = vec![3.0, 3.0, 3.0, 3.0];
1234
1235 let dct_coeffs = dct(&signal, Some(DCTType::Type2), None).unwrap();
1237
1238 assert!(dct_coeffs[0].abs() > 1e-10);
1240 for i in 1..signal.len() {
1241 assert!(dct_coeffs[i].abs() < 1e-10);
1242 }
1243 }
1244}