1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20use arrow::{
21 array::{
22 Array, BinaryArray, BinaryBuilder, StringArray, StringBuilder, UInt8Array, UInt64Array,
23 },
24 datatypes::{DataType, Field, Schema},
25 error::ArrowError,
26 record_batch::RecordBatch,
27};
28use nautilus_core::Params;
29use nautilus_model::{
30 identifiers::{InstrumentId, Symbol},
31 instruments::equity::Equity,
32 types::{price::Price, quantity::Quantity},
33};
34#[allow(unused)]
35use rust_decimal::Decimal;
36#[allow(unused)]
37use serde_json::Value;
38use ustr::Ustr;
39
40use crate::arrow::{
41 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
42 KEY_PRICE_PRECISION, extract_column, extract_column_by_name_or_index,
43 extract_optional_string_column_by_name, optional_ustr_value,
44};
45
46impl ArrowSchemaProvider for Equity {
47 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
48 let fields = vec![
49 Field::new("id", DataType::Utf8, false),
50 Field::new("raw_symbol", DataType::Utf8, false),
51 Field::new("currency", DataType::Utf8, false),
52 Field::new("price_precision", DataType::UInt8, false),
53 Field::new("price_increment", DataType::Utf8, false),
54 Field::new("lot_size", DataType::Utf8, true), Field::new("isin", DataType::Utf8, true), Field::new("max_quantity", DataType::Utf8, true), Field::new("min_quantity", DataType::Utf8, true), Field::new("max_price", DataType::Utf8, true), Field::new("min_price", DataType::Utf8, true), Field::new("margin_init", DataType::Utf8, false),
61 Field::new("margin_maint", DataType::Utf8, false),
62 Field::new("maker_fee", DataType::Utf8, false),
63 Field::new("taker_fee", DataType::Utf8, false),
64 Field::new("tick_scheme", DataType::Utf8, true),
65 Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
67 Field::new("ts_init", DataType::UInt64, false),
68 ];
69
70 let mut final_metadata = HashMap::new();
71 final_metadata.insert("class".to_string(), "Equity".to_string());
72
73 if let Some(meta) = metadata {
74 final_metadata.extend(meta);
75 }
76
77 Schema::new_with_metadata(fields, final_metadata)
78 }
79}
80
81impl EncodeToRecordBatch for Equity {
82 fn encode_batch(
83 #[allow(unused)] metadata: &HashMap<String, String>,
84 data: &[Self],
85 ) -> Result<RecordBatch, ArrowError> {
86 let mut id_builder = StringBuilder::new();
87 let mut raw_symbol_builder = StringBuilder::new();
88 let mut currency_builder = StringBuilder::new();
89 let mut price_precision_builder = UInt8Array::builder(data.len());
90 let mut price_increment_builder = StringBuilder::new();
91 let mut lot_size_builder = StringBuilder::new();
92 let mut isin_builder = StringBuilder::new();
93 let mut max_quantity_builder = StringBuilder::new();
94 let mut min_quantity_builder = StringBuilder::new();
95 let mut max_price_builder = StringBuilder::new();
96 let mut min_price_builder = StringBuilder::new();
97 let mut margin_init_builder = StringBuilder::new();
98 let mut margin_maint_builder = StringBuilder::new();
99 let mut maker_fee_builder = StringBuilder::new();
100 let mut taker_fee_builder = StringBuilder::new();
101 let mut tick_scheme_builder = StringBuilder::new();
102 let mut info_builder = BinaryBuilder::new();
103 let mut ts_event_builder = UInt64Array::builder(data.len());
104 let mut ts_init_builder = UInt64Array::builder(data.len());
105
106 for equity in data {
107 id_builder.append_value(equity.id.to_string());
108 raw_symbol_builder.append_value(equity.raw_symbol);
109 currency_builder.append_value(equity.currency.to_string());
110 price_precision_builder.append_value(equity.price_precision);
111 price_increment_builder.append_value(equity.price_increment.to_string());
112
113 if let Some(lot_size) = equity.lot_size {
114 lot_size_builder.append_value(lot_size.to_string());
115 } else {
116 lot_size_builder.append_null();
117 }
118
119 if let Some(isin) = equity.isin {
120 isin_builder.append_value(isin);
121 } else {
122 isin_builder.append_null();
123 }
124
125 if let Some(max_qty) = equity.max_quantity {
126 max_quantity_builder.append_value(max_qty.to_string());
127 } else {
128 max_quantity_builder.append_null();
129 }
130
131 if let Some(min_qty) = equity.min_quantity {
132 min_quantity_builder.append_value(min_qty.to_string());
133 } else {
134 min_quantity_builder.append_null();
135 }
136
137 if let Some(max_p) = equity.max_price {
138 max_price_builder.append_value(max_p.to_string());
139 } else {
140 max_price_builder.append_null();
141 }
142
143 if let Some(min_p) = equity.min_price {
144 min_price_builder.append_value(min_p.to_string());
145 } else {
146 min_price_builder.append_null();
147 }
148
149 margin_init_builder.append_value(equity.margin_init.to_string());
150 margin_maint_builder.append_value(equity.margin_maint.to_string());
151 maker_fee_builder.append_value(equity.maker_fee.to_string());
152 taker_fee_builder.append_value(equity.taker_fee.to_string());
153
154 if let Some(tick_scheme) = equity.tick_scheme {
155 tick_scheme_builder.append_value(tick_scheme);
156 } else {
157 tick_scheme_builder.append_null();
158 }
159
160 if let Some(ref info) = equity.info {
162 match serde_json::to_vec(info) {
163 Ok(json_bytes) => {
164 info_builder.append_value(json_bytes);
165 }
166 Err(e) => {
167 return Err(ArrowError::InvalidArgumentError(format!(
168 "Failed to serialize info dict to JSON: {e}"
169 )));
170 }
171 }
172 } else {
173 info_builder.append_null();
174 }
175
176 ts_event_builder.append_value(equity.ts_event.as_u64());
177 ts_init_builder.append_value(equity.ts_init.as_u64());
178 }
179
180 let mut final_metadata = metadata.clone();
181 final_metadata.insert("class".to_string(), "Equity".to_string());
182
183 RecordBatch::try_new(
184 Self::get_schema(Some(final_metadata)).into(),
185 vec![
186 Arc::new(id_builder.finish()),
187 Arc::new(raw_symbol_builder.finish()),
188 Arc::new(currency_builder.finish()),
189 Arc::new(price_precision_builder.finish()),
190 Arc::new(price_increment_builder.finish()),
191 Arc::new(lot_size_builder.finish()),
192 Arc::new(isin_builder.finish()),
193 Arc::new(max_quantity_builder.finish()),
194 Arc::new(min_quantity_builder.finish()),
195 Arc::new(max_price_builder.finish()),
196 Arc::new(min_price_builder.finish()),
197 Arc::new(margin_init_builder.finish()),
198 Arc::new(margin_maint_builder.finish()),
199 Arc::new(maker_fee_builder.finish()),
200 Arc::new(taker_fee_builder.finish()),
201 Arc::new(tick_scheme_builder.finish()),
202 Arc::new(info_builder.finish()),
203 Arc::new(ts_event_builder.finish()),
204 Arc::new(ts_init_builder.finish()),
205 ],
206 )
207 }
208
209 fn metadata(&self) -> HashMap<String, String> {
210 let mut metadata = HashMap::new();
211 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
212 metadata.insert(
213 KEY_PRICE_PRECISION.to_string(),
214 self.price_precision.to_string(),
215 );
216 metadata
217 }
218}
219
220pub fn decode_equity_batch(
227 #[allow(unused)] metadata: &HashMap<String, String>,
228 record_batch: &RecordBatch,
229) -> Result<Vec<Equity>, EncodingError> {
230 let cols = record_batch.columns();
231 let num_rows = record_batch.num_rows();
232
233 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
235 let raw_symbol_values = extract_column::<StringArray>(cols, "raw_symbol", 1, DataType::Utf8)?;
236 let currency_values = extract_column::<StringArray>(cols, "currency", 2, DataType::Utf8)?;
237 let price_precision_values =
238 extract_column::<UInt8Array>(cols, "price_precision", 3, DataType::UInt8)?;
239 let price_increment_values =
240 extract_column::<StringArray>(cols, "price_increment", 4, DataType::Utf8)?;
241 let lot_size_values = cols
242 .get(5)
243 .ok_or_else(|| EncodingError::MissingColumn("lot_size", 5))?;
244 let isin_values = cols
245 .get(6)
246 .ok_or_else(|| EncodingError::MissingColumn("isin", 6))?;
247 let max_quantity_values = cols
248 .get(7)
249 .ok_or_else(|| EncodingError::MissingColumn("max_quantity", 7))?;
250 let min_quantity_values = cols
251 .get(8)
252 .ok_or_else(|| EncodingError::MissingColumn("min_quantity", 8))?;
253 let max_price_values = cols
254 .get(9)
255 .ok_or_else(|| EncodingError::MissingColumn("max_price", 9))?;
256 let min_price_values = cols
257 .get(10)
258 .ok_or_else(|| EncodingError::MissingColumn("min_price", 10))?;
259 let margin_init_values =
260 extract_column::<StringArray>(cols, "margin_init", 11, DataType::Utf8)?;
261 let margin_maint_values =
262 extract_column::<StringArray>(cols, "margin_maint", 12, DataType::Utf8)?;
263 let maker_fee_values = extract_column::<StringArray>(cols, "maker_fee", 13, DataType::Utf8)?;
264 let taker_fee_values = extract_column::<StringArray>(cols, "taker_fee", 14, DataType::Utf8)?;
265 let tick_scheme_values = extract_optional_string_column_by_name(record_batch, "tick_scheme")?;
266 let info_values =
267 extract_column_by_name_or_index::<BinaryArray>(record_batch, "info", 15, DataType::Binary)?;
268 let ts_event_values = extract_column_by_name_or_index::<UInt64Array>(
269 record_batch,
270 "ts_event",
271 16,
272 DataType::UInt64,
273 )?;
274 let ts_init_values = extract_column_by_name_or_index::<UInt64Array>(
275 record_batch,
276 "ts_init",
277 17,
278 DataType::UInt64,
279 )?;
280
281 let mut result = Vec::with_capacity(num_rows);
282
283 for i in 0..num_rows {
284 let id = InstrumentId::from_str(id_values.value(i))
285 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
286 let raw_symbol = Symbol::from(raw_symbol_values.value(i));
287 let currency =
288 super::decode_currency(currency_values.value(i), "currency", "equity.currency", i)?;
289 let price_prec = price_precision_values.value(i);
290
291 let price_increment = Price::from_str(price_increment_values.value(i))
292 .map_err(|e| EncodingError::ParseError("price_increment", format!("row {i}: {e}")))?;
293
294 let lot_size = if lot_size_values.is_null(i) {
295 None
296 } else {
297 let lot_size_str = lot_size_values
298 .as_any()
299 .downcast_ref::<StringArray>()
300 .ok_or_else(|| {
301 EncodingError::ParseError("lot_size", format!("row {i}: invalid type"))
302 })?
303 .value(i);
304 Some(
305 Quantity::from_str(lot_size_str)
306 .map_err(|e| EncodingError::ParseError("lot_size", format!("row {i}: {e}")))?,
307 )
308 };
309
310 let isin = if isin_values.is_null(i) {
311 None
312 } else {
313 let isin_str = isin_values
314 .as_any()
315 .downcast_ref::<StringArray>()
316 .ok_or_else(|| EncodingError::ParseError("isin", format!("row {i}: invalid type")))?
317 .value(i);
318 Some(Ustr::from(isin_str))
319 };
320
321 let max_quantity =
322 if max_quantity_values.is_null(i) {
323 None
324 } else {
325 let max_qty_str = max_quantity_values
326 .as_any()
327 .downcast_ref::<StringArray>()
328 .ok_or_else(|| {
329 EncodingError::ParseError("max_quantity", format!("row {i}: invalid type"))
330 })?
331 .value(i);
332 Some(Quantity::from_str(max_qty_str).map_err(|e| {
333 EncodingError::ParseError("max_quantity", format!("row {i}: {e}"))
334 })?)
335 };
336
337 let min_quantity =
338 if min_quantity_values.is_null(i) {
339 None
340 } else {
341 let min_qty_str = min_quantity_values
342 .as_any()
343 .downcast_ref::<StringArray>()
344 .ok_or_else(|| {
345 EncodingError::ParseError("min_quantity", format!("row {i}: invalid type"))
346 })?
347 .value(i);
348 Some(Quantity::from_str(min_qty_str).map_err(|e| {
349 EncodingError::ParseError("min_quantity", format!("row {i}: {e}"))
350 })?)
351 };
352
353 let max_price = if max_price_values.is_null(i) {
354 None
355 } else {
356 let max_p_str = max_price_values
357 .as_any()
358 .downcast_ref::<StringArray>()
359 .ok_or_else(|| {
360 EncodingError::ParseError("max_price", format!("row {i}: invalid type"))
361 })?
362 .value(i);
363 Some(
364 Price::from_str(max_p_str)
365 .map_err(|e| EncodingError::ParseError("max_price", format!("row {i}: {e}")))?,
366 )
367 };
368
369 let min_price = if min_price_values.is_null(i) {
370 None
371 } else {
372 let min_p_str = min_price_values
373 .as_any()
374 .downcast_ref::<StringArray>()
375 .ok_or_else(|| {
376 EncodingError::ParseError("min_price", format!("row {i}: invalid type"))
377 })?
378 .value(i);
379 Some(
380 Price::from_str(min_p_str)
381 .map_err(|e| EncodingError::ParseError("min_price", format!("row {i}: {e}")))?,
382 )
383 };
384
385 let margin_init = Decimal::from_str(margin_init_values.value(i))
386 .map_err(|e| EncodingError::ParseError("margin_init", format!("row {i}: {e}")))?;
387 let margin_maint = Decimal::from_str(margin_maint_values.value(i))
388 .map_err(|e| EncodingError::ParseError("margin_maint", format!("row {i}: {e}")))?;
389 let maker_fee = Decimal::from_str(maker_fee_values.value(i))
390 .map_err(|e| EncodingError::ParseError("maker_fee", format!("row {i}: {e}")))?;
391 let taker_fee = Decimal::from_str(taker_fee_values.value(i))
392 .map_err(|e| EncodingError::ParseError("taker_fee", format!("row {i}: {e}")))?;
393
394 let info = if info_values.is_null(i) {
396 None
397 } else {
398 let info_bytes = info_values
399 .as_any()
400 .downcast_ref::<BinaryArray>()
401 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
402 .value(i);
403
404 match serde_json::from_slice::<Params>(info_bytes) {
405 Ok(info_dict) => Some(info_dict),
406 Err(e) => {
407 return Err(EncodingError::ParseError(
408 "info",
409 format!("row {i}: failed to deserialize JSON: {e}"),
410 ));
411 }
412 }
413 };
414
415 let ts_event = nautilus_core::UnixNanos::from(ts_event_values.value(i));
416 let ts_init = nautilus_core::UnixNanos::from(ts_init_values.value(i));
417
418 let tick_scheme = optional_ustr_value(tick_scheme_values, i);
419
420 let equity = Equity::new_checked(
421 id,
422 raw_symbol,
423 isin,
424 currency,
425 price_prec,
426 price_increment,
427 lot_size,
428 max_quantity,
429 min_quantity,
430 max_price,
431 min_price,
432 Some(margin_init),
433 Some(margin_maint),
434 Some(maker_fee),
435 Some(taker_fee),
436 tick_scheme,
437 info,
438 ts_event,
439 ts_init,
440 )
441 .map_err(|e| super::instrument_validation_error::<Equity>(i, e))?;
442
443 result.push(equity);
444 }
445
446 Ok(result)
447}