1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20use arrow::{
21 array::{
22 Array, BinaryArray, BinaryBuilder, StringArray, StringBuilder, UInt8Array, UInt64Array,
23 },
24 datatypes::{DataType, Field, Schema},
25 error::ArrowError,
26 record_batch::RecordBatch,
27};
28#[allow(unused_imports)]
29use nautilus_core::Params;
30use nautilus_model::{
31 enums::AssetClass,
32 identifiers::{InstrumentId, Symbol},
33 instruments::binary_option::BinaryOption,
34 types::{price::Price, quantity::Quantity},
35};
36#[allow(unused)]
37use rust_decimal::Decimal;
38#[allow(unused)]
39use serde_json::Value;
40use ustr::Ustr;
41
42use crate::arrow::{
43 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
44 KEY_PRICE_PRECISION, KEY_SIZE_PRECISION, extract_column, extract_column_by_name_or_index,
45 extract_optional_string_column_by_name, optional_ustr_value,
46};
47
48fn asset_class_to_string(ac: AssetClass) -> String {
50 match ac {
51 AssetClass::FX => "FX".to_string(),
52 AssetClass::Equity => "Equity".to_string(),
53 AssetClass::Commodity => "Commodity".to_string(),
54 AssetClass::Debt => "Debt".to_string(),
55 AssetClass::Index => "Index".to_string(),
56 AssetClass::Cryptocurrency => "Cryptocurrency".to_string(),
57 AssetClass::Alternative => "Alternative".to_string(),
58 }
59}
60
61fn asset_class_from_str(s: &str) -> Result<AssetClass, EncodingError> {
63 match s {
64 "FX" => Ok(AssetClass::FX),
65 "Equity" => Ok(AssetClass::Equity),
66 "Commodity" => Ok(AssetClass::Commodity),
67 "Debt" => Ok(AssetClass::Debt),
68 "Index" => Ok(AssetClass::Index),
69 "Cryptocurrency" => Ok(AssetClass::Cryptocurrency),
70 "Alternative" => Ok(AssetClass::Alternative),
71 _ => Err(EncodingError::ParseError(
72 "asset_class",
73 format!("Unknown asset class: {s}"),
74 )),
75 }
76}
77
78impl ArrowSchemaProvider for BinaryOption {
79 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
80 let fields = vec![
81 Field::new("id", DataType::Utf8, false),
82 Field::new("raw_symbol", DataType::Utf8, false),
83 Field::new("asset_class", DataType::Utf8, false),
84 Field::new("currency", DataType::Utf8, false),
85 Field::new("price_precision", DataType::UInt8, false),
86 Field::new("size_precision", DataType::UInt8, false),
87 Field::new("price_increment", DataType::Utf8, false),
88 Field::new("size_increment", DataType::Utf8, false),
89 Field::new("activation_ns", DataType::UInt64, false),
90 Field::new("expiration_ns", DataType::UInt64, false),
91 Field::new("maker_fee", DataType::Utf8, false),
92 Field::new("taker_fee", DataType::Utf8, false),
93 Field::new("max_quantity", DataType::Utf8, true), Field::new("min_quantity", DataType::Utf8, true), Field::new("outcome", DataType::Utf8, true), Field::new("description", DataType::Utf8, true), Field::new("tick_scheme", DataType::Utf8, true),
98 Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
100 Field::new("ts_init", DataType::UInt64, false),
101 ];
102
103 let mut final_metadata = HashMap::new();
104 final_metadata.insert("class".to_string(), "BinaryOption".to_string());
105
106 if let Some(meta) = metadata {
107 final_metadata.extend(meta);
108 }
109
110 Schema::new_with_metadata(fields, final_metadata)
111 }
112}
113
114impl EncodeToRecordBatch for BinaryOption {
115 fn encode_batch(
116 #[allow(unused)] metadata: &HashMap<String, String>,
117 data: &[Self],
118 ) -> Result<RecordBatch, ArrowError> {
119 let mut id_builder = StringBuilder::new();
120 let mut raw_symbol_builder = StringBuilder::new();
121 let mut asset_class_builder = StringBuilder::new();
122 let mut currency_builder = StringBuilder::new();
123 let mut price_precision_builder = UInt8Array::builder(data.len());
124 let mut size_precision_builder = UInt8Array::builder(data.len());
125 let mut price_increment_builder = StringBuilder::new();
126 let mut size_increment_builder = StringBuilder::new();
127 let mut activation_ns_builder = UInt64Array::builder(data.len());
128 let mut expiration_ns_builder = UInt64Array::builder(data.len());
129 let mut maker_fee_builder = StringBuilder::new();
130 let mut taker_fee_builder = StringBuilder::new();
131 let mut max_quantity_builder = StringBuilder::new();
132 let mut min_quantity_builder = StringBuilder::new();
133 let mut outcome_builder = StringBuilder::new();
134 let mut description_builder = StringBuilder::new();
135 let mut tick_scheme_builder = StringBuilder::new();
136 let mut info_builder = BinaryBuilder::new();
137 let mut ts_event_builder = UInt64Array::builder(data.len());
138 let mut ts_init_builder = UInt64Array::builder(data.len());
139
140 for bo in data {
141 id_builder.append_value(bo.id.to_string());
142 raw_symbol_builder.append_value(bo.raw_symbol);
143 asset_class_builder.append_value(asset_class_to_string(bo.asset_class));
144 currency_builder.append_value(bo.currency.to_string());
145 price_precision_builder.append_value(bo.price_precision);
146 size_precision_builder.append_value(bo.size_precision);
147 price_increment_builder.append_value(bo.price_increment.to_string());
148 size_increment_builder.append_value(bo.size_increment.to_string());
149 activation_ns_builder.append_value(bo.activation_ns.as_u64());
150 expiration_ns_builder.append_value(bo.expiration_ns.as_u64());
151 maker_fee_builder.append_value(bo.maker_fee.to_string());
152 taker_fee_builder.append_value(bo.taker_fee.to_string());
153
154 if let Some(max_qty) = bo.max_quantity {
155 max_quantity_builder.append_value(max_qty.to_string());
156 } else {
157 max_quantity_builder.append_null();
158 }
159
160 if let Some(min_qty) = bo.min_quantity {
161 min_quantity_builder.append_value(min_qty.to_string());
162 } else {
163 min_quantity_builder.append_null();
164 }
165
166 if let Some(outcome) = bo.outcome {
167 outcome_builder.append_value(outcome);
168 } else {
169 outcome_builder.append_null();
170 }
171
172 if let Some(desc) = bo.description {
173 description_builder.append_value(desc);
174 } else {
175 description_builder.append_null();
176 }
177
178 if let Some(tick_scheme) = bo.tick_scheme {
179 tick_scheme_builder.append_value(tick_scheme);
180 } else {
181 tick_scheme_builder.append_null();
182 }
183
184 if let Some(ref info) = bo.info {
186 match serde_json::to_vec(info) {
187 Ok(json_bytes) => {
188 info_builder.append_value(json_bytes);
189 }
190 Err(e) => {
191 return Err(ArrowError::InvalidArgumentError(format!(
192 "Failed to serialize info dict to JSON: {e}"
193 )));
194 }
195 }
196 } else {
197 info_builder.append_null();
198 }
199
200 ts_event_builder.append_value(bo.ts_event.as_u64());
201 ts_init_builder.append_value(bo.ts_init.as_u64());
202 }
203
204 let mut final_metadata = metadata.clone();
205 final_metadata.insert("class".to_string(), "BinaryOption".to_string());
206
207 RecordBatch::try_new(
208 Self::get_schema(Some(final_metadata)).into(),
209 vec![
210 Arc::new(id_builder.finish()),
211 Arc::new(raw_symbol_builder.finish()),
212 Arc::new(asset_class_builder.finish()),
213 Arc::new(currency_builder.finish()),
214 Arc::new(price_precision_builder.finish()),
215 Arc::new(size_precision_builder.finish()),
216 Arc::new(price_increment_builder.finish()),
217 Arc::new(size_increment_builder.finish()),
218 Arc::new(activation_ns_builder.finish()),
219 Arc::new(expiration_ns_builder.finish()),
220 Arc::new(maker_fee_builder.finish()),
221 Arc::new(taker_fee_builder.finish()),
222 Arc::new(max_quantity_builder.finish()),
223 Arc::new(min_quantity_builder.finish()),
224 Arc::new(outcome_builder.finish()),
225 Arc::new(description_builder.finish()),
226 Arc::new(tick_scheme_builder.finish()),
227 Arc::new(info_builder.finish()),
228 Arc::new(ts_event_builder.finish()),
229 Arc::new(ts_init_builder.finish()),
230 ],
231 )
232 }
233
234 fn metadata(&self) -> HashMap<String, String> {
235 let mut metadata = HashMap::new();
236 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
237 metadata.insert(
238 KEY_PRICE_PRECISION.to_string(),
239 self.price_precision.to_string(),
240 );
241 metadata.insert(
242 KEY_SIZE_PRECISION.to_string(),
243 self.size_precision.to_string(),
244 );
245 metadata
246 }
247}
248
249pub fn decode_binary_option_batch(
256 #[allow(unused)] metadata: &HashMap<String, String>,
257 record_batch: &RecordBatch,
258) -> Result<Vec<BinaryOption>, EncodingError> {
259 let cols = record_batch.columns();
260 let num_rows = record_batch.num_rows();
261
262 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
263 let raw_symbol_values = extract_column::<StringArray>(cols, "raw_symbol", 1, DataType::Utf8)?;
264 let asset_class_values = extract_column::<StringArray>(cols, "asset_class", 2, DataType::Utf8)?;
265 let currency_values = extract_column::<StringArray>(cols, "currency", 3, DataType::Utf8)?;
266 let price_precision_values =
267 extract_column::<UInt8Array>(cols, "price_precision", 4, DataType::UInt8)?;
268 let size_precision_values =
269 extract_column::<UInt8Array>(cols, "size_precision", 5, DataType::UInt8)?;
270 let price_increment_values =
271 extract_column::<StringArray>(cols, "price_increment", 6, DataType::Utf8)?;
272 let size_increment_values =
273 extract_column::<StringArray>(cols, "size_increment", 7, DataType::Utf8)?;
274 let activation_ns_values =
275 extract_column::<UInt64Array>(cols, "activation_ns", 8, DataType::UInt64)?;
276 let expiration_ns_values =
277 extract_column::<UInt64Array>(cols, "expiration_ns", 9, DataType::UInt64)?;
278 let maker_fee_values = extract_column::<StringArray>(cols, "maker_fee", 10, DataType::Utf8)?;
279 let taker_fee_values = extract_column::<StringArray>(cols, "taker_fee", 11, DataType::Utf8)?;
280 let max_quantity_values = cols
281 .get(12)
282 .ok_or_else(|| EncodingError::MissingColumn("max_quantity", 12))?;
283 let min_quantity_values = cols
284 .get(13)
285 .ok_or_else(|| EncodingError::MissingColumn("min_quantity", 13))?;
286 let outcome_values = cols
287 .get(14)
288 .ok_or_else(|| EncodingError::MissingColumn("outcome", 14))?;
289 let description_values = cols
290 .get(15)
291 .ok_or_else(|| EncodingError::MissingColumn("description", 15))?;
292 let tick_scheme_values = extract_optional_string_column_by_name(record_batch, "tick_scheme")?;
293 let info_values =
294 extract_column_by_name_or_index::<BinaryArray>(record_batch, "info", 16, DataType::Binary)?;
295 let ts_event_values = extract_column_by_name_or_index::<UInt64Array>(
296 record_batch,
297 "ts_event",
298 17,
299 DataType::UInt64,
300 )?;
301 let ts_init_values = extract_column_by_name_or_index::<UInt64Array>(
302 record_batch,
303 "ts_init",
304 18,
305 DataType::UInt64,
306 )?;
307
308 let mut result = Vec::with_capacity(num_rows);
309
310 for i in 0..num_rows {
311 let id = InstrumentId::from_str(id_values.value(i))
312 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
313 let raw_symbol = Symbol::from(raw_symbol_values.value(i));
314 let asset_class = asset_class_from_str(asset_class_values.value(i))?;
315 let currency = super::decode_currency(
316 currency_values.value(i),
317 "currency",
318 "binary_option.currency",
319 i,
320 )?;
321 let price_prec = price_precision_values.value(i);
322 let size_prec = size_precision_values.value(i);
323
324 let price_increment = Price::from_str(price_increment_values.value(i))
325 .map_err(|e| EncodingError::ParseError("price_increment", format!("row {i}: {e}")))?;
326 let size_increment = Quantity::from_str(size_increment_values.value(i))
327 .map_err(|e| EncodingError::ParseError("size_increment", format!("row {i}: {e}")))?;
328
329 let activation_ns = nautilus_core::UnixNanos::from(activation_ns_values.value(i));
330 let expiration_ns = nautilus_core::UnixNanos::from(expiration_ns_values.value(i));
331
332 let maker_fee = Decimal::from_str(maker_fee_values.value(i))
333 .map_err(|e| EncodingError::ParseError("maker_fee", format!("row {i}: {e}")))?;
334 let taker_fee = Decimal::from_str(taker_fee_values.value(i))
335 .map_err(|e| EncodingError::ParseError("taker_fee", format!("row {i}: {e}")))?;
336
337 let max_quantity =
338 if max_quantity_values.is_null(i) {
339 None
340 } else {
341 let max_qty_str = max_quantity_values
342 .as_any()
343 .downcast_ref::<StringArray>()
344 .ok_or_else(|| {
345 EncodingError::ParseError("max_quantity", format!("row {i}: invalid type"))
346 })?
347 .value(i);
348 Some(Quantity::from_str(max_qty_str).map_err(|e| {
349 EncodingError::ParseError("max_quantity", format!("row {i}: {e}"))
350 })?)
351 };
352
353 let min_quantity =
354 if min_quantity_values.is_null(i) {
355 None
356 } else {
357 let min_qty_str = min_quantity_values
358 .as_any()
359 .downcast_ref::<StringArray>()
360 .ok_or_else(|| {
361 EncodingError::ParseError("min_quantity", format!("row {i}: invalid type"))
362 })?
363 .value(i);
364 Some(Quantity::from_str(min_qty_str).map_err(|e| {
365 EncodingError::ParseError("min_quantity", format!("row {i}: {e}"))
366 })?)
367 };
368
369 let outcome = if outcome_values.is_null(i) {
370 None
371 } else {
372 let outcome_str = outcome_values
373 .as_any()
374 .downcast_ref::<StringArray>()
375 .ok_or_else(|| {
376 EncodingError::ParseError("outcome", format!("row {i}: invalid type"))
377 })?
378 .value(i);
379 Some(Ustr::from(outcome_str))
380 };
381
382 let description = if description_values.is_null(i) {
383 None
384 } else {
385 let desc_str = description_values
386 .as_any()
387 .downcast_ref::<StringArray>()
388 .ok_or_else(|| {
389 EncodingError::ParseError("description", format!("row {i}: invalid type"))
390 })?
391 .value(i);
392 Some(Ustr::from(desc_str))
393 };
394
395 let info = if info_values.is_null(i) {
397 None
398 } else {
399 let info_bytes = info_values
400 .as_any()
401 .downcast_ref::<BinaryArray>()
402 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
403 .value(i);
404
405 match serde_json::from_slice::<Params>(info_bytes) {
406 Ok(info_dict) => Some(info_dict),
407 Err(e) => {
408 return Err(EncodingError::ParseError(
409 "info",
410 format!("row {i}: failed to deserialize JSON: {e}"),
411 ));
412 }
413 }
414 };
415
416 let ts_event = nautilus_core::UnixNanos::from(ts_event_values.value(i));
417 let ts_init = nautilus_core::UnixNanos::from(ts_init_values.value(i));
418
419 let tick_scheme = optional_ustr_value(tick_scheme_values, i);
420
421 let binary_option = BinaryOption::new_checked(
422 id,
423 raw_symbol,
424 asset_class,
425 currency,
426 activation_ns,
427 expiration_ns,
428 price_prec,
429 size_prec,
430 price_increment,
431 size_increment,
432 outcome,
433 description,
434 max_quantity,
435 min_quantity,
436 None, None, None, None, None, None, Some(maker_fee),
443 Some(taker_fee),
444 tick_scheme,
445 info,
446 ts_event,
447 ts_init,
448 )
449 .map_err(|e| super::instrument_validation_error::<BinaryOption>(i, e))?;
450
451 result.push(binary_option);
452 }
453
454 Ok(result)
455}