use crate::curves::{Curve, Point2D};
use crate::error::{CurveError, OperationErrorKind};
use crate::geometrics::{BasicMetrics, MetricsExtractor, RangeMetrics, ShapeMetrics, TrendMetrics};
use num_traits::ToPrimitive;
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use rust_decimal::Decimal;
use statrs::distribution::{ContinuousCDF, Normal};
use std::collections::BTreeSet;
use tracing::error;
pub trait Curvable {
fn curve(&self) -> Result<Curve, CurveError>;
}
pub trait StatisticalCurve: MetricsExtractor {
fn get_x_values(&self) -> Vec<Decimal>;
fn generate_statistical_curve(
&self,
basic_metrics: &BasicMetrics,
shape_metrics: &ShapeMetrics,
range_metrics: &RangeMetrics,
trend_metrics: &TrendMetrics,
num_points: usize,
seed: Option<u64>,
) -> Result<Curve, CurveError> {
if num_points < 2 {
return Err(CurveError::OperationError(
OperationErrorKind::InvalidParameters {
operation: "generate_statistical_curve".to_string(),
reason: "Number of points must be at least 2".to_string(),
},
));
}
let seed_value = seed.unwrap_or_else(rand::random);
let mut rng = StdRng::seed_from_u64(seed_value);
let mut y_values: Vec<f64> = if basic_metrics.std_dev != Decimal::ZERO {
let normal = Normal::new(
basic_metrics.mean.to_f64().unwrap_or(0.0),
basic_metrics.std_dev.to_f64().unwrap_or(1.0),
)
.map_err(|e| {
error!(
"Failed to create normal distribution with mean {} and std_dev {}: {}",
basic_metrics.mean, basic_metrics.std_dev, e
);
CurveError::MetricsError(e.to_string())
})?;
(0..num_points)
.map(|_| {
let u: f64 = rng.random_range(0.0..1.0); normal.inverse_cdf(u) })
.collect()
} else {
vec![basic_metrics.mean.to_f64().unwrap_or(0.0); num_points]
};
let skewness = shape_metrics.skewness.to_f64().unwrap_or(0.0);
let kurtosis = shape_metrics.kurtosis.to_f64().unwrap_or(0.0);
if skewness.abs() > 0.01 {
for y in &mut y_values {
*y += skewness * (*y - basic_metrics.mean.to_f64().unwrap_or(0.0)).powi(2);
}
}
if kurtosis.abs() > 0.01 {
for y in &mut y_values {
let z = (*y - basic_metrics.mean.to_f64().unwrap_or(0.0))
/ basic_metrics.std_dev.to_f64().unwrap_or(1.0);
*y += kurtosis * 0.1 * z.powi(3);
}
}
let x_values: Vec<Decimal> = self.get_x_values();
let slope = trend_metrics.slope.to_f64().unwrap_or(0.0);
if slope.abs() > 0.001 {
let intercept = trend_metrics.intercept.to_f64().unwrap_or(0.0);
for i in 0..y_values.len() {
y_values[i] += slope * x_values[i].to_f64().unwrap() + intercept;
}
}
let current_min = y_values.iter().cloned().fold(f64::INFINITY, f64::min);
let current_max = y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let current_range = current_max - current_min;
let target_min = range_metrics.min.y.to_f64().unwrap_or(0.0);
let target_max = range_metrics.max.y.to_f64().unwrap_or(1.0);
let target_range = target_max - target_min;
if current_range > 0.0 {
for y in &mut y_values {
*y = ((*y - current_min) / current_range) * target_range + target_min;
}
}
if num_points > 3 {
let index = rng.random_range(0..(num_points / 3));
y_values[index] = basic_metrics.mode.to_f64().unwrap_or(y_values[index]);
}
let mut points = BTreeSet::new();
for i in 0..num_points {
let point = Point2D::from_f64_tuple(x_values[i].to_f64().unwrap(), y_values[i])?;
points.insert(point);
}
Ok(Curve::new(points))
}
#[allow(clippy::too_many_arguments)]
fn generate_refined_statistical_curve(
&self,
basic_metrics: &BasicMetrics,
shape_metrics: &ShapeMetrics,
range_metrics: &RangeMetrics,
trend_metrics: &TrendMetrics,
num_points: usize,
max_attempts: usize,
tolerance: Decimal,
seed: Option<u64>,
) -> Result<Curve, CurveError> {
let max_tries = if max_attempts == 0 { 5 } else { max_attempts };
let mut seed_value = seed.unwrap_or_else(rand::random);
for _ in 0..max_tries {
let curve = self.generate_statistical_curve(
basic_metrics,
shape_metrics,
range_metrics,
trend_metrics,
num_points,
Some(seed_value),
)?;
if self.verify_curve_metrics(&curve, basic_metrics, tolerance)? {
return Ok(curve);
}
seed_value = seed_value.wrapping_add(1);
}
self.generate_statistical_curve(
basic_metrics,
shape_metrics,
range_metrics,
trend_metrics,
num_points,
Some(seed_value),
)
}
fn verify_curve_metrics(
&self,
curve: &Curve,
target_metrics: &BasicMetrics,
tolerance: Decimal,
) -> Result<bool, CurveError> {
let actual_metrics = curve
.compute_basic_metrics()
.map_err(|e| CurveError::MetricsError(format!("Failed to compute metrics: {e}")))?;
let mean_diff = (actual_metrics.mean - target_metrics.mean).abs();
let std_dev_diff = (actual_metrics.std_dev - target_metrics.std_dev).abs();
Ok(mean_diff <= tolerance && std_dev_diff <= tolerance)
}
}
#[cfg(test)]
mod tests_statistical_curve {
use super::*;
use crate::error::MetricsError;
use crate::geometrics::RiskMetrics;
use crate::utils::Len;
use rust_decimal_macros::dec;
use std::collections::BTreeSet;
struct TestCurveGenerator;
impl Len for TestCurveGenerator {
fn len(&self) -> usize {
unreachable!()
}
}
impl MetricsExtractor for TestCurveGenerator {
fn compute_basic_metrics(&self) -> Result<BasicMetrics, MetricsError> {
Ok(BasicMetrics {
mean: dec!(0.0),
median: dec!(0.0),
mode: dec!(0.0),
std_dev: dec!(1.0),
})
}
fn compute_shape_metrics(&self) -> Result<ShapeMetrics, MetricsError> {
Ok(ShapeMetrics {
skewness: dec!(0.0),
kurtosis: dec!(0.0),
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
})
}
fn compute_range_metrics(&self) -> Result<RangeMetrics, MetricsError> {
Ok(RangeMetrics {
min: Point2D::new(dec!(0.0), dec!(0.0)),
max: Point2D::new(dec!(10.0), dec!(10.0)),
range: dec!(10.0),
quartiles: (dec!(2.5), dec!(5.0), dec!(7.5)),
interquartile_range: dec!(5.0),
})
}
fn compute_trend_metrics(&self) -> Result<TrendMetrics, MetricsError> {
Ok(TrendMetrics {
slope: dec!(0.0),
intercept: dec!(0.0),
r_squared: dec!(0.0),
moving_average: vec![],
})
}
fn compute_risk_metrics(&self) -> Result<RiskMetrics, MetricsError> {
unreachable!()
}
}
impl StatisticalCurve for TestCurveGenerator {
fn get_x_values(&self) -> Vec<Decimal> {
(0..10).map(Decimal::from).collect()
}
}
impl Curvable for TestCurveGenerator {
fn curve(&self) -> Result<Curve, CurveError> {
let points: BTreeSet<Point2D> = (0..10)
.map(|i| Point2D::new(Decimal::from(i), Decimal::from(i)))
.collect();
Ok(Curve::new(points))
}
}
#[test]
fn test_get_x_values() {
let generator = TestCurveGenerator;
let x_values = generator.get_x_values();
assert_eq!(x_values.len(), 10);
assert_eq!(x_values[0], dec!(0));
assert_eq!(x_values[9], dec!(9));
}
#[test]
fn test_generate_statistical_curve_invalid_points() {
let generator = TestCurveGenerator;
let basic_metrics = BasicMetrics::default();
let shape_metrics = ShapeMetrics::default();
let range_metrics = RangeMetrics::default();
let trend_metrics = TrendMetrics::default();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
1, None,
);
assert!(result.is_err());
if let Err(CurveError::OperationError(OperationErrorKind::InvalidParameters {
operation,
reason,
})) = result
{
assert_eq!(operation, "generate_statistical_curve");
assert!(reason.contains("Number of points must be at least 2"));
} else {
panic!("Expected InvalidParameters error");
}
}
#[test]
fn test_verify_curve_metrics() {
let generator = TestCurveGenerator;
let curve = generator.curve().unwrap();
let target_metrics = BasicMetrics {
mean: dec!(4.5), median: dec!(4.5),
mode: dec!(0.0),
std_dev: dec!(3.0), };
let result = generator.verify_curve_metrics(&curve, &target_metrics, dec!(1.0));
assert!(result.is_ok());
assert!(result.unwrap());
let result = generator.verify_curve_metrics(&curve, &target_metrics, dec!(0.1));
assert!(result.is_ok());
}
impl Default for BasicMetrics {
fn default() -> Self {
Self {
mean: dec!(0.0),
median: dec!(0.0),
mode: dec!(0.0),
std_dev: dec!(1.0),
}
}
}
impl Default for ShapeMetrics {
fn default() -> Self {
Self {
skewness: dec!(0.0),
kurtosis: dec!(0.0),
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
}
}
}
impl Default for RangeMetrics {
fn default() -> Self {
Self {
min: Point2D::new(dec!(0.0), dec!(0.0)),
max: Point2D::new(dec!(10.0), dec!(10.0)),
range: dec!(10.0),
quartiles: (dec!(2.5), dec!(5.0), dec!(7.5)),
interquartile_range: dec!(5.0),
}
}
}
impl Default for TrendMetrics {
fn default() -> Self {
Self {
slope: dec!(0.0),
intercept: dec!(0.0),
r_squared: dec!(0.0),
moving_average: vec![],
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::MetricsError;
use crate::geometrics::RiskMetrics;
use crate::utils::Len;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use std::collections::BTreeSet;
struct MockCurvable {
points: BTreeSet<Point2D>,
should_fail: bool,
}
impl MockCurvable {
fn new(should_fail: bool) -> Self {
let mut points = BTreeSet::new();
if !should_fail {
points.insert(Point2D::new(dec!(1.0), dec!(2.0)));
points.insert(Point2D::new(dec!(2.0), dec!(3.0)));
points.insert(Point2D::new(dec!(3.0), dec!(4.0)));
}
Self {
points,
should_fail,
}
}
}
impl Curvable for MockCurvable {
fn curve(&self) -> Result<Curve, CurveError> {
if self.should_fail {
Err(CurveError::OperationError(
OperationErrorKind::InvalidParameters {
operation: "curve".to_string(),
reason: "Test failure".to_string(),
},
))
} else {
Ok(Curve::new(self.points.clone()))
}
}
}
struct MockStatisticalCurve {
x_values: Vec<Decimal>,
}
impl MockStatisticalCurve {
fn new() -> Self {
let x_values = vec![dec!(1.0), dec!(2.0), dec!(3.0), dec!(4.0), dec!(5.0)];
Self { x_values }
}
}
impl Len for MockStatisticalCurve {
fn len(&self) -> usize {
self.x_values.len()
}
}
impl MetricsExtractor for MockStatisticalCurve {
fn compute_basic_metrics(&self) -> Result<BasicMetrics, MetricsError> {
Ok(BasicMetrics {
mean: dec!(3.0),
median: dec!(3.0),
mode: dec!(3.0),
std_dev: dec!(1.5),
})
}
fn compute_shape_metrics(&self) -> Result<ShapeMetrics, MetricsError> {
Ok(ShapeMetrics {
skewness: dec!(0.0),
kurtosis: dec!(0.0),
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
})
}
fn compute_range_metrics(&self) -> Result<RangeMetrics, MetricsError> {
Ok(RangeMetrics {
min: Point2D::new(dec!(1.0), dec!(1.0)),
max: Point2D::new(dec!(5.0), dec!(5.0)),
range: dec!(4.0),
quartiles: (Default::default(), Default::default(), Default::default()),
interquartile_range: Default::default(),
})
}
fn compute_trend_metrics(&self) -> Result<TrendMetrics, MetricsError> {
Ok(TrendMetrics {
slope: dec!(1.0),
intercept: dec!(0.0),
r_squared: dec!(1.0),
moving_average: vec![],
})
}
fn compute_risk_metrics(&self) -> Result<RiskMetrics, MetricsError> {
Ok(RiskMetrics {
volatility: Default::default(),
value_at_risk: Default::default(),
expected_shortfall: Default::default(),
beta: Default::default(),
sharpe_ratio: Default::default(),
})
}
}
impl StatisticalCurve for MockStatisticalCurve {
fn get_x_values(&self) -> Vec<Decimal> {
self.x_values.clone()
}
}
#[test]
fn test_curvable_success() {
let mock = MockCurvable::new(false);
let result = mock.curve();
assert!(result.is_ok(), "Curve generation should succeed");
let curve = result.unwrap();
assert_eq!(curve.len(), 3, "Curve should have 3 points");
}
#[test]
fn test_curvable_failure() {
let mock = MockCurvable::new(true);
let result = mock.curve();
assert!(result.is_err(), "Curve generation should fail");
if let Err(CurveError::OperationError(OperationErrorKind::InvalidParameters {
operation,
reason,
})) = result
{
assert_eq!(operation, "curve", "Operation name should match");
assert_eq!(reason, "Test failure", "Error reason should match");
} else {
panic!("Unexpected error type");
}
}
#[test]
fn test_get_x_values() {
let mock = MockStatisticalCurve::new();
let x_values = mock.get_x_values();
assert_eq!(x_values.len(), 5, "Should return 5 x values");
assert_eq!(x_values[0], dec!(1.0), "First x value should be 1.0");
assert_eq!(x_values[4], dec!(5.0), "Last x value should be 5.0");
}
#[test]
fn test_generate_statistical_curve_invalid_points() {
let mock = MockStatisticalCurve::new();
let basic_metrics = BasicMetrics {
mean: dec!(3.0),
median: dec!(3.0),
mode: dec!(3.0),
std_dev: dec!(1.5),
};
let shape_metrics = ShapeMetrics {
skewness: dec!(0.0),
kurtosis: dec!(0.0),
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
};
let range_metrics = RangeMetrics {
min: Point2D::new(dec!(1.0), dec!(1.0)),
max: Point2D::new(dec!(5.0), dec!(5.0)),
range: dec!(4.0),
quartiles: (Default::default(), Default::default(), Default::default()),
interquartile_range: Default::default(),
};
let trend_metrics = TrendMetrics {
slope: dec!(1.0),
intercept: dec!(0.0),
r_squared: dec!(1.0),
moving_average: vec![],
};
let result = mock.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
1,
None,
);
assert!(result.is_err(), "Should fail with less than 2 points");
if let Err(CurveError::OperationError(OperationErrorKind::InvalidParameters {
operation,
reason,
})) = result
{
assert_eq!(operation, "generate_statistical_curve");
assert!(reason.contains("at least 2"));
} else {
panic!("Unexpected error type");
}
}
#[test]
fn test_generate_statistical_curve_success() {
let mock = MockStatisticalCurve::new();
let basic_metrics = BasicMetrics {
mean: dec!(3.0),
median: dec!(3.0),
mode: dec!(3.0),
std_dev: dec!(1.5),
};
let shape_metrics = ShapeMetrics {
skewness: dec!(0.0),
kurtosis: dec!(0.0),
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
};
let range_metrics = RangeMetrics {
min: Point2D::new(dec!(1.0), dec!(1.0)),
max: Point2D::new(dec!(5.0), dec!(5.0)),
range: dec!(4.0),
quartiles: (Default::default(), Default::default(), Default::default()),
interquartile_range: Default::default(),
};
let trend_metrics = TrendMetrics {
slope: dec!(1.0),
intercept: dec!(0.0),
r_squared: dec!(1.0),
moving_average: vec![],
};
let result = mock.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
5, Some(42), );
assert!(result.is_ok(), "Should successfully generate a curve");
let curve = result.unwrap();
assert!(curve.len() > 0, "Generated curve should contain points");
}
#[test]
fn test_verify_curve_metrics() {
let mock = MockStatisticalCurve::new();
let mut points = BTreeSet::new();
points.insert(Point2D::new(dec!(1.0), dec!(2.0)));
points.insert(Point2D::new(dec!(2.0), dec!(3.0)));
points.insert(Point2D::new(dec!(3.0), dec!(4.0)));
let curve = Curve::new(points);
let target_metrics = BasicMetrics {
mean: dec!(3.1),
median: dec!(3.0),
mode: dec!(3.0),
std_dev: dec!(1.0),
};
let result = mock.verify_curve_metrics(&curve, &target_metrics, dec!(0.5));
assert!(result.is_ok(), "Verification should not fail");
assert!(result.unwrap(), "Metrics should be within tolerance");
let result = mock.verify_curve_metrics(&curve, &target_metrics, dec!(0.05));
assert!(result.is_ok(), "Verification should not fail");
assert!(!result.unwrap(), "Metrics should not be within tolerance");
}
#[test]
fn test_refined_statistical_curve() {
let mock = MockStatisticalCurve::new();
let basic_metrics = BasicMetrics {
mean: dec!(3.0),
median: dec!(3.0),
mode: dec!(3.0),
std_dev: dec!(1.5),
};
let shape_metrics = ShapeMetrics {
skewness: dec!(0.0),
kurtosis: dec!(0.0),
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
};
let range_metrics = RangeMetrics {
min: Point2D::new(dec!(1.0), dec!(1.0)),
max: Point2D::new(dec!(5.0), dec!(5.0)),
range: dec!(4.0),
quartiles: (Default::default(), Default::default(), Default::default()),
interquartile_range: Default::default(),
};
let trend_metrics = TrendMetrics {
slope: dec!(1.0),
intercept: dec!(0.0),
r_squared: dec!(1.0),
moving_average: vec![],
};
let result = mock.generate_refined_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
5,
3, dec!(0.2), Some(42), );
assert!(
result.is_ok(),
"Should successfully generate a refined curve"
);
let curve = result.unwrap();
assert!(curve.len() > 0, "Generated curve should contain points");
}
}
#[cfg(test)]
mod tests_statistical_curve_generation {
use super::*;
use crate::error::MetricsError;
use crate::geometrics::RiskMetrics;
use crate::utils::Len;
use rust_decimal_macros::dec;
struct EnhancedTestCurveGenerator {
x_values: Vec<Decimal>,
fail_metrics: bool,
}
impl EnhancedTestCurveGenerator {
fn new(fail_metrics: bool) -> Self {
Self {
x_values: (0..10).map(Decimal::from).collect(),
fail_metrics,
}
}
}
impl Len for EnhancedTestCurveGenerator {
fn len(&self) -> usize {
self.x_values.len()
}
fn is_empty(&self) -> bool {
self.x_values.is_empty()
}
}
impl MetricsExtractor for EnhancedTestCurveGenerator {
fn compute_basic_metrics(&self) -> Result<BasicMetrics, MetricsError> {
if self.fail_metrics {
Err(MetricsError::BasicError(
"compute_basic_metrics error".to_string(),
))
} else {
Ok(BasicMetrics {
mean: dec!(5.0),
median: dec!(5.0),
mode: dec!(5.0),
std_dev: dec!(2.0),
})
}
}
fn compute_shape_metrics(&self) -> Result<ShapeMetrics, MetricsError> {
Ok(ShapeMetrics {
skewness: dec!(0.5), kurtosis: dec!(0.5), peaks: vec![],
valleys: vec![],
inflection_points: vec![],
})
}
fn compute_range_metrics(&self) -> Result<RangeMetrics, MetricsError> {
Ok(RangeMetrics {
min: Point2D::new(dec!(0.0), dec!(0.0)),
max: Point2D::new(dec!(10.0), dec!(10.0)),
range: dec!(10.0),
quartiles: (dec!(2.5), dec!(5.0), dec!(7.5)),
interquartile_range: dec!(5.0),
})
}
fn compute_trend_metrics(&self) -> Result<TrendMetrics, MetricsError> {
Ok(TrendMetrics {
slope: dec!(1.0), intercept: dec!(0.5), r_squared: dec!(0.9),
moving_average: vec![],
})
}
fn compute_risk_metrics(&self) -> Result<RiskMetrics, MetricsError> {
Ok(RiskMetrics {
volatility: dec!(0.5),
value_at_risk: dec!(1.0),
expected_shortfall: dec!(1.5),
beta: dec!(0.8),
sharpe_ratio: dec!(1.2),
})
}
}
impl StatisticalCurve for EnhancedTestCurveGenerator {
fn get_x_values(&self) -> Vec<Decimal> {
self.x_values.clone()
}
}
impl Curvable for EnhancedTestCurveGenerator {
fn curve(&self) -> Result<Curve, CurveError> {
let points: BTreeSet<Point2D> = (0..10)
.map(|i| Point2D::new(Decimal::from(i), Decimal::from(i)))
.collect();
Ok(Curve::new(points))
}
}
#[test]
fn test_generate_statistical_curve_with_skewness() {
let generator = EnhancedTestCurveGenerator::new(false);
let basic_metrics = generator.compute_basic_metrics().unwrap();
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = generator.compute_range_metrics().unwrap();
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10, Some(42), );
assert!(result.is_ok(), "Failed to generate curve with skewness");
let curve = result.unwrap();
assert_eq!(curve.points.len(), 10);
}
#[test]
fn test_generate_statistical_curve_with_kurtosis() {
let generator = EnhancedTestCurveGenerator::new(false);
let basic_metrics = generator.compute_basic_metrics().unwrap();
let mut shape_metrics = generator.compute_shape_metrics().unwrap();
shape_metrics.skewness = dec!(0.0); shape_metrics.kurtosis = dec!(1.0); let range_metrics = generator.compute_range_metrics().unwrap();
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
Some(42),
);
assert!(result.is_ok(), "Failed to generate curve with kurtosis");
}
#[test]
fn test_generate_statistical_curve_with_trend() {
let generator = EnhancedTestCurveGenerator::new(false);
let basic_metrics = generator.compute_basic_metrics().unwrap();
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = generator.compute_range_metrics().unwrap();
let mut trend_metrics = generator.compute_trend_metrics().unwrap();
trend_metrics.slope = dec!(2.0);
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
Some(42),
);
assert!(result.is_ok(), "Failed to generate curve with trend");
}
#[test]
fn test_verify_curve_metrics_failure() {
let generator = EnhancedTestCurveGenerator::new(true);
let points = BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(0.0)),
Point2D::new(dec!(1.0), dec!(1.0)),
]);
let curve = Curve::new(points);
let target_metrics = BasicMetrics {
mean: dec!(0.0),
median: dec!(-5.0),
mode: dec!(5.0),
std_dev: dec!(0.0),
};
let result = generator.verify_curve_metrics(&curve, &target_metrics, dec!(0.0000001));
assert!(result.is_ok()); }
#[test]
fn test_generate_refined_statistical_curve() {
let generator = EnhancedTestCurveGenerator::new(false);
let basic_metrics = generator.compute_basic_metrics().unwrap();
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = generator.compute_range_metrics().unwrap();
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_refined_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
3, dec!(0.01), Some(42),
);
assert!(result.is_ok(), "Failed to generate refined curve");
let result = generator.generate_refined_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
0, dec!(0.5),
Some(42),
);
assert!(
result.is_ok(),
"Failed to generate refined curve with default attempts"
);
}
}
#[cfg(test)]
mod tests_statistical_curve_edge_cases {
use super::*;
use crate::error::MetricsError;
use crate::geometrics::RiskMetrics;
use crate::utils::Len;
use rust_decimal_macros::dec;
#[allow(dead_code)]
struct EmptyXValuesGenerator;
impl Len for EmptyXValuesGenerator {
fn len(&self) -> usize {
0
}
fn is_empty(&self) -> bool {
true
}
}
impl MetricsExtractor for EmptyXValuesGenerator {
fn compute_basic_metrics(&self) -> Result<BasicMetrics, MetricsError> {
Ok(BasicMetrics {
mean: dec!(0.0),
median: dec!(0.0),
mode: dec!(0.0),
std_dev: dec!(0.0),
})
}
fn compute_shape_metrics(&self) -> Result<ShapeMetrics, MetricsError> {
Ok(ShapeMetrics {
skewness: dec!(0.0),
kurtosis: dec!(0.0),
peaks: vec![],
valleys: vec![],
inflection_points: vec![],
})
}
fn compute_range_metrics(&self) -> Result<RangeMetrics, MetricsError> {
Ok(RangeMetrics {
min: Point2D::new(dec!(0.0), dec!(0.0)),
max: Point2D::new(dec!(0.0), dec!(0.0)),
range: dec!(0.0),
quartiles: (dec!(0.0), dec!(0.0), dec!(0.0)),
interquartile_range: dec!(0.0),
})
}
fn compute_trend_metrics(&self) -> Result<TrendMetrics, MetricsError> {
Ok(TrendMetrics {
slope: dec!(0.0),
intercept: dec!(0.0),
r_squared: dec!(0.0),
moving_average: vec![],
})
}
fn compute_risk_metrics(&self) -> Result<RiskMetrics, MetricsError> {
Ok(RiskMetrics {
volatility: dec!(0.0),
value_at_risk: dec!(0.0),
expected_shortfall: dec!(0.0),
beta: dec!(0.0),
sharpe_ratio: dec!(0.0),
})
}
}
impl StatisticalCurve for EmptyXValuesGenerator {
fn get_x_values(&self) -> Vec<Decimal> {
vec![] }
}
struct SpecialStatisticalGenerator {
zero_std_dev: bool,
}
impl SpecialStatisticalGenerator {
fn new(zero_std_dev: bool) -> Self {
Self { zero_std_dev }
}
}
impl Len for SpecialStatisticalGenerator {
fn len(&self) -> usize {
10
}
fn is_empty(&self) -> bool {
false
}
}
impl MetricsExtractor for SpecialStatisticalGenerator {
fn compute_basic_metrics(&self) -> Result<BasicMetrics, MetricsError> {
Ok(BasicMetrics {
mean: dec!(5.0),
median: dec!(5.0),
mode: dec!(5.0),
std_dev: if self.zero_std_dev {
dec!(0.0)
} else {
dec!(1.0)
},
})
}
fn compute_shape_metrics(&self) -> Result<ShapeMetrics, MetricsError> {
Ok(ShapeMetrics {
skewness: dec!(0), kurtosis: dec!(0), peaks: vec![],
valleys: vec![],
inflection_points: vec![],
})
}
fn compute_range_metrics(&self) -> Result<RangeMetrics, MetricsError> {
Ok(RangeMetrics {
min: Point2D::new(dec!(0.0), dec!(0.0)),
max: Point2D::new(dec!(10.0), dec!(10.0)),
range: dec!(10.0),
quartiles: (dec!(2.5), dec!(5.0), dec!(7.5)),
interquartile_range: dec!(5.0),
})
}
fn compute_trend_metrics(&self) -> Result<TrendMetrics, MetricsError> {
Ok(TrendMetrics {
slope: dec!(0.0), intercept: dec!(0.0),
r_squared: dec!(1.0),
moving_average: vec![],
})
}
fn compute_risk_metrics(&self) -> Result<RiskMetrics, MetricsError> {
Ok(RiskMetrics {
volatility: dec!(0.0),
value_at_risk: dec!(0.0),
expected_shortfall: dec!(0.0),
beta: dec!(0.0),
sharpe_ratio: dec!(0.0),
})
}
}
impl StatisticalCurve for SpecialStatisticalGenerator {
fn get_x_values(&self) -> Vec<Decimal> {
(0..10).map(Decimal::from).collect()
}
}
#[test]
fn test_generate_statistical_curve_mode_inclusion() {
let generator = SpecialStatisticalGenerator::new(false);
let basic_metrics = BasicMetrics {
mean: dec!(5.0),
median: dec!(5.0),
mode: dec!(7.5), std_dev: dec!(1.0),
};
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = generator.compute_range_metrics().unwrap();
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
Some(42),
);
assert!(result.is_ok(), "Failed to generate curve including mode");
}
#[test]
fn test_generate_statistical_curve_with_zero_std_dev() {
let generator = SpecialStatisticalGenerator::new(true);
let basic_metrics = generator.compute_basic_metrics().unwrap();
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = generator.compute_range_metrics().unwrap();
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
Some(42),
);
assert!(result.is_ok(), "Failed to handle zero standard deviation");
}
#[test]
fn test_generate_statistical_curve_scale_range() {
let generator = SpecialStatisticalGenerator::new(false);
let basic_metrics = generator.compute_basic_metrics().unwrap();
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = RangeMetrics {
min: Point2D::new(dec!(0.0), dec!(-10.0)), max: Point2D::new(dec!(10.0), dec!(10.0)), range: dec!(20.0),
quartiles: (dec!(2.5), dec!(5.0), dec!(7.5)),
interquartile_range: dec!(5.0),
};
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
Some(42),
);
assert!(result.is_ok(), "Failed to scale curve to target range");
let basic_metrics = BasicMetrics {
mean: dec!(5.0),
median: dec!(5.0),
mode: dec!(5.0),
std_dev: dec!(0.0), };
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
Some(42),
);
assert!(result.is_ok(), "Failed to handle zero current range");
}
#[test]
fn test_normal_distribution_error() {
let generator = SpecialStatisticalGenerator::new(false);
let basic_metrics = BasicMetrics {
mean: dec!(5.0),
median: dec!(5.0),
mode: dec!(5.0),
std_dev: dec!(1.0),
};
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = generator.compute_range_metrics().unwrap();
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
10,
Some(42),
);
assert!(
result.is_ok(),
"Failed to generate curve with normal distribution"
);
}
#[test]
fn test_generate_statistical_curve_insufficient_points() {
let generator = SpecialStatisticalGenerator::new(false);
let basic_metrics = generator.compute_basic_metrics().unwrap();
let shape_metrics = generator.compute_shape_metrics().unwrap();
let range_metrics = generator.compute_range_metrics().unwrap();
let trend_metrics = generator.compute_trend_metrics().unwrap();
let result = generator.generate_statistical_curve(
&basic_metrics,
&shape_metrics,
&range_metrics,
&trend_metrics,
1, Some(42),
);
assert!(result.is_err());
if let Err(CurveError::OperationError(OperationErrorKind::InvalidParameters {
operation,
reason,
})) = result
{
assert_eq!(operation, "generate_statistical_curve");
assert!(reason.contains("at least 2"));
} else {
panic!("Expected InvalidParameters error");
}
}
#[test]
fn test_verify_metrics_within_tolerance() {
let generator = SpecialStatisticalGenerator::new(false);
let points = BTreeSet::from_iter(vec![
Point2D::new(dec!(0.0), dec!(1.0)),
Point2D::new(dec!(1.0), dec!(2.0)),
Point2D::new(dec!(2.0), dec!(3.0)),
]);
let curve = Curve::new(points);
let target_metrics = BasicMetrics {
mean: dec!(2.0), median: dec!(2.0),
mode: dec!(2.0),
std_dev: dec!(0.81), };
let result = generator.verify_curve_metrics(&curve, &target_metrics, dec!(0.2));
assert!(result.is_ok());
assert!(result.unwrap(), "Metrics should be within tolerance");
let result = generator.verify_curve_metrics(&curve, &target_metrics, dec!(0.001));
assert!(result.is_ok());
assert!(!result.unwrap(), "Metrics should be outside tolerance");
}
}