Skip to main content

rustrade/
services.rs

1//! Optional framework-side services wired in via builder methods on
2//! [`Bot`](crate::Bot):
3//!
4//! - [`MarketFeedService`] — `Bot::with_market_source(...)`. Drives a
5//!   [`MarketSource`] under supervisor control; the source publishes
6//!   events to the in-process `MarketDataBus` (the bus reference is the
7//!   source implementor's responsibility — typically obtained via
8//!   `bot.market_data_bus().clone()` before construction).
9//! - [`FillRoutingService`] — `Bot::with_fill_source(...)`. Polls a
10//!   [`FillSource`], calls [`Brain::on_fill`] on each brain, refreshes
11//!   the per-symbol position cache from the exchange, and auto-feeds
12//!   realised PnL into the risk gates using weighted-average entry
13//!   accounting.
14//! - [`CandlePollerService`] — `Bot::with_candle_poller(...)`. Periodic
15//!   poll of a [`CandleSource`]; publishes the newest closed candle for
16//!   each `(symbol, interval)` pair to the market-data bus.
17
18use std::sync::Arc;
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::time::Duration;
21
22use async_trait::async_trait;
23use rustrade_core::{
24    Brain, CandleSource, Exchange, ExchangeClient, Fill, FillSource, MarketDataBus,
25    MarketDataEvent, MarketSource, MetricsSink, Side, Symbol,
26};
27use rustrade_supervisor::{RestartPolicy, TradingService};
28use tokio_util::sync::CancellationToken;
29
30use crate::risk_state::{PositionCache, RiskPersister, RiskStateMap};
31
32// ───────────────────────────────────────────────────────────────────────
33// MarketFeedService
34// ───────────────────────────────────────────────────────────────────────
35
36/// Drives a [`MarketSource`] under supervisor control.
37///
38/// The wrapper does not interact with the bus directly — the source's
39/// `run` method is expected to publish events to whatever bus it was
40/// constructed with. This service just makes the source restartable and
41/// drop-safe under the supervisor's cancellation contract.
42pub struct MarketFeedService {
43    name: String,
44    source: Arc<dyn MarketSource>,
45}
46
47impl MarketFeedService {
48    /// Wrap a [`MarketSource`] into a [`TradingService`].
49    pub fn new(source: Arc<dyn MarketSource>) -> Self {
50        let name = format!("market-feed[{}]", source.name());
51        Self { name, source }
52    }
53}
54
55#[async_trait]
56impl TradingService for MarketFeedService {
57    fn name(&self) -> &str {
58        &self.name
59    }
60
61    fn restart_policy(&self) -> RestartPolicy {
62        RestartPolicy::OnFailure
63    }
64
65    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
66        tracing::info!(service = %self.name, "market feed starting");
67        tokio::select! {
68            _ = cancel.cancelled() => {
69                tracing::info!(service = %self.name, "market feed cancelled");
70                Ok(())
71            }
72            r = self.source.run() => {
73                match &r {
74                    Ok(()) => tracing::info!(service = %self.name, "market feed exited cleanly"),
75                    Err(e) => tracing::warn!(service = %self.name, error = %e, "market feed exited with error"),
76                }
77                r.map_err(|e| anyhow::anyhow!("market source error: {e}"))
78            }
79        }
80    }
81}
82
83// ───────────────────────────────────────────────────────────────────────
84// FillRoutingService
85// ───────────────────────────────────────────────────────────────────────
86
87/// Routes fills from a [`FillSource`] to every brain, refreshes the
88/// position cache, and auto-feeds realised PnL into the risk state.
89///
90/// # PnL accounting
91///
92/// The service uses a **weighted-average entry** model (the same model
93/// the backtest engine uses). It reads the cached `Position` *before*
94/// refreshing it from the exchange, so the `entry_price` available is
95/// the pre-fill average. From that:
96///
97/// - A fill in the same direction as the open position **adds** to it.
98///   No realised PnL emitted; the post-refresh average from
99///   `exchange.get_position` becomes the new entry.
100/// - A fill in the opposite direction **reduces** the position. Gross
101///   PnL = `(fill_price - entry) * closed_qty * direction`. The
102///   service calls `BotHandle::record_trade_outcome` on the closed
103///   portion to feed `SessionPnl` + `CircuitBreaker`.
104/// - A fill that **flips** the position emits realised PnL for the
105///   closed portion only; the opening leg is left for the next
106///   reducing fill.
107///
108/// Fees come from `Fill.fee`. Hosts that need a different accounting
109/// model (FIFO, LIFO, tax-lot) should compute PnL themselves and call
110/// `BotHandle::record_trade_outcome` directly — but cannot also wire a
111/// `FillRoutingService`, since the two would double-count.
112pub struct FillRoutingService {
113    source: Arc<dyn FillSource>,
114    brains: Arc<Vec<Arc<dyn Brain>>>,
115    exchange: Arc<dyn ExchangeClient>,
116    positions: PositionCache,
117    risk: RiskStateMap,
118    metrics: Arc<dyn MetricsSink>,
119    persister: Option<RiskPersister>,
120    oco: Option<crate::order_tracker::OcoRegistry>,
121    fills_routed: AtomicU64,
122    refresh_errors: AtomicU64,
123    trades_recorded: AtomicU64,
124    oco_cancels: AtomicU64,
125}
126
127impl FillRoutingService {
128    #[allow(clippy::too_many_arguments)]
129    pub(crate) fn new(
130        source: Arc<dyn FillSource>,
131        brains: Arc<Vec<Arc<dyn Brain>>>,
132        exchange: Arc<dyn ExchangeClient>,
133        positions: PositionCache,
134        risk: RiskStateMap,
135        metrics: Arc<dyn MetricsSink>,
136        persister: Option<RiskPersister>,
137        oco: Option<crate::order_tracker::OcoRegistry>,
138    ) -> Self {
139        Self {
140            source,
141            brains,
142            exchange,
143            positions,
144            risk,
145            metrics,
146            persister,
147            oco,
148            fills_routed: AtomicU64::new(0),
149            refresh_errors: AtomicU64::new(0),
150            trades_recorded: AtomicU64::new(0),
151            oco_cancels: AtomicU64::new(0),
152        }
153    }
154
155    /// Total OCO siblings cancelled in response to a bracket leg filling.
156    pub fn oco_cancels(&self) -> u64 {
157        self.oco_cancels.load(Ordering::Relaxed)
158    }
159
160    /// Total fills delivered to brains since service start.
161    pub fn fills_routed(&self) -> u64 {
162        self.fills_routed.load(Ordering::Relaxed)
163    }
164
165    /// Total `exchange.get_position` failures during cache refresh.
166    pub fn refresh_errors(&self) -> u64 {
167        self.refresh_errors.load(Ordering::Relaxed)
168    }
169
170    /// Total realised-PnL closures fed into the risk state.
171    pub fn trades_recorded(&self) -> u64 {
172        self.trades_recorded.load(Ordering::Relaxed)
173    }
174
175    /// Compute realised PnL from a reducing fill and feed the risk state.
176    /// Returns the gross PnL portion attributable to this fill.
177    async fn maybe_record_pnl(&self, fill: &Fill, prior_qty: f64, prior_entry: Option<f64>) {
178        // Only reducing or flipping fills produce realised PnL.
179        let signed_fill_qty = match fill.side {
180            Side::Buy => fill.size.value(),
181            Side::Sell => -fill.size.value(),
182        };
183        if prior_qty == 0.0 || prior_qty.signum() == signed_fill_qty.signum() {
184            return;
185        }
186        let Some(entry) = prior_entry else {
187            // Reducing fill but no entry price recorded — can't compute
188            // PnL. Log and skip.
189            tracing::debug!(
190                symbol = %fill.symbol,
191                "reducing fill but cached position has no entry price; skipping auto-PnL"
192            );
193            return;
194        };
195        let closed_qty = prior_qty.abs().min(fill.size.value());
196        if closed_qty <= 0.0 {
197            return;
198        }
199        let direction = prior_qty.signum();
200        let gross = (fill.price.value() - entry) * direction * closed_qty;
201        // Apportion fee by closing fraction so a flip fill charges
202        // fees pro-rata to the closing portion.
203        let fee_share = if fill.size.value() > 0.0 {
204            fill.fee * (closed_qty / fill.size.value())
205        } else {
206            0.0
207        };
208
209        // Update the per-symbol risk state directly.
210        let recorded = {
211            let mut map = self.risk.write().await;
212            if let Some(risk) = map.get_mut(&fill.symbol) {
213                risk.session_pnl.record_close(gross, fee_share);
214                let net = gross - fee_share;
215                if net > 0.0 {
216                    risk.circuit_breaker.record_win();
217                } else if net < 0.0 {
218                    risk.circuit_breaker.record_loss();
219                }
220                self.trades_recorded.fetch_add(1, Ordering::Relaxed);
221                self.metrics.histogram(
222                    "rustrade_realised_pnl_quote",
223                    &[("symbol", fill.symbol.as_str())],
224                    net,
225                );
226                true
227            } else {
228                tracing::debug!(
229                    symbol = %fill.symbol,
230                    "auto-PnL: symbol not in risk-state map (was it configured?)"
231                );
232                false
233            }
234        };
235
236        // Persist the updated risk state (lock released) if a store is wired.
237        if recorded && let Some(persister) = &self.persister {
238            persister.persist_symbol(&self.risk, &fill.symbol).await;
239        }
240    }
241}
242
243#[async_trait]
244impl TradingService for FillRoutingService {
245    fn name(&self) -> &str {
246        "fill-routing"
247    }
248
249    fn restart_policy(&self) -> RestartPolicy {
250        RestartPolicy::OnFailure
251    }
252
253    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
254        tracing::info!("fill-routing service starting");
255        loop {
256            tokio::select! {
257                _ = cancel.cancelled() => {
258                    tracing::info!(
259                        routed = self.fills_routed(),
260                        refresh_errors = self.refresh_errors(),
261                        trades_recorded = self.trades_recorded(),
262                        "fill-routing service shutting down"
263                    );
264                    return Ok(());
265                }
266                next = self.source.next_fill() => {
267                    let Some(fill) = next else {
268                        tracing::info!("fill source closed; exiting");
269                        return Ok(());
270                    };
271
272                    let symbol = fill.symbol.clone();
273
274                    // OCO: if this fill belongs to a bracket leg, cancel its
275                    // sibling so the position isn't closed twice.
276                    if let Some(oco) = &self.oco
277                        && let Some((sym, sibling)) = oco.take_sibling(&fill.order_id).await
278                    {
279                        match self.exchange.cancel_order(&sym, &sibling).await {
280                            Ok(_) => {
281                                self.oco_cancels.fetch_add(1, Ordering::Relaxed);
282                                self.metrics.inc("rustrade_oco_cancels_total");
283                                tracing::info!(symbol = %sym, filled = %fill.order_id, cancelled = %sibling, "OCO: cancelled sibling after bracket leg filled");
284                            }
285                            Err(e) => tracing::warn!(symbol = %sym, sibling = %sibling, error = %e, "OCO: failed to cancel sibling (it may already be gone)"),
286                        }
287                    }
288
289                    // Snapshot the pre-fill position so we can compute
290                    // realised PnL before the exchange refreshes the
291                    // entry price.
292                    let (prior_qty, prior_entry) = {
293                        let map = self.positions.read().await;
294                        let p = map.get(&symbol).copied().unwrap_or(rustrade_core::Position::FLAT);
295                        (p.qty, p.entry_price)
296                    };
297
298                    // Route to every brain. Errors are logged but don't
299                    // stop the service — the brain's on_fill is
300                    // informational by contract.
301                    for brain in self.brains.iter() {
302                        if let Err(e) = brain.on_fill(&fill).await {
303                            tracing::warn!(
304                                brain = brain.name(),
305                                error = %e,
306                                "brain on_fill returned error"
307                            );
308                        }
309                    }
310
311                    self.maybe_record_pnl(&fill, prior_qty, prior_entry).await;
312
313                    // Refresh position cache from the exchange.
314                    match self.exchange.get_position(&symbol).await {
315                        Ok(p) => {
316                            self.positions.write().await.insert(symbol.clone(), p);
317                            tracing::debug!(symbol = %symbol, qty = p.qty, "refreshed position");
318                        }
319                        Err(e) => {
320                            self.refresh_errors.fetch_add(1, Ordering::Relaxed);
321                            self.metrics.inc("rustrade_position_refresh_errors_total");
322                            tracing::warn!(
323                                symbol = %symbol,
324                                error = %e,
325                                "failed to refresh position after fill"
326                            );
327                        }
328                    }
329
330                    self.fills_routed.fetch_add(1, Ordering::Relaxed);
331                    self.metrics.counter(
332                        "rustrade_fills_routed_total",
333                        &[("symbol", symbol.as_str())],
334                        1,
335                    );
336                }
337            }
338        }
339    }
340}
341
342// ───────────────────────────────────────────────────────────────────────
343// CandlePollerService
344// ───────────────────────────────────────────────────────────────────────
345
346/// Periodic poll of a [`CandleSource`] for a single `(symbol, interval)`
347/// pair. Publishes each newly-closed candle to the
348/// [`MarketDataBus`].
349///
350/// Per-symbol cadences are achieved by spawning multiple services —
351/// `Bot::with_candle_poller(...)` accepts repeated calls and spawns one
352/// service per registered tuple.
353///
354/// # Deduplication
355///
356/// The service tracks the highest `Candle::time` it has already
357/// published; only candles with a strictly greater timestamp are
358/// re-published. This is robust against exchanges that return overlapping
359/// windows on consecutive polls.
360pub struct CandlePollerService {
361    name: String,
362    source: Arc<dyn CandleSource>,
363    symbol: Symbol,
364    interval: Duration,
365    poll_cadence: Duration,
366    limit: usize,
367    bus: MarketDataBus,
368    metrics: Arc<dyn MetricsSink>,
369    last_time: std::sync::Mutex<i64>,
370    polled: AtomicU64,
371    poll_errors: AtomicU64,
372    published: AtomicU64,
373}
374
375impl CandlePollerService {
376    pub(crate) fn new(
377        source: Arc<dyn CandleSource>,
378        symbol: Symbol,
379        interval: Duration,
380        poll_cadence: Duration,
381        limit: usize,
382        bus: MarketDataBus,
383        metrics: Arc<dyn MetricsSink>,
384    ) -> Self {
385        let name = format!("candle-poller[{}@{}s]", symbol.as_str(), interval.as_secs());
386        Self {
387            name,
388            source,
389            symbol,
390            interval,
391            poll_cadence,
392            limit,
393            bus,
394            metrics,
395            last_time: std::sync::Mutex::new(i64::MIN),
396            polled: AtomicU64::new(0),
397            poll_errors: AtomicU64::new(0),
398            published: AtomicU64::new(0),
399        }
400    }
401
402    /// Total successful polls.
403    pub fn polled(&self) -> u64 {
404        self.polled.load(Ordering::Relaxed)
405    }
406    /// Total failed polls.
407    pub fn poll_errors(&self) -> u64 {
408        self.poll_errors.load(Ordering::Relaxed)
409    }
410    /// Total candles published (deduplicated).
411    pub fn published(&self) -> u64 {
412        self.published.load(Ordering::Relaxed)
413    }
414}
415
416#[async_trait]
417impl TradingService for CandlePollerService {
418    fn name(&self) -> &str {
419        &self.name
420    }
421
422    fn restart_policy(&self) -> RestartPolicy {
423        RestartPolicy::OnFailure
424    }
425
426    async fn run(&self, cancel: CancellationToken) -> anyhow::Result<()> {
427        tracing::info!(service = %self.name, "candle poller starting");
428        let exchange = Exchange::from(self.source.name());
429
430        loop {
431            tokio::select! {
432                _ = cancel.cancelled() => {
433                    tracing::info!(
434                        service = %self.name,
435                        polled = self.polled(),
436                        published = self.published(),
437                        errors = self.poll_errors(),
438                        "candle poller shutting down"
439                    );
440                    return Ok(());
441                }
442                _ = tokio::time::sleep(self.poll_cadence) => {
443                    match self.source.poll(&self.symbol, self.interval, self.limit).await {
444                        Ok(candles) => {
445                            self.polled.fetch_add(1, Ordering::Relaxed);
446                            let mut last = self.last_time.lock().expect("last_time poisoned");
447                            let mut new_high = *last;
448                            for candle in candles {
449                                if candle.time <= *last {
450                                    continue;
451                                }
452                                new_high = new_high.max(candle.time);
453                                self.bus.publish(MarketDataEvent::Candle {
454                                    exchange: exchange.clone(),
455                                    symbol: self.symbol.clone(),
456                                    candle,
457                                });
458                                self.published.fetch_add(1, Ordering::Relaxed);
459                                self.metrics.counter(
460                                    "rustrade_candles_published_total",
461                                    &[("symbol", self.symbol.as_str())],
462                                    1,
463                                );
464                            }
465                            *last = new_high;
466                        }
467                        Err(e) => {
468                            self.poll_errors.fetch_add(1, Ordering::Relaxed);
469                            self.metrics.inc("rustrade_candle_poll_errors_total");
470                            tracing::warn!(
471                                service = %self.name,
472                                error = %e,
473                                "candle poll failed"
474                            );
475                        }
476                    }
477                }
478            }
479        }
480    }
481}