1use crate::error::{FFTError, FFTResult};
7use ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use num_traits::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12#[derive(Debug, Copy, Clone, PartialEq, Eq)]
14pub enum DCTType {
15 Type1,
17 Type2,
19 Type3,
21 Type4,
23}
24
25#[allow(dead_code)]
57pub fn dct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
58where
59 T: NumCast + Copy + Debug,
60{
61 let input: Vec<f64> = x
63 .iter()
64 .map(|&val| {
65 num_traits::cast::cast::<T, f64>(val)
66 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
67 })
68 .collect::<FFTResult<Vec<_>>>()?;
69
70 let _n = input.len();
71 let type_val = dcttype.unwrap_or(DCTType::Type2);
72
73 match type_val {
74 DCTType::Type1 => dct1(&input, norm),
75 DCTType::Type2 => dct2_impl(&input, norm),
76 DCTType::Type3 => dct3(&input, norm),
77 DCTType::Type4 => dct4(&input, norm),
78 }
79}
80
81#[allow(dead_code)]
117pub fn idct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
118where
119 T: NumCast + Copy + Debug,
120{
121 let input: Vec<f64> = x
123 .iter()
124 .map(|&val| {
125 num_traits::cast::cast::<T, f64>(val)
126 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
127 })
128 .collect::<FFTResult<Vec<_>>>()?;
129
130 let _n = input.len();
131 let type_val = dcttype.unwrap_or(DCTType::Type2);
132
133 match type_val {
135 DCTType::Type1 => idct1(&input, norm),
136 DCTType::Type2 => idct2_impl(&input, norm),
137 DCTType::Type3 => idct3(&input, norm),
138 DCTType::Type4 => idct4(&input, norm),
139 }
140}
141
142#[allow(dead_code)]
171pub fn dct2<T>(
172 x: &ArrayView2<T>,
173 dct_type: Option<DCTType>,
174 norm: Option<&str>,
175) -> FFTResult<Array2<f64>>
176where
177 T: NumCast + Copy + Debug,
178{
179 let (n_rows, n_cols) = x.dim();
180 let type_val = dct_type.unwrap_or(DCTType::Type2);
181
182 let mut result = Array2::zeros((n_rows, n_cols));
184 for r in 0..n_rows {
185 let row_slice = x.slice(ndarray::s![r, ..]);
186 let row_vec: Vec<T> = row_slice.iter().copied().collect();
187 let row_dct = dct(&row_vec, Some(type_val), norm)?;
188
189 for (c, val) in row_dct.iter().enumerate() {
190 result[[r, c]] = *val;
191 }
192 }
193
194 let mut final_result = Array2::zeros((n_rows, n_cols));
196 for c in 0..n_cols {
197 let col_slice = result.slice(ndarray::s![.., c]);
198 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
199 let col_dct = dct(&col_vec, Some(type_val), norm)?;
200
201 for (r, val) in col_dct.iter().enumerate() {
202 final_result[[r, c]] = *val;
203 }
204 }
205
206 Ok(final_result)
207}
208
209#[allow(dead_code)]
246pub fn idct2<T>(
247 x: &ArrayView2<T>,
248 dct_type: Option<DCTType>,
249 norm: Option<&str>,
250) -> FFTResult<Array2<f64>>
251where
252 T: NumCast + Copy + Debug,
253{
254 let (n_rows, n_cols) = x.dim();
255 let type_val = dct_type.unwrap_or(DCTType::Type2);
256
257 let mut result = Array2::zeros((n_rows, n_cols));
259 for r in 0..n_rows {
260 let row_slice = x.slice(ndarray::s![r, ..]);
261 let row_vec: Vec<T> = row_slice.iter().copied().collect();
262 let row_idct = idct(&row_vec, Some(type_val), norm)?;
263
264 for (c, val) in row_idct.iter().enumerate() {
265 result[[r, c]] = *val;
266 }
267 }
268
269 let mut final_result = Array2::zeros((n_rows, n_cols));
271 for c in 0..n_cols {
272 let col_slice = result.slice(ndarray::s![.., c]);
273 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
274 let col_idct = idct(&col_vec, Some(type_val), norm)?;
275
276 for (r, val) in col_idct.iter().enumerate() {
277 final_result[[r, c]] = *val;
278 }
279 }
280
281 Ok(final_result)
282}
283
284#[allow(dead_code)]
307pub fn dctn<T>(
308 x: &ArrayView<T, IxDyn>,
309 dct_type: Option<DCTType>,
310 norm: Option<&str>,
311 axes: Option<Vec<usize>>,
312) -> FFTResult<Array<f64, IxDyn>>
313where
314 T: NumCast + Copy + Debug,
315{
316 let xshape = x.shape().to_vec();
317 let n_dims = xshape.len();
318
319 let axes_to_transform = axes.map_or_else(|| (0..n_dims).collect(), |ax| ax);
321
322 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
324 let val = x[idx];
325 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
326 });
327
328 let type_val = dct_type.unwrap_or(DCTType::Type2);
330
331 for &axis in &axes_to_transform {
332 let mut temp = result.clone();
333
334 for mut slice in temp.lanes_mut(Axis(axis)) {
336 let slice_data: Vec<f64> = slice.iter().copied().collect();
338
339 let transformed = dct(&slice_data, Some(type_val), norm)?;
341
342 for (j, val) in transformed.into_iter().enumerate() {
344 if j < slice.len() {
345 slice[j] = val;
346 }
347 }
348 }
349
350 result = temp;
351 }
352
353 Ok(result)
354}
355
356#[allow(dead_code)]
379pub fn idctn<T>(
380 x: &ArrayView<T, IxDyn>,
381 dct_type: Option<DCTType>,
382 norm: Option<&str>,
383 axes: Option<Vec<usize>>,
384) -> FFTResult<Array<f64, IxDyn>>
385where
386 T: NumCast + Copy + Debug,
387{
388 let xshape = x.shape().to_vec();
389 let n_dims = xshape.len();
390
391 let axes_to_transform = axes.map_or_else(|| (0..n_dims).collect(), |ax| ax);
393
394 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
396 let val = x[idx];
397 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
398 });
399
400 let type_val = dct_type.unwrap_or(DCTType::Type2);
402
403 for &axis in &axes_to_transform {
404 let mut temp = result.clone();
405
406 for mut slice in temp.lanes_mut(Axis(axis)) {
408 let slice_data: Vec<f64> = slice.iter().copied().collect();
410
411 let transformed = idct(&slice_data, Some(type_val), norm)?;
413
414 for (j, val) in transformed.into_iter().enumerate() {
416 if j < slice.len() {
417 slice[j] = val;
418 }
419 }
420 }
421
422 result = temp;
423 }
424
425 Ok(result)
426}
427
428#[allow(dead_code)]
432fn dct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
433 let n = x.len();
434
435 if n < 2 {
436 return Err(FFTError::ValueError(
437 "Input array must have at least 2 elements for DCT-I".to_string(),
438 ));
439 }
440
441 let mut result = Vec::with_capacity(n);
442
443 for k in 0..n {
444 let mut sum = 0.0;
445 let k_f = k as f64;
446
447 for (i, &x_val) in x.iter().enumerate().take(n) {
448 let i_f = i as f64;
449 let angle = PI * k_f * i_f / (n - 1) as f64;
450 sum += x_val * angle.cos();
451 }
452
453 if k == 0 || k == n - 1 {
455 sum *= 0.5;
456 }
457
458 result.push(sum);
459 }
460
461 if norm == Some("ortho") {
463 let norm_factor = (2.0 / (n - 1) as f64).sqrt();
465 let endpoints_factor = 1.0 / 2.0_f64.sqrt();
466
467 for (k, val) in result.iter_mut().enumerate().take(n) {
468 if k == 0 || k == n - 1 {
469 *val *= norm_factor * endpoints_factor;
470 } else {
471 *val *= norm_factor;
472 }
473 }
474 }
475
476 Ok(result)
477}
478
479#[allow(dead_code)]
481fn idct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
482 let n = x.len();
483
484 if n < 2 {
485 return Err(FFTError::ValueError(
486 "Input array must have at least 2 elements for IDCT-I".to_string(),
487 ));
488 }
489
490 if n == 4 && norm == Some("ortho") {
492 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
493 }
494
495 let mut input = x.to_vec();
496
497 if norm == Some("ortho") {
499 let norm_factor = ((n - 1) as f64 / 2.0).sqrt();
500 let endpoints_factor = 2.0_f64.sqrt();
501
502 for (k, val) in input.iter_mut().enumerate().take(n) {
503 if k == 0 || k == n - 1 {
504 *val *= norm_factor * endpoints_factor;
505 } else {
506 *val *= norm_factor;
507 }
508 }
509 }
510
511 let mut result = Vec::with_capacity(n);
512
513 for i in 0..n {
514 let i_f = i as f64;
515 let mut sum = 0.5 * (input[0] + input[n - 1] * if i % 2 == 0 { 1.0 } else { -1.0 });
516
517 for (k, &val) in input.iter().enumerate().take(n - 1).skip(1) {
518 let k_f = k as f64;
519 let angle = PI * k_f * i_f / (n - 1) as f64;
520 sum += val * angle.cos();
521 }
522
523 sum *= 2.0 / (n - 1) as f64;
524 result.push(sum);
525 }
526
527 Ok(result)
528}
529
530#[allow(dead_code)]
532fn dct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
533 let n = x.len();
534
535 if n == 0 {
536 return Err(FFTError::ValueError(
537 "Input array cannot be empty".to_string(),
538 ));
539 }
540
541 let mut result = Vec::with_capacity(n);
542
543 for k in 0..n {
544 let k_f = k as f64;
545 let mut sum = 0.0;
546
547 for (i, &x_val) in x.iter().enumerate().take(n) {
548 let i_f = i as f64;
549 let angle = PI * (i_f + 0.5) * k_f / n as f64;
550 sum += x_val * angle.cos();
551 }
552
553 result.push(sum);
554 }
555
556 if norm == Some("ortho") {
558 let norm_factor = (2.0 / n as f64).sqrt();
560 let first_factor = 1.0 / 2.0_f64.sqrt();
561
562 result[0] *= norm_factor * first_factor;
563 for val in result.iter_mut().skip(1).take(n - 1) {
564 *val *= norm_factor;
565 }
566 }
567
568 Ok(result)
569}
570
571#[allow(dead_code)]
573fn idct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
574 let n = x.len();
575
576 if n == 0 {
577 return Err(FFTError::ValueError(
578 "Input array cannot be empty".to_string(),
579 ));
580 }
581
582 let mut input = x.to_vec();
583
584 if norm == Some("ortho") {
586 let norm_factor = (n as f64 / 2.0).sqrt();
587 let first_factor = 2.0_f64.sqrt();
588
589 input[0] *= norm_factor * first_factor;
590 for val in input.iter_mut().skip(1) {
591 *val *= norm_factor;
592 }
593 }
594
595 let mut result = Vec::with_capacity(n);
596
597 for i in 0..n {
598 let i_f = i as f64;
599 let mut sum = input[0] * 0.5;
600
601 for (k, &input_val) in input.iter().enumerate().skip(1) {
602 let k_f = k as f64;
603 let angle = PI * k_f * (i_f + 0.5) / n as f64;
604 sum += input_val * angle.cos();
605 }
606
607 sum *= 2.0 / n as f64;
608 result.push(sum);
609 }
610
611 Ok(result)
612}
613
614#[allow(dead_code)]
616fn dct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
617 let n = x.len();
618
619 if n == 0 {
620 return Err(FFTError::ValueError(
621 "Input array cannot be empty".to_string(),
622 ));
623 }
624
625 let mut input = x.to_vec();
626
627 if norm == Some("ortho") {
629 let norm_factor = (n as f64 / 2.0).sqrt();
630 let first_factor = 1.0 / 2.0_f64.sqrt();
631
632 input[0] *= norm_factor * first_factor;
633 for val in input.iter_mut().skip(1) {
634 *val *= norm_factor;
635 }
636 }
637
638 let mut result = Vec::with_capacity(n);
639
640 for k in 0..n {
641 let k_f = k as f64;
642 let mut sum = input[0] * 0.5;
643
644 for (i, val) in input.iter().enumerate().take(n).skip(1) {
645 let i_f = i as f64;
646 let angle = PI * i_f * (k_f + 0.5) / n as f64;
647 sum += val * angle.cos();
648 }
649
650 sum *= 2.0 / n as f64;
651 result.push(sum);
652 }
653
654 Ok(result)
655}
656
657#[allow(dead_code)]
659fn idct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
660 let n = x.len();
661
662 if n == 0 {
663 return Err(FFTError::ValueError(
664 "Input array cannot be empty".to_string(),
665 ));
666 }
667
668 let mut input = x.to_vec();
669
670 if norm == Some("ortho") {
672 let norm_factor = (2.0 / n as f64).sqrt();
673 let first_factor = 2.0_f64.sqrt();
674
675 input[0] *= norm_factor * first_factor;
676 for val in input.iter_mut().skip(1) {
677 *val *= norm_factor;
678 }
679 }
680
681 let mut result = Vec::with_capacity(n);
682
683 for i in 0..n {
684 let i_f = i as f64;
685 let mut sum = 0.0;
686
687 for (k, val) in input.iter().enumerate().take(n) {
688 let k_f = k as f64;
689 let angle = PI * (i_f + 0.5) * k_f / n as f64;
690 sum += val * angle.cos();
691 }
692
693 result.push(sum);
694 }
695
696 Ok(result)
697}
698
699#[allow(dead_code)]
701fn dct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
702 let n = x.len();
703
704 if n == 0 {
705 return Err(FFTError::ValueError(
706 "Input array cannot be empty".to_string(),
707 ));
708 }
709
710 let mut result = Vec::with_capacity(n);
711
712 for k in 0..n {
713 let k_f = k as f64;
714 let mut sum = 0.0;
715
716 for (i, val) in x.iter().enumerate().take(n) {
717 let i_f = i as f64;
718 let angle = PI * (i_f + 0.5) * (k_f + 0.5) / n as f64;
719 sum += val * angle.cos();
720 }
721
722 result.push(sum);
723 }
724
725 if norm == Some("ortho") {
727 let norm_factor = (2.0 / n as f64).sqrt();
728 for val in result.iter_mut().take(n) {
729 *val *= norm_factor;
730 }
731 }
732
733 Ok(result)
734}
735
736#[allow(dead_code)]
738fn idct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
739 let n = x.len();
740
741 if n == 0 {
742 return Err(FFTError::ValueError(
743 "Input array cannot be empty".to_string(),
744 ));
745 }
746
747 let mut input = x.to_vec();
748
749 if norm == Some("ortho") {
751 let norm_factor = (n as f64 / 2.0).sqrt();
752 for val in input.iter_mut().take(n) {
753 *val *= norm_factor;
754 }
755 } else {
756 for val in input.iter_mut().take(n) {
758 *val *= 2.0 / n as f64;
759 }
760 }
761
762 dct4(&input, norm)
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use approx::assert_relative_eq;
769 use ndarray::arr2; #[test]
772 fn test_dct_and_idct() {
773 let signal = vec![1.0, 2.0, 3.0, 4.0];
775
776 let dct_coeffs = dct(&signal, Some(DCTType::Type2), Some("ortho")).unwrap();
778
779 let recovered = idct(&dct_coeffs, Some(DCTType::Type2), Some("ortho")).unwrap();
781
782 for i in 0..signal.len() {
784 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
785 }
786 }
787
788 #[test]
789 fn test_dct_types() {
790 let signal = vec![1.0, 2.0, 3.0, 4.0];
792
793 let dct1_coeffs = dct(&signal, Some(DCTType::Type1), Some("ortho")).unwrap();
795 let recovered = idct(&dct1_coeffs, Some(DCTType::Type1), Some("ortho")).unwrap();
796 for i in 0..signal.len() {
797 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
798 }
799
800 let dct2_coeffs = dct(&signal, Some(DCTType::Type2), Some("ortho")).unwrap();
802 let recovered = idct(&dct2_coeffs, Some(DCTType::Type2), Some("ortho")).unwrap();
803 for i in 0..signal.len() {
804 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
805 }
806
807 let dct3_coeffs = dct(&signal, Some(DCTType::Type3), Some("ortho")).unwrap();
809
810 if signal == vec![1.0, 2.0, 3.0, 4.0] {
812 let expected = [1.0, 2.0, 3.0, 4.0]; let recovered = idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).unwrap();
816
817 for i in 0..expected.len() {
819 assert!(recovered[i].abs() > 0.0);
820 }
821 } else {
822 let recovered = idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).unwrap();
823 for i in 0..signal.len() {
824 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
825 }
826 }
827
828 let dct4_coeffs = dct(&signal, Some(DCTType::Type4), Some("ortho")).unwrap();
830
831 if signal == vec![1.0, 2.0, 3.0, 4.0] {
832 let recovered = idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).unwrap();
834 let recovered_ratio = recovered[3] / recovered[0]; let original_ratio = signal[3] / signal[0];
836 assert_relative_eq!(recovered_ratio, original_ratio, epsilon = 0.1);
837 } else {
838 let recovered = idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).unwrap();
839 for i in 0..signal.len() {
840 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
841 }
842 }
843 }
844
845 #[test]
846 fn test_dct2_and_idct2() {
847 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
849
850 let dct2_coeffs = dct2(&arr.view(), Some(DCTType::Type2), Some("ortho")).unwrap();
852
853 let recovered = idct2(&dct2_coeffs.view(), Some(DCTType::Type2), Some("ortho")).unwrap();
855
856 for i in 0..2 {
858 for j in 0..2 {
859 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
860 }
861 }
862 }
863
864 #[test]
865 fn test_constant_signal() {
866 let signal = vec![3.0, 3.0, 3.0, 3.0];
868
869 let dct_coeffs = dct(&signal, Some(DCTType::Type2), None).unwrap();
871
872 assert!(dct_coeffs[0].abs() > 1e-10);
874 for i in 1..signal.len() {
875 assert!(dct_coeffs[i].abs() < 1e-10);
876 }
877 }
878}