Skip to main content

hyper_agent_core/
pipeline.rs

1use async_trait::async_trait;
2use serde::Serialize;
3use std::sync::Arc;
4
5use crate::config::AppConfig;
6use crate::executor::OrderResult;
7use crate::position_manager::PositionManager;
8use crate::signal::TradeSignal;
9use hyper_risk::risk::AccountState;
10
11#[derive(Debug, Clone, PartialEq)]
12pub enum TradingMode {
13    Paper,
14    Live,
15    DryRun,
16}
17
18impl std::fmt::Display for TradingMode {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            TradingMode::Paper => write!(f, "paper"),
22            TradingMode::Live => write!(f, "live"),
23            TradingMode::DryRun => write!(f, "dry-run"),
24        }
25    }
26}
27
28pub struct PipelineContext {
29    pub mode: TradingMode,
30    pub account_state: AccountState,
31    pub position_manager: Arc<PositionManager>,
32    pub config: Arc<AppConfig>,
33    pub execution_results: Vec<OrderResult>,
34}
35
36#[derive(Debug, Serialize)]
37pub enum PipelineError {
38    RiskBlocked(String),
39    CircuitBreakerTripped,
40    ExecutionFailed(String),
41    AuthError(String),
42    ConnectionError(String),
43}
44
45impl std::fmt::Display for PipelineError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match self {
48            Self::RiskBlocked(msg) => write!(f, "Risk blocked: {}", msg),
49            Self::CircuitBreakerTripped => write!(f, "Circuit breaker tripped"),
50            Self::ExecutionFailed(msg) => write!(f, "Execution failed: {}", msg),
51            Self::AuthError(msg) => write!(f, "Auth error: {}", msg),
52            Self::ConnectionError(msg) => write!(f, "Connection error: {}", msg),
53        }
54    }
55}
56
57impl std::error::Error for PipelineError {}
58
59impl PipelineError {
60    pub fn exit_code(&self) -> i32 {
61        match self {
62            Self::RiskBlocked(_) | Self::CircuitBreakerTripped => 4,
63            Self::AuthError(_) => 3,
64            Self::ExecutionFailed(_) | Self::ConnectionError(_) => 5,
65        }
66    }
67}
68
69#[async_trait]
70pub trait PipelineStage: Send + Sync {
71    async fn process(
72        &self,
73        signal: TradeSignal,
74        ctx: &mut PipelineContext,
75    ) -> Result<TradeSignal, PipelineError>;
76}
77
78pub struct OrderPipeline {
79    stages: Vec<Box<dyn PipelineStage>>,
80}
81
82impl OrderPipeline {
83    pub fn new(stages: Vec<Box<dyn PipelineStage>>) -> Self {
84        Self { stages }
85    }
86
87    pub async fn execute(
88        &self,
89        signal: TradeSignal,
90        ctx: &mut PipelineContext,
91    ) -> Result<TradeSignal, PipelineError> {
92        let mut s = signal;
93        for stage in &self.stages {
94            s = stage.process(s, ctx).await?;
95        }
96        Ok(s)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::signal::{Side, SignalAction, TradeSignal};
104
105    struct PassthroughStage;
106
107    #[async_trait]
108    impl PipelineStage for PassthroughStage {
109        async fn process(
110            &self,
111            signal: TradeSignal,
112            _ctx: &mut PipelineContext,
113        ) -> Result<TradeSignal, PipelineError> {
114            Ok(signal)
115        }
116    }
117
118    struct BlockingStage;
119
120    #[async_trait]
121    impl PipelineStage for BlockingStage {
122        async fn process(
123            &self,
124            _signal: TradeSignal,
125            _ctx: &mut PipelineContext,
126        ) -> Result<TradeSignal, PipelineError> {
127            Err(PipelineError::RiskBlocked("test block".into()))
128        }
129    }
130
131    fn test_ctx() -> PipelineContext {
132        PipelineContext {
133            mode: TradingMode::DryRun,
134            account_state: AccountState::default(),
135            position_manager: Arc::new(PositionManager::in_memory().unwrap()),
136            config: Arc::new(AppConfig::default()),
137            execution_results: vec![],
138        }
139    }
140
141    fn test_signal() -> TradeSignal {
142        TradeSignal::manual(
143            "BTC-PERP".into(),
144            SignalAction::Open {
145                side: Side::Buy,
146                size: 0.01,
147                price: Some(65000.0),
148            },
149            "test".into(),
150        )
151    }
152
153    #[tokio::test]
154    async fn pipeline_passthrough() {
155        let pipeline =
156            OrderPipeline::new(vec![Box::new(PassthroughStage), Box::new(PassthroughStage)]);
157        let mut ctx = test_ctx();
158        let result = pipeline.execute(test_signal(), &mut ctx).await;
159        assert!(result.is_ok());
160    }
161
162    #[tokio::test]
163    async fn pipeline_blocks_on_error() {
164        let pipeline = OrderPipeline::new(vec![
165            Box::new(PassthroughStage),
166            Box::new(BlockingStage),
167            Box::new(PassthroughStage),
168        ]);
169        let mut ctx = test_ctx();
170        let result = pipeline.execute(test_signal(), &mut ctx).await;
171        assert!(matches!(result, Err(PipelineError::RiskBlocked(_))));
172    }
173
174    #[test]
175    fn pipeline_error_exit_codes() {
176        assert_eq!(PipelineError::RiskBlocked("x".into()).exit_code(), 4);
177        assert_eq!(PipelineError::CircuitBreakerTripped.exit_code(), 4);
178        assert_eq!(PipelineError::AuthError("x".into()).exit_code(), 3);
179        assert_eq!(PipelineError::ExecutionFailed("x".into()).exit_code(), 5);
180        assert_eq!(PipelineError::ConnectionError("x".into()).exit_code(), 5);
181    }
182
183    #[tokio::test]
184    async fn empty_pipeline_passes_signal_through() {
185        let pipeline = OrderPipeline::new(vec![]);
186        let mut ctx = test_ctx();
187        let sig = test_signal();
188        let id = sig.id.clone();
189        let result = pipeline.execute(sig, &mut ctx).await.unwrap();
190        assert_eq!(result.id, id);
191    }
192}