1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20use arrow::{
21 array::{
22 Array, BinaryArray, BinaryBuilder, StringArray, StringBuilder, UInt8Array, UInt64Array,
23 },
24 datatypes::{DataType, Field, Schema},
25 error::ArrowError,
26 record_batch::RecordBatch,
27};
28use nautilus_core::Params;
29use nautilus_model::{
30 identifiers::{InstrumentId, Symbol},
31 instruments::currency_pair::CurrencyPair,
32 types::{money::Money, price::Price, quantity::Quantity},
33};
34#[allow(unused)]
35use rust_decimal::Decimal;
36#[allow(unused)]
37use serde_json::Value;
38
39use crate::arrow::{
40 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
41 KEY_PRICE_PRECISION, KEY_SIZE_PRECISION, extract_column, extract_column_by_name_or_index,
42 extract_optional_string_column_by_name, optional_ustr_value,
43};
44
45impl ArrowSchemaProvider for CurrencyPair {
46 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
47 let fields = vec![
48 Field::new("id", DataType::Utf8, false),
49 Field::new("raw_symbol", DataType::Utf8, false),
50 Field::new("base_currency", DataType::Utf8, false),
51 Field::new("quote_currency", DataType::Utf8, false),
52 Field::new("price_precision", DataType::UInt8, false),
53 Field::new("size_precision", DataType::UInt8, false),
54 Field::new("price_increment", DataType::Utf8, false),
55 Field::new("size_increment", DataType::Utf8, false),
56 Field::new("multiplier", DataType::Utf8, false),
57 Field::new("lot_size", DataType::Utf8, true), Field::new("max_quantity", DataType::Utf8, true), Field::new("min_quantity", DataType::Utf8, true), Field::new("max_notional", DataType::Utf8, true), Field::new("min_notional", DataType::Utf8, true), Field::new("max_price", DataType::Utf8, true), Field::new("min_price", DataType::Utf8, true), Field::new("margin_init", DataType::Utf8, false),
65 Field::new("margin_maint", DataType::Utf8, false),
66 Field::new("maker_fee", DataType::Utf8, false),
67 Field::new("taker_fee", DataType::Utf8, false),
68 Field::new("tick_scheme", DataType::Utf8, true),
69 Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
71 Field::new("ts_init", DataType::UInt64, false),
72 ];
73
74 let mut final_metadata = HashMap::new();
75 final_metadata.insert("class".to_string(), "CurrencyPair".to_string());
76
77 if let Some(meta) = metadata {
78 final_metadata.extend(meta);
79 }
80
81 Schema::new_with_metadata(fields, final_metadata)
82 }
83}
84
85impl EncodeToRecordBatch for CurrencyPair {
86 fn encode_batch(
87 #[allow(unused)] metadata: &HashMap<String, String>,
88 data: &[Self],
89 ) -> Result<RecordBatch, ArrowError> {
90 let mut id_builder = StringBuilder::new();
91 let mut raw_symbol_builder = StringBuilder::new();
92 let mut base_currency_builder = StringBuilder::new();
93 let mut quote_currency_builder = StringBuilder::new();
94 let mut price_precision_builder = UInt8Array::builder(data.len());
95 let mut size_precision_builder = UInt8Array::builder(data.len());
96 let mut price_increment_builder = StringBuilder::new();
97 let mut size_increment_builder = StringBuilder::new();
98 let mut multiplier_builder = StringBuilder::new();
99 let mut lot_size_builder = StringBuilder::new();
100 let mut max_quantity_builder = StringBuilder::new();
101 let mut min_quantity_builder = StringBuilder::new();
102 let mut max_notional_builder = StringBuilder::new();
103 let mut min_notional_builder = StringBuilder::new();
104 let mut max_price_builder = StringBuilder::new();
105 let mut min_price_builder = StringBuilder::new();
106 let mut margin_init_builder = StringBuilder::new();
107 let mut margin_maint_builder = StringBuilder::new();
108 let mut maker_fee_builder = StringBuilder::new();
109 let mut taker_fee_builder = StringBuilder::new();
110 let mut tick_scheme_builder = StringBuilder::new();
111 let mut info_builder = BinaryBuilder::new();
112 let mut ts_event_builder = UInt64Array::builder(data.len());
113 let mut ts_init_builder = UInt64Array::builder(data.len());
114
115 for cp in data {
116 id_builder.append_value(cp.id.to_string());
117 raw_symbol_builder.append_value(cp.raw_symbol);
118 base_currency_builder.append_value(cp.base_currency.to_string());
119 quote_currency_builder.append_value(cp.quote_currency.to_string());
120 price_precision_builder.append_value(cp.price_precision);
121 size_precision_builder.append_value(cp.size_precision);
122 price_increment_builder.append_value(cp.price_increment.to_string());
123 size_increment_builder.append_value(cp.size_increment.to_string());
124 multiplier_builder.append_value(cp.multiplier.to_string());
125
126 if let Some(lot_size) = cp.lot_size {
127 lot_size_builder.append_value(lot_size.to_string());
128 } else {
129 lot_size_builder.append_null();
130 }
131
132 if let Some(max_qty) = cp.max_quantity {
133 max_quantity_builder.append_value(max_qty.to_string());
134 } else {
135 max_quantity_builder.append_null();
136 }
137
138 if let Some(min_qty) = cp.min_quantity {
139 min_quantity_builder.append_value(min_qty.to_string());
140 } else {
141 min_quantity_builder.append_null();
142 }
143
144 if let Some(max_not) = cp.max_notional {
145 max_notional_builder.append_value(max_not.to_string());
146 } else {
147 max_notional_builder.append_null();
148 }
149
150 if let Some(min_not) = cp.min_notional {
151 min_notional_builder.append_value(min_not.to_string());
152 } else {
153 min_notional_builder.append_null();
154 }
155
156 if let Some(max_p) = cp.max_price {
157 max_price_builder.append_value(max_p.to_string());
158 } else {
159 max_price_builder.append_null();
160 }
161
162 if let Some(min_p) = cp.min_price {
163 min_price_builder.append_value(min_p.to_string());
164 } else {
165 min_price_builder.append_null();
166 }
167
168 margin_init_builder.append_value(cp.margin_init.to_string());
169 margin_maint_builder.append_value(cp.margin_maint.to_string());
170 maker_fee_builder.append_value(cp.maker_fee.to_string());
171 taker_fee_builder.append_value(cp.taker_fee.to_string());
172
173 if let Some(tick_scheme) = cp.tick_scheme {
174 tick_scheme_builder.append_value(tick_scheme);
175 } else {
176 tick_scheme_builder.append_null();
177 }
178
179 if let Some(ref info) = cp.info {
181 match serde_json::to_vec(info) {
182 Ok(json_bytes) => {
183 info_builder.append_value(json_bytes);
184 }
185 Err(e) => {
186 return Err(ArrowError::InvalidArgumentError(format!(
187 "Failed to serialize info dict to JSON: {e}"
188 )));
189 }
190 }
191 } else {
192 info_builder.append_null();
193 }
194
195 ts_event_builder.append_value(cp.ts_event.as_u64());
196 ts_init_builder.append_value(cp.ts_init.as_u64());
197 }
198
199 let mut final_metadata = metadata.clone();
200 final_metadata.insert("class".to_string(), "CurrencyPair".to_string());
201
202 RecordBatch::try_new(
203 Self::get_schema(Some(final_metadata)).into(),
204 vec![
205 Arc::new(id_builder.finish()),
206 Arc::new(raw_symbol_builder.finish()),
207 Arc::new(base_currency_builder.finish()),
208 Arc::new(quote_currency_builder.finish()),
209 Arc::new(price_precision_builder.finish()),
210 Arc::new(size_precision_builder.finish()),
211 Arc::new(price_increment_builder.finish()),
212 Arc::new(size_increment_builder.finish()),
213 Arc::new(multiplier_builder.finish()),
214 Arc::new(lot_size_builder.finish()),
215 Arc::new(max_quantity_builder.finish()),
216 Arc::new(min_quantity_builder.finish()),
217 Arc::new(max_notional_builder.finish()),
218 Arc::new(min_notional_builder.finish()),
219 Arc::new(max_price_builder.finish()),
220 Arc::new(min_price_builder.finish()),
221 Arc::new(margin_init_builder.finish()),
222 Arc::new(margin_maint_builder.finish()),
223 Arc::new(maker_fee_builder.finish()),
224 Arc::new(taker_fee_builder.finish()),
225 Arc::new(tick_scheme_builder.finish()),
226 Arc::new(info_builder.finish()),
227 Arc::new(ts_event_builder.finish()),
228 Arc::new(ts_init_builder.finish()),
229 ],
230 )
231 }
232
233 fn metadata(&self) -> HashMap<String, String> {
234 let mut metadata = HashMap::new();
235 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
236 metadata.insert(
237 KEY_PRICE_PRECISION.to_string(),
238 self.price_precision.to_string(),
239 );
240 metadata.insert(
241 KEY_SIZE_PRECISION.to_string(),
242 self.size_precision.to_string(),
243 );
244 metadata
245 }
246}
247
248pub fn decode_currency_pair_batch(
255 #[allow(unused)] metadata: &HashMap<String, String>,
256 record_batch: &RecordBatch,
257) -> Result<Vec<CurrencyPair>, EncodingError> {
258 let cols = record_batch.columns();
259 let num_rows = record_batch.num_rows();
260
261 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
263 let raw_symbol_values = extract_column::<StringArray>(cols, "raw_symbol", 1, DataType::Utf8)?;
264 let base_currency_values =
265 extract_column::<StringArray>(cols, "base_currency", 2, DataType::Utf8)?;
266 let quote_currency_values =
267 extract_column::<StringArray>(cols, "quote_currency", 3, DataType::Utf8)?;
268 let price_precision_values =
269 extract_column::<UInt8Array>(cols, "price_precision", 4, DataType::UInt8)?;
270 let size_precision_values =
271 extract_column::<UInt8Array>(cols, "size_precision", 5, DataType::UInt8)?;
272
273 let price_increment_values =
275 extract_column::<StringArray>(cols, "price_increment", 6, DataType::Utf8)?;
276 let size_increment_values =
277 extract_column::<StringArray>(cols, "size_increment", 7, DataType::Utf8)?;
278 let multiplier_values = extract_column::<StringArray>(cols, "multiplier", 8, DataType::Utf8)?;
279 let lot_size_values = cols
280 .get(9)
281 .ok_or_else(|| EncodingError::MissingColumn("lot_size", 9))?;
282 let max_quantity_values = cols
283 .get(10)
284 .ok_or_else(|| EncodingError::MissingColumn("max_quantity", 10))?;
285 let min_quantity_values = cols
286 .get(11)
287 .ok_or_else(|| EncodingError::MissingColumn("min_quantity", 11))?;
288 let max_notional_values = cols
289 .get(12)
290 .ok_or_else(|| EncodingError::MissingColumn("max_notional", 12))?;
291 let min_notional_values = cols
292 .get(13)
293 .ok_or_else(|| EncodingError::MissingColumn("min_notional", 13))?;
294 let max_price_values = cols
295 .get(14)
296 .ok_or_else(|| EncodingError::MissingColumn("max_price", 14))?;
297 let min_price_values = cols
298 .get(15)
299 .ok_or_else(|| EncodingError::MissingColumn("min_price", 15))?;
300 let margin_init_values =
301 extract_column::<StringArray>(cols, "margin_init", 16, DataType::Utf8)?;
302 let margin_maint_values =
303 extract_column::<StringArray>(cols, "margin_maint", 17, DataType::Utf8)?;
304 let maker_fee_values = extract_column::<StringArray>(cols, "maker_fee", 18, DataType::Utf8)?;
305 let taker_fee_values = extract_column::<StringArray>(cols, "taker_fee", 19, DataType::Utf8)?;
306 let tick_scheme_values = extract_optional_string_column_by_name(record_batch, "tick_scheme")?;
307 let info_values =
308 extract_column_by_name_or_index::<BinaryArray>(record_batch, "info", 20, DataType::Binary)?;
309 let ts_event_values = extract_column_by_name_or_index::<UInt64Array>(
310 record_batch,
311 "ts_event",
312 21,
313 DataType::UInt64,
314 )?;
315 let ts_init_values = extract_column_by_name_or_index::<UInt64Array>(
316 record_batch,
317 "ts_init",
318 22,
319 DataType::UInt64,
320 )?;
321
322 let mut result = Vec::with_capacity(num_rows);
323
324 for i in 0..num_rows {
325 let id = InstrumentId::from_str(id_values.value(i))
326 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
327 let raw_symbol = Symbol::from(raw_symbol_values.value(i));
328 let base_currency = super::decode_currency(
329 base_currency_values.value(i),
330 "base_currency",
331 "currency_pair.base_currency",
332 i,
333 )?;
334 let quote_currency = super::decode_currency(
335 quote_currency_values.value(i),
336 "quote_currency",
337 "currency_pair.quote_currency",
338 i,
339 )?;
340 let price_prec = price_precision_values.value(i);
341 let size_prec = size_precision_values.value(i);
342
343 let price_increment = Price::from_str(price_increment_values.value(i))
344 .map_err(|e| EncodingError::ParseError("price_increment", format!("row {i}: {e}")))?;
345 let size_increment = Quantity::from_str(size_increment_values.value(i))
346 .map_err(|e| EncodingError::ParseError("size_increment", format!("row {i}: {e}")))?;
347 let multiplier = Quantity::from_str(multiplier_values.value(i))
348 .map_err(|e| EncodingError::ParseError("multiplier", format!("row {i}: {e}")))?;
349
350 let lot_size = if lot_size_values.is_null(i) {
351 None
352 } else {
353 let lot_size_str = lot_size_values
354 .as_any()
355 .downcast_ref::<StringArray>()
356 .ok_or_else(|| {
357 EncodingError::ParseError("lot_size", format!("row {i}: invalid type"))
358 })?
359 .value(i);
360 Some(
361 Quantity::from_str(lot_size_str)
362 .map_err(|e| EncodingError::ParseError("lot_size", format!("row {i}: {e}")))?,
363 )
364 };
365
366 let max_quantity =
367 if max_quantity_values.is_null(i) {
368 None
369 } else {
370 let max_qty_str = max_quantity_values
371 .as_any()
372 .downcast_ref::<StringArray>()
373 .ok_or_else(|| {
374 EncodingError::ParseError("max_quantity", format!("row {i}: invalid type"))
375 })?
376 .value(i);
377 Some(Quantity::from_str(max_qty_str).map_err(|e| {
378 EncodingError::ParseError("max_quantity", format!("row {i}: {e}"))
379 })?)
380 };
381
382 let min_quantity =
383 if min_quantity_values.is_null(i) {
384 None
385 } else {
386 let min_qty_str = min_quantity_values
387 .as_any()
388 .downcast_ref::<StringArray>()
389 .ok_or_else(|| {
390 EncodingError::ParseError("min_quantity", format!("row {i}: invalid type"))
391 })?
392 .value(i);
393 Some(Quantity::from_str(min_qty_str).map_err(|e| {
394 EncodingError::ParseError("min_quantity", format!("row {i}: {e}"))
395 })?)
396 };
397
398 let max_notional =
399 if max_notional_values.is_null(i) {
400 None
401 } else {
402 let max_not_str = max_notional_values
403 .as_any()
404 .downcast_ref::<StringArray>()
405 .ok_or_else(|| {
406 EncodingError::ParseError("max_notional", format!("row {i}: invalid type"))
407 })?
408 .value(i);
409 Some(Money::from_str(max_not_str).map_err(|e| {
410 EncodingError::ParseError("max_notional", format!("row {i}: {e}"))
411 })?)
412 };
413
414 let min_notional =
415 if min_notional_values.is_null(i) {
416 None
417 } else {
418 let min_not_str = min_notional_values
419 .as_any()
420 .downcast_ref::<StringArray>()
421 .ok_or_else(|| {
422 EncodingError::ParseError("min_notional", format!("row {i}: invalid type"))
423 })?
424 .value(i);
425 Some(Money::from_str(min_not_str).map_err(|e| {
426 EncodingError::ParseError("min_notional", format!("row {i}: {e}"))
427 })?)
428 };
429
430 let max_price = if max_price_values.is_null(i) {
431 None
432 } else {
433 let max_p_str = max_price_values
434 .as_any()
435 .downcast_ref::<StringArray>()
436 .ok_or_else(|| {
437 EncodingError::ParseError("max_price", format!("row {i}: invalid type"))
438 })?
439 .value(i);
440 Some(
441 Price::from_str(max_p_str)
442 .map_err(|e| EncodingError::ParseError("max_price", format!("row {i}: {e}")))?,
443 )
444 };
445
446 let min_price = if min_price_values.is_null(i) {
447 None
448 } else {
449 let min_p_str = min_price_values
450 .as_any()
451 .downcast_ref::<StringArray>()
452 .ok_or_else(|| {
453 EncodingError::ParseError("min_price", format!("row {i}: invalid type"))
454 })?
455 .value(i);
456 Some(
457 Price::from_str(min_p_str)
458 .map_err(|e| EncodingError::ParseError("min_price", format!("row {i}: {e}")))?,
459 )
460 };
461
462 let margin_init = Decimal::from_str(margin_init_values.value(i))
463 .map_err(|e| EncodingError::ParseError("margin_init", format!("row {i}: {e}")))?;
464 let margin_maint = Decimal::from_str(margin_maint_values.value(i))
465 .map_err(|e| EncodingError::ParseError("margin_maint", format!("row {i}: {e}")))?;
466 let maker_fee = Decimal::from_str(maker_fee_values.value(i))
467 .map_err(|e| EncodingError::ParseError("maker_fee", format!("row {i}: {e}")))?;
468 let taker_fee = Decimal::from_str(taker_fee_values.value(i))
469 .map_err(|e| EncodingError::ParseError("taker_fee", format!("row {i}: {e}")))?;
470
471 let info = if info_values.is_null(i) {
473 None
474 } else {
475 let info_bytes = info_values
476 .as_any()
477 .downcast_ref::<BinaryArray>()
478 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
479 .value(i);
480
481 match serde_json::from_slice::<Params>(info_bytes) {
482 Ok(info_dict) => Some(info_dict),
483 Err(e) => {
484 return Err(EncodingError::ParseError(
485 "info",
486 format!("row {i}: failed to deserialize JSON: {e}"),
487 ));
488 }
489 }
490 };
491
492 let ts_event = nautilus_core::UnixNanos::from(ts_event_values.value(i));
493 let ts_init = nautilus_core::UnixNanos::from(ts_init_values.value(i));
494
495 let tick_scheme = optional_ustr_value(tick_scheme_values, i);
496
497 let currency_pair = CurrencyPair::new_checked(
498 id,
499 raw_symbol,
500 base_currency,
501 quote_currency,
502 price_prec,
503 size_prec,
504 price_increment,
505 size_increment,
506 Some(multiplier),
507 lot_size,
508 max_quantity,
509 min_quantity,
510 max_notional,
511 min_notional,
512 max_price,
513 min_price,
514 Some(margin_init),
515 Some(margin_maint),
516 Some(maker_fee),
517 Some(taker_fee),
518 tick_scheme,
519 info,
520 ts_event,
521 ts_init,
522 )
523 .map_err(|e| super::instrument_validation_error::<CurrencyPair>(i, e))?;
524
525 result.push(currency_pair);
526 }
527
528 Ok(result)
529}