mbe/
decode_iterator.rs

1use crate::decode::{AsyncRecordDecoder, RecordDecoder};
2use crate::record_enum::RecordEnum;
3use futures::stream::Stream;
4use std::future::Future;
5use std::io::Read;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::AsyncBufRead;
9
10pub struct DecoderIterator<'a, R> {
11    decoder: RecordDecoder<&'a mut R>,
12}
13
14impl<'a, R: Read> DecoderIterator<'a, R> {
15    pub fn new(reader: &'a mut R) -> Self {
16        Self {
17            decoder: RecordDecoder::new(reader),
18        }
19    }
20}
21
22impl<'a, R: Read> Iterator for DecoderIterator<'a, R> {
23    type Item = std::io::Result<RecordEnum>;
24
25    fn next(&mut self) -> Option<Self::Item> {
26        match self.decoder.decode_ref() {
27            Ok(Some(record_ref)) => match RecordEnum::from_ref(record_ref) {
28                Ok(record) => Some(Ok(record)),
29                Err(_) => Some(Err(std::io::Error::new(
30                    std::io::ErrorKind::InvalidData,
31                    "Failed to convert record reference to RecordEnum",
32                ))),
33            },
34            Ok(None) => None,
35            Err(e) => Some(Err(e)),
36        }
37    }
38}
39
40pub struct AsyncDecoderIterator<'a, R> {
41    decoder: AsyncRecordDecoder<&'a mut R>,
42}
43
44impl<'a, R: AsyncBufRead + Unpin> AsyncDecoderIterator<'a, R> {
45    pub fn new(reader: &'a mut R) -> Self {
46        Self {
47            decoder: AsyncRecordDecoder::new(reader),
48        }
49    }
50}
51
52impl<'a, R: AsyncBufRead + Unpin> Stream for AsyncDecoderIterator<'a, R> {
53    type Item = std::io::Result<RecordEnum>;
54
55    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56        // Poll for the next record asynchronously
57        let fut = self.decoder.decode_ref();
58        let mut fut = Box::pin(fut); // Pin the future
59
60        match Future::poll(fut.as_mut(), cx) {
61            Poll::Ready(Ok(Some(record_ref))) => {
62                // If the record_ref is decoded successfully, convert it to RecordEnum
63                match RecordEnum::from_ref(record_ref) {
64                    Ok(record) => Poll::Ready(Some(Ok(record))),
65                    Err(_) => Poll::Ready(Some(Err(std::io::Error::new(
66                        std::io::ErrorKind::InvalidData,
67                        "Failed to convert record reference to RecordEnum",
68                    )))),
69                }
70            }
71            Poll::Ready(Ok(None)) => Poll::Ready(None),
72            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
73            Poll::Pending => Poll::Pending,
74        }
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::decode::*;
82    use crate::encode::{CombinedEncoder, RecordEncoder};
83    use crate::enums::Dataset;
84    use crate::enums::Schema;
85    use crate::metadata::Metadata;
86    use crate::record_enum::RecordEnum;
87    use crate::record_ref::*;
88    use crate::records::BidAskPair;
89    use crate::records::Mbp1Msg;
90    use crate::records::OhlcvMsg;
91    use crate::records::RecordHeader;
92    use crate::symbols::SymbolMap;
93    use futures::stream::StreamExt;
94    use serial_test::serial;
95    use std::io::Cursor;
96    use std::path::PathBuf;
97
98    async fn create_test_file() -> anyhow::Result<PathBuf> {
99        // Metadata
100        let mut symbol_map = SymbolMap::new();
101        symbol_map.add_instrument("AAPL", 1);
102        symbol_map.add_instrument("TSLA", 2);
103
104        let metadata = Metadata::new(
105            Schema::Mbp1,
106            Dataset::Option,
107            1234567898765,
108            123456765432,
109            symbol_map,
110        );
111
112        // Record
113        let msg1 = Mbp1Msg {
114            hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
115            price: 12345676543,
116            size: 1234543,
117            action: 0,
118            side: 0,
119            depth: 0,
120            flags: 0,
121            ts_recv: 1231,
122            ts_in_delta: 123432,
123            sequence: 23432,
124            discriminator: 0,
125            levels: [BidAskPair {
126                bid_px: 10000000,
127                ask_px: 200000,
128                bid_sz: 3000000,
129                ask_sz: 400000000,
130                bid_ct: 50000000,
131                ask_ct: 60000000,
132            }],
133        };
134        let msg2 = Mbp1Msg {
135            hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
136            price: 12345676543,
137            size: 1234543,
138            action: 0,
139            side: 0,
140            depth: 0,
141            flags: 0,
142            ts_recv: 1231,
143            ts_in_delta: 123432,
144            sequence: 23432,
145            discriminator: 0,
146            levels: [BidAskPair {
147                bid_px: 10000000,
148                ask_px: 200000,
149                bid_sz: 3000000,
150                ask_sz: 400000000,
151                bid_ct: 50000000,
152                ask_ct: 60000000,
153            }],
154        };
155
156        let record_ref1: RecordRef = (&msg1).into();
157        let record_ref2: RecordRef = (&msg2).into();
158        let records = &[record_ref1, record_ref2];
159
160        let mut buffer = Vec::new();
161        let mut encoder = CombinedEncoder::new(&mut buffer);
162        encoder
163            .encode(&metadata, records)
164            .expect("Error on encoding");
165
166        // Test
167        let file = PathBuf::from("tests/test_decode_iter.bin");
168        let _ = encoder.write_to_file(&file, false);
169
170        Ok(file)
171    }
172
173    async fn delete_test_file(file: PathBuf) -> anyhow::Result<()> {
174        // Cleanup
175        if file.exists() {
176            std::fs::remove_file(&file).expect("Failed to delete the test file.");
177        }
178        Ok(())
179    }
180
181    // -- Sync --
182    #[tokio::test]
183    #[serial]
184    // #[ignore]
185    async fn test_record_decoder_iter() -> anyhow::Result<()> {
186        let file_path = create_test_file().await?;
187        // let file_path = PathBuf::from("tests/test.bin");
188
189        // Test
190        let mut decoder =
191            Decoder::<std::io::BufReader<std::fs::File>>::from_file(file_path.clone())?;
192        let mut decode_iter = decoder.decode_iterator();
193
194        let mut all_records = Vec::new();
195        while let Some(record_result) = decode_iter.next() {
196            match record_result {
197                Ok(record) => match record {
198                    RecordEnum::Mbp1(msg) => {
199                        all_records.push(msg);
200                    }
201                    _ => unimplemented!(),
202                },
203                Err(e) => {
204                    println!("{:?}", e);
205                }
206            }
207        }
208
209        // println!("{:?}", all_records);
210
211        // Validate
212        assert!(all_records.len() > 0);
213
214        // Cleanup
215        delete_test_file(file_path).await?;
216
217        Ok(())
218    }
219
220    #[test]
221    #[serial]
222    // #[ignore]
223    fn test_iter_decode() {
224        // Setup
225        let ohlcv_msg1 = OhlcvMsg {
226            hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
227            open: 100,
228            high: 200,
229            low: 50,
230            close: 150,
231            volume: 1000,
232        };
233
234        let ohlcv_msg2 = OhlcvMsg {
235            hd: RecordHeader::new::<OhlcvMsg>(2, 1622471125, 0),
236            open: 110,
237            high: 210,
238            low: 55,
239            close: 155,
240            volume: 1100,
241        };
242
243        // Encode
244        let mut buffer = Vec::new();
245        {
246            let mut encoder = RecordEncoder::new(&mut buffer);
247            let record_ref1: RecordRef = (&ohlcv_msg1).into();
248            let record_ref2: RecordRef = (&ohlcv_msg2).into();
249            encoder
250                .encode_records(&[record_ref1, record_ref2])
251                .expect("Encoding failed");
252        }
253
254        // Decode
255        let cursor = Cursor::new(buffer);
256        let mut decoder = RecordDecoder::new(cursor);
257        let iter = decoder.decode_iterator();
258
259        // Test
260        let mut i = 0;
261        for record in iter {
262            match record {
263                Ok(record) => {
264                    // Process the record
265                    if i == 0 {
266                        assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg1.clone()));
267                    } else {
268                        assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg2.clone()));
269                    }
270                    i = i + 1;
271                }
272                Err(e) => {
273                    eprintln!("Error processing record: {:?}", e);
274                }
275            }
276        }
277    }
278
279    // -- Async --
280    #[tokio::test]
281    #[serial]
282    // #[ignore]
283    async fn test_record_decoder_iter_async() -> anyhow::Result<()> {
284        let file_path = create_test_file().await?;
285        // let file_path = PathBuf::from("tests/test.bin");
286
287        // Test
288        let mut decoder =
289            <AsyncDecoder<tokio::io::BufReader<tokio::fs::File>>>::from_file(file_path.clone())
290                .await?;
291        let mut decode_iter = decoder.decode_iterator();
292
293        let mut all_records = Vec::new();
294        while let Some(record_result) = decode_iter.next().await {
295            match record_result {
296                Ok(record) => match record {
297                    RecordEnum::Mbp1(msg) => {
298                        all_records.push(msg);
299                    }
300                    _ => unimplemented!(),
301                },
302                Err(e) => {
303                    println!("{:?}", e);
304                }
305            }
306        }
307        // println!("{:?}", all_records);
308
309        // Validate
310        assert!(all_records.len() > 0);
311
312        // Cleanup
313        delete_test_file(file_path).await?;
314
315        Ok(())
316    }
317
318    #[tokio::test]
319    #[serial]
320    // #[ignore]
321    async fn test_iter_decode_async() -> anyhow::Result<()> {
322        // Setup
323        let ohlcv_msg1 = OhlcvMsg {
324            hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
325            open: 100,
326            high: 200,
327            low: 50,
328            close: 150,
329            volume: 1000,
330        };
331
332        let ohlcv_msg2 = OhlcvMsg {
333            hd: RecordHeader::new::<OhlcvMsg>(2, 1622471125, 0),
334            open: 110,
335            high: 210,
336            low: 55,
337            close: 155,
338            volume: 1100,
339        };
340
341        // Encode
342        let mut buffer = Vec::new();
343        {
344            let mut encoder = RecordEncoder::new(&mut buffer);
345            let record_ref1: RecordRef = (&ohlcv_msg1).into();
346            let record_ref2: RecordRef = (&ohlcv_msg2).into();
347            encoder
348                .encode_records(&[record_ref1, record_ref2])
349                .expect("Encoding failed");
350        }
351
352        // Decode
353        let cursor = Cursor::new(buffer);
354        let mut decoder = AsyncRecordDecoder::new(cursor);
355        let mut iter = decoder.decode_iterator();
356
357        // Test
358        let mut i = 0;
359        while let Some(record) = iter.next().await {
360            match record {
361                Ok(record) => {
362                    if i == 0 {
363                        assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg1.clone()));
364                    } else {
365                        assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg2.clone()));
366                    }
367                    i = i + 1;
368                }
369                Err(e) => {
370                    eprintln!("Error processing record: {:?}", e);
371                }
372            }
373        }
374        Ok(())
375    }
376}