Skip to main content

dbn/
python.rs

1//! Python wrappers around dbn functions. These are implemented here instead of in `python/`
2//! to be able to implement [`pyo3`] traits for DBN types.
3#![allow(clippy::too_many_arguments)]
4
5mod conversions;
6mod enums;
7mod metadata;
8/// Python wrapper types for DBN records that include `ts_out` as a real field.
9pub mod record;
10pub mod repr;
11
12use std::{convert::Infallible, fmt};
13
14use pyo3::{
15    create_exception,
16    exceptions::PyException,
17    prelude::*,
18    types::{PyDate, PyDateAccess, PyInt},
19    IntoPyObjectExt,
20};
21use strum::IntoEnumIterator;
22
23use crate::{Error, FlagSet};
24
25pub use self::repr::WritePyRepr;
26
27create_exception!(
28    databento_dbn,
29    DBNError,
30    PyException,
31    "An exception from databento_dbn Rust code."
32);
33
34/// A helper function for converting any type that implements `Debug` to a Python
35/// `ValueError`.
36pub fn to_py_err(e: impl fmt::Display) -> PyErr {
37    DBNError::new_err(format!("{e}"))
38}
39
40impl From<Error> for PyErr {
41    fn from(err: Error) -> Self {
42        DBNError::new_err(format!("{err}"))
43    }
44}
45
46/// Python iterator over the variants of an enum.
47#[pyclass(module = "databento_dbn")]
48pub struct EnumIterator {
49    // Type erasure for code reuse. Generic types can't be exposed to Python.
50    iter: Box<dyn Iterator<Item = Py<PyAny>> + Send + Sync>,
51}
52
53#[pymethods]
54impl EnumIterator {
55    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
56        slf
57    }
58
59    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<Py<PyAny>> {
60        slf.iter.next()
61    }
62}
63
64impl EnumIterator {
65    fn new<'py, E>(py: Python<'py>) -> PyResult<Self>
66    where
67        E: strum::IntoEnumIterator + IntoPyObject<'py>,
68        <E as IntoEnumIterator>::Iterator: Send + Sync,
69    {
70        Ok(Self {
71            iter: Box::new(
72                E::iter()
73                    .map(|var| var.into_py_any(py))
74                    // force eager evaluation because `py` isn't `Send`
75                    .collect::<PyResult<Vec<_>>>()?
76                    .into_iter(),
77            ),
78        })
79    }
80}
81
82impl<'py> IntoPyObject<'py> for FlagSet {
83    type Target = PyInt;
84    type Output = Bound<'py, Self::Target>;
85    type Error = Infallible;
86
87    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
88        self.raw().into_pyobject(py)
89    }
90}
91
92/// Tries to convert `py_date` to a [`time::Date`].
93///
94/// # Errors
95/// This function returns an error if input has an invalid month.
96pub fn py_to_time_date(py_date: &Bound<'_, PyDate>) -> PyResult<time::Date> {
97    let month =
98        time::Month::try_from(py_date.get_month()).map_err(|e| DBNError::new_err(e.to_string()))?;
99    time::Date::from_calendar_date(py_date.get_year(), month, py_date.get_day())
100        .map_err(|e| DBNError::new_err(e.to_string()))
101}
102
103/// A trait for records that provide descriptions of their fields.
104pub(crate) trait PyFieldDesc {
105    /// Returns a list of all fields and their numpy dtypes.
106    fn field_dtypes(field_name: &str) -> Vec<(String, String)>;
107    /// Returns a list of fields that should be hidden in Python.
108    fn hidden_fields(_field_name: &str) -> Vec<String> {
109        Vec::new()
110    }
111    /// Returns a list of the fixed-precision price fields.
112    fn price_fields(_field_name: &str) -> Vec<String> {
113        Vec::new()
114    }
115    /// Returns a list of UNIX nanosecond timestamp fields.
116    fn timestamp_fields(_field_name: &str) -> Vec<String> {
117        Vec::new()
118    }
119    /// Ordered list of fields excluding hidden fields.
120    fn ordered_fields(field_name: &str) -> Vec<String> {
121        vec![field_name.to_owned()]
122    }
123}
124
125impl PyFieldDesc for i64 {
126    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
127        vec![(field_name.to_owned(), "i8".to_owned())]
128    }
129}
130impl PyFieldDesc for i32 {
131    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
132        vec![(field_name.to_owned(), "i4".to_owned())]
133    }
134}
135impl PyFieldDesc for i16 {
136    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
137        vec![(field_name.to_owned(), "i2".to_owned())]
138    }
139}
140impl PyFieldDesc for i8 {
141    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
142        vec![(field_name.to_owned(), "i1".to_owned())]
143    }
144}
145impl PyFieldDesc for u64 {
146    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
147        vec![(field_name.to_owned(), "u8".to_owned())]
148    }
149}
150impl PyFieldDesc for u32 {
151    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
152        vec![(field_name.to_owned(), "u4".to_owned())]
153    }
154}
155impl PyFieldDesc for u16 {
156    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
157        vec![(field_name.to_owned(), "u2".to_owned())]
158    }
159}
160impl PyFieldDesc for u8 {
161    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
162        vec![(field_name.to_owned(), "u1".to_owned())]
163    }
164}
165impl<const N: usize> PyFieldDesc for [i8; N] {
166    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
167        vec![(field_name.to_owned(), format!("S{N}"))]
168    }
169}
170impl<const N: usize> PyFieldDesc for [u8; N] {
171    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
172        vec![(field_name.to_owned(), format!("S{N}"))]
173    }
174}
175impl PyFieldDesc for FlagSet {
176    fn field_dtypes(field_name: &str) -> Vec<(String, String)> {
177        u8::field_dtypes(field_name)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::PyFieldDesc;
184    use crate::{
185        record::{Cmbp1Msg, InstrumentDefMsg, MboMsg, Mbp10Msg},
186        ASSET_CSTR_LEN, SYMBOL_CSTR_LEN,
187    };
188
189    fn with_record_header_dtype(dtypes: Vec<(String, String)>) -> Vec<(String, String)> {
190        let mut res = vec![
191            ("length".to_owned(), "u1".to_owned()),
192            ("rtype".to_owned(), "u1".to_owned()),
193            ("publisher_id".to_owned(), "u2".to_owned()),
194            ("instrument_id".to_owned(), "u4".to_owned()),
195            ("ts_event".to_owned(), "u8".to_owned()),
196        ];
197        res.extend(dtypes);
198        res
199    }
200
201    #[test]
202    fn test_mbo_dtypes() {
203        let dtypes = MboMsg::field_dtypes("");
204        let exp = with_record_header_dtype(vec![
205            ("order_id".to_owned(), "u8".to_owned()),
206            ("price".to_owned(), "i8".to_owned()),
207            ("size".to_owned(), "u4".to_owned()),
208            ("flags".to_owned(), "u1".to_owned()),
209            ("channel_id".to_owned(), "u1".to_owned()),
210            ("action".to_owned(), "S1".to_owned()),
211            ("side".to_owned(), "S1".to_owned()),
212            ("ts_recv".to_owned(), "u8".to_owned()),
213            ("ts_in_delta".to_owned(), "i4".to_owned()),
214            ("sequence".to_owned(), "u4".to_owned()),
215        ]);
216        assert_eq!(dtypes, exp);
217    }
218
219    #[test]
220    fn test_mbo_fields() {
221        assert_eq!(MboMsg::price_fields(""), vec!["price".to_owned()]);
222        assert_eq!(MboMsg::hidden_fields(""), vec!["length".to_owned()]);
223        assert_eq!(
224            MboMsg::timestamp_fields(""),
225            vec!["ts_event".to_owned(), "ts_recv".to_owned()]
226        );
227    }
228
229    #[test]
230    fn test_mbo_ordered() {
231        assert_eq!(
232            MboMsg::ordered_fields(""),
233            vec![
234                "ts_recv".to_owned(),
235                "ts_event".to_owned(),
236                "rtype".to_owned(),
237                "publisher_id".to_owned(),
238                "instrument_id".to_owned(),
239                "action".to_owned(),
240                "side".to_owned(),
241                "price".to_owned(),
242                "size".to_owned(),
243                "channel_id".to_owned(),
244                "order_id".to_owned(),
245                "flags".to_owned(),
246                "ts_in_delta".to_owned(),
247                "sequence".to_owned(),
248            ]
249        )
250    }
251
252    #[test]
253    fn test_mbp10_dtypes() {
254        let dtypes = Mbp10Msg::field_dtypes("");
255        let mut exp = with_record_header_dtype(vec![
256            ("price".to_owned(), "i8".to_owned()),
257            ("size".to_owned(), "u4".to_owned()),
258            ("action".to_owned(), "S1".to_owned()),
259            ("side".to_owned(), "S1".to_owned()),
260            ("flags".to_owned(), "u1".to_owned()),
261            ("depth".to_owned(), "u1".to_owned()),
262            ("ts_recv".to_owned(), "u8".to_owned()),
263            ("ts_in_delta".to_owned(), "i4".to_owned()),
264            ("sequence".to_owned(), "u4".to_owned()),
265        ]);
266        for i in 0..10 {
267            exp.push((format!("bid_px_{i:02}"), "i8".to_owned()));
268            exp.push((format!("ask_px_{i:02}"), "i8".to_owned()));
269            exp.push((format!("bid_sz_{i:02}"), "u4".to_owned()));
270            exp.push((format!("ask_sz_{i:02}"), "u4".to_owned()));
271            exp.push((format!("bid_ct_{i:02}"), "u4".to_owned()));
272            exp.push((format!("ask_ct_{i:02}"), "u4".to_owned()));
273        }
274        assert_eq!(dtypes, exp);
275    }
276
277    #[test]
278    fn test_mbp10_fields() {
279        let mut exp_price = vec!["price".to_owned()];
280        for i in 0..10 {
281            exp_price.push(format!("bid_px_{i:02}"));
282            exp_price.push(format!("ask_px_{i:02}"));
283        }
284        assert_eq!(Mbp10Msg::price_fields(""), exp_price);
285        assert_eq!(Mbp10Msg::hidden_fields(""), vec!["length".to_owned()]);
286        assert_eq!(
287            Mbp10Msg::timestamp_fields(""),
288            vec!["ts_event".to_owned(), "ts_recv".to_owned()]
289        );
290    }
291
292    #[test]
293    fn test_mbp10_ordered() {
294        let mut exp = vec![
295            "ts_recv".to_owned(),
296            "ts_event".to_owned(),
297            "rtype".to_owned(),
298            "publisher_id".to_owned(),
299            "instrument_id".to_owned(),
300            "action".to_owned(),
301            "side".to_owned(),
302            "depth".to_owned(),
303            "price".to_owned(),
304            "size".to_owned(),
305            "flags".to_owned(),
306            "ts_in_delta".to_owned(),
307            "sequence".to_owned(),
308        ];
309        for i in 0..10 {
310            exp.push(format!("bid_px_{i:02}"));
311            exp.push(format!("ask_px_{i:02}"));
312            exp.push(format!("bid_sz_{i:02}"));
313            exp.push(format!("ask_sz_{i:02}"));
314            exp.push(format!("bid_ct_{i:02}"));
315            exp.push(format!("ask_ct_{i:02}"));
316        }
317        assert_eq!(Mbp10Msg::ordered_fields(""), exp)
318    }
319
320    #[test]
321    fn test_cmbp1_dtypes() {
322        let dtypes = Cmbp1Msg::field_dtypes("");
323        let exp = with_record_header_dtype(vec![
324            ("price".to_owned(), "i8".to_owned()),
325            ("size".to_owned(), "u4".to_owned()),
326            ("action".to_owned(), "S1".to_owned()),
327            ("side".to_owned(), "S1".to_owned()),
328            ("flags".to_owned(), "u1".to_owned()),
329            ("_reserved1".to_owned(), "S1".to_owned()),
330            ("ts_recv".to_owned(), "u8".to_owned()),
331            ("ts_in_delta".to_owned(), "i4".to_owned()),
332            ("_reserved2".to_owned(), "S4".to_owned()),
333            ("bid_px_00".to_owned(), "i8".to_owned()),
334            ("ask_px_00".to_owned(), "i8".to_owned()),
335            ("bid_sz_00".to_owned(), "u4".to_owned()),
336            ("ask_sz_00".to_owned(), "u4".to_owned()),
337            ("bid_pb_00".to_owned(), "u2".to_owned()),
338            ("_reserved1_00".to_owned(), "S2".to_owned()),
339            ("ask_pb_00".to_owned(), "u2".to_owned()),
340            ("_reserved2_00".to_owned(), "S2".to_owned()),
341        ]);
342        assert_eq!(dtypes, exp);
343    }
344
345    #[test]
346    fn test_cbbo_fields() {
347        let mut exp_price = vec!["price".to_owned()];
348        exp_price.push("bid_px_00".to_owned());
349        exp_price.push("ask_px_00".to_owned());
350        assert_eq!(Cmbp1Msg::price_fields(""), exp_price);
351        assert_eq!(
352            Cmbp1Msg::hidden_fields(""),
353            vec![
354                "length".to_owned(),
355                "_reserved1".to_owned(),
356                "_reserved2".to_owned(),
357                "_reserved1_00".to_owned(),
358                "_reserved2_00".to_owned()
359            ]
360        );
361        assert_eq!(
362            Cmbp1Msg::timestamp_fields(""),
363            vec!["ts_event".to_owned(), "ts_recv".to_owned()]
364        );
365    }
366
367    #[test]
368    fn test_cbbo_ordered() {
369        let exp = vec![
370            "ts_recv".to_owned(),
371            "ts_event".to_owned(),
372            "rtype".to_owned(),
373            "publisher_id".to_owned(),
374            "instrument_id".to_owned(),
375            "action".to_owned(),
376            "side".to_owned(),
377            "price".to_owned(),
378            "size".to_owned(),
379            "flags".to_owned(),
380            "ts_in_delta".to_owned(),
381            "bid_px_00".to_owned(),
382            "ask_px_00".to_owned(),
383            "bid_sz_00".to_owned(),
384            "ask_sz_00".to_owned(),
385            "bid_pb_00".to_owned(),
386            "ask_pb_00".to_owned(),
387        ];
388        assert_eq!(Cmbp1Msg::ordered_fields(""), exp)
389    }
390
391    #[test]
392    fn test_definition_dtypes() {
393        let dtypes = InstrumentDefMsg::field_dtypes("");
394        let exp = with_record_header_dtype(vec![
395            ("ts_recv".to_owned(), "u8".to_owned()),
396            ("min_price_increment".to_owned(), "i8".to_owned()),
397            ("display_factor".to_owned(), "i8".to_owned()),
398            ("expiration".to_owned(), "u8".to_owned()),
399            ("activation".to_owned(), "u8".to_owned()),
400            ("high_limit_price".to_owned(), "i8".to_owned()),
401            ("low_limit_price".to_owned(), "i8".to_owned()),
402            ("max_price_variation".to_owned(), "i8".to_owned()),
403            ("unit_of_measure_qty".to_owned(), "i8".to_owned()),
404            ("min_price_increment_amount".to_owned(), "i8".to_owned()),
405            ("price_ratio".to_owned(), "i8".to_owned()),
406            ("strike_price".to_owned(), "i8".to_owned()),
407            ("raw_instrument_id".to_owned(), "u8".to_owned()),
408            ("leg_price".to_owned(), "i8".to_owned()),
409            ("leg_delta".to_owned(), "i8".to_owned()),
410            ("inst_attrib_value".to_owned(), "i4".to_owned()),
411            ("underlying_id".to_owned(), "u4".to_owned()),
412            ("market_depth_implied".to_owned(), "i4".to_owned()),
413            ("market_depth".to_owned(), "i4".to_owned()),
414            ("market_segment_id".to_owned(), "u4".to_owned()),
415            ("max_trade_vol".to_owned(), "u4".to_owned()),
416            ("min_lot_size".to_owned(), "i4".to_owned()),
417            ("min_lot_size_block".to_owned(), "i4".to_owned()),
418            ("min_lot_size_round_lot".to_owned(), "i4".to_owned()),
419            ("min_trade_vol".to_owned(), "u4".to_owned()),
420            ("contract_multiplier".to_owned(), "i4".to_owned()),
421            ("decay_quantity".to_owned(), "i4".to_owned()),
422            ("original_contract_size".to_owned(), "i4".to_owned()),
423            ("leg_instrument_id".to_owned(), "u4".to_owned()),
424            ("leg_ratio_price_numerator".to_owned(), "i4".to_owned()),
425            ("leg_ratio_price_denominator".to_owned(), "i4".to_owned()),
426            ("leg_ratio_qty_numerator".to_owned(), "i4".to_owned()),
427            ("leg_ratio_qty_denominator".to_owned(), "i4".to_owned()),
428            ("leg_underlying_id".to_owned(), "u4".to_owned()),
429            ("appl_id".to_owned(), "i2".to_owned()),
430            ("maturity_year".to_owned(), "u2".to_owned()),
431            ("decay_start_date".to_owned(), "u2".to_owned()),
432            ("channel_id".to_owned(), "u2".to_owned()),
433            ("leg_count".to_owned(), "u2".to_owned()),
434            ("leg_index".to_owned(), "u2".to_owned()),
435            ("currency".to_owned(), "S4".to_owned()),
436            ("settl_currency".to_owned(), "S4".to_owned()),
437            ("secsubtype".to_owned(), "S6".to_owned()),
438            ("raw_symbol".to_owned(), format!("S{SYMBOL_CSTR_LEN}")),
439            ("group".to_owned(), "S21".to_owned()),
440            ("exchange".to_owned(), "S5".to_owned()),
441            ("asset".to_owned(), format!("S{ASSET_CSTR_LEN}")),
442            ("cfi".to_owned(), "S7".to_owned()),
443            ("security_type".to_owned(), "S7".to_owned()),
444            ("unit_of_measure".to_owned(), "S31".to_owned()),
445            ("underlying".to_owned(), "S21".to_owned()),
446            ("strike_price_currency".to_owned(), "S4".to_owned()),
447            ("leg_raw_symbol".to_owned(), format!("S{SYMBOL_CSTR_LEN}")),
448            ("instrument_class".to_owned(), "S1".to_owned()),
449            ("match_algorithm".to_owned(), "S1".to_owned()),
450            ("main_fraction".to_owned(), "u1".to_owned()),
451            ("price_display_format".to_owned(), "u1".to_owned()),
452            ("sub_fraction".to_owned(), "u1".to_owned()),
453            ("underlying_product".to_owned(), "u1".to_owned()),
454            ("security_update_action".to_owned(), "S1".to_owned()),
455            ("maturity_month".to_owned(), "u1".to_owned()),
456            ("maturity_day".to_owned(), "u1".to_owned()),
457            ("maturity_week".to_owned(), "u1".to_owned()),
458            ("user_defined_instrument".to_owned(), "S1".to_owned()),
459            ("contract_multiplier_unit".to_owned(), "i1".to_owned()),
460            ("flow_schedule_type".to_owned(), "i1".to_owned()),
461            ("tick_rule".to_owned(), "u1".to_owned()),
462            ("leg_instrument_class".to_owned(), "S1".to_owned()),
463            ("leg_side".to_owned(), "S1".to_owned()),
464            ("_reserved".to_owned(), "S17".to_owned()),
465        ]);
466        assert_eq!(dtypes, exp);
467    }
468
469    #[test]
470    fn test_definition_fields() {
471        assert_eq!(
472            InstrumentDefMsg::price_fields(""),
473            vec![
474                "min_price_increment".to_owned(),
475                "display_factor".to_owned(),
476                "high_limit_price".to_owned(),
477                "low_limit_price".to_owned(),
478                "max_price_variation".to_owned(),
479                "unit_of_measure_qty".to_owned(),
480                "min_price_increment_amount".to_owned(),
481                "price_ratio".to_owned(),
482                "strike_price".to_owned(),
483                "leg_price".to_owned(),
484                "leg_delta".to_owned(),
485            ]
486        );
487        assert_eq!(
488            InstrumentDefMsg::hidden_fields(""),
489            vec!["length".to_owned(), "_reserved".to_owned(),]
490        );
491        assert_eq!(
492            InstrumentDefMsg::timestamp_fields(""),
493            vec![
494                "ts_event".to_owned(),
495                "ts_recv".to_owned(),
496                "expiration".to_owned(),
497                "activation".to_owned()
498            ]
499        );
500    }
501
502    #[test]
503    fn test_definition_ordered() {
504        assert_eq!(
505            InstrumentDefMsg::ordered_fields(""),
506            vec![
507                "ts_recv".to_owned(),
508                "ts_event".to_owned(),
509                "rtype".to_owned(),
510                "publisher_id".to_owned(),
511                "instrument_id".to_owned(),
512                "raw_symbol".to_owned(),
513                "security_update_action".to_owned(),
514                "instrument_class".to_owned(),
515                "min_price_increment".to_owned(),
516                "display_factor".to_owned(),
517                "expiration".to_owned(),
518                "activation".to_owned(),
519                "high_limit_price".to_owned(),
520                "low_limit_price".to_owned(),
521                "max_price_variation".to_owned(),
522                "unit_of_measure_qty".to_owned(),
523                "min_price_increment_amount".to_owned(),
524                "price_ratio".to_owned(),
525                "inst_attrib_value".to_owned(),
526                "underlying_id".to_owned(),
527                "raw_instrument_id".to_owned(),
528                "market_depth_implied".to_owned(),
529                "market_depth".to_owned(),
530                "market_segment_id".to_owned(),
531                "max_trade_vol".to_owned(),
532                "min_lot_size".to_owned(),
533                "min_lot_size_block".to_owned(),
534                "min_lot_size_round_lot".to_owned(),
535                "min_trade_vol".to_owned(),
536                "contract_multiplier".to_owned(),
537                "decay_quantity".to_owned(),
538                "original_contract_size".to_owned(),
539                "appl_id".to_owned(),
540                "maturity_year".to_owned(),
541                "decay_start_date".to_owned(),
542                "channel_id".to_owned(),
543                "currency".to_owned(),
544                "settl_currency".to_owned(),
545                "secsubtype".to_owned(),
546                "group".to_owned(),
547                "exchange".to_owned(),
548                "asset".to_owned(),
549                "cfi".to_owned(),
550                "security_type".to_owned(),
551                "unit_of_measure".to_owned(),
552                "underlying".to_owned(),
553                "strike_price_currency".to_owned(),
554                "strike_price".to_owned(),
555                "match_algorithm".to_owned(),
556                "main_fraction".to_owned(),
557                "price_display_format".to_owned(),
558                "sub_fraction".to_owned(),
559                "underlying_product".to_owned(),
560                "maturity_month".to_owned(),
561                "maturity_day".to_owned(),
562                "maturity_week".to_owned(),
563                "user_defined_instrument".to_owned(),
564                "contract_multiplier_unit".to_owned(),
565                "flow_schedule_type".to_owned(),
566                "tick_rule".to_owned(),
567                "leg_count".to_owned(),
568                "leg_index".to_owned(),
569                "leg_instrument_id".to_owned(),
570                "leg_raw_symbol".to_owned(),
571                "leg_instrument_class".to_owned(),
572                "leg_side".to_owned(),
573                "leg_price".to_owned(),
574                "leg_delta".to_owned(),
575                "leg_ratio_price_numerator".to_owned(),
576                "leg_ratio_price_denominator".to_owned(),
577                "leg_ratio_qty_numerator".to_owned(),
578                "leg_ratio_qty_denominator".to_owned(),
579                "leg_underlying_id".to_owned(),
580            ]
581        )
582    }
583}