Skip to main content

nautilus_model/python/data/
trade.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, hash_map::DefaultHasher},
18    hash::{Hash, Hasher},
19    str::FromStr,
20};
21
22use nautilus_core::{
23    UnixNanos,
24    python::{
25        IntoPyObjectNautilusExt,
26        serialization::{from_dict_pyo3, to_dict_pyo3},
27        to_pyvalue_err,
28    },
29    serialization::{
30        Serializable,
31        msgpack::{FromMsgPack, ToMsgPack},
32    },
33};
34use pyo3::{
35    IntoPyObjectExt,
36    prelude::*,
37    pyclass::CompareOp,
38    types::{PyDict, PyInt, PyString, PyTuple},
39};
40
41use super::data_to_pycapsule;
42use crate::{
43    data::{Data, TradeTick},
44    enums::{AggressorSide, FromU8},
45    identifiers::{InstrumentId, TradeId},
46    python::common::PY_MODULE_MODEL,
47    types::{
48        price::{Price, PriceRaw},
49        quantity::{Quantity, QuantityRaw},
50    },
51};
52
53impl TradeTick {
54    /// Creates a new [`TradeTick`] from a Python object.
55    ///
56    /// # Panics
57    ///
58    /// Panics if converting `aggressor_side_u8` to `AggressorSide` fails.
59    ///
60    /// # Errors
61    ///
62    /// Returns a `PyErr` if attribute extraction or type conversion fails.
63    pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
64        // Fast path: avoid property getters that trigger enum type deadlocks
65        if let Ok(tick) = obj.cast::<Self>() {
66            return Ok(*tick.borrow());
67        }
68
69        let instrument_id_obj: Bound<'_, PyAny> = obj.getattr("instrument_id")?.extract()?;
70        let instrument_id_str: String = instrument_id_obj.getattr("value")?.extract()?;
71        let instrument_id =
72            InstrumentId::from_str(instrument_id_str.as_str()).map_err(to_pyvalue_err)?;
73
74        let price_py: Bound<'_, PyAny> = obj.getattr("price")?.extract()?;
75        let price_raw: PriceRaw = price_py.getattr("raw")?.extract()?;
76        let price_prec: u8 = price_py.getattr("precision")?.extract()?;
77        let price = Price::from_raw(price_raw, price_prec);
78
79        let size_py: Bound<'_, PyAny> = obj.getattr("size")?.extract()?;
80        let size_raw: QuantityRaw = size_py.getattr("raw")?.extract()?;
81        let size_prec: u8 = size_py.getattr("precision")?.extract()?;
82        let size = Quantity::from_raw(size_raw, size_prec);
83
84        let aggressor_side_obj: Bound<'_, PyAny> = obj.getattr("aggressor_side")?.extract()?;
85        let aggressor_side_u8 = aggressor_side_obj.getattr("value")?.extract()?;
86        let aggressor_side = AggressorSide::from_u8(aggressor_side_u8).unwrap();
87
88        let trade_id_obj: Bound<'_, PyAny> = obj.getattr("trade_id")?.extract()?;
89        let trade_id_str: String = trade_id_obj.getattr("value")?.extract()?;
90        let trade_id = TradeId::from(trade_id_str.as_str());
91
92        let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
93        let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
94
95        Ok(Self::new(
96            instrument_id,
97            price,
98            size,
99            aggressor_side,
100            trade_id,
101            ts_event.into(),
102            ts_init.into(),
103        ))
104    }
105}
106
107#[pymethods]
108#[pyo3_stub_gen::derive::gen_stub_pymethods]
109impl TradeTick {
110    /// Represents a trade tick in a market.
111    #[new]
112    fn py_new(
113        instrument_id: InstrumentId,
114        price: Price,
115        size: Quantity,
116        aggressor_side: AggressorSide,
117        trade_id: TradeId,
118        ts_event: u64,
119        ts_init: u64,
120    ) -> PyResult<Self> {
121        Self::new_checked(
122            instrument_id,
123            price,
124            size,
125            aggressor_side,
126            trade_id,
127            ts_event.into(),
128            ts_init.into(),
129        )
130        .map_err(to_pyvalue_err)
131    }
132
133    fn __setstate__(&mut self, state: &Bound<'_, PyAny>) -> PyResult<()> {
134        let py_tuple: &Bound<'_, PyTuple> = state.cast::<PyTuple>()?;
135        let binding = py_tuple.get_item(0)?;
136        let instrument_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
137        let price_raw = py_tuple
138            .get_item(1)?
139            .cast::<PyInt>()?
140            .extract::<PriceRaw>()?;
141        let price_prec = py_tuple.get_item(2)?.cast::<PyInt>()?.extract::<u8>()?;
142        let size_raw = py_tuple
143            .get_item(3)?
144            .cast::<PyInt>()?
145            .extract::<QuantityRaw>()?;
146        let size_prec = py_tuple.get_item(4)?.cast::<PyInt>()?.extract::<u8>()?;
147
148        let aggressor_side_u8 = py_tuple.get_item(5)?.cast::<PyInt>()?.extract::<u8>()?;
149        let binding = py_tuple.get_item(6)?;
150        let trade_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
151        let ts_event = py_tuple.get_item(7)?.cast::<PyInt>()?.extract::<u64>()?;
152        let ts_init = py_tuple.get_item(8)?.cast::<PyInt>()?.extract::<u64>()?;
153
154        self.instrument_id = InstrumentId::from_str(instrument_id_str).map_err(to_pyvalue_err)?;
155        self.price = Price::from_raw(price_raw, price_prec);
156        self.size = Quantity::from_raw(size_raw, size_prec);
157        self.aggressor_side = AggressorSide::from_u8(aggressor_side_u8).unwrap();
158        self.trade_id = TradeId::from(trade_id_str);
159        self.ts_event = ts_event.into();
160        self.ts_init = ts_init.into();
161
162        Ok(())
163    }
164
165    fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
166        (
167            self.instrument_id.to_string(),
168            self.price.raw,
169            self.price.precision,
170            self.size.raw,
171            self.size.precision,
172            self.aggressor_side as u8,
173            self.trade_id.to_string(),
174            self.ts_event.as_u64(),
175            self.ts_init.as_u64(),
176        )
177            .into_py_any(py)
178    }
179
180    fn __reduce__(&self, py: Python) -> PyResult<Py<PyAny>> {
181        let safe_constructor = py.get_type::<Self>().getattr("_safe_constructor")?;
182        let state = self.__getstate__(py)?;
183        (safe_constructor, PyTuple::empty(py), state).into_py_any(py)
184    }
185
186    #[staticmethod]
187    fn _safe_constructor() -> Self {
188        Self::new(
189            InstrumentId::from("NULL.NULL"),
190            Price::zero(0),
191            Quantity::from(1), // size cannot be zero
192            AggressorSide::NoAggressor,
193            TradeId::from("NULL"),
194            UnixNanos::default(),
195            UnixNanos::default(),
196        )
197    }
198
199    fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
200        match op {
201            CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
202            CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
203            _ => py.NotImplemented(),
204        }
205    }
206
207    fn __hash__(&self) -> isize {
208        let mut h = DefaultHasher::new();
209        self.hash(&mut h);
210        h.finish() as isize
211    }
212
213    fn __repr__(&self) -> String {
214        format!("{}({})", stringify!(TradeTick), self)
215    }
216
217    fn __str__(&self) -> String {
218        self.to_string()
219    }
220
221    #[getter]
222    #[pyo3(name = "instrument_id")]
223    fn py_instrument_id(&self) -> InstrumentId {
224        self.instrument_id
225    }
226
227    #[getter]
228    #[pyo3(name = "price")]
229    fn py_price(&self) -> Price {
230        self.price
231    }
232
233    #[getter]
234    #[pyo3(name = "size")]
235    fn py_size(&self) -> Quantity {
236        self.size
237    }
238
239    #[getter]
240    #[pyo3(name = "aggressor_side")]
241    fn py_aggressor_side(&self) -> AggressorSide {
242        self.aggressor_side
243    }
244
245    #[getter]
246    #[pyo3(name = "trade_id")]
247    fn py_trade_id(&self) -> TradeId {
248        self.trade_id
249    }
250
251    #[getter]
252    #[pyo3(name = "ts_event")]
253    fn py_ts_event(&self) -> u64 {
254        self.ts_event.as_u64()
255    }
256
257    #[getter]
258    #[pyo3(name = "ts_init")]
259    fn py_ts_init(&self) -> u64 {
260        self.ts_init.as_u64()
261    }
262
263    #[staticmethod]
264    #[pyo3(name = "fully_qualified_name")]
265    fn py_fully_qualified_name() -> String {
266        format!("{}:{}", PY_MODULE_MODEL, stringify!(TradeTick))
267    }
268
269    /// Returns the metadata for the type, for use with serialization formats.
270    #[staticmethod]
271    #[pyo3(name = "get_metadata")]
272    fn py_get_metadata(
273        instrument_id: &InstrumentId,
274        price_precision: u8,
275        size_precision: u8,
276    ) -> HashMap<String, String> {
277        Self::get_metadata(instrument_id, price_precision, size_precision)
278    }
279
280    /// Returns the field map for the type, for use with Arrow schemas.
281    #[staticmethod]
282    #[pyo3(name = "get_fields")]
283    fn py_get_fields(py: Python<'_>) -> PyResult<Bound<'_, PyDict>> {
284        let py_dict = PyDict::new(py);
285        for (k, v) in Self::get_fields() {
286            py_dict.set_item(k, v)?;
287        }
288
289        Ok(py_dict)
290    }
291
292    #[staticmethod]
293    #[pyo3(name = "from_raw")]
294    #[expect(clippy::too_many_arguments)]
295    fn py_from_raw(
296        instrument_id: InstrumentId,
297        price_raw: PriceRaw,
298        price_prec: u8,
299        size_raw: QuantityRaw,
300        size_prec: u8,
301        aggressor_side: AggressorSide,
302        trade_id: TradeId,
303        ts_event: u64,
304        ts_init: u64,
305    ) -> PyResult<Self> {
306        Self::new_checked(
307            instrument_id,
308            Price::from_raw(price_raw, price_prec),
309            Quantity::from_raw(size_raw, size_prec),
310            aggressor_side,
311            trade_id,
312            ts_event.into(),
313            ts_init.into(),
314        )
315        .map_err(to_pyvalue_err)
316    }
317
318    /// Returns a new object from the given dictionary representation.
319    #[staticmethod]
320    #[pyo3(name = "from_dict")]
321    fn py_from_dict(py: Python<'_>, values: Py<PyDict>) -> PyResult<Self> {
322        from_dict_pyo3(py, values)
323    }
324
325    /// Creates a `PyCapsule` containing a raw pointer to a `Data::Trade` object.
326    ///
327    /// This function takes the current object (assumed to be of a type that can be represented as
328    /// `Data::Trade`), and encapsulates a raw pointer to it within a `PyCapsule`.
329    ///
330    /// # Safety
331    ///
332    /// This function is safe as long as the following conditions are met:
333    /// - The `Data::Trade` object pointed to by the capsule must remain valid for the lifetime of the capsule.
334    /// - The consumer of the capsule must ensure proper handling to avoid dereferencing a dangling pointer.
335    ///
336    /// # Panics
337    ///
338    /// The function will panic if the `PyCapsule` creation fails, which can occur if the
339    /// `Data::Trade` object cannot be converted into a raw pointer.
340    #[pyo3(name = "as_pycapsule")]
341    fn py_as_pycapsule(&self, py: Python<'_>) -> Py<PyAny> {
342        data_to_pycapsule(py, Data::Trade(*self))
343    }
344
345    /// Return a dictionary representation of the object.
346    #[pyo3(name = "to_dict")]
347    fn py_to_dict(&self, py: Python<'_>) -> PyResult<Py<PyDict>> {
348        to_dict_pyo3(py, self)
349    }
350
351    /// Return JSON encoded bytes representation of the object.
352    #[pyo3(name = "to_json_bytes")]
353    fn py_to_json_bytes(&self, py: Python<'_>) -> Py<PyAny> {
354        self.to_json_bytes().unwrap().into_py_any_unwrap(py)
355    }
356
357    /// Return `MsgPack` encoded bytes representation of the object.
358    #[pyo3(name = "to_msgpack_bytes")]
359    fn py_to_msgpack_bytes(&self, py: Python<'_>) -> Py<PyAny> {
360        self.to_msgpack_bytes().unwrap().into_py_any_unwrap(py)
361    }
362}
363
364#[pymethods]
365impl TradeTick {
366    #[staticmethod]
367    #[pyo3(name = "from_json")]
368    fn py_from_json(data: &[u8]) -> PyResult<Self> {
369        Self::from_json_bytes(data).map_err(to_pyvalue_err)
370    }
371
372    #[staticmethod]
373    #[pyo3(name = "from_msgpack")]
374    fn py_from_msgpack(data: &[u8]) -> PyResult<Self> {
375        Self::from_msgpack_bytes(data).map_err(to_pyvalue_err)
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use nautilus_core::python::IntoPyObjectNautilusExt;
382    use pyo3::Python;
383    use rstest::rstest;
384
385    use crate::{
386        data::{TradeTick, stubs::stub_trade_ethusdt_buyer},
387        enums::AggressorSide,
388        identifiers::{InstrumentId, TradeId},
389        types::{Price, Quantity},
390    };
391
392    #[rstest]
393    fn test_trade_tick_py_new_with_zero_size() {
394        let instrument_id = InstrumentId::from("ETH-USDT-SWAP.OKX");
395        let price = Price::from("10000.00");
396        let zero_size = Quantity::from(0);
397        let aggressor_side = AggressorSide::Buyer;
398        let trade_id = TradeId::from("123456789");
399        let ts_event = 1;
400        let ts_init = 2;
401
402        let result = TradeTick::py_new(
403            instrument_id,
404            price,
405            zero_size,
406            aggressor_side,
407            trade_id,
408            ts_event,
409            ts_init,
410        );
411
412        assert!(result.is_err());
413    }
414
415    #[rstest]
416    fn test_to_dict(stub_trade_ethusdt_buyer: TradeTick) {
417        let trade = stub_trade_ethusdt_buyer;
418
419        Python::initialize();
420        Python::attach(|py| {
421            let dict_string = trade.py_to_dict(py).unwrap().to_string();
422            let expected_string = "{'type': 'TradeTick', 'instrument_id': 'ETHUSDT-PERP.BINANCE', 'price': '10000.0000', 'size': '1.00000000', 'aggressor_side': 'BUYER', 'trade_id': '123456789', 'ts_event': 0, 'ts_init': 1}";
423            assert_eq!(dict_string, expected_string);
424        });
425    }
426
427    #[rstest]
428    fn test_from_dict(stub_trade_ethusdt_buyer: TradeTick) {
429        let trade = stub_trade_ethusdt_buyer;
430
431        Python::initialize();
432        Python::attach(|py| {
433            let dict = trade.py_to_dict(py).unwrap();
434            let parsed = TradeTick::py_from_dict(py, dict).unwrap();
435            assert_eq!(parsed, trade);
436        });
437    }
438
439    #[rstest]
440    fn test_from_pyobject(stub_trade_ethusdt_buyer: TradeTick) {
441        let trade = stub_trade_ethusdt_buyer;
442
443        Python::initialize();
444        Python::attach(|py| {
445            let tick_pyobject = trade.into_py_any_unwrap(py);
446            let parsed_tick = TradeTick::from_pyobject(tick_pyobject.bind(py)).unwrap();
447            assert_eq!(parsed_tick, trade);
448        });
449    }
450}