use crate::curves::Curve;
use crate::error::CurveError;
use crate::greeks::Greeks;
use crate::model::BasicAxisTypes;
use crate::{OptionStyle, Options, Side};
use rust_decimal::Decimal;
use std::sync::Arc;
pub trait BasicCurves {
fn curve(
&self,
axis: &BasicAxisTypes,
option_style: &OptionStyle,
side: &Side,
) -> Result<Curve, CurveError>;
fn get_curve_strike_versus(
&self,
axis: &BasicAxisTypes,
option: &Arc<Options>,
) -> Result<(Decimal, Decimal), CurveError> {
match axis {
BasicAxisTypes::Delta => Ok((option.strike_price.to_dec(), option.delta()?)),
BasicAxisTypes::Gamma => Ok((option.strike_price.to_dec(), option.gamma()?)),
BasicAxisTypes::Theta => Ok((option.strike_price.to_dec(), option.theta()?)),
BasicAxisTypes::Vanna => Ok((option.strike_price.to_dec(), option.vanna()?)),
BasicAxisTypes::Vega => Ok((option.strike_price.to_dec(), option.vega()?)),
BasicAxisTypes::Veta => Ok((option.strike_price.to_dec(), option.veta()?)),
BasicAxisTypes::Charm => Ok((option.strike_price.to_dec(), option.charm()?)),
BasicAxisTypes::Color => Ok((option.strike_price.to_dec(), option.color()?)),
BasicAxisTypes::Volatility => Ok((
option.strike_price.to_dec(),
option.implied_volatility.to_dec(),
)),
BasicAxisTypes::Price => Ok((
option.strike_price.to_dec(),
option.calculate_price_black_scholes()?,
)),
_ => Err(CurveError::OperationError(
crate::error::OperationErrorKind::InvalidParameters {
operation: "get_axis_value".to_string(),
reason: format!("Axis: {axis:?} not supported"),
},
)),
}
}
}
#[cfg(test)]
mod tests_basic_curves_trait {
use super::*;
use crate::curves::Point2D;
use crate::error::OperationErrorKind;
use crate::model::types::{OptionStyle, Side};
use crate::{ExpirationDate, OptionType};
use positive::{Positive, pos_or_panic};
use rust_decimal_macros::dec;
use std::collections::BTreeSet;
use std::sync::Arc;
fn create_test_option() -> Arc<Options> {
Arc::new(Options::new(
OptionType::European,
Side::Long,
"AAPL".to_string(),
Positive::HUNDRED, ExpirationDate::Days(pos_or_panic!(30.0)),
pos_or_panic!(0.2), Positive::ONE, pos_or_panic!(105.0), dec!(0.05), OptionStyle::Call,
pos_or_panic!(0.01), None,
))
}
struct TestBasicCurves;
impl BasicCurves for TestBasicCurves {
fn curve(
&self,
axis: &BasicAxisTypes,
_option_style: &OptionStyle,
_side: &Side,
) -> Result<Curve, CurveError> {
let option = create_test_option();
let point = self.get_curve_strike_versus(axis, &option)?;
Ok(Curve::new(BTreeSet::from([Point2D::new(point.0, point.1)])))
}
}
#[test]
fn test_get_strike_versus_delta() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Delta, &option);
assert!(result.is_ok());
let (x, y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
assert!(y.abs() <= dec!(1.0)); }
#[test]
fn test_get_strike_versus_gamma() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Gamma, &option);
assert!(result.is_ok());
let (x, y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
assert!(y >= Decimal::ZERO); }
#[test]
fn test_get_strike_versus_theta() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Theta, &option);
assert!(result.is_ok());
let (x, _y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
}
#[test]
fn test_get_strike_versus_vega() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Vega, &option);
assert!(result.is_ok());
let (x, y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
assert!(y >= Decimal::ZERO); }
#[test]
fn test_get_strike_versus_vanna() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Vanna, &option);
assert!(result.is_ok());
let (x, y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
assert!(y <= Decimal::ZERO);
}
#[test]
fn test_get_strike_versus_veta() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Veta, &option);
assert!(result.is_ok());
let (x, y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
assert!(y >= Decimal::ZERO);
}
#[test]
fn test_get_strike_versus_charm() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Charm, &option);
assert!(result.is_ok());
let (x, _y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
}
#[test]
fn test_get_strike_versus_color() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Color, &option);
assert!(result.is_ok());
let (x, _y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
}
#[test]
fn test_get_strike_versus_volatility() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Volatility, &option);
assert!(result.is_ok());
let (x, y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
assert_eq!(y, option.implied_volatility.to_dec());
}
#[test]
fn test_get_strike_versus_price() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Price, &option);
assert!(result.is_ok());
let (x, y) = result.unwrap();
assert_eq!(x, option.strike_price.to_dec());
assert!(y > Decimal::ZERO); }
#[test]
fn test_curve_method() {
let test_curves = TestBasicCurves;
let curve_result =
test_curves.curve(&BasicAxisTypes::Delta, &OptionStyle::Call, &Side::Long);
assert!(curve_result.is_ok());
let curve = curve_result.unwrap();
assert_eq!(curve.points.len(), 1);
}
#[test]
fn test_get_strike_versus_black_scholes_price() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Price, &option);
assert!(result.is_ok());
let (strike, price) = result.unwrap();
assert_eq!(strike, option.strike_price.to_dec());
assert!(price > Decimal::ZERO);
let direct_bs_price = option.calculate_price_black_scholes().unwrap();
assert_eq!(price, direct_bs_price);
}
#[test]
fn test_get_strike_versus_unsupported_axis() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Expiration, &option);
assert!(result.is_err());
match result {
Err(CurveError::OperationError(
crate::error::OperationErrorKind::InvalidParameters { operation, reason },
)) => {
assert_eq!(operation, "get_axis_value");
assert!(reason.contains("not supported"));
assert!(reason.contains("Expiration"));
}
_ => panic!("Expected OperationError with InvalidParameters"),
}
}
#[test]
fn test_invalid_axis_error_message() {
let test_curves = TestBasicCurves;
let option = create_test_option();
let result = test_curves.get_curve_strike_versus(&BasicAxisTypes::Expiration, &option);
assert!(result.is_err());
if let Err(CurveError::OperationError(OperationErrorKind::InvalidParameters {
operation,
reason,
})) = result
{
assert_eq!(operation, "get_axis_value");
assert!(reason.contains("Axis: Expiration not supported"));
} else {
panic!("Expected OperationError with InvalidParameters");
}
}
#[test]
fn test_curve_with_various_params() {
let test_curves = TestBasicCurves;
let curve_call_long =
test_curves.curve(&BasicAxisTypes::Delta, &OptionStyle::Call, &Side::Long);
let curve_call_short =
test_curves.curve(&BasicAxisTypes::Delta, &OptionStyle::Call, &Side::Short);
let curve_put_long =
test_curves.curve(&BasicAxisTypes::Delta, &OptionStyle::Put, &Side::Long);
let curve_put_short =
test_curves.curve(&BasicAxisTypes::Delta, &OptionStyle::Put, &Side::Short);
assert!(curve_call_long.is_ok());
assert!(curve_call_short.is_ok());
assert!(curve_put_long.is_ok());
assert!(curve_put_short.is_ok());
}
}