use std::collections::HashMap;
use chrono::{DateTime, FixedOffset, Utc};
use thiserror::Error;
use crate::unified_data::{OrderRequest, OrderSide, OrderType, Position};
#[derive(Debug, Clone)]
pub struct RiskConfig {
pub max_position_size_pct: f64,
pub stop_loss_pct: f64,
pub take_profit_pct: f64,
}
impl Default for RiskConfig {
fn default() -> Self {
Self {
max_position_size_pct: 0.1,
stop_loss_pct: 0.05,
take_profit_pct: 0.1,
}
}
}
#[derive(Debug, Error, Clone)]
pub enum RiskError {
#[error("position size exceeds configured limit: {message}")]
PositionSizeExceeded { message: String },
#[error("trading is halted by the emergency stop toggle")]
TradingHalted,
}
pub type Result<T> = std::result::Result<T, RiskError>;
#[derive(Debug, Clone)]
pub struct RiskOrder {
pub parent_order_id: String,
pub symbol: String,
pub side: OrderSide,
pub order_type: OrderType,
pub quantity: f64,
pub trigger_price: f64,
pub is_stop_loss: bool,
pub is_take_profit: bool,
pub created_at: DateTime<FixedOffset>,
}
impl RiskOrder {
fn new(
parent_order_id: &str,
symbol: &str,
side: OrderSide,
quantity: f64,
trigger_price: f64,
is_stop_loss: bool,
is_take_profit: bool,
) -> Self {
Self {
parent_order_id: parent_order_id.to_string(),
symbol: symbol.to_string(),
side,
order_type: OrderType::Market,
quantity,
trigger_price,
is_stop_loss,
is_take_profit,
created_at: Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap()),
}
}
}
#[derive(Debug, Clone)]
pub struct RiskManager {
config: RiskConfig,
portfolio_value: f64,
stop_losses: Vec<RiskOrder>,
take_profits: Vec<RiskOrder>,
emergency_stop: bool,
}
impl RiskManager {
pub fn new(config: RiskConfig, portfolio_value: f64) -> Self {
Self {
config,
portfolio_value,
stop_losses: Vec::new(),
take_profits: Vec::new(),
emergency_stop: false,
}
}
pub fn config(&self) -> &RiskConfig {
&self.config
}
pub fn update_portfolio_value(
&mut self,
new_value: f64,
_realized_pnl_delta: f64,
) -> Result<()> {
self.portfolio_value = new_value.max(0.0);
Ok(())
}
pub fn validate_order(
&self,
order: &OrderRequest,
_positions: &HashMap<String, Position>,
) -> Result<()> {
if self.emergency_stop {
return Err(RiskError::TradingHalted);
}
if let Some(price) = order.price {
let notional = price * order.quantity.abs();
let max_notional = self.config.max_position_size_pct * self.portfolio_value;
if max_notional > 0.0 && notional > max_notional {
return Err(RiskError::PositionSizeExceeded {
message: format!(
"order notional {:.2} exceeds {:.2} ({:.2}% of portfolio)",
notional,
max_notional,
self.config.max_position_size_pct * 100.0,
),
});
}
}
Ok(())
}
pub fn generate_stop_loss(&self, position: &Position, order_id: &str) -> Option<RiskOrder> {
if position.size == 0.0 || self.config.stop_loss_pct <= 0.0 {
return None;
}
let trigger_price = if position.size > 0.0 {
position.entry_price * (1.0 - self.config.stop_loss_pct)
} else {
position.entry_price * (1.0 + self.config.stop_loss_pct)
};
let side = if position.size > 0.0 {
OrderSide::Sell
} else {
OrderSide::Buy
};
Some(RiskOrder::new(
order_id,
&position.symbol,
side,
position.size.abs(),
trigger_price,
true,
false,
))
}
pub fn generate_take_profit(&self, position: &Position, order_id: &str) -> Option<RiskOrder> {
if position.size == 0.0 || self.config.take_profit_pct <= 0.0 {
return None;
}
let trigger_price = if position.size > 0.0 {
position.entry_price * (1.0 + self.config.take_profit_pct)
} else {
position.entry_price * (1.0 - self.config.take_profit_pct)
};
let side = if position.size > 0.0 {
OrderSide::Sell
} else {
OrderSide::Buy
};
Some(RiskOrder::new(
order_id,
&position.symbol,
side,
position.size.abs(),
trigger_price,
false,
true,
))
}
pub fn register_stop_loss(&mut self, order: RiskOrder) {
self.stop_losses.push(order);
}
pub fn register_take_profit(&mut self, order: RiskOrder) {
self.take_profits.push(order);
}
pub fn check_risk_orders(&mut self, current_prices: &HashMap<String, f64>) -> Vec<RiskOrder> {
fn should_trigger(order: &RiskOrder, price: f64) -> bool {
if order.is_stop_loss {
match order.side {
OrderSide::Sell => price <= order.trigger_price,
OrderSide::Buy => price >= order.trigger_price,
}
} else if order.is_take_profit {
match order.side {
OrderSide::Sell => price >= order.trigger_price,
OrderSide::Buy => price <= order.trigger_price,
}
} else {
false
}
}
let mut triggered = Vec::new();
self.stop_losses.retain(|order| {
if let Some(price) = current_prices.get(&order.symbol) {
if should_trigger(order, *price) {
triggered.push(order.clone());
return false;
}
}
true
});
self.take_profits.retain(|order| {
if let Some(price) = current_prices.get(&order.symbol) {
if should_trigger(order, *price) {
triggered.push(order.clone());
return false;
}
}
true
});
triggered
}
pub fn activate_emergency_stop(&mut self) {
self.emergency_stop = true;
}
pub fn deactivate_emergency_stop(&mut self) {
self.emergency_stop = false;
}
pub fn should_stop_trading(&self) -> bool {
self.emergency_stop
}
}