1use crate::enums::{BaseKind, TransformKind};
3use crate::traits::{
4 BaseElements, BaseFromOrtho, BaseGradient, BaseMatOpDiffmat, BaseMatOpLaplacian,
5 BaseMatOpStencil, BaseSize, BaseTransform,
6};
7use crate::types::{FloatNum, ScalarNum};
8use ndarray::{s, Array2};
9use rustdct::{Dct1, DctPlanner};
10use std::f64::consts::PI;
11use std::sync::Arc;
12
13#[derive(Clone)]
15pub struct Chebyshev<A> {
16 n: usize,
18 m: usize,
20 plan_dct: Arc<dyn Dct1<A>>,
22}
23
24impl<A: FloatNum> Chebyshev<A> {
25 #[must_use]
39 pub fn new(n: usize) -> Self {
40 let mut planner = DctPlanner::<A>::new();
41 let dct1 = planner.plan_dct1(n);
42 Self {
43 n,
44 m: n,
45 plan_dct: Arc::clone(&dct1),
46 }
47 }
48
49 #[allow(clippy::cast_precision_loss)]
65 fn chebyshev_nodes_2nd_kind(n: usize) -> Vec<A> {
66 let m = (n - 1) as f64;
67 (0..n)
68 .map(|k| {
69 let arg = A::from_f64(PI * (m - 2. * k as f64) / (2. * m)).unwrap();
70 -A::one() * arg.sin()
71 })
72 .collect::<Vec<A>>()
73 }
74
75 #[must_use]
77 pub fn nodes(n: usize) -> Vec<A> {
78 Self::chebyshev_nodes_2nd_kind(n)
79 }
80
81 #[must_use]
87 #[allow(clippy::cast_precision_loss)]
88 pub fn _dmat(n: usize, deriv: usize) -> Array2<A> {
89 let mut dmat = Array2::<f64>::zeros((n, n));
90 if deriv == 1 {
91 for p in 0..n {
92 for q in p + 1..n {
93 if (p + q) % 2 != 0 {
94 dmat[[p, q]] = (q * 2) as f64;
95 }
96 }
97 }
98 } else if deriv == 2 {
99 for p in 0..n {
100 for q in p + 2..n {
101 if (p + q) % 2 == 0 {
102 dmat[[p, q]] = (q * (q * q - p * p)) as f64;
103 }
104 }
105 }
106 } else if deriv == 3 {
107 for p in 0..n {
108 let p2 = p * p;
109 for q in p + 3..n {
110 let q2 = q * q;
111 if (p + q) % 2 != 0 {
112 dmat[[p, q]] =
113 (q * (q2 * (q2 - 2) - 2 * q2 * p2 + p2 * p2 - 2 * p2 + 1)) as f64 / 4.;
114 }
115 }
116 }
117 } else if deriv == 4 {
118 for p in 0..n {
119 let (p2, p4) = (p * p, p * p * p * p);
120 for q in p + 4..n {
121 let (q2, q4) = (q * q, q * q * q * q);
122 if (p + q) % 2 == 0 {
123 dmat[[p, q]] = (q
124 * (q2 * (q2 - 4) * (q2 - 4) + 3 * q2 * p4 - p2 * p4 + 8 * p4
125 - 16 * p2
126 - 3 * q4 * p2)) as f64
127 / 24.;
128 }
129 }
130 }
131 } else {
132 todo!()
133 }
134 for d in dmat.slice_mut(s![0, ..]).iter_mut() {
135 *d *= 0.5;
136 }
137 dmat.mapv(|elem| A::from_f64(elem).unwrap())
138 }
139 #[must_use]
156 #[allow(clippy::cast_precision_loss)]
157 pub fn _pinv(n: usize, deriv: usize) -> Array2<A> {
158 let mut pinv = Array2::<f64>::zeros([n, n]);
164 if deriv == 1 {
165 pinv[[1, 0]] = 1.;
166 for i in 2..n {
167 pinv[[i, i - 1]] = 1. / (2. * i as f64); }
169 for i in 1..n - 2 {
170 pinv[[i, i + 1]] = -1. / (2. * i as f64); }
172 } else if deriv == 2 {
173 pinv[[2, 0]] = 0.25;
174 for i in 3..n {
175 pinv[[i, i - 2]] = 1. / (4 * i * (i - 1)) as f64; }
177 for i in 2..n - 2 {
178 pinv[[i, i]] = -1. / (2 * (i * i - 1)) as f64; }
180 for i in 2..n - 4 {
181 pinv[[i, i + 2]] = 1. / (4 * i * (i + 1)) as f64; }
183 } else if deriv == 3 {
184 for i in 3..n {
186 let d = 8 * i * (i - 1) * (i - 2);
187 pinv[[i, i - 3]] = 1. / d as f64;
188 }
189 pinv[[3, 0]] *= 2.;
190 for i in 3..n - 2 {
192 let d = 8 * (i + 1) * (i - 2) * i;
193 pinv[[i, i - 1]] = -3. / d as f64;
194 }
195 for i in 3..n - 4 {
197 let d = 8 * (i - 1) * (i + 2) * i;
198 pinv[[i, i + 1]] = 3. / d as f64;
199 }
200 for i in 3..n - 6 {
202 let d = 8 * i * (i + 1) * (i + 2);
203 pinv[[i, i + 3]] = -1. / d as f64;
204 }
205 } else if deriv == 4 {
206 for i in 4..n {
208 let d = 16 * i * (i - 1) * (i - 2) * (i - 3);
209 pinv[[i, i - 4]] = 1. / d as f64;
210 }
211 pinv[[4, 0]] *= 2.;
212 for i in 4..n - 2 {
214 let d = 4 * (i - 3) * (i - 1) * i * (i + 1);
215 pinv[[i, i - 2]] = -1. / d as f64;
216 }
217 for i in 4..n - 4 {
219 let d = 8 * (i - 2) * (i - 1) * (i + 1) * (i + 2);
220 pinv[[i, i]] = 3. / d as f64;
221 }
222 for i in 4..n - 6 {
224 let d = 4 * (i - 1) * i * (i + 1) * (i + 3);
225 pinv[[i, i + 2]] = -1. / d as f64;
226 }
227 for i in 4..n - 8 {
229 let d = 16 * i * (i + 1) * (i + 2) * (i + 3);
230 pinv[[i, i + 4]] = 1. / d as f64;
231 }
232 } else {
233 todo!()
234 }
235 pinv.mapv(|elem| A::from_f64(elem).unwrap())
237 }
238
239 #[must_use]
244 pub fn _pinv_eye(n: usize, deriv: usize) -> Array2<A> {
245 let pinv_eye = Array2::<f64>::eye(n).slice(s![deriv.., ..]).to_owned();
246 pinv_eye.mapv(|elem| A::from_f64(elem).unwrap())
247 }
248}
249
250impl<A: FloatNum> BaseSize for Chebyshev<A> {
251 #[must_use]
253 fn len_phys(&self) -> usize {
254 self.n
255 }
256
257 #[must_use]
259 fn len_spec(&self) -> usize {
260 self.m
261 }
262
263 #[must_use]
265 fn len_orth(&self) -> usize {
266 self.m
267 }
268}
269
270impl<A: FloatNum> BaseElements for Chebyshev<A> {
271 type RealNum = A;
273
274 fn base_kind(&self) -> BaseKind {
276 BaseKind::Chebyshev
277 }
278
279 fn transform_kind(&self) -> TransformKind {
281 TransformKind::R2r
282 }
283
284 fn coords(&self) -> Vec<A> {
286 Chebyshev::nodes(self.len_phys())
287 }
288}
289
290impl<A: FloatNum> BaseMatOpDiffmat for Chebyshev<A> {
291 type NumType = A;
293
294 fn diffmat(&self, deriv: usize) -> Array2<Self::NumType> {
298 assert!(deriv > 0);
299 Self::_dmat(self.n, deriv)
300 }
301
302 fn diffmat_pinv(&self, deriv: usize) -> (Array2<Self::NumType>, Array2<Self::NumType>) {
314 assert!(deriv > 0);
315 (Self::_pinv(self.n, deriv), Self::_pinv_eye(self.n, deriv))
316 }
317}
318
319impl<A: FloatNum> BaseMatOpStencil for Chebyshev<A> {
320 type NumType = A;
322
323 fn stencil(&self) -> Array2<Self::NumType> {
325 Array2::<A>::eye(self.len_spec())
326 }
327
328 fn stencil_inv(&self) -> Array2<Self::NumType> {
330 Array2::<A>::eye(self.len_spec())
331 }
332}
333
334impl<A: FloatNum> BaseMatOpLaplacian for Chebyshev<A> {
335 type NumType = A;
337
338 fn laplacian(&self) -> Array2<Self::NumType> {
340 self.diffmat(2)
341 }
342
343 fn laplacian_pinv(&self) -> (Array2<Self::NumType>, Array2<Self::NumType>) {
352 self.diffmat_pinv(2)
353 }
354}
355
356impl<A, T> BaseFromOrtho<T> for Chebyshev<A>
357where
358 A: FloatNum,
359 T: ScalarNum,
360{
361 fn to_ortho_slice(&self, indata: &[T], outdata: &mut [T]) {
363 for (y, x) in outdata.iter_mut().zip(indata.iter()) {
364 *y = *x;
365 }
366 }
367
368 fn from_ortho_slice(&self, indata: &[T], outdata: &mut [T]) {
370 for (y, x) in outdata.iter_mut().zip(indata.iter()) {
371 *y = *x;
372 }
373 }
374}
375
376impl<A, T> BaseGradient<T> for Chebyshev<A>
377where
378 A: FloatNum,
379 T: ScalarNum,
380{
381 fn gradient_slice(&self, indata: &[T], outdata: &mut [T], n_times: usize) {
401 assert!(outdata.len() == self.m);
402 for (y, x) in outdata.iter_mut().zip(indata.iter()) {
404 *y = *x;
405 }
406 let two = T::one() + T::one();
408 for _ in 0..n_times {
409 unsafe {
411 *outdata.get_unchecked_mut(0) = *outdata.get_unchecked(1);
412 for i in 1..self.m - 1 {
413 *outdata.get_unchecked_mut(i) =
414 two * T::from_usize(i + 1).unwrap() * *outdata.get_unchecked(i + 1);
415 }
416 *outdata.get_unchecked_mut(self.m - 1) = T::zero();
417 for i in (1..self.m - 2).rev() {
419 *outdata.get_unchecked_mut(i) =
420 *outdata.get_unchecked(i) + *outdata.get_unchecked(i + 2);
421 }
422 *outdata.get_unchecked_mut(0) =
423 *outdata.get_unchecked(0) + *outdata.get_unchecked(2) / two;
424 }
425 }
426 }
427}
428
429impl<A: FloatNum> BaseTransform for Chebyshev<A> {
430 type Physical = A;
431
432 type Spectral = A;
433
434 fn forward_slice(&self, indata: &[Self::Physical], outdata: &mut [Self::Spectral]) {
435 assert!(indata.len() == self.len_phys());
437 assert!(indata.len() == outdata.len());
438
439 let cor = (A::one() + A::one()) * A::one() / A::from(self.n - 1).unwrap();
441
442 for (y, x) in outdata.iter_mut().zip(indata.iter().rev()) {
444 *y = *x * cor;
445 }
446 self.plan_dct.process_dct1(outdata);
448
449 let half = A::from_f64(0.5).unwrap();
451 outdata[0] *= half;
452 outdata[self.n - 1] *= half;
453 }
454
455 fn backward_slice(&self, indata: &[Self::Spectral], outdata: &mut [Self::Physical]) {
456 assert!(indata.len() == self.len_spec());
458 assert!(indata.len() == outdata.len());
459
460 let two = A::one() + A::one();
464 for (i, (y, x)) in outdata.iter_mut().zip(indata.iter()).enumerate() {
465 if i % 2 == 0 {
466 *y = *x;
467 } else {
468 *y = *x * -A::one();
469 }
470 }
471 outdata[0] *= two;
472 outdata[self.m - 1] *= two;
473
474 self.plan_dct.process_dct1(outdata);
476 }
477}
478
479#[cfg(test)]
480mod test {
481 use super::*;
482 use crate::utils::approx_eq;
483
484 #[test]
485 fn test_chebyshev_transform_1() {
486 let cheby = Chebyshev::<f64>::new(4);
487 let indata = vec![1., 2., 3., 4.];
488 let mut outdata = vec![0.; 4];
489 cheby.forward_slice(&indata, &mut outdata);
490 approx_eq(&outdata, &vec![2.5, 1.33333333, 0., 0.16666667]);
491 }
492
493 #[test]
494 fn test_chebyshev_pinv() {
495 let n = 27;
496 let ch = Chebyshev::<f64>::new(n);
497 let (b2, _) = ch.diffmat_pinv(2);
499 let mut b4_v2 = b2.dot(&b2);
500 b4_v2.slice_mut(ndarray::s![..4, ..]).fill(0.);
501 b4_v2.slice_mut(ndarray::s![.., n - 4..]).fill(0.);
502 let (b4, _) = ch.diffmat_pinv(4);
504 for (x, y) in b4_v2.iter().zip(b4.iter()) {
506 assert!((x - y).abs() < 1e-6);
507 }
508 }
509}