1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::prices::IndexPriceUpdate, identifiers::InstrumentId, types::fixed::PRECISION_BYTES,
26};
27
28use super::{
29 DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION, decode_price,
30 extract_column,
31};
32use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
33
34impl ArrowSchemaProvider for IndexPriceUpdate {
35 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
36 let fields = vec![
37 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
38 Field::new("ts_event", DataType::UInt64, false),
39 Field::new("ts_init", DataType::UInt64, false),
40 ];
41
42 match metadata {
43 Some(metadata) => Schema::new_with_metadata(fields, metadata),
44 None => Schema::new(fields),
45 }
46 }
47}
48
49fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
50 let instrument_id_str = metadata
51 .get(KEY_INSTRUMENT_ID)
52 .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
53 let instrument_id = InstrumentId::from_str(instrument_id_str)
54 .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
55
56 let price_precision = metadata
57 .get(KEY_PRICE_PRECISION)
58 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
59 .parse::<u8>()
60 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
61
62 Ok((instrument_id, price_precision))
63}
64
65impl EncodeToRecordBatch for IndexPriceUpdate {
66 fn encode_batch(
67 metadata: &HashMap<String, String>,
68 data: &[Self],
69 ) -> Result<RecordBatch, ArrowError> {
70 let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
71 let mut ts_event_builder = UInt64Array::builder(data.len());
72 let mut ts_init_builder = UInt64Array::builder(data.len());
73
74 for update in data {
75 value_builder
76 .append_value(update.value.raw.to_le_bytes())
77 .unwrap();
78 ts_event_builder.append_value(update.ts_event.as_u64());
79 ts_init_builder.append_value(update.ts_init.as_u64());
80 }
81
82 RecordBatch::try_new(
83 Self::get_schema(Some(metadata.clone())).into(),
84 vec![
85 Arc::new(value_builder.finish()),
86 Arc::new(ts_event_builder.finish()),
87 Arc::new(ts_init_builder.finish()),
88 ],
89 )
90 }
91
92 fn metadata(&self) -> HashMap<String, String> {
93 let mut metadata = HashMap::new();
94 metadata.insert(
95 KEY_INSTRUMENT_ID.to_string(),
96 self.instrument_id.to_string(),
97 );
98 metadata.insert(
99 KEY_PRICE_PRECISION.to_string(),
100 self.value.precision.to_string(),
101 );
102 metadata
103 }
104}
105
106impl DecodeFromRecordBatch for IndexPriceUpdate {
107 fn decode_batch(
108 metadata: &HashMap<String, String>,
109 record_batch: RecordBatch,
110 ) -> Result<Vec<Self>, EncodingError> {
111 let (instrument_id, price_precision) = parse_metadata(metadata)?;
112 let cols = record_batch.columns();
113
114 let value_values = extract_column::<FixedSizeBinaryArray>(
115 cols,
116 "value",
117 0,
118 DataType::FixedSizeBinary(PRECISION_BYTES),
119 )?;
120 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
121 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
122
123 if value_values.value_length() != PRECISION_BYTES {
124 return Err(EncodingError::ParseError(
125 "value",
126 format!(
127 "Invalid value length: expected {PRECISION_BYTES}, found {}",
128 value_values.value_length()
129 ),
130 ));
131 }
132
133 let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
134 .map(|row| {
135 let value = decode_price(value_values.value(row), price_precision, "value", row)?;
136 Ok(Self {
137 instrument_id,
138 value,
139 ts_event: ts_event_values.value(row).into(),
140 ts_init: ts_init_values.value(row).into(),
141 })
142 })
143 .collect();
144
145 result
146 }
147}
148
149impl DecodeDataFromRecordBatch for IndexPriceUpdate {
150 fn decode_data_batch(
151 metadata: &HashMap<String, String>,
152 record_batch: RecordBatch,
153 ) -> Result<Vec<Data>, EncodingError> {
154 let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
155 Ok(updates.into_iter().map(Data::from).collect())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::sync::Arc;
162
163 use arrow::{array::Array, record_batch::RecordBatch};
164 use nautilus_model::types::{Price, fixed::FIXED_SCALAR, price::PriceRaw};
165 use rstest::rstest;
166 use rust_decimal_macros::dec;
167
168 use super::*;
169 use crate::arrow::get_raw_price;
170
171 #[rstest]
172 fn test_get_schema() {
173 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
174 let metadata = HashMap::from([
175 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
176 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
177 ]);
178 let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
179
180 let expected_fields = vec![
181 Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
182 Field::new("ts_event", DataType::UInt64, false),
183 Field::new("ts_init", DataType::UInt64, false),
184 ];
185
186 let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
187 assert_eq!(schema, expected_schema);
188 }
189
190 #[rstest]
191 fn test_get_schema_map() {
192 let schema_map = IndexPriceUpdate::get_schema_map();
193 let mut expected_map = HashMap::new();
194
195 let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
196 expected_map.insert("value".to_string(), fixed_size_binary);
197 expected_map.insert("ts_event".to_string(), "UInt64".to_string());
198 expected_map.insert("ts_init".to_string(), "UInt64".to_string());
199 assert_eq!(schema_map, expected_map);
200 }
201
202 #[rstest]
203 fn test_encode_batch() {
204 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
205 let metadata = HashMap::from([
206 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
207 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
208 ]);
209
210 let update1 = IndexPriceUpdate {
211 instrument_id,
212 value: Price::from("50000.00"),
213 ts_event: 1.into(),
214 ts_init: 3.into(),
215 };
216
217 let update2 = IndexPriceUpdate {
218 instrument_id,
219 value: Price::from("51000.00"),
220 ts_event: 2.into(),
221 ts_init: 4.into(),
222 };
223
224 let data = vec![update1, update2];
225 let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
226
227 let columns = record_batch.columns();
228 let value_values = columns[0]
229 .as_any()
230 .downcast_ref::<FixedSizeBinaryArray>()
231 .unwrap();
232 let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
233 let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
234
235 assert_eq!(columns.len(), 3);
236 assert_eq!(value_values.len(), 2);
237 assert_eq!(
238 get_raw_price(value_values.value(0)),
239 Price::from(dec!(50000.00).to_string()).raw
240 );
241 assert_eq!(
242 get_raw_price(value_values.value(1)),
243 Price::from(dec!(51000.00).to_string()).raw
244 );
245 assert_eq!(ts_event_values.len(), 2);
246 assert_eq!(ts_event_values.value(0), 1);
247 assert_eq!(ts_event_values.value(1), 2);
248 assert_eq!(ts_init_values.len(), 2);
249 assert_eq!(ts_init_values.value(0), 3);
250 assert_eq!(ts_init_values.value(1), 4);
251 }
252
253 #[rstest]
254 fn test_decode_batch() {
255 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
256 let metadata = HashMap::from([
257 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
258 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
259 ]);
260
261 let raw_price1 = (50.00 * FIXED_SCALAR) as PriceRaw;
262 let raw_price2 = (51.00 * FIXED_SCALAR) as PriceRaw;
263 let value =
264 FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
265 let ts_event = UInt64Array::from(vec![1, 2]);
266 let ts_init = UInt64Array::from(vec![3, 4]);
267
268 let record_batch = RecordBatch::try_new(
269 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
270 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
271 )
272 .unwrap();
273
274 let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
275
276 assert_eq!(decoded_data.len(), 2);
277 assert_eq!(decoded_data[0].instrument_id, instrument_id);
278 assert_eq!(decoded_data[0].value, Price::from_raw(raw_price1, 2));
279 assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
280 assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
281
282 assert_eq!(decoded_data[1].instrument_id, instrument_id);
283 assert_eq!(decoded_data[1].value, Price::from_raw(raw_price2, 2));
284 assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
285 assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
286 }
287
288 #[rstest]
289 fn test_decode_batch_invalid_value_returns_error() {
290 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
291 let metadata = HashMap::from([
292 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
293 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
294 ]);
295
296 let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
297 let value = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
298 let ts_event = UInt64Array::from(vec![1]);
299 let ts_init = UInt64Array::from(vec![2]);
300
301 let record_batch = RecordBatch::try_new(
302 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
303 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
304 )
305 .unwrap();
306
307 let result = IndexPriceUpdate::decode_batch(&metadata, record_batch);
308 assert!(result.is_err());
309 let err = result.unwrap_err();
310 assert!(
311 err.to_string().contains("value") && err.to_string().contains("row 0"),
312 "Expected value error at row 0, got: {err}"
313 );
314 }
315
316 #[rstest]
317 fn test_decode_batch_missing_instrument_id_returns_error() {
318 let mut metadata = HashMap::from([
319 (
320 KEY_INSTRUMENT_ID.to_string(),
321 "BTC-USDT.BINANCE".to_string(),
322 ),
323 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
324 ]);
325
326 let raw_price = (50.00 * FIXED_SCALAR) as PriceRaw;
327 let value = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
328 let ts_event = UInt64Array::from(vec![1]);
329 let ts_init = UInt64Array::from(vec![2]);
330
331 let record_batch = RecordBatch::try_new(
332 IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
333 vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
334 )
335 .unwrap();
336
337 metadata.remove(KEY_INSTRUMENT_ID);
338
339 let result = IndexPriceUpdate::decode_batch(&metadata, record_batch);
340 assert!(result.is_err());
341 let err = result.unwrap_err();
342 assert!(
343 err.to_string().contains("instrument_id"),
344 "Expected missing instrument_id error, got: {err}"
345 );
346 }
347
348 #[rstest]
349 fn test_encode_decode_round_trip() {
350 let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
351 let metadata = HashMap::from([
352 (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
353 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
354 ]);
355
356 let update1 = IndexPriceUpdate {
357 instrument_id,
358 value: Price::from("50000.00"),
359 ts_event: 1_000_000_000.into(),
360 ts_init: 1_000_000_001.into(),
361 };
362
363 let update2 = IndexPriceUpdate {
364 instrument_id,
365 value: Price::from("51000.00"),
366 ts_event: 2_000_000_000.into(),
367 ts_init: 2_000_000_001.into(),
368 };
369
370 let original = vec![update1, update2];
371 let record_batch = IndexPriceUpdate::encode_batch(&metadata, &original).unwrap();
372 let decoded = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
373
374 assert_eq!(decoded.len(), original.len());
375 for (orig, dec) in original.iter().zip(decoded.iter()) {
376 assert_eq!(dec.instrument_id, orig.instrument_id);
377 assert_eq!(dec.value, orig.value);
378 assert_eq!(dec.ts_event, orig.ts_event);
379 assert_eq!(dec.ts_init, orig.ts_init);
380 }
381 }
382}