nautilus_serialization/arrow/
index_price.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::prices::IndexPriceUpdate, identifiers::InstrumentId, types::fixed::PRECISION_BYTES,
26};
27
28use super::{
29    DecodeDataFromRecordBatch, EncodingError, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION, decode_price,
30    extract_column,
31};
32use crate::arrow::{ArrowSchemaProvider, Data, DecodeFromRecordBatch, EncodeToRecordBatch};
33
34impl ArrowSchemaProvider for IndexPriceUpdate {
35    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
36        let fields = vec![
37            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
38            Field::new("ts_event", DataType::UInt64, false),
39            Field::new("ts_init", DataType::UInt64, false),
40        ];
41
42        match metadata {
43            Some(metadata) => Schema::new_with_metadata(fields, metadata),
44            None => Schema::new(fields),
45        }
46    }
47}
48
49fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(InstrumentId, u8), EncodingError> {
50    let instrument_id_str = metadata
51        .get(KEY_INSTRUMENT_ID)
52        .ok_or_else(|| EncodingError::MissingMetadata(KEY_INSTRUMENT_ID))?;
53    let instrument_id = InstrumentId::from_str(instrument_id_str)
54        .map_err(|e| EncodingError::ParseError(KEY_INSTRUMENT_ID, e.to_string()))?;
55
56    let price_precision = metadata
57        .get(KEY_PRICE_PRECISION)
58        .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
59        .parse::<u8>()
60        .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
61
62    Ok((instrument_id, price_precision))
63}
64
65impl EncodeToRecordBatch for IndexPriceUpdate {
66    fn encode_batch(
67        metadata: &HashMap<String, String>,
68        data: &[Self],
69    ) -> Result<RecordBatch, ArrowError> {
70        let mut value_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
71        let mut ts_event_builder = UInt64Array::builder(data.len());
72        let mut ts_init_builder = UInt64Array::builder(data.len());
73
74        for update in data {
75            value_builder
76                .append_value(update.value.raw.to_le_bytes())
77                .unwrap();
78            ts_event_builder.append_value(update.ts_event.as_u64());
79            ts_init_builder.append_value(update.ts_init.as_u64());
80        }
81
82        RecordBatch::try_new(
83            Self::get_schema(Some(metadata.clone())).into(),
84            vec![
85                Arc::new(value_builder.finish()),
86                Arc::new(ts_event_builder.finish()),
87                Arc::new(ts_init_builder.finish()),
88            ],
89        )
90    }
91
92    fn metadata(&self) -> HashMap<String, String> {
93        let mut metadata = HashMap::new();
94        metadata.insert(
95            KEY_INSTRUMENT_ID.to_string(),
96            self.instrument_id.to_string(),
97        );
98        metadata.insert(
99            KEY_PRICE_PRECISION.to_string(),
100            self.value.precision.to_string(),
101        );
102        metadata
103    }
104}
105
106impl DecodeFromRecordBatch for IndexPriceUpdate {
107    fn decode_batch(
108        metadata: &HashMap<String, String>,
109        record_batch: RecordBatch,
110    ) -> Result<Vec<Self>, EncodingError> {
111        let (instrument_id, price_precision) = parse_metadata(metadata)?;
112        let cols = record_batch.columns();
113
114        let value_values = extract_column::<FixedSizeBinaryArray>(
115            cols,
116            "value",
117            0,
118            DataType::FixedSizeBinary(PRECISION_BYTES),
119        )?;
120        let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 1, DataType::UInt64)?;
121        let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 2, DataType::UInt64)?;
122
123        if value_values.value_length() != PRECISION_BYTES {
124            return Err(EncodingError::ParseError(
125                "value",
126                format!(
127                    "Invalid value length: expected {PRECISION_BYTES}, found {}",
128                    value_values.value_length()
129                ),
130            ));
131        }
132
133        let result: Result<Vec<Self>, EncodingError> = (0..record_batch.num_rows())
134            .map(|row| {
135                let value = decode_price(value_values.value(row), price_precision, "value", row)?;
136                Ok(Self {
137                    instrument_id,
138                    value,
139                    ts_event: ts_event_values.value(row).into(),
140                    ts_init: ts_init_values.value(row).into(),
141                })
142            })
143            .collect();
144
145        result
146    }
147}
148
149impl DecodeDataFromRecordBatch for IndexPriceUpdate {
150    fn decode_data_batch(
151        metadata: &HashMap<String, String>,
152        record_batch: RecordBatch,
153    ) -> Result<Vec<Data>, EncodingError> {
154        let updates: Vec<Self> = Self::decode_batch(metadata, record_batch)?;
155        Ok(updates.into_iter().map(Data::from).collect())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::sync::Arc;
162
163    use arrow::{array::Array, record_batch::RecordBatch};
164    use nautilus_model::types::{Price, fixed::FIXED_SCALAR, price::PriceRaw};
165    use rstest::rstest;
166    use rust_decimal_macros::dec;
167
168    use super::*;
169    use crate::arrow::get_raw_price;
170
171    #[rstest]
172    fn test_get_schema() {
173        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
174        let metadata = HashMap::from([
175            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
176            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
177        ]);
178        let schema = IndexPriceUpdate::get_schema(Some(metadata.clone()));
179
180        let expected_fields = vec![
181            Field::new("value", DataType::FixedSizeBinary(PRECISION_BYTES), false),
182            Field::new("ts_event", DataType::UInt64, false),
183            Field::new("ts_init", DataType::UInt64, false),
184        ];
185
186        let expected_schema = Schema::new_with_metadata(expected_fields, metadata);
187        assert_eq!(schema, expected_schema);
188    }
189
190    #[rstest]
191    fn test_get_schema_map() {
192        let schema_map = IndexPriceUpdate::get_schema_map();
193        let mut expected_map = HashMap::new();
194
195        let fixed_size_binary = format!("FixedSizeBinary({PRECISION_BYTES})");
196        expected_map.insert("value".to_string(), fixed_size_binary);
197        expected_map.insert("ts_event".to_string(), "UInt64".to_string());
198        expected_map.insert("ts_init".to_string(), "UInt64".to_string());
199        assert_eq!(schema_map, expected_map);
200    }
201
202    #[rstest]
203    fn test_encode_batch() {
204        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
205        let metadata = HashMap::from([
206            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
207            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
208        ]);
209
210        let update1 = IndexPriceUpdate {
211            instrument_id,
212            value: Price::from("50000.00"),
213            ts_event: 1.into(),
214            ts_init: 3.into(),
215        };
216
217        let update2 = IndexPriceUpdate {
218            instrument_id,
219            value: Price::from("51000.00"),
220            ts_event: 2.into(),
221            ts_init: 4.into(),
222        };
223
224        let data = vec![update1, update2];
225        let record_batch = IndexPriceUpdate::encode_batch(&metadata, &data).unwrap();
226
227        let columns = record_batch.columns();
228        let value_values = columns[0]
229            .as_any()
230            .downcast_ref::<FixedSizeBinaryArray>()
231            .unwrap();
232        let ts_event_values = columns[1].as_any().downcast_ref::<UInt64Array>().unwrap();
233        let ts_init_values = columns[2].as_any().downcast_ref::<UInt64Array>().unwrap();
234
235        assert_eq!(columns.len(), 3);
236        assert_eq!(value_values.len(), 2);
237        assert_eq!(
238            get_raw_price(value_values.value(0)),
239            Price::from(dec!(50000.00).to_string()).raw
240        );
241        assert_eq!(
242            get_raw_price(value_values.value(1)),
243            Price::from(dec!(51000.00).to_string()).raw
244        );
245        assert_eq!(ts_event_values.len(), 2);
246        assert_eq!(ts_event_values.value(0), 1);
247        assert_eq!(ts_event_values.value(1), 2);
248        assert_eq!(ts_init_values.len(), 2);
249        assert_eq!(ts_init_values.value(0), 3);
250        assert_eq!(ts_init_values.value(1), 4);
251    }
252
253    #[rstest]
254    fn test_decode_batch() {
255        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
256        let metadata = HashMap::from([
257            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
258            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
259        ]);
260
261        let raw_price1 = (50.00 * FIXED_SCALAR) as PriceRaw;
262        let raw_price2 = (51.00 * FIXED_SCALAR) as PriceRaw;
263        let value =
264            FixedSizeBinaryArray::from(vec![&raw_price1.to_le_bytes(), &raw_price2.to_le_bytes()]);
265        let ts_event = UInt64Array::from(vec![1, 2]);
266        let ts_init = UInt64Array::from(vec![3, 4]);
267
268        let record_batch = RecordBatch::try_new(
269            IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
270            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
271        )
272        .unwrap();
273
274        let decoded_data = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
275
276        assert_eq!(decoded_data.len(), 2);
277        assert_eq!(decoded_data[0].instrument_id, instrument_id);
278        assert_eq!(decoded_data[0].value, Price::from_raw(raw_price1, 2));
279        assert_eq!(decoded_data[0].ts_event.as_u64(), 1);
280        assert_eq!(decoded_data[0].ts_init.as_u64(), 3);
281
282        assert_eq!(decoded_data[1].instrument_id, instrument_id);
283        assert_eq!(decoded_data[1].value, Price::from_raw(raw_price2, 2));
284        assert_eq!(decoded_data[1].ts_event.as_u64(), 2);
285        assert_eq!(decoded_data[1].ts_init.as_u64(), 4);
286    }
287
288    #[rstest]
289    fn test_decode_batch_invalid_value_returns_error() {
290        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
291        let metadata = HashMap::from([
292            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
293            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
294        ]);
295
296        let invalid_price: PriceRaw = PriceRaw::MAX - 1000;
297        let value = FixedSizeBinaryArray::from(vec![&invalid_price.to_le_bytes()]);
298        let ts_event = UInt64Array::from(vec![1]);
299        let ts_init = UInt64Array::from(vec![2]);
300
301        let record_batch = RecordBatch::try_new(
302            IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
303            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
304        )
305        .unwrap();
306
307        let result = IndexPriceUpdate::decode_batch(&metadata, record_batch);
308        assert!(result.is_err());
309        let err = result.unwrap_err();
310        assert!(
311            err.to_string().contains("value") && err.to_string().contains("row 0"),
312            "Expected value error at row 0, got: {err}"
313        );
314    }
315
316    #[rstest]
317    fn test_decode_batch_missing_instrument_id_returns_error() {
318        let mut metadata = HashMap::from([
319            (
320                KEY_INSTRUMENT_ID.to_string(),
321                "BTC-USDT.BINANCE".to_string(),
322            ),
323            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
324        ]);
325
326        let raw_price = (50.00 * FIXED_SCALAR) as PriceRaw;
327        let value = FixedSizeBinaryArray::from(vec![&raw_price.to_le_bytes()]);
328        let ts_event = UInt64Array::from(vec![1]);
329        let ts_init = UInt64Array::from(vec![2]);
330
331        let record_batch = RecordBatch::try_new(
332            IndexPriceUpdate::get_schema(Some(metadata.clone())).into(),
333            vec![Arc::new(value), Arc::new(ts_event), Arc::new(ts_init)],
334        )
335        .unwrap();
336
337        metadata.remove(KEY_INSTRUMENT_ID);
338
339        let result = IndexPriceUpdate::decode_batch(&metadata, record_batch);
340        assert!(result.is_err());
341        let err = result.unwrap_err();
342        assert!(
343            err.to_string().contains("instrument_id"),
344            "Expected missing instrument_id error, got: {err}"
345        );
346    }
347
348    #[rstest]
349    fn test_encode_decode_round_trip() {
350        let instrument_id = InstrumentId::from("BTC-USDT.BINANCE");
351        let metadata = HashMap::from([
352            (KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string()),
353            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
354        ]);
355
356        let update1 = IndexPriceUpdate {
357            instrument_id,
358            value: Price::from("50000.00"),
359            ts_event: 1_000_000_000.into(),
360            ts_init: 1_000_000_001.into(),
361        };
362
363        let update2 = IndexPriceUpdate {
364            instrument_id,
365            value: Price::from("51000.00"),
366            ts_event: 2_000_000_000.into(),
367            ts_init: 2_000_000_001.into(),
368        };
369
370        let original = vec![update1, update2];
371        let record_batch = IndexPriceUpdate::encode_batch(&metadata, &original).unwrap();
372        let decoded = IndexPriceUpdate::decode_batch(&metadata, record_batch).unwrap();
373
374        assert_eq!(decoded.len(), original.len());
375        for (orig, dec) in original.iter().zip(decoded.iter()) {
376            assert_eq!(dec.instrument_id, orig.instrument_id);
377            assert_eq!(dec.value, orig.value);
378            assert_eq!(dec.ts_event, orig.ts_event);
379            assert_eq!(dec.ts_init, orig.ts_init);
380        }
381    }
382}