nautilus_data/python/
option_chain_manager.rs1use std::collections::HashMap;
24
25use nautilus_core::{UnixNanos, python::to_pyvalue_err};
26use nautilus_model::{
27 data::{
28 QuoteTick,
29 option_chain::{OptionChainSlice, OptionGreeks},
30 },
31 enums::OptionKind,
32 identifiers::{InstrumentId, OptionSeriesId},
33 python::data::option_chain::PyStrikeRange,
34 types::Price,
35};
36use pyo3::prelude::*;
37
38use crate::option_chains::{AtmTracker, OptionChainAggregator};
39
40fn parse_option_kind(value: u8) -> PyResult<OptionKind> {
41 match value {
42 0 => Ok(OptionKind::Call),
43 1 => Ok(OptionKind::Put),
44 _ => Err(to_pyvalue_err(format!(
45 "invalid `OptionKind` value, expected 0 (Call) or 1 (Put), received {value}"
46 ))),
47 }
48}
49
50#[pyclass(
61 name = "OptionChainManager",
62 module = "nautilus_trader.core.nautilus_pyo3.data"
63)]
64#[pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.data")]
65#[derive(Debug)]
66pub struct PyOptionChainManager {
67 aggregator: OptionChainAggregator,
68 series_id: OptionSeriesId,
69 raw_mode: bool,
70 bootstrapped: bool,
71}
72
73#[pymethods]
74#[pyo3_stub_gen::derive::gen_stub_pymethods]
75impl PyOptionChainManager {
76 #[new]
78 #[pyo3(signature = (series_id, strike_range, instruments, snapshot_interval_ms=None, initial_atm_price=None))]
79 fn py_new(
80 series_id: OptionSeriesId,
81 strike_range: PyStrikeRange,
82 instruments: HashMap<InstrumentId, (Price, u8)>,
83 snapshot_interval_ms: Option<u64>,
84 initial_atm_price: Option<Price>,
85 ) -> PyResult<Self> {
86 let rust_instruments: HashMap<InstrumentId, (Price, OptionKind)> = instruments
87 .into_iter()
88 .map(|(id, (strike, kind_u8))| {
89 parse_option_kind(kind_u8).map(|kind| (id, (strike, kind)))
90 })
91 .collect::<PyResult<_>>()?;
92
93 let mut tracker = AtmTracker::new();
94
95 if let Some((strike, _)) = rust_instruments.values().next() {
97 tracker.set_forward_precision(strike.precision);
98 }
99
100 if let Some(price) = initial_atm_price {
101 tracker.set_initial_price(price);
102 }
103
104 let aggregator =
105 OptionChainAggregator::new(series_id, strike_range.inner, tracker, rust_instruments);
106
107 let active_ids = aggregator.instrument_ids();
108 let all_ids = aggregator.all_instrument_ids();
109 let bootstrapped = !active_ids.is_empty() || all_ids.is_empty();
110 let raw_mode = snapshot_interval_ms.is_none();
111
112 Ok(Self {
113 aggregator,
114 series_id,
115 raw_mode,
116 bootstrapped,
117 })
118 }
119
120 #[pyo3(name = "handle_quote")]
125 fn py_handle_quote(&mut self, quote: &Bound<'_, PyAny>) -> PyResult<bool> {
126 let tick = quote
127 .extract::<QuoteTick>()
128 .or_else(|_| QuoteTick::from_pyobject(quote))?;
129 self.aggregator.update_quote(&tick);
130
131 if !self.bootstrapped && self.aggregator.atm_tracker().atm_price().is_some() {
132 self.aggregator.recompute_active_set();
133 self.bootstrapped = true;
134 return Ok(true);
135 }
136 Ok(false)
137 }
138
139 #[pyo3(name = "handle_greeks")]
144 fn py_handle_greeks(&mut self, greeks_obj: &Bound<'_, PyAny>) -> PyResult<bool> {
145 let greeks = greeks_obj
146 .extract::<OptionGreeks>()
147 .or_else(|_| OptionGreeks::from_pyobject(greeks_obj))?;
148
149 self.aggregator
151 .atm_tracker_mut()
152 .update_from_option_greeks(&greeks);
153
154 self.aggregator.update_greeks(&greeks);
156
157 if !self.bootstrapped && self.aggregator.atm_tracker().atm_price().is_some() {
158 self.aggregator.recompute_active_set();
159 self.bootstrapped = true;
160 return Ok(true);
161 }
162 Ok(false)
163 }
164
165 #[pyo3(name = "snapshot")]
169 fn py_snapshot(&self, ts_ns: u64) -> Option<OptionChainSlice> {
170 if self.aggregator.is_buffer_empty() {
171 return None;
172 }
173 Some(self.aggregator.snapshot(UnixNanos::from(ts_ns)))
174 }
175
176 #[pyo3(name = "check_rebalance")]
182 fn py_check_rebalance(&mut self, ts_ns: u64) -> Option<(Vec<InstrumentId>, Vec<InstrumentId>)> {
183 let now = UnixNanos::from(ts_ns);
184 let action = self.aggregator.check_rebalance(now)?;
185 let add = action.add.clone();
186 let remove = action.remove.clone();
187 self.aggregator.apply_rebalance(&action, now);
188 Some((add, remove))
189 }
190
191 #[pyo3(name = "active_instrument_ids")]
193 fn py_active_instrument_ids(&self) -> Vec<InstrumentId> {
194 self.aggregator.instrument_ids()
195 }
196
197 #[pyo3(name = "all_instrument_ids")]
199 fn py_all_instrument_ids(&self) -> Vec<InstrumentId> {
200 self.aggregator.all_instrument_ids()
201 }
202
203 #[pyo3(name = "add_instrument")]
207 fn py_add_instrument(
208 &mut self,
209 instrument_id: InstrumentId,
210 strike: Price,
211 kind: u8,
212 ) -> PyResult<bool> {
213 let option_kind = parse_option_kind(kind)?;
214 Ok(self
215 .aggregator
216 .add_instrument(instrument_id, strike, option_kind))
217 }
218
219 #[pyo3(name = "remove_instrument")]
223 fn py_remove_instrument(&mut self, instrument_id: InstrumentId) -> bool {
224 let _ = self.aggregator.remove_instrument(&instrument_id);
225 self.aggregator.is_catalog_empty()
226 }
227
228 #[getter]
229 #[pyo3(name = "series_id")]
230 fn py_series_id(&self) -> OptionSeriesId {
231 self.series_id
232 }
233
234 #[getter]
235 #[pyo3(name = "bootstrapped")]
236 fn py_bootstrapped(&self) -> bool {
237 self.bootstrapped
238 }
239
240 #[getter]
241 #[pyo3(name = "raw_mode")]
242 fn py_raw_mode(&self) -> bool {
243 self.raw_mode
244 }
245
246 #[getter]
247 #[pyo3(name = "atm_price")]
248 fn py_atm_price(&self) -> Option<Price> {
249 self.aggregator.atm_tracker().atm_price()
250 }
251
252 fn __repr__(&self) -> String {
253 format!(
254 "OptionChainManager(series_id={}, bootstrapped={}, raw_mode={}, \
255 active={}/{})",
256 self.series_id,
257 self.bootstrapped,
258 self.raw_mode,
259 self.aggregator.instrument_ids().len(),
260 self.aggregator.all_instrument_ids().len(),
261 )
262 }
263}