1use std::{
17 collections::{HashMap, hash_map::DefaultHasher},
18 hash::{Hash, Hasher},
19 str::FromStr,
20};
21
22use nautilus_core::{
23 UnixNanos,
24 python::{
25 IntoPyObjectNautilusExt,
26 serialization::{from_dict_pyo3, to_dict_pyo3},
27 to_pyvalue_err,
28 },
29 serialization::{
30 Serializable,
31 msgpack::{FromMsgPack, ToMsgPack},
32 },
33};
34use pyo3::{
35 IntoPyObjectExt,
36 prelude::*,
37 pyclass::CompareOp,
38 types::{PyDict, PyInt, PyString, PyTuple},
39};
40
41use crate::{
42 data::{IndexPriceUpdate, MarkPriceUpdate},
43 identifiers::InstrumentId,
44 python::common::PY_MODULE_MODEL,
45 types::price::{Price, PriceRaw},
46};
47
48impl MarkPriceUpdate {
49 pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
55 let instrument_id_obj: Bound<'_, PyAny> = obj.getattr("instrument_id")?.extract()?;
56 let instrument_id_str: String = instrument_id_obj.getattr("value")?.extract()?;
57 let instrument_id =
58 InstrumentId::from_str(instrument_id_str.as_str()).map_err(to_pyvalue_err)?;
59
60 let value_py: Bound<'_, PyAny> = obj.getattr("value")?.extract()?;
61 let value_raw: PriceRaw = value_py.getattr("raw")?.extract()?;
62 let value_prec: u8 = value_py.getattr("precision")?.extract()?;
63 let value = Price::from_raw(value_raw, value_prec);
64
65 let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
66 let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
67
68 Ok(Self::new(
69 instrument_id,
70 value,
71 ts_event.into(),
72 ts_init.into(),
73 ))
74 }
75}
76
77#[pymethods]
78impl MarkPriceUpdate {
79 #[new]
80 fn py_new(
81 instrument_id: InstrumentId,
82 value: Price,
83 ts_event: u64,
84 ts_init: u64,
85 ) -> PyResult<Self> {
86 Ok(Self::new(
87 instrument_id,
88 value,
89 ts_event.into(),
90 ts_init.into(),
91 ))
92 }
93
94 fn __setstate__(&mut self, state: &Bound<'_, PyAny>) -> PyResult<()> {
95 let py_tuple: &Bound<'_, PyTuple> = state.cast::<PyTuple>()?;
96 let binding = py_tuple.get_item(0)?;
97 let instrument_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
98 let value_raw = py_tuple
99 .get_item(1)?
100 .cast::<PyInt>()?
101 .extract::<PriceRaw>()?;
102 let value_prec = py_tuple.get_item(2)?.cast::<PyInt>()?.extract::<u8>()?;
103
104 let ts_event = py_tuple.get_item(7)?.cast::<PyInt>()?.extract::<u64>()?;
105 let ts_init = py_tuple.get_item(8)?.cast::<PyInt>()?.extract::<u64>()?;
106
107 self.instrument_id = InstrumentId::from_str(instrument_id_str).map_err(to_pyvalue_err)?;
108 self.value = Price::from_raw(value_raw, value_prec);
109 self.ts_event = ts_event.into();
110 self.ts_init = ts_init.into();
111
112 Ok(())
113 }
114
115 fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
116 (
117 self.instrument_id.to_string(),
118 self.value.raw,
119 self.value.precision,
120 self.ts_event.as_u64(),
121 self.ts_init.as_u64(),
122 )
123 .into_py_any(py)
124 }
125
126 fn __reduce__(&self, py: Python) -> PyResult<Py<PyAny>> {
127 let safe_constructor = py.get_type::<Self>().getattr("_safe_constructor")?;
128 let state = self.__getstate__(py)?;
129 (safe_constructor, PyTuple::empty(py), state).into_py_any(py)
130 }
131
132 #[staticmethod]
133 fn _safe_constructor() -> Self {
134 Self::new(
135 InstrumentId::from("NULL.NULL"),
136 Price::zero(0),
137 UnixNanos::default(),
138 UnixNanos::default(),
139 )
140 }
141
142 fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
143 match op {
144 CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
145 CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
146 _ => py.NotImplemented(),
147 }
148 }
149
150 fn __hash__(&self) -> isize {
151 let mut h = DefaultHasher::new();
152 self.hash(&mut h);
153 h.finish() as isize
154 }
155
156 fn __repr__(&self) -> String {
157 format!("{}({})", stringify!(MarkPriceUpdate), self)
158 }
159
160 fn __str__(&self) -> String {
161 self.to_string()
162 }
163
164 #[getter]
165 #[pyo3(name = "instrument_id")]
166 fn py_instrument_id(&self) -> InstrumentId {
167 self.instrument_id
168 }
169
170 #[getter]
171 #[pyo3(name = "value")]
172 fn py_value(&self) -> Price {
173 self.value
174 }
175
176 #[getter]
177 #[pyo3(name = "ts_event")]
178 fn py_ts_event(&self) -> u64 {
179 self.ts_event.as_u64()
180 }
181
182 #[getter]
183 #[pyo3(name = "ts_init")]
184 fn py_ts_init(&self) -> u64 {
185 self.ts_init.as_u64()
186 }
187
188 #[staticmethod]
189 #[pyo3(name = "fully_qualified_name")]
190 fn py_fully_qualified_name() -> String {
191 format!("{}:{}", PY_MODULE_MODEL, stringify!(MarkPriceUpdate))
192 }
193
194 #[staticmethod]
195 #[pyo3(name = "get_metadata")]
196 fn py_get_metadata(
197 instrument_id: &InstrumentId,
198 price_precision: u8,
199 ) -> PyResult<HashMap<String, String>> {
200 Ok(Self::get_metadata(instrument_id, price_precision))
201 }
202
203 #[staticmethod]
204 #[pyo3(name = "get_fields")]
205 fn py_get_fields(py: Python<'_>) -> PyResult<Bound<'_, PyDict>> {
206 let py_dict = PyDict::new(py);
207 for (k, v) in Self::get_fields() {
208 py_dict.set_item(k, v)?;
209 }
210
211 Ok(py_dict)
212 }
213
214 #[staticmethod]
216 #[pyo3(name = "from_dict")]
217 fn py_from_dict(py: Python<'_>, values: Py<PyDict>) -> PyResult<Self> {
218 from_dict_pyo3(py, values)
219 }
220
221 #[staticmethod]
222 #[pyo3(name = "from_json")]
223 fn py_from_json(data: Vec<u8>) -> PyResult<Self> {
224 Self::from_json_bytes(&data).map_err(to_pyvalue_err)
225 }
226
227 #[staticmethod]
228 #[pyo3(name = "from_msgpack")]
229 fn py_from_msgpack(data: Vec<u8>) -> PyResult<Self> {
230 Self::from_msgpack_bytes(&data).map_err(to_pyvalue_err)
231 }
232
233 #[pyo3(name = "to_dict")]
235 fn py_to_dict(&self, py: Python<'_>) -> PyResult<Py<PyDict>> {
236 to_dict_pyo3(py, self)
237 }
238
239 #[pyo3(name = "to_json_bytes")]
241 fn py_to_json_bytes(&self, py: Python<'_>) -> Py<PyAny> {
242 self.to_json_bytes().unwrap().into_py_any_unwrap(py)
243 }
244
245 #[pyo3(name = "to_msgpack_bytes")]
247 fn py_to_msgpack_bytes(&self, py: Python<'_>) -> Py<PyAny> {
248 self.to_msgpack_bytes().unwrap().into_py_any_unwrap(py)
249 }
250}
251
252impl IndexPriceUpdate {
253 pub fn from_pyobject(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
259 let instrument_id_obj: Bound<'_, PyAny> = obj.getattr("instrument_id")?.extract()?;
260 let instrument_id_str: String = instrument_id_obj.getattr("value")?.extract()?;
261 let instrument_id =
262 InstrumentId::from_str(instrument_id_str.as_str()).map_err(to_pyvalue_err)?;
263
264 let value_py: Bound<'_, PyAny> = obj.getattr("value")?.extract()?;
265 let value_raw: PriceRaw = value_py.getattr("raw")?.extract()?;
266 let value_prec: u8 = value_py.getattr("precision")?.extract()?;
267 let value = Price::from_raw(value_raw, value_prec);
268
269 let ts_event: u64 = obj.getattr("ts_event")?.extract()?;
270 let ts_init: u64 = obj.getattr("ts_init")?.extract()?;
271
272 Ok(Self::new(
273 instrument_id,
274 value,
275 ts_event.into(),
276 ts_init.into(),
277 ))
278 }
279}
280
281#[pymethods]
282impl IndexPriceUpdate {
283 #[new]
284 fn py_new(
285 instrument_id: InstrumentId,
286 value: Price,
287 ts_event: u64,
288 ts_init: u64,
289 ) -> PyResult<Self> {
290 Ok(Self::new(
291 instrument_id,
292 value,
293 ts_event.into(),
294 ts_init.into(),
295 ))
296 }
297
298 fn __setstate__(&mut self, state: &Bound<'_, PyAny>) -> PyResult<()> {
299 let py_tuple: &Bound<'_, PyTuple> = state.cast::<PyTuple>()?;
300 let binding = py_tuple.get_item(0)?;
301 let instrument_id_str = binding.cast::<PyString>()?.extract::<&str>()?;
302 let value_raw = py_tuple
303 .get_item(1)?
304 .cast::<PyInt>()?
305 .extract::<PriceRaw>()?;
306 let value_prec = py_tuple.get_item(2)?.cast::<PyInt>()?.extract::<u8>()?;
307
308 let ts_event = py_tuple.get_item(7)?.cast::<PyInt>()?.extract::<u64>()?;
309 let ts_init = py_tuple.get_item(8)?.cast::<PyInt>()?.extract::<u64>()?;
310
311 self.instrument_id = InstrumentId::from_str(instrument_id_str).map_err(to_pyvalue_err)?;
312 self.value = Price::from_raw(value_raw, value_prec);
313 self.ts_event = ts_event.into();
314 self.ts_init = ts_init.into();
315
316 Ok(())
317 }
318
319 fn __getstate__(&self, py: Python) -> PyResult<Py<PyAny>> {
320 (
321 self.instrument_id.to_string(),
322 self.value.raw,
323 self.value.precision,
324 self.ts_event.as_u64(),
325 self.ts_init.as_u64(),
326 )
327 .into_py_any(py)
328 }
329
330 fn __reduce__(&self, py: Python) -> PyResult<Py<PyAny>> {
331 let safe_constructor = py.get_type::<Self>().getattr("_safe_constructor")?;
332 let state = self.__getstate__(py)?;
333 (safe_constructor, PyTuple::empty(py), state).into_py_any(py)
334 }
335
336 #[staticmethod]
337 fn _safe_constructor() -> Self {
338 Self::new(
339 InstrumentId::from("NULL.NULL"),
340 Price::zero(0),
341 UnixNanos::default(),
342 UnixNanos::default(),
343 )
344 }
345
346 fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> Py<PyAny> {
347 match op {
348 CompareOp::Eq => self.eq(other).into_py_any_unwrap(py),
349 CompareOp::Ne => self.ne(other).into_py_any_unwrap(py),
350 _ => py.NotImplemented(),
351 }
352 }
353
354 fn __hash__(&self) -> isize {
355 let mut h = DefaultHasher::new();
356 self.hash(&mut h);
357 h.finish() as isize
358 }
359
360 fn __repr__(&self) -> String {
361 format!("{}({})", stringify!(IndexPriceUpdate), self)
362 }
363
364 fn __str__(&self) -> String {
365 self.to_string()
366 }
367
368 #[getter]
369 #[pyo3(name = "instrument_id")]
370 fn py_instrument_id(&self) -> InstrumentId {
371 self.instrument_id
372 }
373
374 #[getter]
375 #[pyo3(name = "value")]
376 fn py_value(&self) -> Price {
377 self.value
378 }
379
380 #[getter]
381 #[pyo3(name = "ts_event")]
382 fn py_ts_event(&self) -> u64 {
383 self.ts_event.as_u64()
384 }
385
386 #[getter]
387 #[pyo3(name = "ts_init")]
388 fn py_ts_init(&self) -> u64 {
389 self.ts_init.as_u64()
390 }
391
392 #[staticmethod]
393 #[pyo3(name = "fully_qualified_name")]
394 fn py_fully_qualified_name() -> String {
395 format!("{}:{}", PY_MODULE_MODEL, stringify!(IndexPriceUpdate))
396 }
397
398 #[staticmethod]
399 #[pyo3(name = "get_metadata")]
400 fn py_get_metadata(
401 instrument_id: &InstrumentId,
402 price_precision: u8,
403 ) -> PyResult<HashMap<String, String>> {
404 Ok(Self::get_metadata(instrument_id, price_precision))
405 }
406
407 #[staticmethod]
408 #[pyo3(name = "get_fields")]
409 fn py_get_fields(py: Python<'_>) -> PyResult<Bound<'_, PyDict>> {
410 let py_dict = PyDict::new(py);
411 for (k, v) in Self::get_fields() {
412 py_dict.set_item(k, v)?;
413 }
414
415 Ok(py_dict)
416 }
417
418 #[staticmethod]
420 #[pyo3(name = "from_dict")]
421 fn py_from_dict(py: Python<'_>, values: Py<PyDict>) -> PyResult<Self> {
422 from_dict_pyo3(py, values)
423 }
424
425 #[staticmethod]
426 #[pyo3(name = "from_json")]
427 fn py_from_json(data: Vec<u8>) -> PyResult<Self> {
428 Self::from_json_bytes(&data).map_err(to_pyvalue_err)
429 }
430
431 #[staticmethod]
432 #[pyo3(name = "from_msgpack")]
433 fn py_from_msgpack(data: Vec<u8>) -> PyResult<Self> {
434 Self::from_msgpack_bytes(&data).map_err(to_pyvalue_err)
435 }
436
437 #[pyo3(name = "to_dict")]
439 fn py_to_dict(&self, py: Python<'_>) -> PyResult<Py<PyDict>> {
440 to_dict_pyo3(py, self)
441 }
442
443 #[pyo3(name = "to_json_bytes")]
445 fn py_to_json_bytes(&self, py: Python<'_>) -> Py<PyAny> {
446 self.to_json_bytes().unwrap().into_py_any_unwrap(py)
447 }
448
449 #[pyo3(name = "to_msgpack_bytes")]
451 fn py_to_msgpack_bytes(&self, py: Python<'_>) -> Py<PyAny> {
452 self.to_msgpack_bytes().unwrap().into_py_any_unwrap(py)
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use nautilus_core::python::IntoPyObjectNautilusExt;
459 use pyo3::Python;
460 use rstest::{fixture, rstest};
461
462 use super::*;
463 use crate::{identifiers::InstrumentId, types::Price};
464
465 #[fixture]
466 fn mark_price() -> MarkPriceUpdate {
467 MarkPriceUpdate::new(
468 InstrumentId::from("BTC-USDT.OKX"),
469 Price::from("100_000.00"),
470 UnixNanos::from(1),
471 UnixNanos::from(2),
472 )
473 }
474
475 #[fixture]
476 fn index_price() -> IndexPriceUpdate {
477 IndexPriceUpdate::new(
478 InstrumentId::from("BTC-USDT.OKX"),
479 Price::from("100_000.00"),
480 UnixNanos::from(1),
481 UnixNanos::from(2),
482 )
483 }
484
485 #[rstest]
486 fn test_mark_price_to_dict(mark_price: MarkPriceUpdate) {
487 Python::initialize();
488 Python::attach(|py| {
489 let dict_string = mark_price.py_to_dict(py).unwrap().to_string();
490 let expected_string = "{'type': 'MarkPriceUpdate', 'instrument_id': 'BTC-USDT.OKX', 'value': '100000.00', 'ts_event': 1, 'ts_init': 2}";
491 assert_eq!(dict_string, expected_string);
492 });
493 }
494
495 #[rstest]
496 fn test_mark_price_from_dict(mark_price: MarkPriceUpdate) {
497 Python::initialize();
498 Python::attach(|py| {
499 let dict = mark_price.py_to_dict(py).unwrap();
500 let parsed = MarkPriceUpdate::py_from_dict(py, dict).unwrap();
501 assert_eq!(parsed, mark_price);
502 });
503 }
504
505 #[rstest]
506 fn test_mark_price_from_pyobject(mark_price: MarkPriceUpdate) {
507 Python::initialize();
508 Python::attach(|py| {
509 let tick_pyobject = mark_price.into_py_any_unwrap(py);
510 let parsed_tick = MarkPriceUpdate::from_pyobject(tick_pyobject.bind(py)).unwrap();
511 assert_eq!(parsed_tick, mark_price);
512 });
513 }
514
515 #[rstest]
516 fn test_index_price_to_dict(index_price: IndexPriceUpdate) {
517 Python::initialize();
518 Python::attach(|py| {
519 let dict_string = index_price.py_to_dict(py).unwrap().to_string();
520 let expected_string = "{'type': 'IndexPriceUpdate', 'instrument_id': 'BTC-USDT.OKX', 'value': '100000.00', 'ts_event': 1, 'ts_init': 2}";
521 assert_eq!(dict_string, expected_string);
522 });
523 }
524
525 #[rstest]
526 fn test_index_price_from_dict(index_price: IndexPriceUpdate) {
527 Python::initialize();
528 Python::attach(|py| {
529 let dict = index_price.py_to_dict(py).unwrap();
530 let parsed = IndexPriceUpdate::py_from_dict(py, dict).unwrap();
531 assert_eq!(parsed, index_price);
532 });
533 }
534
535 #[rstest]
536 fn test_index_price_from_pyobject(index_price: IndexPriceUpdate) {
537 Python::initialize();
538 Python::attach(|py| {
539 let tick_pyobject = index_price.into_py_any_unwrap(py);
540 let parsed_tick = IndexPriceUpdate::from_pyobject(tick_pyobject.bind(py)).unwrap();
541 assert_eq!(parsed_tick, index_price);
542 });
543 }
544}