1use std::{collections::HashMap, io, num::NonZeroU64};
2
3use pyo3::{
4 intern,
5 prelude::*,
6 types::{PyBytes, PyDate, PyDict, PyType},
7 Bound,
8};
9
10use crate::{
11 decode::{DbnMetadata, DynDecoder},
12 encode::dbn::MetadataEncoder,
13 enums::{SType, Schema},
14 MappingInterval, Metadata, SymbolMapping, VersionUpgradePolicy,
15};
16
17use super::{py_to_time_date, to_py_err};
18
19#[pymethods]
20impl Metadata {
21 #[new]
22 #[pyo3(signature = (
23 dataset,
24 start,
25 stype_in,
26 stype_out,
27 schema,
28 symbols=None,
29 partial=None,
30 not_found=None,
31 mappings=None,
32 end=None,
33 limit=None,
34 ts_out=None,
35 version=crate::DBN_VERSION,
36 ))]
37 fn py_new(
38 dataset: String,
39 start: u64,
40 stype_in: Option<SType>,
41 stype_out: SType,
42 schema: Option<Schema>,
43 symbols: Option<Vec<String>>,
44 partial: Option<Vec<String>>,
45 not_found: Option<Vec<String>>,
46 mappings: Option<Vec<SymbolMapping>>,
47 end: Option<u64>,
48 limit: Option<u64>,
49 ts_out: Option<bool>,
50 version: u8,
51 ) -> Metadata {
52 Metadata::builder()
53 .dataset(dataset)
54 .start(start)
55 .stype_out(stype_out)
56 .symbols(symbols.unwrap_or_default())
57 .partial(partial.unwrap_or_default())
58 .not_found(not_found.unwrap_or_default())
59 .mappings(mappings.unwrap_or_default())
60 .schema(schema)
61 .stype_in(stype_in)
62 .end(NonZeroU64::new(end.unwrap_or_default()))
63 .limit(NonZeroU64::new(limit.unwrap_or_default()))
64 .ts_out(ts_out.unwrap_or_default())
65 .version(version)
66 .build()
67 }
68
69 fn __repr__(&self) -> String {
70 format!("{self:?}")
71 }
72
73 fn __bytes__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
75 self.py_encode(py)
76 }
77
78 #[getter]
79 fn get_mappings<'py>(&self, py: Python<'py>) -> PyResult<HashMap<String, Bound<'py, PyAny>>> {
80 let mut res = HashMap::new();
81 for mapping in self.mappings.iter() {
82 res.insert(
83 mapping.raw_symbol.clone(),
84 mapping.intervals.into_pyobject(py)?,
85 );
86 }
87 Ok(res)
88 }
89
90 #[pyo3(name = "decode", signature = (data, upgrade_policy = VersionUpgradePolicy::default()))]
91 #[classmethod]
92 fn py_decode(
93 _cls: &Bound<PyType>,
94 data: &Bound<PyBytes>,
95 upgrade_policy: VersionUpgradePolicy,
96 ) -> PyResult<Metadata> {
97 let reader = io::BufReader::new(data.as_bytes());
98 let mut metadata = DynDecoder::inferred_with_buffer(reader, upgrade_policy)?
99 .metadata()
100 .clone();
101 metadata.upgrade(upgrade_policy);
102 Ok(metadata)
103 }
104
105 #[pyo3(name = "encode")]
106 fn py_encode<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
107 let mut buffer = Vec::new();
108 let mut encoder = MetadataEncoder::new(&mut buffer);
109 encoder.encode(self)?;
110 Ok(PyBytes::new(py, buffer.as_slice()))
111 }
112}
113
114impl<'py> IntoPyObject<'py> for SymbolMapping {
115 type Target = PyDict;
116 type Output = Bound<'py, PyDict>;
117 type Error = PyErr;
118
119 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
120 let dict = PyDict::new(py);
121 dict.set_item(intern!(py, "raw_symbol"), &self.raw_symbol)?;
122 dict.set_item(intern!(py, "intervals"), &self.intervals)?;
123 Ok(dict)
124 }
125}
126
127impl<'a, 'py> FromPyObject<'a, 'py> for MappingInterval {
128 type Error = PyErr;
129
130 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> Result<Self, Self::Error> {
131 let start_date = ob
132 .getattr(intern!(ob.py(), "start_date"))
133 .map_err(|_| to_py_err("Missing start_date".to_owned()))
134 .and_then(extract_date)?;
135 let end_date = ob
136 .getattr(intern!(ob.py(), "end_date"))
137 .map_err(|_| to_py_err("Missing end_date".to_owned()))
138 .and_then(extract_date)?;
139 let symbol = ob
140 .getattr(intern!(ob.py(), "symbol"))
141 .map_err(|_| to_py_err("Missing symbol".to_owned()))
142 .and_then(|d| d.extract::<String>())?;
143 Ok(Self {
144 start_date,
145 end_date,
146 symbol,
147 })
148 }
149}
150
151impl<'py> IntoPyObject<'py> for &MappingInterval {
152 type Target = PyDict;
153 type Output = Bound<'py, PyDict>;
154 type Error = PyErr;
155
156 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
157 let dict = PyDict::new(py);
158 dict.set_item(
159 intern!(py, "start_date"),
160 PyDate::new(
161 py,
162 self.start_date.year(),
163 self.start_date.month() as u8,
164 self.start_date.day(),
165 )?,
166 )?;
167 dict.set_item(
168 intern!(py, "end_date"),
169 PyDate::new(
170 py,
171 self.end_date.year(),
172 self.end_date.month() as u8,
173 self.end_date.day(),
174 )?,
175 )?;
176 dict.set_item(intern!(py, "symbol"), &self.symbol)?;
177 Ok(dict)
178 }
179}
180
181fn extract_date(any: Bound<'_, PyAny>) -> PyResult<time::Date> {
182 let py_date = any.cast::<PyDate>().map_err(PyErr::from)?;
183 py_to_time_date(py_date)
184}