1use ndarray::Array1;
2use ndarray_linalg::{Lapack, MatrixLayout, SolveTridiagonal, Tridiagonal};
3use num::One;
4
5use crate::Accelerator;
6use crate::DomainError;
7use crate::InterpType;
8use crate::Interpolation;
9use crate::InterpolationError;
10use crate::types::utils::integ_eval;
11use crate::types::utils::{check_if_inbounds, check1d_data, diff};
12
13const MIN_SIZE: usize = 3;
14
15#[doc(alias = "gsl_interp_cspline")]
26pub struct Cubic;
27
28impl<T> InterpType<T> for Cubic
29where
30 T: crate::Num + Lapack,
31{
32 type Interpolation = CubicInterp<T>;
33
34 fn build(&self, xa: &[T], ya: &[T]) -> Result<CubicInterp<T>, InterpolationError> {
49 check1d_data(xa, ya, MIN_SIZE)?;
50
51 let sys_size = xa.len() - 2;
53
54 let h = diff(xa);
55 debug_assert_eq!(h.len(), xa.len() - 1);
56
57 let two = T::from(2).unwrap();
58 let three = T::from(3).unwrap();
59
60 let mut g = Vec::<T>::with_capacity(sys_size);
62 let mut diag = Vec::<T>::with_capacity(sys_size);
63 let mut offdiag = Vec::<T>::with_capacity(sys_size);
64 for i in 0..sys_size {
65 g.push(if h[i].is_zero() {
66 T::zero()
67 } else {
68 three * (ya[i + 2] - ya[i + 1]) / h[i + 1] - three * (ya[i + 1] - ya[i]) / h[i]
69 });
70 diag.push(two * (h[i] + h[i + 1]));
71 offdiag.push(h[i + 1]);
72 }
73 offdiag.pop();
77 debug_assert_eq!(diag.len(), offdiag.len() + 1);
78
79 let matrix = Tridiagonal {
80 l: MatrixLayout::C {
81 row: (sys_size) as i32,
82 lda: (sys_size) as i32,
83 },
84 d: diag.clone(),
85 dl: offdiag.clone(),
86 du: offdiag.clone(),
87 };
88
89 let mut c = Vec::<T>::with_capacity(xa.len());
91 c.push(T::zero());
92 if sys_size.is_one() {
93 c.push(g[0] / diag[0]);
94 } else {
95 let coeffs = match matrix.solve_tridiagonal(&Array1::from_vec(g.clone())) {
96 Ok(coeffs) => coeffs,
97 Err(err) => {
98 return Err(InterpolationError::BLASTridiagError {
99 which_interp: "Cubic".into(),
100 source: err,
101 });
102 }
103 };
104 c = [c, coeffs.to_vec()].concat();
105 }
106 c.push(T::zero());
107
108 let state = CubicInterp {
111 c,
112 g,
113 diag,
114 offdiag,
115 };
116 Ok(state)
117 }
118
119 fn name(&self) -> &str {
120 "Cubic"
121 }
122
123 fn min_size(&self) -> usize {
124 MIN_SIZE
125 }
126}
127
128#[allow(dead_code)]
136pub struct CubicInterp<T>
137where
138 T: crate::Num,
139{
140 c: Vec<T>,
141 g: Vec<T>,
142 diag: Vec<T>,
143 offdiag: Vec<T>,
144}
145
146impl<T> Interpolation<T> for CubicInterp<T>
147where
148 T: crate::Num + Lapack,
149{
150 fn eval(&self, xa: &[T], ya: &[T], x: T, acc: &mut Accelerator) -> Result<T, DomainError> {
151 cubic_eval(xa, ya, &self.c, x, acc)
152 }
153
154 fn eval_deriv(
155 &self,
156 xa: &[T],
157 ya: &[T],
158 x: T,
159 acc: &mut Accelerator,
160 ) -> Result<T, DomainError> {
161 cubic_eval_deriv(xa, ya, &self.c, x, acc)
162 }
163
164 fn eval_deriv2(
165 &self,
166 xa: &[T],
167 ya: &[T],
168 x: T,
169 acc: &mut Accelerator,
170 ) -> Result<T, DomainError> {
171 cubic_eval_deriv2(xa, ya, &self.c, x, acc)
172 }
173
174 fn eval_integ(
175 &self,
176 xa: &[T],
177 ya: &[T],
178 a: T,
179 b: T,
180 acc: &mut Accelerator,
181 ) -> Result<T, DomainError> {
182 cubic_eval_integ(xa, ya, &self.c, a, b, acc)
183 }
184}
185
186#[doc(alias = "gsl_interp_cspline_periodic")]
201pub struct CubicPeriodic;
202
203impl<T> InterpType<T> for CubicPeriodic
204where
205 T: crate::Num + Lapack,
206{
207 type Interpolation = CubicPeriodicInterp<T>;
208
209 fn build(&self, xa: &[T], ya: &[T]) -> Result<CubicPeriodicInterp<T>, InterpolationError> {
225 check1d_data(xa, ya, MIN_SIZE)?;
226
227 let sys_size = xa.len() - 1;
229
230 let h = diff(xa);
231 debug_assert!(h.len() == xa.len() - 1);
232
233 let two = T::from(2).unwrap();
234 let three = T::from(3).unwrap();
235
236 let mut c = Vec::<T>::with_capacity(xa.len());
238 let mut g = Vec::<T>::with_capacity(sys_size);
239 let mut diag = Vec::<T>::with_capacity(sys_size);
240 let mut offdiag = Vec::<T>::with_capacity(sys_size);
241
242 if sys_size == 2 {
243 let h0 = xa[1] - xa[0];
244 let h1 = xa[2] - xa[1];
245
246 let a = two * (h0 + h1);
247 let b = h0 + h1;
248
249 g.push(three * ((ya[2] - ya[1]) / h1 - (ya[1] - ya[0]) / h0));
250 g.push(three * ((ya[1] - ya[2]) / h0 - (ya[2] - ya[1]) / h1));
251
252 let det = three * (h0 + h1) * (h0 + h1);
253 c.push((-b * g[0] + a * g[1]) / det);
254 c.push((a * g[0] - b * g[1]) / det);
255 c.push(c[0]);
256 } else {
257 for i in 0..sys_size - 1 {
259 g.push(if h[i].is_zero() {
260 T::zero()
261 } else {
262 three * (ya[i + 2] - ya[i + 1]) / h[i + 1] - three * (ya[i + 1] - ya[i]) / h[i]
263 });
264 diag.push(two * (h[i] + h[i + 1]));
265 offdiag.push(h[i + 1]);
266 }
267
268 let i = sys_size - 1;
270 let hi = xa[i + 1] - xa[i];
271 let hiplus1 = xa[1] - xa[0];
272 let ydiffi = ya[i + 1] - ya[i];
273 let ydiffplus1 = ya[1] - ya[0];
274 let gi = if !hi.is_zero() {
275 T::one() / hi
276 } else {
277 T::zero()
278 };
279 let giplus1 = if !hiplus1.is_zero() {
280 T::one() / hiplus1
281 } else {
282 T::zero()
283 };
284 offdiag.push(hiplus1);
285 diag.push(two * (hiplus1 + hi));
286 g.push(three * (ydiffplus1 * giplus1 - ydiffi * gi));
287 debug_assert_eq!(diag.len(), offdiag.len());
289
290 let matrix = Tridiagonal {
291 l: MatrixLayout::C {
292 row: (sys_size) as i32,
293 lda: (sys_size) as i32,
294 },
295 d: diag.clone(),
296 dl: offdiag.clone(),
297 du: offdiag.clone(),
298 };
299
300 c.push(T::zero());
302 if sys_size.is_one() {
303 c.push(g[0] / diag[0]);
304 } else {
305 let coeffs = match matrix.solve_tridiagonal(&Array1::from_vec(g.clone())) {
308 Ok(coeffs) => coeffs,
309 Err(err) => {
310 return Err(InterpolationError::BLASTridiagError {
311 which_interp: "Cubic Periodic".into(),
312 source: err,
313 });
314 }
315 };
316 c = [c, coeffs.to_vec()].concat();
317 }
318 c[0] = c[sys_size];
319 panic!(
320 "\nNot implemented: Cubic Periodic Splines with more than 3 points require a solver for\
321 cyclically tridiagonal matrices, which is currently not implemented by ndarray_linalg.\n"
322 )
323 }
324
325 let state = CubicPeriodicInterp {
328 c,
329 g,
330 diag,
331 offdiag,
332 };
333 Ok(state)
334 }
335
336 fn name(&self) -> &str {
337 "Cubic Periodic"
338 }
339
340 fn min_size(&self) -> usize {
341 MIN_SIZE
342 }
343}
344
345#[allow(dead_code)]
353#[doc(alias = "gsl_interp_cspline_periodic")]
354pub struct CubicPeriodicInterp<T>
355where
356 T: crate::Num + Lapack,
357{
358 c: Vec<T>,
359 g: Vec<T>,
360 diag: Vec<T>,
361 offdiag: Vec<T>,
362}
363
364impl<T> Interpolation<T> for CubicPeriodicInterp<T>
365where
366 T: crate::Num + Lapack,
367{
368 fn eval(&self, xa: &[T], ya: &[T], x: T, acc: &mut Accelerator) -> Result<T, DomainError> {
369 cubic_eval(xa, ya, &self.c, x, acc)
370 }
371
372 fn eval_deriv(
373 &self,
374 xa: &[T],
375 ya: &[T],
376 x: T,
377 acc: &mut Accelerator,
378 ) -> Result<T, DomainError> {
379 cubic_eval_deriv(xa, ya, &self.c, x, acc)
380 }
381
382 fn eval_deriv2(
383 &self,
384 xa: &[T],
385 ya: &[T],
386 x: T,
387 acc: &mut Accelerator,
388 ) -> Result<T, DomainError> {
389 cubic_eval_deriv2(xa, ya, &self.c, x, acc)
390 }
391
392 fn eval_integ(
393 &self,
394 xa: &[T],
395 ya: &[T],
396 a: T,
397 b: T,
398 acc: &mut Accelerator,
399 ) -> Result<T, DomainError> {
400 cubic_eval_integ(xa, ya, &self.c, a, b, acc)
401 }
402}
403
404#[inline(always)]
407fn cubic_eval<T>(xa: &[T], ya: &[T], c: &[T], x: T, acc: &mut Accelerator) -> Result<T, DomainError>
408where
409 T: crate::Num + Lapack,
410{
411 check_if_inbounds(xa, x)?;
412 let index = acc.find(xa, x);
413
414 let xlo = xa[index];
415 let xhi = xa[index + 1];
416 let ylo = ya[index];
417 let yhi = ya[index + 1];
418
419 let dx = xhi - xlo;
420 let dy = yhi - ylo;
421
422 let delx = x - xlo;
423 let (b, c, d) = coeff_calc(c, dx, dy, index);
424
425 debug_assert!(dx > T::zero());
426 Ok(ylo + delx * (b + delx * (c + delx * d)))
427}
428
429fn cubic_eval_deriv<T>(
430 xa: &[T],
431 ya: &[T],
432 c: &[T],
433 x: T,
434 acc: &mut Accelerator,
435) -> Result<T, DomainError>
436where
437 T: crate::Num + Lapack,
438{
439 check_if_inbounds(xa, x)?;
440 let index = acc.find(xa, x);
441
442 let xlo = xa[index];
443 let xhi = xa[index + 1];
444 let ylo = ya[index];
445 let yhi = ya[index + 1];
446
447 let dx = xhi - xlo;
448 let dy = yhi - ylo;
449
450 let delx = x - xlo;
451 let (b, c, d) = coeff_calc(c, dx, dy, index);
452
453 let two = T::from(2).unwrap();
454 let three = T::from(3).unwrap();
455
456 debug_assert!(dx > T::zero());
457 Ok(b + delx * (two * c + three * d * delx))
458}
459
460#[inline(always)]
461fn cubic_eval_deriv2<T>(
462 xa: &[T],
463 ya: &[T],
464 c: &[T],
465 x: T,
466 acc: &mut Accelerator,
467) -> Result<T, DomainError>
468where
469 T: crate::Num + Lapack,
470{
471 check_if_inbounds(xa, x)?;
472 let index = acc.find(xa, x);
473
474 let xlo = xa[index];
475 let xhi = xa[index + 1];
476 let ylo = ya[index];
477 let yhi = ya[index + 1];
478
479 let dx = xhi - xlo;
480 let dy = yhi - ylo;
481
482 let delx = x - xlo;
483 let (_, c, d) = coeff_calc(c, dx, dy, index);
484
485 let two = T::from(2).unwrap();
486 let six = T::from(6).unwrap();
487
488 debug_assert!(dx > T::zero());
489 Ok(two * c + six * delx * d)
490}
491
492#[inline(always)]
493fn cubic_eval_integ<T>(
494 xa: &[T],
495 ya: &[T],
496 c: &[T],
497 a: T,
498 b: T,
499 acc: &mut Accelerator,
500) -> Result<T, DomainError>
501where
502 T: crate::Num + Lapack,
503{
504 check_if_inbounds(xa, a)?;
505 check_if_inbounds(xa, b)?;
506 let index_a = acc.find(xa, a);
507 let index_b = acc.find(xa, b);
508
509 let quarter = T::from(0.25).unwrap();
510 let half = T::from(0.5).unwrap();
511 let third = T::from(1.0 / 3.0).unwrap();
512
513 let mut result = T::zero();
514
515 for i in index_a..=index_b {
516 let xlo = xa[i];
517 let xhi = xa[i + 1];
518 let ylo = ya[i];
519 let yhi = ya[i + 1];
520
521 let dx = xhi - xlo;
522 let dy = yhi - ylo;
523
524 if dx.is_zero() {
526 continue;
527 }
528
529 let (bi, ci, di) = coeff_calc(c, dx, dy, i);
530
531 if (i == index_a) | (i == index_b) {
532 let x1 = if i == index_a { a } else { xlo };
533 let x2 = if i == index_b { b } else { xhi };
534 result += integ_eval(ylo, bi, ci, di, xlo, x1, x2);
535 } else {
536 result += dx * (ylo + dx * (half * bi + dx * (third * ci + quarter * di * dx)))
537 }
538 }
539 Ok(result)
540}
541fn coeff_calc<T>(carray: &[T], dx: T, dy: T, index: usize) -> (T, T, T)
543where
544 T: crate::Num + Lapack,
545{
546 let two = T::from(2).unwrap();
547 let three = T::from(3).unwrap();
548
549 let c = carray[index];
550 let cplus1 = carray[index + 1];
551
552 let b = (dy / dx) - dx * (cplus1 + two * c) / three;
553 let d = (cplus1 - c) / (three * dx);
554 (b, c, d)
555}