1use std::{collections::HashMap, io, num::NonZeroU64};
2
3use pyo3::{
4 intern,
5 prelude::*,
6 types::{PyBytes, PyDate, PyDict, PyType},
7 Bound,
8};
9
10use crate::{
11 decode::{DbnMetadata, DynDecoder},
12 encode::dbn::MetadataEncoder,
13 enums::{SType, Schema},
14 MappingInterval, Metadata, SymbolMapping, VersionUpgradePolicy,
15};
16
17use super::{py_to_time_date, to_py_err};
18
19#[pymethods]
20impl Metadata {
21 #[new]
22 #[pyo3(signature = (
23 dataset,
24 start,
25 stype_in,
26 stype_out,
27 schema,
28 symbols=None,
29 partial=None,
30 not_found=None,
31 mappings=None,
32 end=None,
33 limit=None,
34 ts_out=None,
35 version=crate::DBN_VERSION,
36 ))]
37 fn py_new(
38 dataset: String,
39 start: u64,
40 stype_in: Option<SType>,
41 stype_out: SType,
42 schema: Option<Schema>,
43 symbols: Option<Vec<String>>,
44 partial: Option<Vec<String>>,
45 not_found: Option<Vec<String>>,
46 mappings: Option<Vec<SymbolMapping>>,
47 end: Option<u64>,
48 limit: Option<u64>,
49 ts_out: Option<bool>,
50 version: u8,
51 ) -> Metadata {
52 Metadata::builder()
53 .dataset(dataset)
54 .start(start)
55 .stype_out(stype_out)
56 .symbols(symbols.unwrap_or_default())
57 .partial(partial.unwrap_or_default())
58 .not_found(not_found.unwrap_or_default())
59 .mappings(mappings.unwrap_or_default())
60 .schema(schema)
61 .stype_in(stype_in)
62 .end(NonZeroU64::new(end.unwrap_or_default()))
63 .limit(NonZeroU64::new(limit.unwrap_or_default()))
64 .ts_out(ts_out.unwrap_or_default())
65 .version(version)
66 .build()
67 }
68
69 fn __repr__(&self) -> String {
70 format!("{self:?}")
71 }
72
73 fn __bytes__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
75 self.py_encode(py)
76 }
77
78 #[getter]
79 fn get_mappings<'py>(&self, py: Python<'py>) -> PyResult<HashMap<String, Bound<'py, PyAny>>> {
80 let mut res = HashMap::new();
81 for mapping in self.mappings.iter() {
82 res.insert(
83 mapping.raw_symbol.clone(),
84 mapping.intervals.into_pyobject(py)?,
85 );
86 }
87 Ok(res)
88 }
89
90 #[pyo3(name = "decode", signature = (data, upgrade_policy = VersionUpgradePolicy::default()))]
91 #[classmethod]
92 fn py_decode(
93 _cls: &Bound<PyType>,
94 data: &Bound<PyBytes>,
95 upgrade_policy: VersionUpgradePolicy,
96 ) -> PyResult<Metadata> {
97 let reader = io::BufReader::new(data.as_bytes());
98 let mut metadata = DynDecoder::inferred_with_buffer(reader, upgrade_policy)?
99 .metadata()
100 .clone();
101 metadata.upgrade(upgrade_policy);
102 Ok(metadata)
103 }
104
105 #[pyo3(name = "encode")]
106 fn py_encode<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyBytes>> {
107 let mut buffer = Vec::new();
108 let mut encoder = MetadataEncoder::new(&mut buffer);
109 encoder.encode(self)?;
110 Ok(PyBytes::new(py, buffer.as_slice()))
111 }
112}
113
114impl<'py> IntoPyObject<'py> for SymbolMapping {
115 type Target = PyDict;
116 type Output = Bound<'py, PyDict>;
117 type Error = PyErr;
118
119 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
120 let dict = PyDict::new(py);
121 dict.set_item(intern!(py, "raw_symbol"), &self.raw_symbol)?;
122 dict.set_item(intern!(py, "intervals"), &self.intervals)?;
123 Ok(dict)
124 }
125}
126
127impl<'py> FromPyObject<'py> for MappingInterval {
128 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
129 let start_date = ob
130 .getattr(intern!(ob.py(), "start_date"))
131 .map_err(|_| to_py_err("Missing start_date".to_owned()))
132 .and_then(extract_date)?;
133 let end_date = ob
134 .getattr(intern!(ob.py(), "end_date"))
135 .map_err(|_| to_py_err("Missing end_date".to_owned()))
136 .and_then(extract_date)?;
137 let symbol = ob
138 .getattr(intern!(ob.py(), "symbol"))
139 .map_err(|_| to_py_err("Missing symbol".to_owned()))
140 .and_then(|d| d.extract::<String>())?;
141 Ok(Self {
142 start_date,
143 end_date,
144 symbol,
145 })
146 }
147}
148
149impl<'py> IntoPyObject<'py> for &MappingInterval {
150 type Target = PyDict;
151 type Output = Bound<'py, PyDict>;
152 type Error = PyErr;
153
154 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
155 let dict = PyDict::new(py);
156 dict.set_item(
157 intern!(py, "start_date"),
158 PyDate::new(
159 py,
160 self.start_date.year(),
161 self.start_date.month() as u8,
162 self.start_date.day(),
163 )?,
164 )?;
165 dict.set_item(
166 intern!(py, "end_date"),
167 PyDate::new(
168 py,
169 self.end_date.year(),
170 self.end_date.month() as u8,
171 self.end_date.day(),
172 )?,
173 )?;
174 dict.set_item(intern!(py, "symbol"), &self.symbol)?;
175 Ok(dict)
176 }
177}
178
179fn extract_date(any: Bound<'_, PyAny>) -> PyResult<time::Date> {
180 let py_date = any.downcast::<PyDate>().map_err(PyErr::from)?;
181 py_to_time_date(py_date)
182}