dbn/python/
metadata.rs

1use std::{collections::HashMap, io, num::NonZeroU64};
2
3use pyo3::{
4    intern,
5    prelude::*,
6    types::{PyBytes, PyDate, PyDict, PyType},
7    Bound,
8};
9
10use crate::{
11    decode::{DbnMetadata, DynDecoder},
12    encode::dbn::MetadataEncoder,
13    enums::{SType, Schema},
14    MappingInterval, Metadata, SymbolMapping, VersionUpgradePolicy,
15};
16
17use super::{py_to_time_date, to_py_err};
18
19#[pymethods]
20impl Metadata {
21    #[new]
22    #[pyo3(signature = (
23        dataset,
24        start,
25        stype_in,
26        stype_out,
27        schema,
28        symbols=None,
29        partial=None,
30        not_found=None,
31        mappings=None,
32        end=None,
33        limit=None,
34        ts_out=None,
35        version=crate::DBN_VERSION,
36    ))]
37    fn py_new(
38        dataset: String,
39        start: u64,
40        stype_in: Option<SType>,
41        stype_out: SType,
42        schema: Option<Schema>,
43        symbols: Option<Vec<String>>,
44        partial: Option<Vec<String>>,
45        not_found: Option<Vec<String>>,
46        mappings: Option<Vec<SymbolMapping>>,
47        end: Option<u64>,
48        limit: Option<u64>,
49        ts_out: Option<bool>,
50        version: u8,
51    ) -> Metadata {
52        Metadata::builder()
53            .dataset(dataset)
54            .start(start)
55            .stype_out(stype_out)
56            .symbols(symbols.unwrap_or_default())
57            .partial(partial.unwrap_or_default())
58            .not_found(not_found.unwrap_or_default())
59            .mappings(mappings.unwrap_or_default())
60            .schema(schema)
61            .stype_in(stype_in)
62            .end(NonZeroU64::new(end.unwrap_or_default()))
63            .limit(NonZeroU64::new(limit.unwrap_or_default()))
64            .ts_out(ts_out.unwrap_or_default())
65            .version(version)
66            .build()
67    }
68
69    fn __repr__(&self) -> String {
70        format!("{self:?}")
71    }
72
73    /// Encodes Metadata back into DBN format.
74    fn __bytes__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
75        self.py_encode(py)
76    }
77
78    #[getter]
79    fn get_mappings<'py>(&self, py: Python<'py>) -> PyResult<HashMap<String, Bound<'py, PyAny>>> {
80        let mut res = HashMap::new();
81        for mapping in self.mappings.iter() {
82            res.insert(
83                mapping.raw_symbol.clone(),
84                mapping.intervals.into_pyobject(py)?,
85            );
86        }
87        Ok(res)
88    }
89
90    #[pyo3(name = "decode", signature = (data, upgrade_policy = VersionUpgradePolicy::default()))]
91    #[classmethod]
92    fn py_decode(
93        _cls: &Bound<PyType>,
94        data: &Bound<PyBytes>,
95        upgrade_policy: VersionUpgradePolicy,
96    ) -> PyResult<Metadata> {
97        let reader = io::BufReader::new(data.as_bytes());
98        let mut metadata = DynDecoder::inferred_with_buffer(reader, upgrade_policy)?
99            .metadata()
100            .clone();
101        metadata.upgrade(upgrade_policy);
102        Ok(metadata)
103    }
104
105    #[pyo3(name = "encode")]
106    fn py_encode<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
107        let mut buffer = Vec::new();
108        let mut encoder = MetadataEncoder::new(&mut buffer);
109        encoder.encode(self)?;
110        Ok(PyBytes::new(py, buffer.as_slice()))
111    }
112}
113
114impl<'py> IntoPyObject<'py> for SymbolMapping {
115    type Target = PyDict;
116    type Output = Bound<'py, PyDict>;
117    type Error = PyErr;
118
119    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
120        let dict = PyDict::new(py);
121        dict.set_item(intern!(py, "raw_symbol"), &self.raw_symbol)?;
122        dict.set_item(intern!(py, "intervals"), &self.intervals)?;
123        Ok(dict)
124    }
125}
126
127impl<'py> FromPyObject<'py> for MappingInterval {
128    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
129        let start_date = ob
130            .getattr(intern!(ob.py(), "start_date"))
131            .map_err(|_| to_py_err("Missing start_date".to_owned()))
132            .and_then(extract_date)?;
133        let end_date = ob
134            .getattr(intern!(ob.py(), "end_date"))
135            .map_err(|_| to_py_err("Missing end_date".to_owned()))
136            .and_then(extract_date)?;
137        let symbol = ob
138            .getattr(intern!(ob.py(), "symbol"))
139            .map_err(|_| to_py_err("Missing symbol".to_owned()))
140            .and_then(|d| d.extract::<String>())?;
141        Ok(Self {
142            start_date,
143            end_date,
144            symbol,
145        })
146    }
147}
148
149impl<'py> IntoPyObject<'py> for &MappingInterval {
150    type Target = PyDict;
151    type Output = Bound<'py, PyDict>;
152    type Error = PyErr;
153
154    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
155        let dict = PyDict::new(py);
156        dict.set_item(
157            intern!(py, "start_date"),
158            PyDate::new(
159                py,
160                self.start_date.year(),
161                self.start_date.month() as u8,
162                self.start_date.day(),
163            )?,
164        )?;
165        dict.set_item(
166            intern!(py, "end_date"),
167            PyDate::new(
168                py,
169                self.end_date.year(),
170                self.end_date.month() as u8,
171                self.end_date.day(),
172            )?,
173        )?;
174        dict.set_item(intern!(py, "symbol"), &self.symbol)?;
175        Ok(dict)
176    }
177}
178
179fn extract_date(any: Bound<'_, PyAny>) -> PyResult<time::Date> {
180    let py_date = any.downcast::<PyDate>().map_err(PyErr::from)?;
181    py_to_time_date(py_date)
182}