use std::{
collections::HashMap,
hash::{Hash, Hasher},
str::FromStr,
};
use nautilus_core::{
UnixNanos,
python::{IntoPyObjectNautilusExt, to_pykey_err, to_pyvalue_err},
serialization::{
Serializable,
msgpack::{FromMsgPack, ToMsgPack},
},
};
use pyo3::{
prelude::*,
pyclass::CompareOp,
types::{PyString, PyTuple},
};
use rust_decimal::Decimal;
use crate::{data::FundingRateUpdate, identifiers::InstrumentId, python::common::PY_MODULE_MODEL};
#[pymethods]
#[pyo3_stub_gen::derive::gen_stub_pymethods]
impl FundingRateUpdate {
#[new]
#[pyo3(signature = (instrument_id, rate, ts_event, ts_init, interval=None, next_funding_ns=None))]
fn py_new(
instrument_id: InstrumentId,
rate: Decimal,
ts_event: u64,
ts_init: u64,
interval: Option<u16>,
next_funding_ns: Option<u64>,
) -> Self {
let ts_event_nanos = UnixNanos::from(ts_event);
let ts_init_nanos = UnixNanos::from(ts_init);
let next_funding_nanos = next_funding_ns.map(UnixNanos::from);
Self::new(
instrument_id,
rate,
interval,
next_funding_nanos,
ts_event_nanos,
ts_init_nanos,
)
}
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
match op {
CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
_ => py.NotImplemented(),
}
}
fn __repr__(&self) -> String {
format!("{self:?}")
}
fn __str__(&self) -> String {
format!("{self}")
}
fn __hash__(&self) -> isize {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
Hash::hash(self, &mut hasher);
Hasher::finish(&hasher) as isize
}
#[getter]
#[pyo3(name = "instrument_id")]
fn py_instrument_id(&self) -> InstrumentId {
self.instrument_id
}
#[getter]
#[pyo3(name = "rate")]
fn py_rate(&self) -> Decimal {
self.rate
}
#[getter]
#[pyo3(name = "interval")]
fn py_interval(&self) -> Option<u16> {
self.interval
}
#[getter]
#[pyo3(name = "next_funding_ns")]
fn py_next_funding_ns(&self) -> Option<u64> {
self.next_funding_ns.map(|ts| ts.as_u64())
}
#[getter]
#[pyo3(name = "ts_event")]
fn py_ts_event(&self) -> u64 {
self.ts_event.as_u64()
}
#[getter]
#[pyo3(name = "ts_init")]
fn py_ts_init(&self) -> u64 {
self.ts_init.as_u64()
}
#[staticmethod]
#[pyo3(name = "fully_qualified_name")]
fn py_fully_qualified_name() -> String {
format!("{}:{}", PY_MODULE_MODEL, stringify!(FundingRateUpdate))
}
#[staticmethod]
#[pyo3(name = "get_metadata")]
fn py_get_metadata(instrument_id: &InstrumentId) -> HashMap<String, String> {
Self::get_metadata(instrument_id)
}
#[staticmethod]
#[pyo3(name = "get_fields")]
fn py_get_fields() -> HashMap<String, String> {
Self::get_fields().into_iter().collect()
}
#[pyo3(name = "to_dict")]
fn py_to_dict(&self, py: Python<'_>) -> Py<PyAny> {
let mut dict = HashMap::new();
dict.insert(
"type".to_string(),
"FundingRateUpdate".into_py_any_unwrap(py),
);
dict.insert(
"instrument_id".to_string(),
self.instrument_id.to_string().into_py_any_unwrap(py),
);
dict.insert(
"rate".to_string(),
self.rate.to_string().into_py_any_unwrap(py),
);
if let Some(interval) = self.interval {
dict.insert("interval".to_string(), interval.into_py_any_unwrap(py));
}
if let Some(next_funding_ns) = self.next_funding_ns {
dict.insert(
"next_funding_ns".to_string(),
next_funding_ns.as_u64().into_py_any_unwrap(py),
);
}
dict.insert(
"ts_event".to_string(),
self.ts_event.as_u64().into_py_any_unwrap(py),
);
dict.insert(
"ts_init".to_string(),
self.ts_init.as_u64().into_py_any_unwrap(py),
);
dict.into_py_any_unwrap(py)
}
#[staticmethod]
#[pyo3(name = "from_dict")]
#[allow(clippy::needless_pass_by_value)]
fn py_from_dict(py: Python<'_>, values: Py<PyAny>) -> PyResult<Self> {
let dict = values.cast_bound::<pyo3::types::PyDict>(py)?;
let instrument_id_str: String = dict
.get_item("instrument_id")?
.ok_or_else(|| to_pykey_err("Missing 'instrument_id' field"))?
.extract()?;
let instrument_id = InstrumentId::from_str(&instrument_id_str).map_err(to_pyvalue_err)?;
let rate_str: String = dict
.get_item("rate")?
.ok_or_else(|| to_pykey_err("Missing 'rate' field"))?
.extract()?;
let rate = Decimal::from_str(&rate_str).map_err(to_pyvalue_err)?;
let ts_event: u64 = dict
.get_item("ts_event")?
.ok_or_else(|| to_pykey_err("Missing 'ts_event' field"))?
.extract()?;
let ts_init: u64 = dict
.get_item("ts_init")?
.ok_or_else(|| to_pykey_err("Missing 'ts_init' field"))?
.extract()?;
let interval: Option<u16> = dict
.get_item("interval")
.ok()
.flatten()
.and_then(|v| v.extract().ok());
let next_funding_ns: Option<u64> = dict
.get_item("next_funding_ns")
.ok()
.flatten()
.and_then(|v| v.extract().ok());
Ok(Self::new(
instrument_id,
rate,
interval,
next_funding_ns.map(UnixNanos::from),
UnixNanos::from(ts_event),
UnixNanos::from(ts_init),
))
}
#[pyo3(name = "to_json")]
fn py_to_json(&self) -> PyResult<Vec<u8>> {
self.to_json_bytes()
.map(|b| b.to_vec())
.map_err(to_pyvalue_err)
}
#[pyo3(name = "to_msgpack")]
fn py_to_msgpack(&self) -> PyResult<Vec<u8>> {
self.to_msgpack_bytes()
.map(|b| b.to_vec())
.map_err(to_pyvalue_err)
}
fn __setstate__(&mut self, state: &Bound<'_, PyAny>) -> PyResult<()> {
let py_tuple: &Bound<'_, PyTuple> = state.cast::<PyTuple>()?;
let item0 = py_tuple.get_item(0)?;
let instrument_id_str: String = item0.cast::<PyString>()?.extract()?;
let item1 = py_tuple.get_item(1)?;
let rate_str: String = item1.cast::<PyString>()?.extract()?;
let interval: Option<u16> = py_tuple.get_item(2).ok().and_then(|item| {
if item.is_none() {
None
} else {
item.extract().ok()
}
});
let next_funding_ns: Option<u64> = py_tuple.get_item(3).ok().and_then(|item| {
if item.is_none() {
None
} else {
item.extract().ok()
}
});
let ts_event: u64 = py_tuple.get_item(4)?.extract()?;
let ts_init: u64 = py_tuple.get_item(5)?.extract()?;
self.instrument_id = InstrumentId::from_str(&instrument_id_str).map_err(to_pyvalue_err)?;
self.rate = Decimal::from_str(&rate_str).map_err(to_pyvalue_err)?;
self.interval = interval;
self.next_funding_ns = next_funding_ns.map(UnixNanos::from);
self.ts_event = UnixNanos::from(ts_event);
self.ts_init = UnixNanos::from(ts_init);
Ok(())
}
fn __getstate__(&self, py: Python) -> Py<PyAny> {
(
self.instrument_id.to_string(),
self.rate.to_string(),
self.interval,
self.next_funding_ns.map(|ts| ts.as_u64()),
self.ts_event.as_u64(),
self.ts_init.as_u64(),
)
.into_py_any_unwrap(py)
}
fn __reduce__(&self, py: Python) -> PyResult<Py<PyAny>> {
let safe_constructor = py.get_type::<Self>().getattr("_safe_constructor")?;
let state = self.__getstate__(py);
Ok((safe_constructor, PyTuple::empty(py), state).into_py_any_unwrap(py))
}
#[staticmethod]
#[pyo3(name = "_safe_constructor")]
fn py_safe_constructor() -> Self {
Self::new(
InstrumentId::from("NULL.NULL"),
Decimal::ZERO,
None,
None,
UnixNanos::default(),
UnixNanos::default(),
)
}
}
#[pymethods]
impl FundingRateUpdate {
#[pyo3(name = "from_json")]
#[staticmethod]
fn py_from_json(data: &[u8]) -> PyResult<Self> {
Self::from_json_bytes(data).map_err(to_pyvalue_err)
}
#[pyo3(name = "from_msgpack")]
#[staticmethod]
fn py_from_msgpack(data: &[u8]) -> PyResult<Self> {
Self::from_msgpack_bytes(data).map_err(to_pyvalue_err)
}
}
impl FundingRateUpdate {
pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let instrument_id_obj: Bound<'_, PyAny> = obj.getattr("instrument_id")?.extract()?;
let instrument_id_str: String = instrument_id_obj.getattr("value")?.extract()?;
let instrument_id =
InstrumentId::from_str(instrument_id_str.as_str()).map_err(to_pyvalue_err)?;
let rate: Decimal = obj.getattr("rate")?.extract()?;
let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
let interval: Option<u16> = obj.getattr("interval").ok().and_then(|x| x.extract().ok());
let next_funding_ns: Option<u64> = obj
.getattr("next_funding_ns")
.ok()
.and_then(|x| x.extract().ok());
Ok(Self::new(
instrument_id,
rate,
interval,
next_funding_ns.map(UnixNanos::from),
UnixNanos::from(ts_event),
UnixNanos::from(ts_init),
))
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
fn test_py_funding_rate_update_new() {
Python::initialize();
Python::attach(|_py| {
let instrument_id = InstrumentId::from("BTCUSDT-PERP.BINANCE");
let rate = Decimal::new(1, 4); let ts_event = UnixNanos::from(1_640_000_000_000_000_000_u64);
let ts_init = UnixNanos::from(1_640_000_000_000_000_000_u64);
let funding_rate = FundingRateUpdate::py_new(
instrument_id,
rate,
ts_event.as_u64(),
ts_init.as_u64(),
None,
None,
);
assert_eq!(funding_rate.instrument_id, instrument_id);
assert_eq!(funding_rate.rate, rate);
assert_eq!(funding_rate.interval, None);
assert_eq!(funding_rate.next_funding_ns, None);
assert_eq!(funding_rate.ts_event, ts_event);
assert_eq!(funding_rate.ts_init, ts_init);
});
}
}