1use async_trait::async_trait;
2use serde::Serialize;
3use std::sync::Arc;
4
5use crate::config::AppConfig;
6use crate::executor::OrderResult;
7use crate::position_manager::PositionManager;
8use crate::signal::TradeSignal;
9use hyper_risk::risk::AccountState;
10
11#[derive(Debug, Clone, PartialEq)]
12pub enum TradingMode {
13 Paper,
14 Live,
15 DryRun,
16}
17
18impl std::fmt::Display for TradingMode {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 TradingMode::Paper => write!(f, "paper"),
22 TradingMode::Live => write!(f, "live"),
23 TradingMode::DryRun => write!(f, "dry-run"),
24 }
25 }
26}
27
28pub struct PipelineContext {
29 pub mode: TradingMode,
30 pub account_state: AccountState,
31 pub position_manager: Arc<PositionManager>,
32 pub config: Arc<AppConfig>,
33 pub execution_results: Vec<OrderResult>,
34}
35
36#[derive(Debug, Serialize)]
37pub enum PipelineError {
38 RiskBlocked(String),
39 CircuitBreakerTripped,
40 ExecutionFailed(String),
41 AuthError(String),
42 ConnectionError(String),
43}
44
45impl std::fmt::Display for PipelineError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Self::RiskBlocked(msg) => write!(f, "Risk blocked: {}", msg),
49 Self::CircuitBreakerTripped => write!(f, "Circuit breaker tripped"),
50 Self::ExecutionFailed(msg) => write!(f, "Execution failed: {}", msg),
51 Self::AuthError(msg) => write!(f, "Auth error: {}", msg),
52 Self::ConnectionError(msg) => write!(f, "Connection error: {}", msg),
53 }
54 }
55}
56
57impl std::error::Error for PipelineError {}
58
59impl PipelineError {
60 pub fn exit_code(&self) -> i32 {
61 match self {
62 Self::RiskBlocked(_) | Self::CircuitBreakerTripped => 4,
63 Self::AuthError(_) => 3,
64 Self::ExecutionFailed(_) | Self::ConnectionError(_) => 5,
65 }
66 }
67}
68
69#[async_trait]
70pub trait PipelineStage: Send + Sync {
71 async fn process(
72 &self,
73 signal: TradeSignal,
74 ctx: &mut PipelineContext,
75 ) -> Result<TradeSignal, PipelineError>;
76}
77
78pub struct OrderPipeline {
79 stages: Vec<Box<dyn PipelineStage>>,
80}
81
82impl OrderPipeline {
83 pub fn new(stages: Vec<Box<dyn PipelineStage>>) -> Self {
84 Self { stages }
85 }
86
87 pub async fn execute(
88 &self,
89 signal: TradeSignal,
90 ctx: &mut PipelineContext,
91 ) -> Result<TradeSignal, PipelineError> {
92 let mut s = signal;
93 for stage in &self.stages {
94 s = stage.process(s, ctx).await?;
95 }
96 Ok(s)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::signal::{Side, SignalAction, TradeSignal};
104
105 struct PassthroughStage;
106
107 #[async_trait]
108 impl PipelineStage for PassthroughStage {
109 async fn process(
110 &self,
111 signal: TradeSignal,
112 _ctx: &mut PipelineContext,
113 ) -> Result<TradeSignal, PipelineError> {
114 Ok(signal)
115 }
116 }
117
118 struct BlockingStage;
119
120 #[async_trait]
121 impl PipelineStage for BlockingStage {
122 async fn process(
123 &self,
124 _signal: TradeSignal,
125 _ctx: &mut PipelineContext,
126 ) -> Result<TradeSignal, PipelineError> {
127 Err(PipelineError::RiskBlocked("test block".into()))
128 }
129 }
130
131 fn test_ctx() -> PipelineContext {
132 PipelineContext {
133 mode: TradingMode::DryRun,
134 account_state: AccountState::default(),
135 position_manager: Arc::new(PositionManager::in_memory().unwrap()),
136 config: Arc::new(AppConfig::default()),
137 execution_results: vec![],
138 }
139 }
140
141 fn test_signal() -> TradeSignal {
142 TradeSignal::manual(
143 "BTC-PERP".into(),
144 SignalAction::Open {
145 side: Side::Buy,
146 size: 0.01,
147 price: Some(65000.0),
148 },
149 "test".into(),
150 )
151 }
152
153 #[tokio::test]
154 async fn pipeline_passthrough() {
155 let pipeline =
156 OrderPipeline::new(vec![Box::new(PassthroughStage), Box::new(PassthroughStage)]);
157 let mut ctx = test_ctx();
158 let result = pipeline.execute(test_signal(), &mut ctx).await;
159 assert!(result.is_ok());
160 }
161
162 #[tokio::test]
163 async fn pipeline_blocks_on_error() {
164 let pipeline = OrderPipeline::new(vec![
165 Box::new(PassthroughStage),
166 Box::new(BlockingStage),
167 Box::new(PassthroughStage),
168 ]);
169 let mut ctx = test_ctx();
170 let result = pipeline.execute(test_signal(), &mut ctx).await;
171 assert!(matches!(result, Err(PipelineError::RiskBlocked(_))));
172 }
173
174 #[test]
175 fn pipeline_error_exit_codes() {
176 assert_eq!(PipelineError::RiskBlocked("x".into()).exit_code(), 4);
177 assert_eq!(PipelineError::CircuitBreakerTripped.exit_code(), 4);
178 assert_eq!(PipelineError::AuthError("x".into()).exit_code(), 3);
179 assert_eq!(PipelineError::ExecutionFailed("x".into()).exit_code(), 5);
180 assert_eq!(PipelineError::ConnectionError("x".into()).exit_code(), 5);
181 }
182
183 #[tokio::test]
184 async fn empty_pipeline_passes_signal_through() {
185 let pipeline = OrderPipeline::new(vec![]);
186 let mut ctx = test_ctx();
187 let sig = test_signal();
188 let id = sig.id.clone();
189 let result = pipeline.execute(sig, &mut ctx).await.unwrap();
190 assert_eq!(result.id, id);
191 }
192}