1use crate::decode::{AsyncRecordDecoder, RecordDecoder};
2use crate::record_enum::RecordEnum;
3use futures::stream::Stream;
4use std::future::Future;
5use std::io::Read;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::AsyncBufRead;
9
10pub struct DecoderIterator<'a, R> {
11 decoder: RecordDecoder<&'a mut R>,
12}
13
14impl<'a, R: Read> DecoderIterator<'a, R> {
15 pub fn new(reader: &'a mut R) -> Self {
16 Self {
17 decoder: RecordDecoder::new(reader),
18 }
19 }
20}
21
22impl<'a, R: Read> Iterator for DecoderIterator<'a, R> {
23 type Item = std::io::Result<RecordEnum>;
24
25 fn next(&mut self) -> Option<Self::Item> {
26 match self.decoder.decode_ref() {
27 Ok(Some(record_ref)) => match RecordEnum::from_ref(record_ref) {
28 Ok(record) => Some(Ok(record)),
29 Err(_) => Some(Err(std::io::Error::new(
30 std::io::ErrorKind::InvalidData,
31 "Failed to convert record reference to RecordEnum",
32 ))),
33 },
34 Ok(None) => None,
35 Err(e) => Some(Err(e)),
36 }
37 }
38}
39
40pub struct AsyncDecoderIterator<'a, R> {
41 decoder: AsyncRecordDecoder<&'a mut R>,
42}
43
44impl<'a, R: AsyncBufRead + Unpin> AsyncDecoderIterator<'a, R> {
45 pub fn new(reader: &'a mut R) -> Self {
46 Self {
47 decoder: AsyncRecordDecoder::new(reader),
48 }
49 }
50}
51
52impl<'a, R: AsyncBufRead + Unpin> Stream for AsyncDecoderIterator<'a, R> {
53 type Item = std::io::Result<RecordEnum>;
54
55 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56 let fut = self.decoder.decode_ref();
58 let mut fut = Box::pin(fut); match Future::poll(fut.as_mut(), cx) {
61 Poll::Ready(Ok(Some(record_ref))) => {
62 match RecordEnum::from_ref(record_ref) {
64 Ok(record) => Poll::Ready(Some(Ok(record))),
65 Err(_) => Poll::Ready(Some(Err(std::io::Error::new(
66 std::io::ErrorKind::InvalidData,
67 "Failed to convert record reference to RecordEnum",
68 )))),
69 }
70 }
71 Poll::Ready(Ok(None)) => Poll::Ready(None),
72 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
73 Poll::Pending => Poll::Pending,
74 }
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::decode::*;
82 use crate::encode::{CombinedEncoder, RecordEncoder};
83 use crate::enums::Dataset;
84 use crate::enums::Schema;
85 use crate::metadata::Metadata;
86 use crate::record_enum::RecordEnum;
87 use crate::record_ref::*;
88 use crate::records::BidAskPair;
89 use crate::records::Mbp1Msg;
90 use crate::records::OhlcvMsg;
91 use crate::records::RecordHeader;
92 use crate::symbols::SymbolMap;
93 use futures::stream::StreamExt;
94 use serial_test::serial;
95 use std::io::Cursor;
96 use std::path::PathBuf;
97
98 async fn create_test_file() -> anyhow::Result<PathBuf> {
99 let mut symbol_map = SymbolMap::new();
101 symbol_map.add_instrument("AAPL", 1);
102 symbol_map.add_instrument("TSLA", 2);
103
104 let metadata = Metadata::new(
105 Schema::Mbp1,
106 Dataset::Option,
107 1234567898765,
108 123456765432,
109 symbol_map,
110 );
111
112 let msg1 = Mbp1Msg {
114 hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
115 price: 12345676543,
116 size: 1234543,
117 action: 0,
118 side: 0,
119 depth: 0,
120 flags: 0,
121 ts_recv: 1231,
122 ts_in_delta: 123432,
123 sequence: 23432,
124 discriminator: 0,
125 levels: [BidAskPair {
126 bid_px: 10000000,
127 ask_px: 200000,
128 bid_sz: 3000000,
129 ask_sz: 400000000,
130 bid_ct: 50000000,
131 ask_ct: 60000000,
132 }],
133 };
134 let msg2 = Mbp1Msg {
135 hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
136 price: 12345676543,
137 size: 1234543,
138 action: 0,
139 side: 0,
140 depth: 0,
141 flags: 0,
142 ts_recv: 1231,
143 ts_in_delta: 123432,
144 sequence: 23432,
145 discriminator: 0,
146 levels: [BidAskPair {
147 bid_px: 10000000,
148 ask_px: 200000,
149 bid_sz: 3000000,
150 ask_sz: 400000000,
151 bid_ct: 50000000,
152 ask_ct: 60000000,
153 }],
154 };
155
156 let record_ref1: RecordRef = (&msg1).into();
157 let record_ref2: RecordRef = (&msg2).into();
158 let records = &[record_ref1, record_ref2];
159
160 let mut buffer = Vec::new();
161 let mut encoder = CombinedEncoder::new(&mut buffer);
162 encoder
163 .encode(&metadata, records)
164 .expect("Error on encoding");
165
166 let file = PathBuf::from("tests/test_decode_iter.bin");
168 let _ = encoder.write_to_file(&file, false);
169
170 Ok(file)
171 }
172
173 async fn delete_test_file(file: PathBuf) -> anyhow::Result<()> {
174 if file.exists() {
176 std::fs::remove_file(&file).expect("Failed to delete the test file.");
177 }
178 Ok(())
179 }
180
181 #[tokio::test]
183 #[serial]
184 async fn test_record_decoder_iter() -> anyhow::Result<()> {
186 let file_path = create_test_file().await?;
187 let mut decoder =
191 Decoder::<std::io::BufReader<std::fs::File>>::from_file(file_path.clone())?;
192 let mut decode_iter = decoder.decode_iterator();
193
194 let mut all_records = Vec::new();
195 while let Some(record_result) = decode_iter.next() {
196 match record_result {
197 Ok(record) => match record {
198 RecordEnum::Mbp1(msg) => {
199 all_records.push(msg);
200 }
201 _ => unimplemented!(),
202 },
203 Err(e) => {
204 println!("{:?}", e);
205 }
206 }
207 }
208
209 assert!(all_records.len() > 0);
213
214 delete_test_file(file_path).await?;
216
217 Ok(())
218 }
219
220 #[test]
221 #[serial]
222 fn test_iter_decode() {
224 let ohlcv_msg1 = OhlcvMsg {
226 hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
227 open: 100,
228 high: 200,
229 low: 50,
230 close: 150,
231 volume: 1000,
232 };
233
234 let ohlcv_msg2 = OhlcvMsg {
235 hd: RecordHeader::new::<OhlcvMsg>(2, 1622471125, 0),
236 open: 110,
237 high: 210,
238 low: 55,
239 close: 155,
240 volume: 1100,
241 };
242
243 let mut buffer = Vec::new();
245 {
246 let mut encoder = RecordEncoder::new(&mut buffer);
247 let record_ref1: RecordRef = (&ohlcv_msg1).into();
248 let record_ref2: RecordRef = (&ohlcv_msg2).into();
249 encoder
250 .encode_records(&[record_ref1, record_ref2])
251 .expect("Encoding failed");
252 }
253
254 let cursor = Cursor::new(buffer);
256 let mut decoder = RecordDecoder::new(cursor);
257 let iter = decoder.decode_iterator();
258
259 let mut i = 0;
261 for record in iter {
262 match record {
263 Ok(record) => {
264 if i == 0 {
266 assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg1.clone()));
267 } else {
268 assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg2.clone()));
269 }
270 i = i + 1;
271 }
272 Err(e) => {
273 eprintln!("Error processing record: {:?}", e);
274 }
275 }
276 }
277 }
278
279 #[tokio::test]
281 #[serial]
282 async fn test_record_decoder_iter_async() -> anyhow::Result<()> {
284 let file_path = create_test_file().await?;
285 let mut decoder =
289 <AsyncDecoder<tokio::io::BufReader<tokio::fs::File>>>::from_file(file_path.clone())
290 .await?;
291 let mut decode_iter = decoder.decode_iterator();
292
293 let mut all_records = Vec::new();
294 while let Some(record_result) = decode_iter.next().await {
295 match record_result {
296 Ok(record) => match record {
297 RecordEnum::Mbp1(msg) => {
298 all_records.push(msg);
299 }
300 _ => unimplemented!(),
301 },
302 Err(e) => {
303 println!("{:?}", e);
304 }
305 }
306 }
307 assert!(all_records.len() > 0);
311
312 delete_test_file(file_path).await?;
314
315 Ok(())
316 }
317
318 #[tokio::test]
319 #[serial]
320 async fn test_iter_decode_async() -> anyhow::Result<()> {
322 let ohlcv_msg1 = OhlcvMsg {
324 hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
325 open: 100,
326 high: 200,
327 low: 50,
328 close: 150,
329 volume: 1000,
330 };
331
332 let ohlcv_msg2 = OhlcvMsg {
333 hd: RecordHeader::new::<OhlcvMsg>(2, 1622471125, 0),
334 open: 110,
335 high: 210,
336 low: 55,
337 close: 155,
338 volume: 1100,
339 };
340
341 let mut buffer = Vec::new();
343 {
344 let mut encoder = RecordEncoder::new(&mut buffer);
345 let record_ref1: RecordRef = (&ohlcv_msg1).into();
346 let record_ref2: RecordRef = (&ohlcv_msg2).into();
347 encoder
348 .encode_records(&[record_ref1, record_ref2])
349 .expect("Encoding failed");
350 }
351
352 let cursor = Cursor::new(buffer);
354 let mut decoder = AsyncRecordDecoder::new(cursor);
355 let mut iter = decoder.decode_iterator();
356
357 let mut i = 0;
359 while let Some(record) = iter.next().await {
360 match record {
361 Ok(record) => {
362 if i == 0 {
363 assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg1.clone()));
364 } else {
365 assert_eq!(record, RecordEnum::Ohlcv(ohlcv_msg2.clone()));
366 }
367 i = i + 1;
368 }
369 Err(e) => {
370 eprintln!("Error processing record: {:?}", e);
371 }
372 }
373 }
374 Ok(())
375 }
376}