1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20#[allow(unused_imports)]
21use arrow::{
22 array::{
23 Array, BinaryArray, BinaryBuilder, Float64Array, Float64Builder, Int64Array, Int64Builder,
24 StringArray, StringBuilder, UInt8Array, UInt64Array,
25 },
26 datatypes::{DataType, Field, Schema},
27 error::ArrowError,
28 record_batch::RecordBatch,
29};
30#[allow(unused_imports)]
31use nautilus_core::Params;
32use nautilus_model::{
33 identifiers::InstrumentId,
34 instruments::betting::BettingInstrument,
35 types::{price::Price, quantity::Quantity},
36};
37#[allow(unused)]
38use rust_decimal::Decimal;
39#[allow(unused)]
40use serde_json::Value;
41use ustr::Ustr;
42
43use crate::arrow::{
44 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
45 KEY_PRICE_PRECISION, KEY_SIZE_PRECISION, extract_column, extract_column_by_name_or_index,
46 extract_optional_string_column_by_name, optional_ustr_value,
47};
48
49impl ArrowSchemaProvider for BettingInstrument {
50 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
51 let fields = vec![
52 Field::new("id", DataType::Utf8, false),
53 Field::new("venue_name", DataType::Utf8, false),
54 Field::new("currency", DataType::Utf8, false),
55 Field::new("event_type_id", DataType::UInt64, false),
56 Field::new("event_type_name", DataType::Utf8, false),
57 Field::new("competition_id", DataType::UInt64, false),
58 Field::new("competition_name", DataType::Utf8, false),
59 Field::new("event_id", DataType::UInt64, false),
60 Field::new("event_name", DataType::Utf8, false),
61 Field::new("event_country_code", DataType::Utf8, false),
62 Field::new("event_open_date", DataType::UInt64, false),
63 Field::new("betting_type", DataType::Utf8, false),
64 Field::new("market_id", DataType::Utf8, false),
65 Field::new("market_name", DataType::Utf8, false),
66 Field::new("market_type", DataType::Utf8, false),
67 Field::new("market_start_time", DataType::UInt64, false),
68 Field::new("selection_id", DataType::UInt64, false),
69 Field::new("selection_name", DataType::Utf8, false),
70 Field::new("selection_handicap", DataType::Float64, false),
71 Field::new("price_precision", DataType::UInt8, false),
72 Field::new("size_precision", DataType::UInt8, false),
73 Field::new("tick_scheme", DataType::Utf8, true),
74 Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
76 Field::new("ts_init", DataType::UInt64, false),
77 ];
78
79 let mut final_metadata = HashMap::new();
80 final_metadata.insert("class".to_string(), "BettingInstrument".to_string());
81
82 if let Some(meta) = metadata {
83 final_metadata.extend(meta);
84 }
85
86 Schema::new_with_metadata(fields, final_metadata)
87 }
88}
89
90impl EncodeToRecordBatch for BettingInstrument {
91 fn encode_batch(
92 #[allow(unused)] metadata: &HashMap<String, String>,
93 data: &[Self],
94 ) -> Result<RecordBatch, ArrowError> {
95 let mut id_builder = StringBuilder::new();
96 let mut venue_name_builder = StringBuilder::new();
97 let mut currency_builder = StringBuilder::new();
98 let mut event_type_id_builder = UInt64Array::builder(data.len());
99 let mut event_type_name_builder = StringBuilder::new();
100 let mut competition_id_builder = UInt64Array::builder(data.len());
101 let mut competition_name_builder = StringBuilder::new();
102 let mut event_id_builder = UInt64Array::builder(data.len());
103 let mut event_name_builder = StringBuilder::new();
104 let mut event_country_code_builder = StringBuilder::new();
105 let mut event_open_date_builder = UInt64Array::builder(data.len());
106 let mut betting_type_builder = StringBuilder::new();
107 let mut market_id_builder = StringBuilder::new();
108 let mut market_name_builder = StringBuilder::new();
109 let mut market_type_builder = StringBuilder::new();
110 let mut market_start_time_builder = UInt64Array::builder(data.len());
111 let mut selection_id_builder = UInt64Array::builder(data.len());
112 let mut selection_name_builder = StringBuilder::new();
113 let mut selection_handicap_builder = Float64Array::builder(data.len());
114 let mut price_precision_builder = UInt8Array::builder(data.len());
115 let mut size_precision_builder = UInt8Array::builder(data.len());
116 let mut tick_scheme_builder = StringBuilder::new();
117 let mut info_builder = BinaryBuilder::new();
118 let mut ts_event_builder = UInt64Array::builder(data.len());
119 let mut ts_init_builder = UInt64Array::builder(data.len());
120
121 for bi in data {
122 id_builder.append_value(bi.id.to_string());
123 let venue_name = bi.id.venue.to_string();
125 venue_name_builder.append_value(venue_name);
126 currency_builder.append_value(bi.currency.to_string());
127 event_type_id_builder.append_value(bi.event_type_id);
128 event_type_name_builder.append_value(bi.event_type_name);
129 competition_id_builder.append_value(bi.competition_id);
130 competition_name_builder.append_value(bi.competition_name);
131 event_id_builder.append_value(bi.event_id);
132 event_name_builder.append_value(bi.event_name);
133 event_country_code_builder.append_value(bi.event_country_code);
134 event_open_date_builder.append_value(bi.event_open_date.as_u64());
135 betting_type_builder.append_value(bi.betting_type);
136 market_id_builder.append_value(bi.market_id);
137 market_name_builder.append_value(bi.market_name);
138 market_type_builder.append_value(bi.market_type);
139 market_start_time_builder.append_value(bi.market_start_time.as_u64());
140 selection_id_builder.append_value(bi.selection_id);
141 selection_name_builder.append_value(bi.selection_name);
142 selection_handicap_builder.append_value(bi.selection_handicap);
143 price_precision_builder.append_value(bi.price_precision);
144 size_precision_builder.append_value(bi.size_precision);
145
146 if let Some(tick_scheme) = bi.tick_scheme {
147 tick_scheme_builder.append_value(tick_scheme);
148 } else {
149 tick_scheme_builder.append_null();
150 }
151
152 if let Some(ref info) = bi.info {
154 match serde_json::to_vec(info) {
155 Ok(json_bytes) => {
156 info_builder.append_value(json_bytes);
157 }
158 Err(e) => {
159 return Err(ArrowError::InvalidArgumentError(format!(
160 "Failed to serialize info dict to JSON: {e}"
161 )));
162 }
163 }
164 } else {
165 info_builder.append_null();
166 }
167
168 ts_event_builder.append_value(bi.ts_event.as_u64());
169 ts_init_builder.append_value(bi.ts_init.as_u64());
170 }
171
172 let mut final_metadata = metadata.clone();
173 final_metadata.insert("class".to_string(), "BettingInstrument".to_string());
174
175 RecordBatch::try_new(
176 Self::get_schema(Some(final_metadata)).into(),
177 vec![
178 Arc::new(id_builder.finish()),
179 Arc::new(venue_name_builder.finish()),
180 Arc::new(currency_builder.finish()),
181 Arc::new(event_type_id_builder.finish()),
182 Arc::new(event_type_name_builder.finish()),
183 Arc::new(competition_id_builder.finish()),
184 Arc::new(competition_name_builder.finish()),
185 Arc::new(event_id_builder.finish()),
186 Arc::new(event_name_builder.finish()),
187 Arc::new(event_country_code_builder.finish()),
188 Arc::new(event_open_date_builder.finish()),
189 Arc::new(betting_type_builder.finish()),
190 Arc::new(market_id_builder.finish()),
191 Arc::new(market_name_builder.finish()),
192 Arc::new(market_type_builder.finish()),
193 Arc::new(market_start_time_builder.finish()),
194 Arc::new(selection_id_builder.finish()),
195 Arc::new(selection_name_builder.finish()),
196 Arc::new(selection_handicap_builder.finish()),
197 Arc::new(price_precision_builder.finish()),
198 Arc::new(size_precision_builder.finish()),
199 Arc::new(tick_scheme_builder.finish()),
200 Arc::new(info_builder.finish()),
201 Arc::new(ts_event_builder.finish()),
202 Arc::new(ts_init_builder.finish()),
203 ],
204 )
205 }
206
207 fn metadata(&self) -> HashMap<String, String> {
208 let mut metadata = HashMap::new();
209 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
210 metadata.insert(
211 KEY_PRICE_PRECISION.to_string(),
212 self.price_precision.to_string(),
213 );
214 metadata.insert(
215 KEY_SIZE_PRECISION.to_string(),
216 self.size_precision.to_string(),
217 );
218 metadata
219 }
220}
221
222pub fn decode_betting_instrument_batch(
229 #[allow(unused)] metadata: &HashMap<String, String>,
230 record_batch: &RecordBatch,
231) -> Result<Vec<BettingInstrument>, EncodingError> {
232 let cols = record_batch.columns();
233 let num_rows = record_batch.num_rows();
234
235 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
236 let _venue_name_values = extract_column::<StringArray>(cols, "venue_name", 1, DataType::Utf8)?; let currency_values = extract_column::<StringArray>(cols, "currency", 2, DataType::Utf8)?;
238 let event_type_id_values =
239 extract_column::<UInt64Array>(cols, "event_type_id", 3, DataType::UInt64)?;
240 let event_type_name_values =
241 extract_column::<StringArray>(cols, "event_type_name", 4, DataType::Utf8)?;
242 let competition_id_values =
243 extract_column::<UInt64Array>(cols, "competition_id", 5, DataType::UInt64)?;
244 let competition_name_values =
245 extract_column::<StringArray>(cols, "competition_name", 6, DataType::Utf8)?;
246 let event_id_values = extract_column::<UInt64Array>(cols, "event_id", 7, DataType::UInt64)?;
247 let event_name_values = extract_column::<StringArray>(cols, "event_name", 8, DataType::Utf8)?;
248 let event_country_code_values =
249 extract_column::<StringArray>(cols, "event_country_code", 9, DataType::Utf8)?;
250 let event_open_date_values =
251 extract_column::<UInt64Array>(cols, "event_open_date", 10, DataType::UInt64)?;
252 let betting_type_values =
253 extract_column::<StringArray>(cols, "betting_type", 11, DataType::Utf8)?;
254 let market_id_values = extract_column::<StringArray>(cols, "market_id", 12, DataType::Utf8)?;
255 let market_name_values =
256 extract_column::<StringArray>(cols, "market_name", 13, DataType::Utf8)?;
257 let market_type_values =
258 extract_column::<StringArray>(cols, "market_type", 14, DataType::Utf8)?;
259 let market_start_time_values =
260 extract_column::<UInt64Array>(cols, "market_start_time", 15, DataType::UInt64)?;
261 let selection_id_values =
262 extract_column::<UInt64Array>(cols, "selection_id", 16, DataType::UInt64)?;
263 let selection_name_values =
264 extract_column::<StringArray>(cols, "selection_name", 17, DataType::Utf8)?;
265 let selection_handicap_values =
266 extract_column::<Float64Array>(cols, "selection_handicap", 18, DataType::Float64)?;
267 let price_precision_values =
268 extract_column::<UInt8Array>(cols, "price_precision", 19, DataType::UInt8)?;
269 let size_precision_values =
270 extract_column::<UInt8Array>(cols, "size_precision", 20, DataType::UInt8)?;
271 let tick_scheme_values = extract_optional_string_column_by_name(record_batch, "tick_scheme")?;
272 let info_values =
273 extract_column_by_name_or_index::<BinaryArray>(record_batch, "info", 21, DataType::Binary)?;
274 let ts_event_values = extract_column_by_name_or_index::<UInt64Array>(
275 record_batch,
276 "ts_event",
277 22,
278 DataType::UInt64,
279 )?;
280 let ts_init_values = extract_column_by_name_or_index::<UInt64Array>(
281 record_batch,
282 "ts_init",
283 23,
284 DataType::UInt64,
285 )?;
286
287 let mut result = Vec::with_capacity(num_rows);
288
289 for i in 0..num_rows {
290 let id = InstrumentId::from_str(id_values.value(i))
291 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
292 let currency = super::decode_currency(
293 currency_values.value(i),
294 "currency",
295 "betting_instrument.currency",
296 i,
297 )?;
298 let event_type_id = event_type_id_values.value(i);
299 let event_type_name = Ustr::from(event_type_name_values.value(i));
300 let competition_id = competition_id_values.value(i);
301 let competition_name = Ustr::from(competition_name_values.value(i));
302 let event_id = event_id_values.value(i);
303 let event_name = Ustr::from(event_name_values.value(i));
304 let event_country_code = Ustr::from(event_country_code_values.value(i));
305 let event_open_date = nautilus_core::UnixNanos::from(event_open_date_values.value(i));
306 let betting_type = Ustr::from(betting_type_values.value(i));
307 let market_id = Ustr::from(market_id_values.value(i));
308 let market_name = Ustr::from(market_name_values.value(i));
309 let market_type = Ustr::from(market_type_values.value(i));
310 let market_start_time = nautilus_core::UnixNanos::from(market_start_time_values.value(i));
311 let selection_id = selection_id_values.value(i);
312 let selection_name = Ustr::from(selection_name_values.value(i));
313 let selection_handicap = selection_handicap_values.value(i);
314 let price_prec = price_precision_values.value(i);
315 let size_prec = size_precision_values.value(i);
316
317 let info = if info_values.is_null(i) {
319 None
320 } else {
321 let info_bytes = info_values
322 .as_any()
323 .downcast_ref::<BinaryArray>()
324 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
325 .value(i);
326
327 match serde_json::from_slice::<Params>(info_bytes) {
328 Ok(info_dict) => Some(info_dict),
329 Err(e) => {
330 return Err(EncodingError::ParseError(
331 "info",
332 format!("row {i}: failed to deserialize JSON: {e}"),
333 ));
334 }
335 }
336 };
337
338 let ts_event = nautilus_core::UnixNanos::from(ts_event_values.value(i));
339 let ts_init = nautilus_core::UnixNanos::from(ts_init_values.value(i));
340
341 let tick_scheme = optional_ustr_value(tick_scheme_values, i);
342
343 let price_increment = Price::new_checked(0.01, price_prec)
347 .map_err(|e| EncodingError::ParseError("price_increment", format!("row {i}: {e}")))?;
348 let size_increment = Quantity::new_checked(1.0, size_prec)
349 .map_err(|e| EncodingError::ParseError("size_increment", format!("row {i}: {e}")))?;
350
351 let raw_symbol = id.symbol;
353
354 let betting_instrument = BettingInstrument::new_checked(
355 id,
356 raw_symbol,
357 event_type_id,
358 event_type_name,
359 competition_id,
360 competition_name,
361 event_id,
362 event_name,
363 event_country_code,
364 event_open_date,
365 betting_type,
366 market_id,
367 market_name,
368 market_type,
369 market_start_time,
370 selection_id,
371 selection_name,
372 selection_handicap,
373 currency,
374 price_prec,
375 size_prec,
376 price_increment,
377 size_increment,
378 None, None, None, None, None, None, None, None, None, None, tick_scheme,
389 info,
390 ts_event,
391 ts_init,
392 )
393 .map_err(|e| super::instrument_validation_error::<BettingInstrument>(i, e))?;
394
395 result.push(betting_instrument);
396 }
397
398 Ok(result)
399}
400
401#[cfg(test)]
402mod tests {
403 use std::{collections::HashMap, sync::Arc};
404
405 use arrow::{array::UInt8Array, record_batch::RecordBatch};
406 use nautilus_model::instruments::stubs::betting;
407 use rstest::rstest;
408
409 use super::*;
410 use crate::arrow::EncodeToRecordBatch;
411
412 const PRICE_PRECISION_COLUMN: usize = 19;
413 const SIZE_PRECISION_COLUMN: usize = 20;
414
415 fn betting_batch_with_precision(column_index: usize, precision: u8) -> RecordBatch {
416 betting_batch_with_precision_values(column_index, &[precision])
417 }
418
419 fn betting_batch_with_precision_values(column_index: usize, precisions: &[u8]) -> RecordBatch {
420 let instruments = vec![betting(); precisions.len()];
421 let batch = BettingInstrument::encode_batch(&HashMap::new(), &instruments).unwrap();
422 let mut columns = batch.columns().to_vec();
423 columns[column_index] = Arc::new(UInt8Array::from(precisions.to_vec()));
424 RecordBatch::try_new(batch.schema(), columns).unwrap()
425 }
426
427 #[rstest]
428 fn decode_betting_instrument_invalid_price_precision_returns_error() {
429 let batch = betting_batch_with_precision(PRICE_PRECISION_COLUMN, u8::MAX);
430 let error = decode_betting_instrument_batch(&HashMap::new(), &batch).unwrap_err();
431
432 match error {
433 EncodingError::ParseError(field, message) => {
434 assert_eq!(field, "price_increment");
435 assert!(message.starts_with("row 0:"));
436 assert!(message.contains("precision"));
437 }
438 _ => panic!("Expected price_increment parse error, was: {error}"),
439 }
440 }
441
442 #[rstest]
443 fn decode_betting_instrument_invalid_second_row_precision_reports_row_index() {
444 let batch = betting_batch_with_precision_values(PRICE_PRECISION_COLUMN, &[2, u8::MAX]);
445 let error = decode_betting_instrument_batch(&HashMap::new(), &batch).unwrap_err();
446
447 match error {
448 EncodingError::ParseError(field, message) => {
449 assert_eq!(field, "price_increment");
450 assert!(message.starts_with("row 1:"));
451 assert!(message.contains("precision"));
452 }
453 _ => panic!("Expected price_increment parse error, was: {error}"),
454 }
455 }
456
457 #[rstest]
458 fn decode_betting_instrument_invalid_size_precision_returns_error() {
459 let batch = betting_batch_with_precision(SIZE_PRECISION_COLUMN, u8::MAX);
460 let error = decode_betting_instrument_batch(&HashMap::new(), &batch).unwrap_err();
461
462 match error {
463 EncodingError::ParseError(field, message) => {
464 assert_eq!(field, "size_increment");
465 assert!(message.starts_with("row 0:"));
466 assert!(message.contains("precision"));
467 }
468 _ => panic!("Expected size_increment parse error, was: {error}"),
469 }
470 }
471
472 #[rstest]
473 fn decode_betting_instrument_invalid_default_price_increment_returns_error() {
474 let batch = betting_batch_with_precision(PRICE_PRECISION_COLUMN, 1);
475 let error = decode_betting_instrument_batch(&HashMap::new(), &batch).unwrap_err();
476
477 match error {
478 EncodingError::ParseError(field, message) => {
479 assert_eq!(field, super::super::INSTRUMENT_VALIDATION_FIELD);
480 assert!(message.starts_with("row 0:"));
481 assert!(message.contains("BettingInstrument"));
482 assert!(message.contains("price_increment"));
483 }
484 _ => panic!("Expected instrument parse error, was: {error}"),
485 }
486 }
487}