use super::types::ValidationError;
use crate::contracts::TagValue;
fn bool_param(v: bool) -> String {
if v { "1" } else { "0" }.to_string()
}
pub const MIN_PCT_VOL: f64 = 0.1;
pub const MAX_PCT_VOL: f64 = 0.5;
fn validate_pct_vol(field: &'static str, value: f64) -> Result<(), ValidationError> {
if !(MIN_PCT_VOL..=MAX_PCT_VOL).contains(&value) {
Err(ValidationError::InvalidPercentage {
field,
value,
min: MIN_PCT_VOL,
max: MAX_PCT_VOL,
})
} else {
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct AlgoParams {
pub strategy: String,
pub params: Vec<TagValue>,
}
impl From<String> for AlgoParams {
fn from(strategy: String) -> Self {
Self {
strategy,
params: Vec::new(),
}
}
}
impl From<&str> for AlgoParams {
fn from(strategy: &str) -> Self {
Self {
strategy: strategy.to_string(),
params: Vec::new(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct VwapBuilder {
max_pct_vol: Option<f64>,
start_time: Option<String>,
end_time: Option<String>,
allow_past_end_time: Option<bool>,
no_take_liq: Option<bool>,
speed_up: Option<bool>,
}
impl VwapBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_pct_vol(mut self, pct: f64) -> Self {
self.max_pct_vol = Some(pct);
self
}
pub fn start_time(mut self, time: impl Into<String>) -> Self {
self.start_time = Some(time.into());
self
}
pub fn end_time(mut self, time: impl Into<String>) -> Self {
self.end_time = Some(time.into());
self
}
pub fn allow_past_end_time(mut self, allow: bool) -> Self {
self.allow_past_end_time = Some(allow);
self
}
pub fn no_take_liq(mut self, no_take: bool) -> Self {
self.no_take_liq = Some(no_take);
self
}
pub fn speed_up(mut self, speed_up: bool) -> Self {
self.speed_up = Some(speed_up);
self
}
pub fn build(self) -> Result<AlgoParams, ValidationError> {
let mut params = Vec::new();
if let Some(v) = self.max_pct_vol {
validate_pct_vol("max_pct_vol", v)?;
params.push(TagValue {
tag: "maxPctVol".to_string(),
value: v.to_string(),
});
}
if let Some(v) = self.start_time {
params.push(TagValue {
tag: "startTime".to_string(),
value: v,
});
}
if let Some(v) = self.end_time {
params.push(TagValue {
tag: "endTime".to_string(),
value: v,
});
}
if let Some(v) = self.allow_past_end_time {
params.push(TagValue {
tag: "allowPastEndTime".to_string(),
value: bool_param(v),
});
}
if let Some(v) = self.no_take_liq {
params.push(TagValue {
tag: "noTakeLiq".to_string(),
value: bool_param(v),
});
}
if let Some(v) = self.speed_up {
params.push(TagValue {
tag: "speedUp".to_string(),
value: bool_param(v),
});
}
Ok(AlgoParams {
strategy: "Vwap".to_string(),
params,
})
}
}
impl TryFrom<VwapBuilder> for AlgoParams {
type Error = ValidationError;
fn try_from(builder: VwapBuilder) -> Result<Self, Self::Error> {
builder.build()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum TwapStrategyType {
#[default]
Marketable,
MatchingMidpoint,
MatchingSameSide,
MatchingLast,
}
impl TwapStrategyType {
fn as_str(&self) -> &'static str {
match self {
TwapStrategyType::Marketable => "Marketable",
TwapStrategyType::MatchingMidpoint => "Matching Midpoint",
TwapStrategyType::MatchingSameSide => "Matching Same Side",
TwapStrategyType::MatchingLast => "Matching Last",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TwapBuilder {
strategy_type: Option<TwapStrategyType>,
start_time: Option<String>,
end_time: Option<String>,
allow_past_end_time: Option<bool>,
}
impl TwapBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn strategy_type(mut self, strategy: TwapStrategyType) -> Self {
self.strategy_type = Some(strategy);
self
}
pub fn start_time(mut self, time: impl Into<String>) -> Self {
self.start_time = Some(time.into());
self
}
pub fn end_time(mut self, time: impl Into<String>) -> Self {
self.end_time = Some(time.into());
self
}
pub fn allow_past_end_time(mut self, allow: bool) -> Self {
self.allow_past_end_time = Some(allow);
self
}
pub fn build(self) -> Result<AlgoParams, ValidationError> {
let mut params = Vec::new();
if let Some(v) = self.strategy_type {
params.push(TagValue {
tag: "strategyType".to_string(),
value: v.as_str().to_string(),
});
}
if let Some(v) = self.start_time {
params.push(TagValue {
tag: "startTime".to_string(),
value: v,
});
}
if let Some(v) = self.end_time {
params.push(TagValue {
tag: "endTime".to_string(),
value: v,
});
}
if let Some(v) = self.allow_past_end_time {
params.push(TagValue {
tag: "allowPastEndTime".to_string(),
value: bool_param(v),
});
}
Ok(AlgoParams {
strategy: "Twap".to_string(),
params,
})
}
}
impl TryFrom<TwapBuilder> for AlgoParams {
type Error = ValidationError;
fn try_from(builder: TwapBuilder) -> Result<Self, Self::Error> {
builder.build()
}
}
#[derive(Debug, Clone, Default)]
pub struct PctVolBuilder {
pct_vol: Option<f64>,
start_time: Option<String>,
end_time: Option<String>,
no_take_liq: Option<bool>,
}
impl PctVolBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn pct_vol(mut self, pct: f64) -> Self {
self.pct_vol = Some(pct);
self
}
pub fn start_time(mut self, time: impl Into<String>) -> Self {
self.start_time = Some(time.into());
self
}
pub fn end_time(mut self, time: impl Into<String>) -> Self {
self.end_time = Some(time.into());
self
}
pub fn no_take_liq(mut self, no_take: bool) -> Self {
self.no_take_liq = Some(no_take);
self
}
pub fn build(self) -> Result<AlgoParams, ValidationError> {
let mut params = Vec::new();
if let Some(v) = self.pct_vol {
validate_pct_vol("pct_vol", v)?;
params.push(TagValue {
tag: "pctVol".to_string(),
value: v.to_string(),
});
}
if let Some(v) = self.start_time {
params.push(TagValue {
tag: "startTime".to_string(),
value: v,
});
}
if let Some(v) = self.end_time {
params.push(TagValue {
tag: "endTime".to_string(),
value: v,
});
}
if let Some(v) = self.no_take_liq {
params.push(TagValue {
tag: "noTakeLiq".to_string(),
value: bool_param(v),
});
}
Ok(AlgoParams {
strategy: "PctVol".to_string(),
params,
})
}
}
impl TryFrom<PctVolBuilder> for AlgoParams {
type Error = ValidationError;
fn try_from(builder: PctVolBuilder) -> Result<Self, Self::Error> {
builder.build()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum RiskAversion {
GetDone,
Aggressive,
#[default]
Neutral,
Passive,
}
impl RiskAversion {
fn as_str(&self) -> &'static str {
match self {
RiskAversion::GetDone => "Get Done",
RiskAversion::Aggressive => "Aggressive",
RiskAversion::Neutral => "Neutral",
RiskAversion::Passive => "Passive",
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ArrivalPriceBuilder {
max_pct_vol: Option<f64>,
risk_aversion: Option<RiskAversion>,
start_time: Option<String>,
end_time: Option<String>,
force_completion: Option<bool>,
allow_past_end_time: Option<bool>,
}
impl ArrivalPriceBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_pct_vol(mut self, pct: f64) -> Self {
self.max_pct_vol = Some(pct);
self
}
pub fn risk_aversion(mut self, risk: RiskAversion) -> Self {
self.risk_aversion = Some(risk);
self
}
pub fn start_time(mut self, time: impl Into<String>) -> Self {
self.start_time = Some(time.into());
self
}
pub fn end_time(mut self, time: impl Into<String>) -> Self {
self.end_time = Some(time.into());
self
}
pub fn force_completion(mut self, force: bool) -> Self {
self.force_completion = Some(force);
self
}
pub fn allow_past_end_time(mut self, allow: bool) -> Self {
self.allow_past_end_time = Some(allow);
self
}
pub fn build(self) -> Result<AlgoParams, ValidationError> {
let mut params = Vec::new();
if let Some(v) = self.max_pct_vol {
validate_pct_vol("max_pct_vol", v)?;
params.push(TagValue {
tag: "maxPctVol".to_string(),
value: v.to_string(),
});
}
if let Some(v) = self.risk_aversion {
params.push(TagValue {
tag: "riskAversion".to_string(),
value: v.as_str().to_string(),
});
}
if let Some(v) = self.start_time {
params.push(TagValue {
tag: "startTime".to_string(),
value: v,
});
}
if let Some(v) = self.end_time {
params.push(TagValue {
tag: "endTime".to_string(),
value: v,
});
}
if let Some(v) = self.force_completion {
params.push(TagValue {
tag: "forceCompletion".to_string(),
value: bool_param(v),
});
}
if let Some(v) = self.allow_past_end_time {
params.push(TagValue {
tag: "allowPastEndTime".to_string(),
value: bool_param(v),
});
}
Ok(AlgoParams {
strategy: "ArrivalPx".to_string(),
params,
})
}
}
impl TryFrom<ArrivalPriceBuilder> for AlgoParams {
type Error = ValidationError;
fn try_from(builder: ArrivalPriceBuilder) -> Result<Self, Self::Error> {
builder.build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_algo_params_from_string() {
let params: AlgoParams = "Vwap".into();
assert_eq!(params.strategy, "Vwap");
assert!(params.params.is_empty());
}
#[test]
fn test_vwap_builder() {
let params = VwapBuilder::new()
.max_pct_vol(0.2)
.start_time("09:00:00 US/Eastern")
.end_time("16:00:00 US/Eastern")
.allow_past_end_time(true)
.no_take_liq(true)
.speed_up(true)
.build()
.unwrap();
assert_eq!(params.strategy, "Vwap");
assert_eq!(params.params.len(), 6);
let find_param = |tag: &str| params.params.iter().find(|p| p.tag == tag).map(|p| &p.value);
assert_eq!(find_param("maxPctVol"), Some(&"0.2".to_string()));
assert_eq!(find_param("startTime"), Some(&"09:00:00 US/Eastern".to_string()));
assert_eq!(find_param("endTime"), Some(&"16:00:00 US/Eastern".to_string()));
assert_eq!(find_param("allowPastEndTime"), Some(&"1".to_string()));
assert_eq!(find_param("noTakeLiq"), Some(&"1".to_string()));
assert_eq!(find_param("speedUp"), Some(&"1".to_string()));
}
#[test]
fn test_twap_builder() {
let params = TwapBuilder::new()
.strategy_type(TwapStrategyType::MatchingMidpoint)
.start_time("09:00:00 US/Eastern")
.end_time("16:00:00 US/Eastern")
.allow_past_end_time(false)
.build()
.unwrap();
assert_eq!(params.strategy, "Twap");
assert_eq!(params.params.len(), 4);
let find_param = |tag: &str| params.params.iter().find(|p| p.tag == tag).map(|p| &p.value);
assert_eq!(find_param("strategyType"), Some(&"Matching Midpoint".to_string()));
assert_eq!(find_param("allowPastEndTime"), Some(&"0".to_string()));
}
#[test]
fn test_pct_vol_builder() {
let params = PctVolBuilder::new()
.pct_vol(0.15)
.start_time("09:30:00 US/Eastern")
.end_time("15:30:00 US/Eastern")
.no_take_liq(false)
.build()
.unwrap();
assert_eq!(params.strategy, "PctVol");
assert_eq!(params.params.len(), 4);
let find_param = |tag: &str| params.params.iter().find(|p| p.tag == tag).map(|p| &p.value);
assert_eq!(find_param("pctVol"), Some(&"0.15".to_string()));
assert_eq!(find_param("noTakeLiq"), Some(&"0".to_string()));
}
#[test]
fn test_arrival_price_builder() {
let params = ArrivalPriceBuilder::new()
.max_pct_vol(0.1)
.risk_aversion(RiskAversion::Aggressive)
.start_time("09:00:00 US/Eastern")
.end_time("16:00:00 US/Eastern")
.force_completion(true)
.allow_past_end_time(true)
.build()
.unwrap();
assert_eq!(params.strategy, "ArrivalPx");
assert_eq!(params.params.len(), 6);
let find_param = |tag: &str| params.params.iter().find(|p| p.tag == tag).map(|p| &p.value);
assert_eq!(find_param("riskAversion"), Some(&"Aggressive".to_string()));
assert_eq!(find_param("forceCompletion"), Some(&"1".to_string()));
}
#[test]
fn test_builder_minimal() {
let vwap = VwapBuilder::new().build().unwrap();
assert_eq!(vwap.strategy, "Vwap");
assert!(vwap.params.is_empty());
let twap = TwapBuilder::new().build().unwrap();
assert_eq!(twap.strategy, "Twap");
assert!(twap.params.is_empty());
}
#[test]
fn test_pct_vol_out_of_range_errors() {
let result = PctVolBuilder::new().pct_vol(0.8).build();
assert!(matches!(result, Err(ValidationError::InvalidPercentage { field: "pct_vol", .. })));
let result = VwapBuilder::new().max_pct_vol(1.0).build();
assert!(matches!(result, Err(ValidationError::InvalidPercentage { field: "max_pct_vol", .. })));
let result = PctVolBuilder::new().pct_vol(0.05).build();
assert!(matches!(result, Err(ValidationError::InvalidPercentage { field: "pct_vol", .. })));
let result = ArrivalPriceBuilder::new().max_pct_vol(0.01).build();
assert!(matches!(result, Err(ValidationError::InvalidPercentage { field: "max_pct_vol", .. })));
}
#[test]
fn test_pct_vol_valid_values_succeed() {
let params = PctVolBuilder::new().pct_vol(0.25).build().unwrap();
let find_param = |tag: &str| params.params.iter().find(|p| p.tag == tag).map(|p| &p.value);
assert_eq!(find_param("pctVol"), Some(&"0.25".to_string()));
let params = VwapBuilder::new().max_pct_vol(0.1).build().unwrap();
let find_param = |tag: &str| params.params.iter().find(|p| p.tag == tag).map(|p| &p.value);
assert_eq!(find_param("maxPctVol"), Some(&"0.1".to_string()));
let params = VwapBuilder::new().max_pct_vol(0.5).build().unwrap();
let find_param = |tag: &str| params.params.iter().find(|p| p.tag == tag).map(|p| &p.value);
assert_eq!(find_param("maxPctVol"), Some(&"0.5".to_string()));
}
#[test]
fn test_pct_vol_boundary_values() {
assert!(VwapBuilder::new().max_pct_vol(0.1).build().is_ok());
assert!(PctVolBuilder::new().pct_vol(0.1).build().is_ok());
assert!(ArrivalPriceBuilder::new().max_pct_vol(0.1).build().is_ok());
assert!(VwapBuilder::new().max_pct_vol(0.5).build().is_ok());
assert!(PctVolBuilder::new().pct_vol(0.5).build().is_ok());
assert!(ArrivalPriceBuilder::new().max_pct_vol(0.5).build().is_ok());
assert!(VwapBuilder::new().max_pct_vol(0.09).build().is_err());
assert!(VwapBuilder::new().max_pct_vol(0.51).build().is_err());
}
}