use std::collections::HashMap;
use nautilus_core::{UnixNanos, python::to_pyvalue_err};
use nautilus_model::{
data::{
QuoteTick,
option_chain::{OptionChainSlice, OptionGreeks},
},
enums::OptionKind,
identifiers::{InstrumentId, OptionSeriesId},
python::data::option_chain::PyStrikeRange,
types::Price,
};
use pyo3::prelude::*;
use crate::option_chains::{AtmTracker, OptionChainAggregator};
fn parse_option_kind(value: u8) -> PyResult<OptionKind> {
match value {
0 => Ok(OptionKind::Call),
1 => Ok(OptionKind::Put),
_ => Err(to_pyvalue_err(format!(
"invalid `OptionKind` value, expected 0 (Call) or 1 (Put), received {value}"
))),
}
}
#[pyclass(
name = "OptionChainManager",
module = "nautilus_trader.core.nautilus_pyo3.data"
)]
#[pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.data")]
#[derive(Debug)]
pub struct PyOptionChainManager {
aggregator: OptionChainAggregator,
series_id: OptionSeriesId,
raw_mode: bool,
bootstrapped: bool,
}
#[pymethods]
#[pyo3_stub_gen::derive::gen_stub_pymethods]
impl PyOptionChainManager {
#[new]
#[pyo3(signature = (series_id, strike_range, instruments, snapshot_interval_ms=None, initial_atm_price=None))]
fn py_new(
series_id: OptionSeriesId,
strike_range: PyStrikeRange,
instruments: HashMap<InstrumentId, (Price, u8)>,
snapshot_interval_ms: Option<u64>,
initial_atm_price: Option<Price>,
) -> PyResult<Self> {
let rust_instruments: HashMap<InstrumentId, (Price, OptionKind)> = instruments
.into_iter()
.map(|(id, (strike, kind_u8))| {
parse_option_kind(kind_u8).map(|kind| (id, (strike, kind)))
})
.collect::<PyResult<_>>()?;
let mut tracker = AtmTracker::new();
if let Some((strike, _)) = rust_instruments.values().next() {
tracker.set_forward_precision(strike.precision);
}
if let Some(price) = initial_atm_price {
tracker.set_initial_price(price);
}
let aggregator =
OptionChainAggregator::new(series_id, strike_range.inner, tracker, rust_instruments);
let active_ids = aggregator.instrument_ids();
let all_ids = aggregator.all_instrument_ids();
let bootstrapped = !active_ids.is_empty() || all_ids.is_empty();
let raw_mode = snapshot_interval_ms.is_none();
Ok(Self {
aggregator,
series_id,
raw_mode,
bootstrapped,
})
}
#[pyo3(name = "handle_quote")]
fn py_handle_quote(&mut self, quote: &Bound<'_, PyAny>) -> PyResult<bool> {
let tick = quote
.extract::<QuoteTick>()
.or_else(|_| QuoteTick::from_pyobject(quote))?;
self.aggregator.update_quote(&tick);
if !self.bootstrapped && self.aggregator.atm_tracker().atm_price().is_some() {
self.aggregator.recompute_active_set();
self.bootstrapped = true;
return Ok(true);
}
Ok(false)
}
#[pyo3(name = "handle_greeks")]
fn py_handle_greeks(&mut self, greeks_obj: &Bound<'_, PyAny>) -> PyResult<bool> {
let greeks = greeks_obj
.extract::<OptionGreeks>()
.or_else(|_| OptionGreeks::from_pyobject(greeks_obj))?;
self.aggregator
.atm_tracker_mut()
.update_from_option_greeks(&greeks);
self.aggregator.update_greeks(&greeks);
if !self.bootstrapped && self.aggregator.atm_tracker().atm_price().is_some() {
self.aggregator.recompute_active_set();
self.bootstrapped = true;
return Ok(true);
}
Ok(false)
}
#[pyo3(name = "snapshot")]
fn py_snapshot(&self, ts_ns: u64) -> Option<OptionChainSlice> {
if self.aggregator.is_buffer_empty() {
return None;
}
Some(self.aggregator.snapshot(UnixNanos::from(ts_ns)))
}
#[pyo3(name = "check_rebalance")]
fn py_check_rebalance(&mut self, ts_ns: u64) -> Option<(Vec<InstrumentId>, Vec<InstrumentId>)> {
let now = UnixNanos::from(ts_ns);
let action = self.aggregator.check_rebalance(now)?;
let add = action.add.clone();
let remove = action.remove.clone();
self.aggregator.apply_rebalance(&action, now);
Some((add, remove))
}
#[pyo3(name = "active_instrument_ids")]
fn py_active_instrument_ids(&self) -> Vec<InstrumentId> {
self.aggregator.instrument_ids()
}
#[pyo3(name = "all_instrument_ids")]
fn py_all_instrument_ids(&self) -> Vec<InstrumentId> {
self.aggregator.all_instrument_ids()
}
#[pyo3(name = "add_instrument")]
fn py_add_instrument(
&mut self,
instrument_id: InstrumentId,
strike: Price,
kind: u8,
) -> PyResult<bool> {
let option_kind = parse_option_kind(kind)?;
Ok(self
.aggregator
.add_instrument(instrument_id, strike, option_kind))
}
#[pyo3(name = "remove_instrument")]
fn py_remove_instrument(&mut self, instrument_id: InstrumentId) -> bool {
let _ = self.aggregator.remove_instrument(&instrument_id);
self.aggregator.is_catalog_empty()
}
#[getter]
#[pyo3(name = "series_id")]
fn py_series_id(&self) -> OptionSeriesId {
self.series_id
}
#[getter]
#[pyo3(name = "bootstrapped")]
fn py_bootstrapped(&self) -> bool {
self.bootstrapped
}
#[getter]
#[pyo3(name = "raw_mode")]
fn py_raw_mode(&self) -> bool {
self.raw_mode
}
#[getter]
#[pyo3(name = "atm_price")]
fn py_atm_price(&self) -> Option<Price> {
self.aggregator.atm_tracker().atm_price()
}
fn __repr__(&self) -> String {
format!(
"OptionChainManager(series_id={}, bootstrapped={}, raw_mode={}, \
active={}/{})",
self.series_id,
self.bootstrapped,
self.raw_mode,
self.aggregator.instrument_ids().len(),
self.aggregator.all_instrument_ids().len(),
)
}
}