use crate::curves::Point2D;
use crate::curves::traits::StatisticalCurve;
use crate::curves::utils::detect_peaks_and_valleys;
use crate::error::{CurveError, InterpolationError, MetricsError};
use crate::geometrics::{
Arithmetic, AxisOperations, BasicMetrics, BiLinearInterpolation, ConstructionMethod,
ConstructionParams, CubicInterpolation, GeometricObject, GeometricTransformations, Interpolate,
InterpolationType, LinearInterpolation, MergeAxisInterpolate, MergeOperation, MetricsExtractor,
RangeMetrics, RiskMetrics, ShapeMetrics, SplineInterpolation, TrendMetrics,
};
use crate::utils::Len;
use crate::visualization::{Graph, GraphData};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::*;
use rust_decimal::{Decimal, MathematicalOps};
use rust_decimal_macros::dec;
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
use std::fmt::{Display, Formatter};
use std::ops::Index;
use utoipa::ToSchema;
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Curve {
pub points: BTreeSet<Point2D>,
pub x_range: (Decimal, Decimal),
}
impl Display for Curve {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
for point in self.points.iter() {
writeln!(f, "{point}")?;
}
Ok(())
}
}
impl Default for Curve {
fn default() -> Self {
Curve {
points: BTreeSet::new(),
x_range: (Decimal::ZERO, Decimal::ZERO),
}
}
}
impl Curve {
pub fn new(points: BTreeSet<Point2D>) -> Self {
let x_range = Self::calculate_range(points.iter().map(|p| p.x));
Curve { points, x_range }
}
}
impl Len for Curve {
fn len(&self) -> usize {
self.points.len()
}
fn is_empty(&self) -> bool {
self.points.is_empty()
}
}
impl Graph for Curve {
fn graph_data(&self) -> GraphData {
self.clone().into()
}
}
impl Graph for Vec<Curve> {
fn graph_data(&self) -> GraphData {
self.clone().into()
}
}
impl GeometricObject<Point2D, Decimal> for Curve {
type Error = CurveError;
fn get_points(&self) -> BTreeSet<&Point2D> {
self.points.iter().collect()
}
fn from_vector<T>(points: Vec<T>) -> Self
where
T: Into<Point2D> + Clone,
{
let points: BTreeSet<Point2D> = points.into_iter().map(|p| p.into()).collect();
let x_range = Self::calculate_range(points.iter().map(|p| p.x));
Curve { points, x_range }
}
fn construct<T>(method: T) -> Result<Self, Self::Error>
where
Self: Sized,
T: Into<ConstructionMethod<Point2D, Decimal>>,
{
let method = method.into();
match method {
ConstructionMethod::FromData { points } => {
if points.is_empty() {
return Err(CurveError::Point2DError {
reason: "Empty points array",
});
}
Ok(Curve::new(points))
}
ConstructionMethod::Parametric { f, params } => {
let (t_start, t_end, steps) = match params {
ConstructionParams::D2 {
t_start,
t_end,
steps,
} => (t_start, t_end, steps),
_ => {
return Err(CurveError::ConstructionError(
"Invalid parameters".to_string(),
));
}
};
let step_size = (t_end - t_start) / Decimal::from(steps);
let points: Result<BTreeSet<Point2D>, CurveError> = (0..=steps)
.into_par_iter()
.map(|i| {
let t = t_start + step_size * Decimal::from(i);
f(t).map_err(|e| CurveError::ConstructionError(e.to_string()))
})
.collect();
points.map(Curve::new)
}
}
}
}
impl Index<usize> for Curve {
type Output = Point2D;
fn index(&self, index: usize) -> &Self::Output {
self.points.iter().nth(index).expect("Index out of bounds")
}
}
impl Interpolate<Point2D, Decimal> for Curve {}
impl LinearInterpolation<Point2D, Decimal> for Curve {
fn linear_interpolate(&self, x: Decimal) -> Result<Point2D, InterpolationError> {
let (i, j) = self.find_bracket_points(x)?;
let p1 = &self[i];
let p2 = &self[j];
let y = p1.y + (x - p1.x) * (p2.y - p1.y) / (p2.x - p1.x);
Ok(Point2D::new(x, y))
}
}
impl BiLinearInterpolation<Point2D, Decimal> for Curve {
fn bilinear_interpolate(&self, x: Decimal) -> Result<Point2D, InterpolationError> {
let points = self.get_points();
if points.len() < 4 {
return Err(InterpolationError::Bilinear(
"Need at least four points for bilinear interpolation".to_string(),
));
}
if let Some(point) = points.iter().find(|p| p.x == x) {
return Ok(**point);
}
let (i, _j) = self.find_bracket_points(x)?;
let p11 = &self[i]; let p12 = &self[i + 1]; let p21 = &self[i + 2]; let p22 = &self[i + 3];
let dx = (x - p11.x) / (p12.x - p11.x);
let bottom = p11.y + dx * (p12.y - p11.y);
let top = p21.y + dx * (p22.y - p21.y);
let y = bottom + (top - bottom) / dec!(2);
Ok(Point2D::new(x, y))
}
}
impl CubicInterpolation<Point2D, Decimal> for Curve {
fn cubic_interpolate(&self, x: Decimal) -> Result<Point2D, InterpolationError> {
let points = self.get_points();
let len = self.len();
if len < 4 {
return Err(InterpolationError::Cubic(
"Need at least four points for cubic interpolation".to_string(),
));
}
if let Some(point) = points.iter().find(|p| p.x == x) {
return Ok(**point);
}
let (i, _) = self.find_bracket_points(x)?;
let (p0, p1, p2, p3) = if i == 0 {
(&self[0], &self[1], &self[2], &self[3])
} else if i == len - 2 {
(
&self[len - 4],
&self[len - 3],
&self[len - 2],
&self[len - 1],
)
} else {
(&self[i - 1], &self[i], &self[i + 1], &self[i + 2])
};
let t = (x - p1.x) / (p2.x - p1.x);
let t2 = t * t;
let t3 = t2 * t;
let y = dec!(0.5)
* (dec!(2) * p1.y
+ (-p0.y + p2.y) * t
+ (dec!(2) * p0.y - dec!(5) * p1.y + dec!(4) * p2.y - p3.y) * t2
+ (-p0.y + dec!(3) * p1.y - dec!(3) * p2.y + p3.y) * t3);
Ok(Point2D::new(x, y))
}
}
impl SplineInterpolation<Point2D, Decimal> for Curve {
fn spline_interpolate(&self, x: Decimal) -> Result<Point2D, InterpolationError> {
let points = self.get_points();
let len = self.len();
if len < 3 {
return Err(InterpolationError::Spline(
"Need at least three points for spline interpolation".to_string(),
));
}
if x < self[0].x || x > self[len - 1].x {
return Err(InterpolationError::Spline(
"x is outside the range of points".to_string(),
));
}
if let Some(point) = points.iter().find(|p| p.x == x) {
return Ok(**point);
}
let n = len;
let mut a = vec![Decimal::ZERO; n];
let mut b = vec![Decimal::ZERO; n];
let mut c = vec![Decimal::ZERO; n];
let mut r = vec![Decimal::ZERO; n];
for i in 1..n - 1 {
let hi = self[i].x - self[i - 1].x;
let hi1 = self[i + 1].x - self[i].x;
a[i] = hi;
b[i] = dec!(2) * (hi + hi1);
c[i] = hi1;
r[i] = dec!(6) * ((self[i + 1].y - self[i].y) / hi1 - (self[i].y - self[i - 1].y) / hi);
}
b[0] = dec!(1);
b[n - 1] = dec!(1);
let mut m = vec![Decimal::ZERO; n];
for i in 1..n - 1 {
let w = a[i] / b[i - 1];
b[i] -= w * c[i - 1];
r[i] = r[i] - w * r[i - 1];
}
m[n - 1] = r[n - 1] / b[n - 1];
for i in (1..n - 1).rev() {
m[i] = (r[i] - c[i] * m[i + 1]) / b[i];
}
let mut segment = None;
for i in 0..n - 1 {
if self[i].x <= x && x <= self[i + 1].x {
segment = Some(i);
break;
}
}
let segment = segment.ok_or_else(|| {
InterpolationError::Spline("Could not find valid segment for interpolation".to_string())
})?;
let h = self[segment + 1].x - self[segment].x;
let dx = self[segment + 1].x - x;
let dx1 = x - self[segment].x;
let y = m[segment] * dx * dx * dx / (dec!(6) * h)
+ m[segment + 1] * dx1 * dx1 * dx1 / (dec!(6) * h)
+ (self[segment].y / h - m[segment] * h / dec!(6)) * dx
+ (self[segment + 1].y / h - m[segment + 1] * h / dec!(6)) * dx1;
Ok(Point2D::new(x, y))
}
}
impl StatisticalCurve for Curve {
fn get_x_values(&self) -> Vec<Decimal> {
self.points.iter().map(|p| p.x).collect()
}
}
impl MetricsExtractor for Curve {
fn compute_basic_metrics(&self) -> Result<BasicMetrics, MetricsError> {
let y_values: Vec<Decimal> = self.points.iter().map(|p| p.y).collect();
if y_values.is_empty() {
return Ok(BasicMetrics {
mean: Decimal::ZERO,
median: Decimal::ZERO,
mode: Decimal::ZERO,
std_dev: Decimal::ZERO,
});
}
let mean = y_values.iter().sum::<Decimal>() / Decimal::from(y_values.len());
let mut sorted_values = y_values.clone();
sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let median = if sorted_values.len().is_multiple_of(2) {
(sorted_values[sorted_values.len() / 2 - 1] + sorted_values[sorted_values.len() / 2])
/ Decimal::TWO
} else {
sorted_values[sorted_values.len() / 2]
};
let mode = {
let mut freq_map = std::collections::HashMap::new();
for &val in &y_values {
*freq_map.entry(val).or_insert(0) += 1;
}
freq_map
.into_iter()
.max_by_key(|&(_, count)| count)
.map(|(val, _)| val)
.unwrap_or(Decimal::ZERO)
};
let variance = y_values
.iter()
.map(|&x| (x - mean).powu(2))
.sum::<Decimal>()
/ Decimal::from(y_values.len());
let std_dev = variance.sqrt().unwrap_or(Decimal::ZERO);
Ok(BasicMetrics {
mean,
median,
mode,
std_dev,
})
}
fn compute_shape_metrics(&self) -> Result<ShapeMetrics, MetricsError> {
let y_values: Vec<Decimal> = self.points.iter().map(|p| p.y).collect();
if y_values.len() < 2 {
return Ok(ShapeMetrics {
skewness: Decimal::ZERO,
kurtosis: Decimal::ZERO,
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
});
}
let mean = y_values.iter().sum::<Decimal>() / Decimal::from(y_values.len());
let centered_values: Vec<Decimal> = y_values.iter().map(|&x| x - mean).collect();
let variance = centered_values.iter().map(|&x| x.powu(2)).sum::<Decimal>()
/ Decimal::from(y_values.len());
let std_dev = variance.sqrt().unwrap_or(Decimal::ONE);
if std_dev.is_zero() || std_dev < dec!(1e-9) {
panic!("The standard deviation is too small or zero.");
}
let skewness = centered_values
.iter()
.map(|&x| (x / std_dev).powu(3))
.sum::<Decimal>()
/ (Decimal::from(y_values.len()));
let kurtosis = centered_values
.iter()
.map(|&x| (x / std_dev).powu(4))
.sum::<Decimal>()
/ Decimal::from(y_values.len())
- Decimal::from(3);
let (peaks, valleys) = detect_peaks_and_valleys(&self.points, dec!(0.1), 2);
Ok(ShapeMetrics {
skewness,
kurtosis,
peaks,
valleys,
inflection_points: vec![],
})
}
fn compute_range_metrics(&self) -> Result<RangeMetrics, MetricsError> {
if self.points.is_empty() {
return Ok(RangeMetrics {
min: Point2D::new(Decimal::ZERO, Decimal::ZERO),
max: Point2D::new(Decimal::ZERO, Decimal::ZERO),
range: Decimal::ZERO,
quartiles: (Decimal::ZERO, Decimal::ZERO, Decimal::ZERO),
interquartile_range: Decimal::ZERO,
});
}
let mut y_values: Vec<Decimal> = self.points.iter().map(|p| p.y).collect();
y_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
let len = y_values.len();
let min_point = self
.points
.iter()
.min_by(|a, b| a.y.partial_cmp(&b.y).unwrap())
.cloned()
.unwrap();
let max_point = self
.points
.iter()
.max_by(|a, b| a.y.partial_cmp(&b.y).unwrap())
.cloned()
.unwrap();
let range = max_point.y - min_point.y;
let q1 = y_values[len / 4];
let q2 = y_values[len / 2];
let q3 = y_values[3 * len / 4];
let interquartile_range = q3 - q1;
Ok(RangeMetrics {
min: min_point,
max: max_point,
range,
quartiles: (q1, q2, q3),
interquartile_range,
})
}
fn compute_trend_metrics(&self) -> Result<TrendMetrics, MetricsError> {
let points: Vec<Point2D> = self.points.clone().into_iter().collect();
if points.len() < 2 {
return Ok(TrendMetrics {
slope: Decimal::ZERO,
intercept: Decimal::ZERO,
r_squared: Decimal::ZERO,
moving_average: vec![],
});
}
let n = Decimal::from(points.len());
let x_vals: Vec<Decimal> = points.iter().map(|p| p.x).collect();
let y_vals: Vec<Decimal> = points.iter().map(|p| p.y).collect();
let sum_x: Decimal = x_vals.iter().sum();
let sum_y: Decimal = y_vals.iter().sum();
let sum_xy: Decimal = x_vals.iter().zip(&y_vals).map(|(x, y)| *x * *y).sum();
let sum_xx: Decimal = x_vals.iter().map(|x| *x * *x).sum();
let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x);
let intercept = (sum_y - slope * sum_x) / n;
let mean_y = sum_y / n;
let sst: Decimal = y_vals.iter().map(|y| (*y - mean_y).powu(2)).sum();
let ssr: Decimal = y_vals
.iter()
.zip(&x_vals)
.map(|(y, x)| {
let y_predicted = slope * *x + intercept;
(*y - y_predicted).powu(2)
})
.sum();
let r_squared = if sst == Decimal::ZERO {
Decimal::ONE
} else {
Decimal::ONE - (ssr / sst)
};
let window_sizes = [3, 5, 7];
let moving_average: Vec<Point2D> = window_sizes
.iter()
.flat_map(|&window| {
if window > points.len() {
vec![]
} else {
points
.windows(window)
.map(|window_points| {
let avg_x = window_points.iter().map(|p| p.x).sum::<Decimal>()
/ Decimal::from(window_points.len());
let avg_y = window_points.iter().map(|p| p.y).sum::<Decimal>()
/ Decimal::from(window_points.len());
Point2D::new(avg_x, avg_y)
})
.collect::<Vec<Point2D>>()
}
})
.collect();
Ok(TrendMetrics {
slope,
intercept,
r_squared,
moving_average,
})
}
fn compute_risk_metrics(&self) -> Result<RiskMetrics, MetricsError> {
let y_values: Vec<Decimal> = self.points.iter().map(|p| p.y).collect();
if y_values.is_empty() {
return Ok(RiskMetrics {
volatility: Decimal::ZERO,
value_at_risk: Decimal::ZERO,
expected_shortfall: Decimal::ZERO,
beta: Decimal::ZERO,
sharpe_ratio: Decimal::ZERO,
});
}
let mean = y_values.iter().sum::<Decimal>() / Decimal::from(y_values.len());
let volatility = y_values
.iter()
.map(|&x| (x - mean).powu(2))
.sum::<Decimal>()
/ Decimal::from(y_values.len())
.sqrt()
.unwrap_or(Decimal::ZERO);
if volatility == Decimal::ZERO {
return Ok(RiskMetrics {
volatility,
value_at_risk: Decimal::ZERO,
expected_shortfall: Decimal::ZERO,
beta: Decimal::ZERO,
sharpe_ratio: Decimal::ZERO,
});
}
let z_score = dec!(1.645);
let var = mean - z_score * volatility;
let below_var_count = y_values.iter().filter(|&&x| x < var).count();
let expected_shortfall = if below_var_count > 0 {
y_values.iter().filter(|&&x| x < var).sum::<Decimal>()
/ Decimal::from(below_var_count as u64)
} else {
Decimal::ZERO
};
let beta = if mean != Decimal::ZERO {
volatility / mean
} else {
Decimal::ZERO
};
let sharpe_ratio = mean / volatility;
Ok(RiskMetrics {
volatility,
value_at_risk: var,
expected_shortfall,
beta,
sharpe_ratio,
})
}
}
impl Arithmetic<Curve> for Curve {
type Error = CurveError;
fn merge(curves: &[&Curve], operation: MergeOperation) -> Result<Curve, CurveError> {
if curves.is_empty() {
return Err(CurveError::invalid_parameters(
"merge_curves",
"No curves provided for merging",
));
}
if curves.len() == 1 {
return Ok(curves[0].clone());
}
let min_x = curves
.iter()
.map(|c| c.x_range.0)
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(Decimal::ZERO);
let max_x = curves
.iter()
.map(|c| c.x_range.1)
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(Decimal::ZERO);
if min_x >= max_x {
return Err(CurveError::invalid_parameters(
"merge_curves",
"Curves have incompatible x-ranges",
));
}
let steps = 100; let step_size = (max_x - min_x) / Decimal::from(steps);
let result_points: Result<Vec<Point2D>, CurveError> = (0..=steps)
.into_par_iter()
.map(|i| {
let x = min_x + step_size * Decimal::from(i);
let y_values: Result<Vec<Decimal>, CurveError> = curves
.iter()
.map(|curve| {
curve
.interpolate(x, InterpolationType::Cubic)
.map(|point| point.y)
.map_err(CurveError::from)
})
.collect();
let y_values = y_values?;
let result_y: Decimal = match operation {
MergeOperation::Add => y_values.par_iter().sum(),
MergeOperation::Subtract => {
y_values
.par_iter()
.enumerate()
.map(|(i, &val)| if i == 0 { val } else { -val })
.reduce(|| Decimal::ZERO, |a, b| a + b)
}
MergeOperation::Multiply => y_values.par_iter().product(),
MergeOperation::Divide => y_values
.par_iter()
.enumerate()
.map(|(i, &val)| {
if i == 0 {
val
} else if val == Decimal::ZERO {
Decimal::MAX
} else {
Decimal::ONE / val
}
})
.reduce(|| Decimal::ONE, |a, b| a * b),
MergeOperation::Max => y_values
.par_iter()
.cloned()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(Decimal::ZERO),
MergeOperation::Min => y_values
.par_iter()
.cloned()
.min_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(Decimal::ZERO),
};
Ok(Point2D::new(x, result_y))
})
.collect();
let result_points = result_points?;
Ok(Curve::from_vector(result_points))
}
fn merge_with(&self, other: &Curve, operation: MergeOperation) -> Result<Curve, CurveError> {
Self::merge(&[self, other], operation)
}
}
impl AxisOperations<Point2D, Decimal> for Curve {
type Error = CurveError;
fn contains_point(&self, x: &Decimal) -> bool {
self.points.iter().any(|p| &p.x == x)
}
fn get_index_values(&self) -> Vec<Decimal> {
self.points.iter().map(|p| p.x).collect()
}
fn get_values(&self, x: Decimal) -> Vec<&Decimal> {
self.points
.iter()
.filter(|p| p.x == x)
.map(|p| &p.y)
.collect()
}
fn get_closest_point(&self, x: &Decimal) -> Result<&Point2D, Self::Error> {
self.points
.iter()
.min_by(|a, b| {
let dist_a = (a.x - *x).abs();
let dist_b = (b.x - *x).abs();
dist_a.partial_cmp(&dist_b).unwrap()
})
.ok_or(CurveError::Point2DError {
reason: "No points available",
})
}
fn get_point(&self, x: &Decimal) -> Option<&Point2D> {
if self.contains_point(x) {
self.points.iter().find(|p| p.x == *x)
} else {
None
}
}
}
impl MergeAxisInterpolate<Point2D, Decimal> for Curve
where
Self: Sized,
{
fn merge_axis_interpolate(
&self,
other: &Self,
interpolation: InterpolationType,
) -> Result<(Self, Self), Self::Error> {
let merged_x_values = self.merge_axis_index(other);
let mut sorted_x_values: Vec<Decimal> = merged_x_values.into_iter().collect();
sorted_x_values.sort();
let mut interpolated_self_points = BTreeSet::new();
let mut interpolated_other_points = BTreeSet::new();
for x in &sorted_x_values {
if self.contains_point(x) {
interpolated_self_points.insert(*self.get_point(x).unwrap());
} else {
let interpolated_point = self.interpolate(*x, interpolation)?;
interpolated_self_points.insert(interpolated_point);
}
if other.contains_point(x) {
interpolated_other_points.insert(*other.get_point(x).unwrap());
} else {
let interpolated_point = other.interpolate(*x, interpolation)?;
interpolated_other_points.insert(interpolated_point);
}
}
Ok((
Curve::new(interpolated_self_points),
Curve::new(interpolated_other_points),
))
}
}
impl GeometricTransformations<Point2D> for Curve {
type Error = CurveError;
fn translate(&self, deltas: Vec<&Decimal>) -> Result<Self, Self::Error> {
if deltas.len() != 2 {
return Err(CurveError::invalid_parameters(
"translate",
"Expected 2 deltas for 2D translation",
));
}
let translated_points = self
.points
.iter()
.map(|point| Point2D::new(point.x + deltas[0], point.y + deltas[1]))
.collect();
Ok(Curve::new(translated_points))
}
fn scale(&self, factors: Vec<&Decimal>) -> Result<Self, Self::Error> {
if factors.len() != 2 {
return Err(CurveError::invalid_parameters(
"scale",
"Expected 2 factors for 2D scaling",
));
}
let scaled_points = self
.points
.iter()
.map(|point| Point2D::new(point.x * factors[0], point.y * factors[1]))
.collect();
Ok(Curve::new(scaled_points))
}
fn intersect_with(&self, other: &Self) -> Result<Vec<Point2D>, Self::Error> {
let mut intersections = Vec::new();
for p1 in self.get_points() {
for p2 in other.get_points() {
if (p1.x - p2.x).abs() < Decimal::new(1, 6)
&& (p1.y - p2.y).abs() < Decimal::new(1, 6)
{
intersections.push(*p1);
}
}
}
Ok(intersections)
}
fn derivative_at(&self, point: &Point2D) -> Result<Vec<Decimal>, Self::Error> {
let (i, j) = self.find_bracket_points(point.x)?;
let p0 = &self[i];
let p1 = &self[j];
let a = (p1.y - p0.y) / (p1.x * p1.x - p0.x * p0.x);
let derivative = dec!(2.0) * a * point.x;
Ok(vec![derivative])
}
fn extrema(&self) -> Result<(Point2D, Point2D), Self::Error> {
if self.points.is_empty() {
return Err(CurveError::invalid_parameters(
"extrema",
"Curve has no points",
));
}
let min_point = self
.points
.iter()
.min_by(|a, b| a.y.partial_cmp(&b.y).unwrap())
.cloned()
.unwrap();
let max_point = self
.points
.iter()
.max_by(|a, b| a.y.partial_cmp(&b.y).unwrap())
.cloned()
.unwrap();
Ok((min_point, max_point))
}
fn measure_under(&self, base_value: &Decimal) -> Result<Decimal, Self::Error> {
if self.points.len() < 2 {
return Ok(Decimal::ZERO);
}
let mut area = Decimal::ZERO;
let points: Vec<_> = self.points.iter().collect();
for pair in points.windows(2) {
let width = pair[1].x - pair[0].x;
let height = ((pair[0].y - base_value) + (pair[1].y - base_value)) / Decimal::TWO;
area += width * height;
}
Ok(area.abs())
}
}
#[cfg(test)]
mod tests_curves {
use super::*;
use crate::curves::utils::{create_constant_curve, create_linear_curve};
use Decimal;
use positive::{Positive, pos_or_panic};
use rust_decimal_macros::dec;
#[test]
fn test_new_with_decimal() {
let x = dec!(1.5);
let y = dec!(2.5);
let point = Point2D::new(x, y);
assert_eq!(point.x, dec!(1.5));
assert_eq!(point.y, dec!(2.5));
}
#[test]
fn test_new_with_positive() {
let x = pos_or_panic!(1.5_f64);
let y = pos_or_panic!(2.5_f64);
let point = Point2D::new(x, y);
assert_eq!(point.x, dec!(1.5));
assert_eq!(point.y, dec!(2.5));
}
#[test]
fn test_to_tuple_with_decimal() {
let point = Point2D::new(dec!(1.5), dec!(2.5));
let tuple: (Decimal, Decimal) = point.to_tuple().unwrap();
assert_eq!(tuple, (dec!(1.5), dec!(2.5)));
}
#[test]
fn test_to_tuple_with_positive() {
let point = Point2D::new(dec!(1.5), dec!(2.5));
let tuple: (Positive, Positive) = point.to_tuple().unwrap();
assert_eq!(tuple, (pos_or_panic!(1.5), pos_or_panic!(2.5)));
}
#[test]
fn test_from_tuple_with_decimal() {
let x = dec!(1.5);
let y = dec!(2.5);
let point = Point2D::from_tuple(x, y).unwrap();
assert_eq!(point, Point2D::new(dec!(1.5), dec!(2.5)));
}
#[test]
fn test_from_tuple_with_positive() {
let x = pos_or_panic!(1.5_f64);
let y = pos_or_panic!(2.5_f64);
let point = Point2D::from_tuple(x, y).unwrap();
assert_eq!(point, Point2D::new(dec!(1.5), dec!(2.5)));
}
#[test]
fn test_new_with_mixed_types() {
let x = dec!(1.5);
let y = pos_or_panic!(2.5_f64);
let point = Point2D::new(x, y);
assert_eq!(point.x, dec!(1.5));
assert_eq!(point.y, dec!(2.5));
}
#[test]
fn test_create_constant_curve() {
let curve = create_constant_curve(dec!(1.0), dec!(2.0), dec!(5.0));
assert_eq!(curve.get_points().len(), 11);
for point in curve.get_points() {
assert_eq!(point.y, dec!(5.0));
}
}
#[test]
fn test_create_linear_curve() {
let curve = create_linear_curve(dec!(1.0), dec!(2.0), dec!(2.0));
assert_eq!(curve.get_points().len(), 11);
let mut slope = dec!(2.0);
for point in curve.get_points() {
assert_eq!(point.y, slope);
slope += dec!(0.2);
}
}
}
#[cfg(test)]
mod tests_linear_interpolate {
use super::*;
use crate::geometrics::InterpolationType;
use rust_decimal_macros::dec;
#[test]
fn test_linear_interpolation_exact_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(Decimal::ZERO, Decimal::ZERO),
Point2D::new(Decimal::ONE, Decimal::TWO),
]));
let p0 = curve
.interpolate(Decimal::ZERO, InterpolationType::Linear)
.unwrap();
assert_eq!(p0.x, Decimal::ZERO);
assert_eq!(p0.y, Decimal::ZERO);
let p1 = curve
.interpolate(Decimal::ONE, InterpolationType::Linear)
.unwrap();
assert_eq!(p1.x, Decimal::ONE);
assert_eq!(p1.y, Decimal::TWO);
}
#[test]
fn test_linear_interpolation_midpoint() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(Decimal::ZERO, Decimal::ZERO),
Point2D::new(Decimal::ONE, Decimal::TWO),
]));
let mid = curve
.interpolate(dec!(0.5), InterpolationType::Linear)
.unwrap();
assert_eq!(mid.x, dec!(0.5));
assert_eq!(mid.y, dec!(1.0));
}
#[test]
fn test_linear_interpolation_quarter_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(Decimal::ZERO, Decimal::ZERO),
Point2D::new(Decimal::ONE, Decimal::TWO),
]));
let p25 = curve
.interpolate(dec!(0.25), InterpolationType::Linear)
.unwrap();
assert_eq!(p25.x, dec!(0.25));
assert_eq!(p25.y, dec!(0.5));
let p75 = curve
.interpolate(dec!(0.75), InterpolationType::Linear)
.unwrap();
assert_eq!(p75.x, dec!(0.75));
assert_eq!(p75.y, dec!(1.5));
}
#[test]
fn test_linear_interpolation_out_of_range() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
]));
assert!(
curve
.interpolate(dec!(-0.1), InterpolationType::Linear)
.is_err()
);
assert!(
curve
.interpolate(dec!(1.1), InterpolationType::Linear)
.is_err()
);
}
#[test]
fn test_linear_interpolation_insufficient_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![Point2D::new(
dec!(0.0),
dec!(0.0),
)]));
assert!(
curve
.interpolate(dec!(0.5), InterpolationType::Linear)
.is_err()
);
}
#[test]
fn test_linear_interpolation_non_monotonic() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
Point2D::new(dec!(2.0), dec!(1.0)),
]));
let p15 = curve
.interpolate(dec!(1.5), InterpolationType::Linear)
.unwrap();
assert_eq!(p15.x, dec!(1.5));
assert_eq!(p15.y, dec!(1.5));
}
}
#[cfg(test)]
mod tests_bilinear_interpolate {
use super::*;
use crate::geometrics::InterpolationType;
use rust_decimal_macros::dec;
#[test]
fn test_bilinear_interpolation() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(Decimal::ZERO, Decimal::ZERO),
Point2D::new(Decimal::TWO, Decimal::ONE),
Point2D::new(Decimal::TEN, Decimal::ONE),
Point2D::new(Decimal::ONE, Decimal::TWO),
]));
let corner = curve
.interpolate(Decimal::ZERO, InterpolationType::Bilinear)
.unwrap();
assert_eq!(corner.x, Decimal::ZERO);
assert_eq!(corner.y, Decimal::ZERO);
let mid = curve
.interpolate(dec!(0.5), InterpolationType::Bilinear)
.unwrap();
assert_eq!(mid.x, dec!(0.5));
assert_eq!(mid.y, dec!(1.0));
}
#[test]
fn test_bilinear_interpolation_out_of_range() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(0.0), dec!(1.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
]));
assert!(
curve
.interpolate(dec!(-0.5), InterpolationType::Bilinear)
.is_err()
);
assert!(
curve
.interpolate(dec!(1.5), InterpolationType::Bilinear)
.is_err()
);
}
#[test]
fn test_bilinear_interpolation_insufficient_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(2.0)),
]));
assert!(
curve
.interpolate(dec!(0.5), InterpolationType::Bilinear)
.is_err()
);
}
#[test]
fn test_bilinear_interpolation_quarter_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(Decimal::ZERO, Decimal::ZERO), Point2D::new(Decimal::ONE, Decimal::ONE), Point2D::new(Decimal::TWO, Decimal::ONE), Point2D::new(Decimal::TEN, Decimal::TWO), ]));
let p25 = curve
.interpolate(dec!(0.25), InterpolationType::Bilinear)
.unwrap();
assert_eq!(p25.x, dec!(0.25));
assert_eq!(p25.y, dec!(0.75));
let p75 = curve
.interpolate(dec!(0.75), InterpolationType::Bilinear)
.unwrap();
assert_eq!(p75.x, dec!(0.75));
assert_eq!(p75.y, dec!(1.25));
}
#[test]
fn test_bilinear_interpolation_boundaries() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(0.0), dec!(1.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
]));
assert!(
curve
.interpolate(dec!(-0.1), InterpolationType::Bilinear)
.is_err()
);
assert!(
curve
.interpolate(dec!(1.1), InterpolationType::Bilinear)
.is_err()
);
}
#[test]
fn test_out_of_range() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(Decimal::ZERO, Decimal::ZERO),
Point2D::new(Decimal::ONE, Decimal::ONE),
Point2D::new(Decimal::ZERO, Decimal::ONE),
Point2D::new(Decimal::ONE, Decimal::TWO),
]));
assert!(
curve
.interpolate(dec!(-1), InterpolationType::Bilinear)
.is_err()
);
assert!(
curve
.interpolate(Decimal::TWO, InterpolationType::Bilinear)
.is_err()
);
}
}
#[cfg(test)]
mod tests_cubic_interpolate {
use super::*;
use crate::geometrics::InterpolationType;
use rust_decimal_macros::dec;
use tracing::info;
#[test]
fn test_cubic_interpolation_exact_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let p1 = curve
.interpolate(dec!(1.0), InterpolationType::Cubic)
.unwrap();
assert_eq!(p1.x, dec!(1.0));
assert_eq!(p1.y, dec!(1.0));
}
#[test]
fn test_cubic_interpolation_midpoints() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let mid = curve
.interpolate(dec!(1.5), InterpolationType::Cubic)
.unwrap();
assert_eq!(mid.x, dec!(1.5));
assert!(mid.y > dec!(1.0) && mid.y < dec!(4.0));
}
#[test]
fn test_cubic_interpolation_insufficient_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
]));
assert!(
curve
.interpolate(dec!(1.5), InterpolationType::Cubic)
.is_err()
);
}
#[test]
fn test_cubic_interpolation_out_of_range() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
assert!(
curve
.interpolate(dec!(-0.5), InterpolationType::Cubic)
.is_err()
);
assert!(
curve
.interpolate(dec!(3.5), InterpolationType::Cubic)
.is_err()
);
}
#[test]
fn test_cubic_interpolation_monotonicity() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let p1 = curve
.interpolate(dec!(0.5), InterpolationType::Cubic)
.unwrap();
let p2 = curve
.interpolate(dec!(1.5), InterpolationType::Cubic)
.unwrap();
let p3 = curve
.interpolate(dec!(2.5), InterpolationType::Cubic)
.unwrap();
assert!(p1.y < p2.y);
assert!(p2.y < p3.y);
info!("p1: {:?}, p2: {:?}, p3: {:?}", p1, p2, p3);
}
}
#[cfg(test)]
mod tests_spline_interpolate {
use super::*;
use crate::geometrics::InterpolationType;
use rust_decimal_macros::dec;
#[test]
fn test_spline_interpolation_exact_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let p1 = curve
.interpolate(dec!(1.0), InterpolationType::Spline)
.unwrap();
assert_eq!(p1.x, dec!(1.0));
assert_eq!(p1.y, dec!(1.0));
}
#[test]
fn test_spline_interpolation_midpoints() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let mid = curve
.interpolate(dec!(1.5), InterpolationType::Spline)
.unwrap();
assert_eq!(mid.x, dec!(1.5));
assert!(mid.y > dec!(1.0) && mid.y < dec!(4.0));
}
#[test]
fn test_spline_interpolation_insufficient_points() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
assert!(
curve
.interpolate(dec!(0.5), InterpolationType::Spline)
.is_err()
);
}
#[test]
fn test_spline_interpolation_out_of_range() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
]));
assert!(
curve
.interpolate(dec!(-0.5), InterpolationType::Spline)
.is_err()
);
assert!(
curve
.interpolate(dec!(2.5), InterpolationType::Spline)
.is_err()
);
}
#[test]
fn test_spline_interpolation_smoothness() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let p1 = curve
.interpolate(dec!(1.48), InterpolationType::Spline)
.unwrap();
let p2 = curve
.interpolate(dec!(1.49), InterpolationType::Spline)
.unwrap();
let p3 = curve
.interpolate(dec!(1.50), InterpolationType::Spline)
.unwrap();
let p4 = curve
.interpolate(dec!(1.51), InterpolationType::Spline)
.unwrap();
let p5 = curve
.interpolate(dec!(1.52), InterpolationType::Spline)
.unwrap();
assert!(p1.y < p2.y);
assert!(p2.y < p3.y);
assert!(p3.y < p4.y);
assert!(p4.y < p5.y);
let d1 = p2.y - p1.y;
let d2 = p3.y - p2.y;
let d3 = p4.y - p3.y;
let d4 = p5.y - p4.y;
assert!((d2 - d1).abs() < dec!(0.001));
assert!((d3 - d2).abs() < dec!(0.001));
assert!((d4 - d3).abs() < dec!(0.001));
}
}
#[cfg(test)]
mod tests_curve_arithmetic {
use super::*;
use crate::curves::utils::create_linear_curve;
use crate::geometrics::InterpolationType;
#[test]
fn test_merge_curves_add() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(1.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(2.0));
let result = Curve::merge(&[&curve1, &curve2], MergeOperation::Add).unwrap();
let test_points = [dec!(0.0), dec!(5.0), dec!(10.0)];
for x in &test_points {
let expected_y = curve1.interpolate(*x, InterpolationType::Cubic).unwrap().y
+ curve2.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let result_point = result.interpolate(*x, InterpolationType::Cubic).unwrap();
assert!(
(result_point.y - expected_y).abs() < dec!(0.001),
"Failed at x = {}, expected {}, got {}",
x,
expected_y,
result_point.y
);
}
}
#[test]
fn test_merge_curves_subtract() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(3.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(1.0));
let result = Curve::merge(&[&curve1, &curve2], MergeOperation::Subtract).unwrap();
let test_points = [dec!(0.0), dec!(5.0), dec!(10.0)];
for x in &test_points {
let expected_y = curve1.interpolate(*x, InterpolationType::Cubic).unwrap().y
- curve2.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let result_point = result.interpolate(*x, InterpolationType::Cubic).unwrap();
assert!(
(result_point.y - expected_y).abs() < dec!(0.001),
"Failed at x = {}, expected {}, got {}",
x,
expected_y,
result_point.y
);
}
}
#[test]
fn test_merge_curves_multiply() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(2.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(3.0));
let result = Curve::merge(&[&curve1, &curve2], MergeOperation::Multiply).unwrap();
let test_points = [dec!(0.0), dec!(5.0), dec!(10.0)];
for x in &test_points {
let expected_y = curve1.interpolate(*x, InterpolationType::Cubic).unwrap().y
* curve2.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let result_point = result.interpolate(*x, InterpolationType::Cubic).unwrap();
assert!(
(result_point.y - expected_y).abs() < dec!(0.001),
"Failed at x = {}, expected {}, got {}",
x,
expected_y,
result_point.y
);
}
}
#[test]
fn test_merge_curves_divide() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(6.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(2.0));
let result = Curve::merge(&[&curve1, &curve2], MergeOperation::Divide).unwrap();
let test_points = [dec!(0.0), dec!(5.0), dec!(10.0)];
for x in &test_points {
let y2 = curve2.interpolate(*x, InterpolationType::Cubic).unwrap().y;
if y2 == Decimal::ZERO {
continue;
}
let expected_y = curve1.interpolate(*x, InterpolationType::Cubic).unwrap().y / y2;
let result_point = result.interpolate(*x, InterpolationType::Cubic).unwrap();
assert!(
(result_point.y - expected_y).abs() < dec!(0.001),
"Failed at x = {}, expected {}, got {}",
x,
expected_y,
result_point.y
);
}
}
#[test]
fn test_merge_curves_max() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(2.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(3.0));
let result = Curve::merge(&[&curve1, &curve2], MergeOperation::Max).unwrap();
let test_points = [dec!(0.0), dec!(5.0), dec!(10.0)];
for x in &test_points {
let y1 = curve1.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let y2 = curve2.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let expected_y = y1.max(y2);
let result_point = result.interpolate(*x, InterpolationType::Cubic).unwrap();
assert!(
(result_point.y - expected_y).abs() < dec!(0.001),
"Failed at x = {}, expected {}, got {}",
x,
expected_y,
result_point.y
);
}
}
#[test]
fn test_merge_curves_min() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(2.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(3.0));
let result = Curve::merge(&[&curve1, &curve2], MergeOperation::Min).unwrap();
let test_points = [dec!(0.0), dec!(5.0), dec!(10.0)];
for x in &test_points {
let y1 = curve1.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let y2 = curve2.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let expected_y = y1.min(y2);
let result_point = result.interpolate(*x, InterpolationType::Cubic).unwrap();
assert!(
(result_point.y - expected_y).abs() < dec!(0.001),
"Failed at x = {}, expected {}, got {}",
x,
expected_y,
result_point.y
);
}
}
#[test]
fn test_merge_with_single_operation() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(2.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(3.0));
let result = curve1.merge_with(&curve2, MergeOperation::Add).unwrap();
let merged_result = Curve::merge(&[&curve1, &curve2], MergeOperation::Add).unwrap();
assert_eq!(result.points.len(), merged_result.points.len());
for i in 0..result.points.len() {
assert!((result[i].x - merged_result[i].x).abs() < dec!(0.001));
assert!((result[i].y - merged_result[i].y).abs() < dec!(0.001));
}
}
#[test]
fn test_merge_curves_error_handling() {
let result = Curve::merge(&[], MergeOperation::Add);
assert!(result.is_err());
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(1.0));
let curve2 = create_linear_curve(dec!(5.0), dec!(15.0), dec!(2.0));
let result = Curve::merge(&[&curve1, &curve2], MergeOperation::Add);
assert!(result.is_ok());
}
#[test]
fn test_merge_multiple_curves() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(1.0));
let curve2 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(2.0));
let curve3 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(3.0));
let result = Curve::merge(&[&curve1, &curve2, &curve3], MergeOperation::Add).unwrap();
let test_points = [dec!(0.0), dec!(5.0), dec!(10.0)];
for x in &test_points {
let expected_y = curve1.interpolate(*x, InterpolationType::Cubic).unwrap().y
+ curve2.interpolate(*x, InterpolationType::Cubic).unwrap().y
+ curve3.interpolate(*x, InterpolationType::Cubic).unwrap().y;
let result_point = result.interpolate(*x, InterpolationType::Cubic).unwrap();
assert!(
(result_point.y - expected_y).abs() < dec!(0.001),
"Failed at x = {}, expected {}, got {}",
x,
expected_y,
result_point.y
);
}
}
}
#[cfg(test)]
mod tests_extended {
use super::*;
use crate::error::CurveError::OperationError;
use crate::error::{ChainError, OperationErrorKind};
use crate::geometrics::{ConstructionMethod, ConstructionParams};
#[test]
fn test_construct_from_data_empty() {
let result = Curve::construct(ConstructionMethod::FromData {
points: BTreeSet::new(),
});
assert!(result.is_err());
let error = result.unwrap_err();
match error {
CurveError::Point2DError { reason } => {
assert_eq!(reason, "Empty points array");
}
_ => {
panic!("Unexpected error type");
}
}
}
#[test]
fn test_construct_parametric_valid() {
let f = |t: Decimal| Ok(Point2D::new(t, t * dec!(2.0)));
let params = ConstructionParams::D2 {
t_start: Decimal::ZERO,
t_end: dec!(10.0),
steps: 10,
};
let result = Curve::construct(ConstructionMethod::Parametric {
f: Box::new(f),
params,
});
assert!(result.is_ok());
}
#[test]
fn test_construct_parametric_invalid_function() {
let f = |_t: Decimal| -> Result<Point2D, ChainError> {
Err(ChainError::DynError {
message: "Function evaluation failed".to_string(),
})
};
let params = ConstructionParams::D2 {
t_start: Decimal::ZERO,
t_end: dec!(10.0),
steps: 10,
};
let result = Curve::construct(ConstructionMethod::Parametric {
f: Box::new(f),
params,
});
assert!(result.is_err());
let error = result.unwrap_err();
match error {
CurveError::ConstructionError(reason) => {
assert!(reason.contains("Function evaluation failed"));
}
_ => {
panic!("Unexpected error type");
}
}
}
#[test]
fn test_segment_not_found_error() {
let segment: Option<Point2D> = None;
let result: Result<Point2D, CurveError> = segment.ok_or_else(|| CurveError::StdError {
reason: "Could not find valid segment for interpolation".to_string(),
});
assert!(result.is_err());
let error = result.unwrap_err();
match error {
CurveError::StdError { reason } => {
assert_eq!(reason, "Could not find valid segment for interpolation");
}
_ => {
panic!("Unexpected error type");
}
}
}
#[test]
fn test_compute_basic_metrics_placeholder() {
let curve = Curve {
points: BTreeSet::new(),
x_range: (Default::default(), Default::default()),
};
let metrics = curve.compute_basic_metrics();
assert!(metrics.is_ok());
let metrics = metrics.unwrap();
assert_eq!(metrics.mean, Decimal::ZERO);
}
#[test]
fn test_single_curve_return() {
let curve = Curve {
points: BTreeSet::new(),
x_range: (Default::default(), Default::default()),
};
let result = if vec![curve.clone()].len() == 1 {
Ok(curve.clone())
} else {
Err(CurveError::invalid_parameters(
"merge_curves",
"Invalid state",
))
};
assert!(result.is_ok());
}
#[test]
fn test_merge_curves_invalid_x_range() {
let min_x = dec!(10.0);
let max_x = dec!(5.0);
let result = if min_x >= max_x {
Err(CurveError::invalid_parameters(
"merge_curves",
"Curves have incompatible x-ranges",
))
} else {
Ok(())
};
assert!(result.is_err());
let error = result.unwrap_err();
match error {
OperationError(OperationErrorKind::InvalidParameters { operation, reason }) => {
assert_eq!(operation, "merge_curves");
assert_eq!(reason, "Curves have incompatible x-ranges");
}
_ => {
panic!("Unexpected error type");
}
}
}
}
#[cfg(test)]
mod tests_curve_metrics {
use super::*;
use crate::assert_decimal_eq;
use rust_decimal_macros::dec;
use std::collections::BTreeSet;
fn create_linear_curve() -> Curve {
let points = BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(6.0)),
Point2D::new(dec!(4.0), dec!(8.0)),
]);
Curve::new(points)
}
fn create_non_linear_curve() -> Curve {
Curve {
points: (0..=20)
.map(|x| Point2D {
x: Decimal::from(x),
y: Decimal::from(x * x % 7), })
.collect(),
x_range: (Default::default(), Default::default()),
}
}
fn create_constant_curve() -> Curve {
let points = BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(5.0)),
Point2D::new(dec!(1.0), dec!(5.0)),
Point2D::new(dec!(2.0), dec!(5.0)),
]);
Curve::new(points)
}
#[test]
fn test_basic_metrics() {
let linear_curve = create_linear_curve();
let basic_metrics = linear_curve.compute_basic_metrics().unwrap();
assert_decimal_eq!(basic_metrics.mean, dec!(4.0), dec!(0.001));
assert_decimal_eq!(basic_metrics.median, dec!(4.0), dec!(0.001));
assert_decimal_eq!(basic_metrics.std_dev, dec!(2.82842712), dec!(0.001));
let constant_curve = create_constant_curve();
let constant_metrics = constant_curve.compute_basic_metrics().unwrap();
assert_decimal_eq!(constant_metrics.mean, dec!(5.0), dec!(0.001));
assert_decimal_eq!(constant_metrics.median, dec!(5.0), dec!(0.001));
assert_decimal_eq!(constant_metrics.std_dev, dec!(0.0), dec!(0.001));
}
#[test]
fn test_shape_metrics() {
let linear_curve = create_linear_curve();
let shape_metrics = linear_curve.compute_shape_metrics().unwrap();
assert!(
shape_metrics.skewness.abs() < dec!(0.5),
"Skewness for linear curve should be very close to 0, got {}",
shape_metrics.skewness
);
assert!(
shape_metrics.kurtosis.abs() < dec!(2.0),
"Kurtosis for linear curve should be close to 0, got {}",
shape_metrics.kurtosis
);
let non_linear_curve = create_non_linear_curve();
let non_linear_metrics = non_linear_curve.compute_shape_metrics().unwrap();
assert!(
non_linear_metrics.skewness.abs() > dec!(0.003),
"Non-linear curve should have significant skewness, got {}",
non_linear_metrics.skewness
);
assert!(
non_linear_metrics.kurtosis.abs() > dec!(1.0),
"Non-linear curve should have significant kurtosis, got {}",
non_linear_metrics.kurtosis
);
assert!(
!non_linear_metrics.peaks.is_empty(),
"Peaks should be detected"
);
assert!(
!non_linear_metrics.valleys.is_empty(),
"Valleys should be detected"
);
}
#[test]
fn test_range_metrics() {
let linear_curve = create_linear_curve();
let range_metrics = linear_curve.compute_range_metrics().unwrap();
assert_decimal_eq!(range_metrics.min.y, dec!(0.0), dec!(0.001));
assert_decimal_eq!(range_metrics.max.y, dec!(8.0), dec!(0.001));
assert_decimal_eq!(range_metrics.range, dec!(8.0), dec!(0.001));
let constant_curve = create_constant_curve();
let constant_range_metrics = constant_curve.compute_range_metrics().unwrap();
assert_decimal_eq!(constant_range_metrics.min.y, dec!(5.0), dec!(0.001));
assert_decimal_eq!(constant_range_metrics.max.y, dec!(5.0), dec!(0.001));
assert_decimal_eq!(constant_range_metrics.range, dec!(0.0), dec!(0.001));
}
#[test]
fn test_trend_metrics() {
let linear_curve = create_linear_curve();
let trend_metrics = linear_curve.compute_trend_metrics().unwrap();
assert_decimal_eq!(trend_metrics.slope, dec!(2.0), dec!(0.001));
assert_decimal_eq!(trend_metrics.intercept, dec!(0.0), dec!(0.001));
assert_decimal_eq!(trend_metrics.r_squared, dec!(1.0), dec!(0.001));
let non_linear_curve = create_non_linear_curve();
let non_linear_trend_metrics = non_linear_curve.compute_trend_metrics().unwrap();
assert!(non_linear_trend_metrics.r_squared < dec!(1.0));
assert!(!non_linear_trend_metrics.moving_average.is_empty());
}
#[test]
fn test_constant_curve_risk_metrics() {
let constant_curve = create_constant_curve();
let risk_metrics = constant_curve.compute_risk_metrics().unwrap();
assert_eq!(risk_metrics.volatility, dec!(0.0));
assert_eq!(risk_metrics.beta, dec!(0.0));
assert_eq!(risk_metrics.sharpe_ratio, dec!(0.0));
}
#[test]
fn test_risk_metrics() {
let linear_curve = create_linear_curve();
let risk_metrics = linear_curve.compute_risk_metrics().unwrap();
assert!(
risk_metrics.volatility > dec!(0.0),
"Volatility debe ser mayor a cero."
);
assert!(
risk_metrics.value_at_risk != dec!(0.0),
"Value at Risk no debe ser cero."
);
assert!(risk_metrics.beta != dec!(0.0), "Beta no debe ser cero.");
}
#[test]
fn test_risk_metrics_bis() {
let linear_curve = create_linear_curve();
let risk_metrics = linear_curve.compute_risk_metrics().unwrap();
assert!(risk_metrics.volatility > dec!(0.0));
assert!(risk_metrics.value_at_risk != dec!(0.0));
assert!(risk_metrics.beta != dec!(0.0));
let constant_curve = create_constant_curve();
let constant_risk_metrics = constant_curve.compute_risk_metrics().unwrap();
assert_decimal_eq!(constant_risk_metrics.volatility, dec!(0.0), dec!(0.001));
}
#[test]
fn test_edge_cases() {
let empty_curve = Curve::new(BTreeSet::new());
assert!(empty_curve.compute_basic_metrics().is_ok());
assert!(empty_curve.compute_shape_metrics().is_ok());
assert!(empty_curve.compute_range_metrics().is_ok());
assert!(empty_curve.compute_trend_metrics().is_ok());
assert!(empty_curve.compute_risk_metrics().is_ok());
let single_point_curve = Curve::new(BTreeSet::from_iter(vec![Point2D::new(
dec!(1.0),
dec!(1.0),
)]));
assert!(single_point_curve.compute_basic_metrics().is_ok());
assert!(single_point_curve.compute_shape_metrics().is_ok());
assert!(single_point_curve.compute_range_metrics().is_ok());
assert!(single_point_curve.compute_trend_metrics().is_ok());
assert!(single_point_curve.compute_risk_metrics().is_ok());
}
}
#[cfg(test)]
mod tests_merge_axis_interpolate {
use super::*;
use crate::curves::utils::create_linear_curve;
use crate::geometrics::InterpolationType;
use rust_decimal_macros::dec;
#[test]
fn test_merge_axis_interpolate_linear() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(0.5));
let curve2 = create_linear_curve(dec!(4.0), dec!(20.0), dec!(1.0));
let result = curve1.merge_axis_interpolate(&curve2, InterpolationType::Linear);
assert!(result.is_ok());
let (interpolated_curve1, interpolated_curve2) = result.unwrap();
assert_eq!(interpolated_curve1.x_range.0, interpolated_curve2.x_range.0);
assert_eq!(interpolated_curve1.x_range.1, interpolated_curve2.x_range.1);
assert_eq!(interpolated_curve1.points.len(), 10);
assert_eq!(interpolated_curve2.points.len(), 10);
assert_eq!(interpolated_curve1.x_range, interpolated_curve2.x_range);
assert_eq!(
interpolated_curve1.get_index_values(),
interpolated_curve2.get_index_values()
);
}
#[test]
fn test_merge_axis_interpolate_cubic() {
let curve1 = create_linear_curve(dec!(0.0), dec!(10.0), dec!(0.5));
let curve2 = create_linear_curve(dec!(4.0), dec!(20.0), dec!(1.0));
let result = curve1.merge_axis_interpolate(&curve2, InterpolationType::Cubic);
assert!(result.is_ok());
let (interpolated_curve1, interpolated_curve2) = result.unwrap();
assert_eq!(interpolated_curve1.x_range.0, interpolated_curve2.x_range.0);
assert_eq!(interpolated_curve1.x_range.1, interpolated_curve2.x_range.1);
assert_eq!(interpolated_curve1.points.len(), 10);
assert_eq!(interpolated_curve2.points.len(), 10);
assert_eq!(interpolated_curve1.x_range, interpolated_curve2.x_range);
assert_eq!(
interpolated_curve1.get_index_values(),
interpolated_curve2.get_index_values()
);
}
}
#[cfg(test)]
mod tests_geometric_transformations {
use super::*;
use rust_decimal_macros::dec;
fn create_test_curve() -> Curve {
Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]))
}
mod test_translate {
use super::*;
#[test]
fn test_translate_positive() {
let curve = create_test_curve();
let result = curve.translate(vec![&dec!(2.0), &dec!(3.0)]).unwrap();
let translated_points: Vec<_> = result.points.iter().collect();
assert_eq!(translated_points[0].x, dec!(2.0));
assert_eq!(translated_points[0].y, dec!(3.0));
assert_eq!(translated_points[1].x, dec!(3.0));
assert_eq!(translated_points[1].y, dec!(4.0));
assert_eq!(translated_points[2].x, dec!(4.0));
assert_eq!(translated_points[2].y, dec!(7.0));
assert_eq!(translated_points[3].x, dec!(5.0));
assert_eq!(translated_points[3].y, dec!(12.0));
}
#[test]
fn test_translate_negative() {
let curve = create_test_curve();
let result = curve.translate(vec![&dec!(-1.0), &dec!(-2.0)]).unwrap();
let translated_points: Vec<_> = result.points.iter().collect();
assert_eq!(translated_points[0].x, dec!(-1.0));
assert_eq!(translated_points[0].y, dec!(-2.0));
}
#[test]
fn test_translate_zero() {
let curve = create_test_curve();
let result = curve.translate(vec![&dec!(0.0), &dec!(0.0)]).unwrap();
assert_eq!(curve.points, result.points);
}
#[test]
fn test_translate_wrong_dimensions() {
let curve = create_test_curve();
let result = curve.translate(vec![&dec!(1.0)]);
assert!(result.is_err());
}
#[test]
fn test_translate_preserves_shape() {
let curve = create_test_curve();
let result = curve.translate(vec![&dec!(1.0), &dec!(1.0)]).unwrap();
let original_diffs: Vec<Decimal> = curve
.points
.iter()
.zip(curve.points.iter().skip(1))
.map(|(a, b)| b.y - a.y)
.collect();
let translated_diffs: Vec<Decimal> = result
.points
.iter()
.zip(result.points.iter().skip(1))
.map(|(a, b)| b.y - a.y)
.collect();
assert_eq!(original_diffs, translated_diffs);
}
}
mod test_scale {
use super::*;
#[test]
fn test_scale_uniform() {
let curve = create_test_curve();
let result = curve.scale(vec![&dec!(2.0), &dec!(2.0)]).unwrap();
let scaled_points: Vec<_> = result.points.iter().collect();
assert_eq!(scaled_points[0].x, dec!(0.0));
assert_eq!(scaled_points[0].y, dec!(0.0));
assert_eq!(scaled_points[1].x, dec!(2.0));
assert_eq!(scaled_points[1].y, dec!(2.0));
assert_eq!(scaled_points[2].x, dec!(4.0));
assert_eq!(scaled_points[2].y, dec!(8.0));
assert_eq!(scaled_points[3].x, dec!(6.0));
assert_eq!(scaled_points[3].y, dec!(18.0));
}
#[test]
fn test_scale_non_uniform() {
let curve = create_test_curve();
let result = curve.scale(vec![&dec!(2.0), &dec!(3.0)]).unwrap();
let scaled_points: Vec<_> = result.points.iter().collect();
assert_eq!(scaled_points[1].x, dec!(2.0));
assert_eq!(scaled_points[1].y, dec!(3.0));
}
#[test]
fn test_scale_zero() {
let curve = create_test_curve();
let result = curve.scale(vec![&dec!(0.0), &dec!(0.0)]).unwrap();
assert!(
result
.points
.iter()
.all(|p| p.x == dec!(0.0) && p.y == dec!(0.0))
);
}
#[test]
fn test_scale_wrong_dimensions() {
let curve = create_test_curve();
let result = curve.scale(vec![&dec!(2.0)]);
assert!(result.is_err());
}
#[test]
fn test_scale_negative() {
let curve = create_test_curve();
let result = curve.scale(vec![&dec!(-1.0), &dec!(-1.0)]).unwrap();
assert_eq!(result[1].x, dec!(-2.0));
assert_eq!(result[1].y, dec!(-4.0));
assert_eq!(result[3].x, dec!(0.0));
assert_eq!(result[3].y, dec!(0.0));
}
}
mod test_intersect_with {
use super::*;
#[test]
fn test_curves_intersect() {
let curve1 = create_test_curve();
let curve2 = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
]));
let intersections = curve1.intersect_with(&curve2).unwrap();
assert_eq!(intersections.len(), 1);
}
#[test]
fn test_no_intersection() {
let curve1 = create_test_curve();
let curve2 = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(10.0), dec!(10.0)),
Point2D::new(dec!(11.0), dec!(11.0)),
]));
let intersections = curve1.intersect_with(&curve2).unwrap();
assert!(intersections.is_empty());
}
#[test]
fn test_multiple_intersections() {
let curve1 = create_test_curve();
let curve2 = create_test_curve();
let intersections = curve1.intersect_with(&curve2).unwrap();
assert_eq!(intersections.len(), curve1.points.len());
}
#[test]
fn test_self_intersection() {
let curve = create_test_curve();
let intersections = curve.intersect_with(&curve).unwrap();
assert_eq!(intersections.len(), curve.points.len());
}
#[test]
fn test_empty_curves() {
let curve1 = Curve::new(BTreeSet::new());
let curve2 = Curve::new(BTreeSet::new());
let intersections = curve1.intersect_with(&curve2).unwrap();
assert!(intersections.is_empty());
}
}
mod test_derivative_at {
use super::*;
#[test]
fn test_linear_derivative() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let derivative = curve
.derivative_at(&Point2D::new(dec!(0.5), dec!(0.5)))
.unwrap();
assert_eq!(derivative[0], dec!(1.0));
}
#[test]
fn test_quadratic_derivative() {
let curve = create_test_curve();
let derivative = curve
.derivative_at(&Point2D::new(dec!(1.0), dec!(1.0)))
.unwrap();
assert_eq!(derivative[0], dec!(2.0));
let derivative2 = curve
.derivative_at(&Point2D::new(dec!(2.0), dec!(4.0)))
.unwrap();
assert_eq!(derivative2[0], dec!(4.0));
}
#[test]
fn test_out_of_range() {
let curve = create_test_curve();
let result = curve.derivative_at(&Point2D::new(dec!(10.0), dec!(0.0)));
assert!(result.is_err());
}
#[test]
fn test_at_endpoint() {
let curve = create_test_curve();
let derivative = curve
.derivative_at(&Point2D::new(dec!(0.0), dec!(0.0)))
.unwrap();
assert!(derivative[0] == dec!(0.0));
}
#[test]
fn test_vertical_line() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(1.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let result = curve.derivative_at(&Point2D::new(dec!(1.0), dec!(0.5)));
assert!(result.is_err());
}
}
mod test_extrema {
use super::*;
#[test]
fn test_find_extrema() {
let curve = create_test_curve();
let (min, max) = curve.extrema().unwrap();
assert_eq!(min.y, dec!(0.0));
assert_eq!(max.y, dec!(9.0));
}
#[test]
fn test_empty_curve() {
let curve = Curve::new(BTreeSet::new());
let result = curve.extrema();
assert!(result.is_err());
}
#[test]
fn test_single_point() {
let curve = Curve::new(BTreeSet::from_iter(vec![Point2D::new(
dec!(1.0),
dec!(1.0),
)]));
let (min, max) = curve.extrema().unwrap();
assert_eq!(min, max);
}
#[test]
fn test_flat_curve() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(1.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let (min, max) = curve.extrema().unwrap();
assert_eq!(min.y, max.y);
}
#[test]
fn test_multiple_extrema() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(0.0)),
]));
let (min, max) = curve.extrema().unwrap();
assert_eq!(min.y, dec!(0.0));
assert_eq!(max.y, dec!(1.0));
}
}
mod test_measure_under {
use super::*;
#[test]
fn test_area_under_linear() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let area = curve.measure_under(&dec!(0.0)).unwrap();
assert_eq!(area, dec!(0.5));
}
#[test]
fn test_area_empty_curve() {
let curve = Curve::new(BTreeSet::new());
let area = curve.measure_under(&dec!(0.0)).unwrap();
assert_eq!(area, dec!(0.0));
}
#[test]
fn test_area_single_point() {
let curve = Curve::new(BTreeSet::from_iter(vec![Point2D::new(
dec!(1.0),
dec!(1.0),
)]));
let area = curve.measure_under(&dec!(0.0)).unwrap();
assert_eq!(area, dec!(0.0));
}
#[test]
fn test_area_with_base_value() {
let curve = create_test_curve();
let area1 = curve.measure_under(&dec!(0.0)).unwrap();
let area2 = curve.measure_under(&dec!(1.0)).unwrap();
assert!(area1 > area2);
}
#[test]
fn test_negative_area() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(-1.0)),
Point2D::new(dec!(1.0), dec!(-2.0)),
]));
let area = curve.measure_under(&dec!(0.0)).unwrap();
assert!(area > dec!(0.0));
}
}
}
#[cfg(test)]
mod tests_curve_serde {
use super::*;
use rust_decimal_macros::dec;
fn create_test_curve() -> Curve {
let mut points = BTreeSet::new();
points.insert(Point2D {
x: dec!(1.0),
y: dec!(2.0),
});
points.insert(Point2D {
x: dec!(3.0),
y: dec!(4.0),
});
points.insert(Point2D {
x: dec!(5.0),
y: dec!(6.0),
});
Curve {
points,
x_range: (dec!(1.0), dec!(5.0)),
}
}
#[test]
fn test_basic_serialization() {
let curve = create_test_curve();
let serialized = serde_json::to_string(&curve).unwrap();
let deserialized: Curve = serde_json::from_str(&serialized).unwrap();
assert_eq!(curve.points, deserialized.points);
assert_eq!(curve.x_range, deserialized.x_range);
}
#[test]
fn test_pretty_print() {
let curve = create_test_curve();
let serialized = serde_json::to_string_pretty(&curve).unwrap();
assert!(serialized.contains('\n'));
assert!(serialized.contains(" "));
let deserialized: Curve = serde_json::from_str(&serialized).unwrap();
assert_eq!(curve.points, deserialized.points);
}
#[test]
fn test_empty_curve() {
let curve = Curve {
points: BTreeSet::new(),
x_range: (dec!(0.0), dec!(0.0)),
};
let serialized = serde_json::to_string(&curve).unwrap();
let deserialized: Curve = serde_json::from_str(&serialized).unwrap();
assert!(deserialized.points.is_empty());
assert_eq!(deserialized.x_range, (dec!(0.0), dec!(0.0)));
}
#[test]
fn test_curve_with_negative_values() {
let mut points = BTreeSet::new();
points.insert(Point2D {
x: dec!(-1.0),
y: dec!(-2.0),
});
points.insert(Point2D {
x: dec!(-3.0),
y: dec!(-4.0),
});
let curve = Curve {
points,
x_range: (dec!(-3.0), dec!(-1.0)),
};
let serialized = serde_json::to_string(&curve).unwrap();
let deserialized: Curve = serde_json::from_str(&serialized).unwrap();
assert_eq!(curve.points, deserialized.points);
assert_eq!(curve.x_range, deserialized.x_range);
}
#[test]
fn test_curve_with_high_precision() {
let mut points = BTreeSet::new();
points.insert(Point2D {
x: dec!(1.12345678901234567890),
y: dec!(2.12345678901234567890),
});
points.insert(Point2D {
x: dec!(3.12345678901234567890),
y: dec!(4.12345678901234567890),
});
let curve = Curve {
points,
x_range: (dec!(1.12345678901234567890), dec!(3.12345678901234567890)),
};
let serialized = serde_json::to_string(&curve).unwrap();
let deserialized: Curve = serde_json::from_str(&serialized).unwrap();
assert_eq!(curve.points, deserialized.points);
assert_eq!(curve.x_range, deserialized.x_range);
}
#[test]
fn test_invalid_json() {
let json_str = r#"{"points": []}"#;
let result = serde_json::from_str::<Curve>(json_str);
assert!(result.is_err());
let json_str = r#"{"points": [1, 2, 3], "x_range": [0, 1]}"#;
let result = serde_json::from_str::<Curve>(json_str);
assert!(result.is_err());
let json_str = r#"{"points": [], "x_range": "invalid"}"#;
let result = serde_json::from_str::<Curve>(json_str);
assert!(result.is_err());
}
#[test]
fn test_json_structure() {
let curve = create_test_curve();
let serialized = serde_json::to_string(&curve).unwrap();
let json: serde_json::Value = serde_json::from_str(&serialized).unwrap();
assert!(json.is_object());
assert!(json.get("points").is_some());
assert!(json.get("x_range").is_some());
assert!(json.get("points").unwrap().is_array());
let x_range = json.get("x_range").unwrap().as_array().unwrap();
assert_eq!(x_range.len(), 2);
}
#[test]
fn test_multiple_curves() {
let curve1 = create_test_curve();
let mut curve2 = create_test_curve();
curve2.x_range = (dec!(6.0), dec!(10.0));
let curves = vec![curve1, curve2];
let serialized = serde_json::to_string(&curves).unwrap();
let deserialized: Vec<Curve> = serde_json::from_str(&serialized).unwrap();
assert_eq!(curves.len(), deserialized.len());
assert_eq!(curves[0].points, deserialized[0].points);
assert_eq!(curves[1].points, deserialized[1].points);
}
#[test]
fn test_ordering_preservation() {
let curve = create_test_curve();
let serialized = serde_json::to_string(&curve).unwrap();
let deserialized: Curve = serde_json::from_str(&serialized).unwrap();
let original_points: Vec<_> = curve.points.into_iter().collect();
let deserialized_points: Vec<_> = deserialized.points.into_iter().collect();
assert_eq!(original_points, deserialized_points);
}
#[test]
fn test_curve_with_extremes() {
let mut points = BTreeSet::new();
points.insert(Point2D {
x: Decimal::MAX,
y: Decimal::MAX,
});
points.insert(Point2D {
x: Decimal::MIN,
y: Decimal::MIN,
});
let curve = Curve {
points,
x_range: (Decimal::MIN, Decimal::MAX),
};
let serialized = serde_json::to_string(&curve).unwrap();
let deserialized: Curve = serde_json::from_str(&serialized).unwrap();
assert_eq!(curve.points, deserialized.points);
assert_eq!(curve.x_range, deserialized.x_range);
}
}
#[cfg(test)]
mod tests_curve_display_and_default {
use crate::curves::{Curve, Point2D};
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use std::collections::BTreeSet;
#[test]
fn test_curve_display() {
let point1 = Point2D::new(dec!(1.0), dec!(2.0));
let point2 = Point2D::new(dec!(3.0), dec!(4.0));
let mut points = BTreeSet::new();
points.insert(point1);
points.insert(point2);
let curve = Curve::new(points);
let display_string = format!("{curve}");
assert!(display_string.contains("(x: 1.0, y: 2.0)"));
assert!(display_string.contains("(x: 3.0, y: 4.0)"));
}
#[test]
fn test_curve_default() {
let curve = Curve::default();
assert!(curve.points.is_empty());
assert_eq!(curve.x_range, (Decimal::ZERO, Decimal::ZERO));
}
}
#[cfg(test)]
mod tests_curve_len_and_geometric {
use crate::curves::{Curve, Point2D};
use crate::error::CurveError;
use crate::geometrics::{ConstructionMethod, ConstructionParams, GeometricObject};
use crate::utils::Len;
use rust_decimal_macros::dec;
use std::collections::BTreeSet;
#[test]
fn test_curve_len() {
let curve = Curve::default();
assert_eq!(curve.len(), 0);
assert!(curve.is_empty());
let mut points = BTreeSet::new();
points.insert(Point2D::new(dec!(1.0), dec!(2.0)));
let curve_with_point = Curve::new(points);
assert_eq!(curve_with_point.len(), 1);
assert!(!curve_with_point.is_empty());
}
#[test]
fn test_curve_get_points() {
let mut points = BTreeSet::new();
points.insert(Point2D::new(dec!(1.0), dec!(2.0)));
points.insert(Point2D::new(dec!(3.0), dec!(4.0)));
let curve = Curve::new(points);
let retrieved_points = curve.get_points();
assert_eq!(retrieved_points.len(), 2);
assert!(
retrieved_points
.iter()
.any(|p| p.x == dec!(1.0) && p.y == dec!(2.0))
);
assert!(
retrieved_points
.iter()
.any(|p| p.x == dec!(3.0) && p.y == dec!(4.0))
);
}
#[test]
fn test_construct_method_error() {
let result = Curve::construct(ConstructionMethod::Parametric {
f: Box::new(|_| Err("Test error".into())),
params: ConstructionParams::D2 {
t_start: dec!(0.0),
t_end: dec!(1.0),
steps: 10,
},
});
assert!(result.is_err());
match result {
Err(CurveError::ConstructionError(msg)) => {
assert!(msg.contains("Test error"));
}
_ => panic!("Expected ConstructionError"),
}
let result = Curve::construct(ConstructionMethod::Parametric {
f: Box::new(|t| Ok(Point2D::new(t, t * dec!(2.0)))),
params: ConstructionParams::D3 {
x_start: dec!(0.0),
x_end: dec!(1.0),
y_start: dec!(0.0),
y_end: dec!(1.0),
x_steps: 10,
y_steps: 10,
},
});
assert!(result.is_err());
match result {
Err(CurveError::ConstructionError(msg)) => {
assert_eq!(msg, "Invalid parameters");
}
_ => panic!("Expected ConstructionError"),
}
}
}
#[cfg(test)]
mod tests_interpolation_edge_cases {
use crate::curves::{Curve, Point2D};
use crate::geometrics::{AxisOperations, Interpolate, InterpolationType};
use rust_decimal_macros::dec;
use std::collections::BTreeSet;
use tracing::info;
#[test]
fn test_cubic_interpolation_edge_cases() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let start_result = curve.interpolate(dec!(0.25), InterpolationType::Cubic);
assert!(start_result.is_ok());
let end_result = curve.interpolate(dec!(2.75), InterpolationType::Cubic);
assert!(end_result.is_ok());
}
#[test]
fn test_spline_interpolation_edge_cases() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
Point2D::new(dec!(3.0), dec!(9.0)),
]));
let exact_result = curve.interpolate(dec!(1.0), InterpolationType::Spline);
assert!(exact_result.is_ok());
let exact_point = exact_result.unwrap();
assert_eq!(exact_point.x, dec!(1.0));
assert_eq!(exact_point.y, dec!(1.0));
let mid_result = curve.interpolate(dec!(1.5), InterpolationType::Spline);
assert!(mid_result.is_ok());
}
#[test]
fn test_axis_operations() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
]));
let indices = curve.get_index_values();
assert_eq!(indices, vec![dec!(0.0), dec!(1.0), dec!(2.0)]);
let values = curve.get_values(dec!(1.0));
assert_eq!(values.len(), 1);
assert_eq!(*values[0], dec!(1.0));
let closest = curve.get_closest_point(&dec!(0.9)).unwrap();
assert_eq!(closest.x, dec!(1.0));
assert_eq!(closest.y, dec!(1.0));
info!("{:?}", curve);
let point = curve.get_point(&dec!(1.0));
assert!(point.is_some());
assert_eq!(point.unwrap().y, dec!(1.0));
let non_existent = curve.get_point(&dec!(1.5));
assert!(non_existent.is_none());
}
}
#[cfg(test)]
mod tests_axis_merge_and_transformations {
use super::*;
use rust_decimal_macros::dec;
#[test]
fn test_merge_axis_index() {
let curve1 = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let curve2 = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(1.0), dec!(2.0)),
Point2D::new(dec!(2.0), dec!(4.0)),
]));
let merged = curve1.merge_axis_index(&curve2);
assert_eq!(merged.len(), 1);
assert!(merged.contains(&dec!(1.0)));
}
#[test]
fn test_translate_with_negative_values() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let result = curve.translate(vec![&dec!(-2.0), &dec!(-3.0)]).unwrap();
assert_eq!(result.points.len(), 2);
let translated_points: Vec<_> = result.points.iter().collect();
assert_eq!(translated_points[0].x, dec!(-2.0));
assert_eq!(translated_points[0].y, dec!(-3.0));
assert_eq!(translated_points[1].x, dec!(-1.0));
assert_eq!(translated_points[1].y, dec!(-2.0));
}
#[test]
fn test_intersect_with_empty_curves() {
let curve1 = Curve::new(BTreeSet::new());
let curve2 = Curve::new(BTreeSet::new());
let result = curve1.intersect_with(&curve2).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_derivative_edge_cases() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let result = curve
.derivative_at(&Point2D::new(dec!(0.0), dec!(0.0)))
.unwrap();
assert_eq!(result[0], dec!(0.0));
let result = curve
.derivative_at(&Point2D::new(dec!(0.5), dec!(0.5)))
.unwrap();
assert_eq!(result[0], dec!(1.0));
}
#[test]
fn test_derivative_vertical_line() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(1.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let result = curve.derivative_at(&Point2D::new(dec!(1.0), dec!(0.5)));
assert!(result.is_err());
}
#[test]
fn test_extrema_single_point() {
let curve = Curve::new(BTreeSet::from_iter(vec![Point2D::new(
dec!(1.0),
dec!(1.0),
)]));
let (min, max) = curve.extrema().unwrap();
assert_eq!(min, max);
assert_eq!(min.x, dec!(1.0));
assert_eq!(min.y, dec!(1.0));
}
#[test]
fn test_extrema_flat_curve() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(1.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]));
let (min, max) = curve.extrema().unwrap();
assert_eq!(min.y, max.y);
assert_eq!(min.y, dec!(1.0));
}
#[test]
fn test_measure_under_single_point() {
let curve = Curve::new(BTreeSet::from_iter(vec![Point2D::new(
dec!(1.0),
dec!(1.0),
)]));
let area = curve.measure_under(&dec!(0.0)).unwrap();
assert_eq!(area, dec!(0.0));
}
#[test]
fn test_measure_under_negative_area() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(-1.0)),
Point2D::new(dec!(1.0), dec!(-2.0)),
]));
let area = curve.measure_under(&dec!(0.0)).unwrap();
assert!(area > dec!(0.0));
assert_eq!(area, dec!(1.5)); }
#[test]
fn test_extrema_multiple_extrema() {
let curve = Curve::new(BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
Point2D::new(dec!(2.0), dec!(1.0)),
Point2D::new(dec!(3.0), dec!(3.0)),
]));
let (min, max) = curve.extrema().unwrap();
assert_eq!(min.y, dec!(0.0));
assert_eq!(max.y, dec!(3.0));
}
}