mbinary/
encode.rs

1use crate::metadata::Metadata;
2use crate::record_ref::*;
3use std::fs::OpenOptions;
4use std::io::{self, Write};
5use std::path::Path;
6use tokio::io::{AsyncWrite, AsyncWriteExt};
7
8pub struct CombinedEncoder<W> {
9    writer: W,
10}
11
12impl<W: Write> CombinedEncoder<W> {
13    pub fn new(writer: W) -> Self {
14        CombinedEncoder { writer }
15    }
16
17    pub fn encode_metadata(&mut self, metadata: &Metadata) -> io::Result<()> {
18        let mut metadata_encoder = MetadataEncoder::new(&mut self.writer);
19        metadata_encoder.encode_metadata(metadata)
20    }
21
22    pub fn encode_record(&mut self, record: &RecordRef) -> io::Result<()> {
23        let mut record_encoder = RecordEncoder::new(&mut self.writer);
24        record_encoder.encode_record(record)
25    }
26
27    pub fn encode_records(&mut self, records: &[RecordRef]) -> io::Result<()> {
28        let mut record_encoder = RecordEncoder::new(&mut self.writer);
29        record_encoder.encode_records(records)
30    }
31
32    pub fn encode(&mut self, metadata: &Metadata, records: &[RecordRef]) -> io::Result<()> {
33        self.encode_metadata(metadata)?;
34        self.encode_records(records)?;
35        Ok(())
36    }
37
38    pub fn write_to_file(&self, file_path: &Path, append: bool) -> io::Result<()>
39    where
40        W: AsRef<[u8]>,
41    {
42        let mut options = OpenOptions::new();
43        options.create(true);
44
45        if append {
46            options.append(true);
47        } else {
48            options.write(true).truncate(true);
49        }
50
51        let mut file = options.open(file_path)?;
52
53        file.write_all(self.writer.as_ref())?;
54        file.flush()?;
55        Ok(())
56    }
57}
58
59pub struct MetadataEncoder<W> {
60    writer: W,
61    // buffer: Vec<u8>,
62}
63
64impl<W: Write> MetadataEncoder<W> {
65    pub fn new(writer: W) -> Self {
66        MetadataEncoder { writer }
67    }
68
69    pub fn encode_metadata(&mut self, metadata: &Metadata) -> io::Result<()> {
70        let bytes = metadata.serialize();
71
72        // Calculate and prepend the length
73        let length: u16 = bytes.len() as u16;
74        let mut buffer = Vec::with_capacity(length as usize + 2);
75
76        // Add length as the first 2 bytes
77        buffer.extend_from_slice(&length.to_le_bytes());
78        buffer.extend_from_slice(&bytes);
79
80        // Write the buffer to the writer
81        self.writer.write_all(&buffer)?;
82        self.writer.flush()?;
83        Ok(())
84
85        // self.buffer[..serialized.len()].copy_from_slice(&serialized);
86        // self.writer.write_all(&self.buffer)?;
87        // self.writer.flush()?;
88        // Ok(())
89    }
90
91    pub fn write_to_file(&self, file_path: &Path, append: bool) -> io::Result<()>
92    where
93        W: AsRef<[u8]>,
94    {
95        let mut options = OpenOptions::new();
96        options.create(true);
97
98        if append {
99            options.append(true);
100        } else {
101            options.write(true).truncate(true);
102        }
103
104        let mut file = options.open(file_path)?;
105
106        file.write_all(self.writer.as_ref())?;
107        file.flush()?;
108        Ok(())
109    }
110}
111
112pub struct RecordEncoder<W> {
113    writer: W,
114}
115
116impl<W: Write> RecordEncoder<W> {
117    pub fn new(writer: W) -> Self {
118        RecordEncoder { writer }
119    }
120
121    pub async fn flush(&mut self) -> tokio::io::Result<()> {
122        self.writer.flush()?;
123        Ok(())
124    }
125
126    pub fn encode_record(&mut self, record: &RecordRef) -> io::Result<()> {
127        let bytes = record.as_ref();
128        self.writer.write_all(bytes)?;
129        Ok(())
130    }
131
132    pub fn encode_records(&mut self, records: &[RecordRef]) -> io::Result<()> {
133        for record in records {
134            self.encode_record(record)?;
135        }
136        self.writer.flush()?;
137        Ok(())
138    }
139
140    pub fn write_to_file(&self, file_path: &Path, append: bool) -> io::Result<()>
141    where
142        W: AsRef<[u8]>,
143    {
144        let mut options = OpenOptions::new();
145        options.create(true);
146
147        if append {
148            options.append(true);
149        } else {
150            options.write(true).truncate(true);
151        }
152
153        let mut file = options.open(file_path)?;
154
155        file.write_all(self.writer.as_ref())?;
156        file.flush()?;
157        Ok(())
158    }
159}
160
161// -- Aysnc --
162
163pub struct AsyncRecordEncoder<W> {
164    writer: W,
165}
166
167impl<W> AsyncRecordEncoder<W>
168where
169    W: AsyncWrite + Unpin,
170{
171    pub fn new(writer: W) -> Self {
172        AsyncRecordEncoder { writer }
173    }
174
175    pub async fn flush(&mut self) -> tokio::io::Result<()> {
176        self.writer.flush().await?;
177        Ok(())
178    }
179
180    pub async fn encode_record<'a>(&mut self, record: &'a RecordRef<'a>) -> tokio::io::Result<()> {
181        let bytes = record.as_ref();
182        self.writer.write_all(bytes).await?;
183        Ok(())
184    }
185
186    pub async fn encode_records<'a>(
187        &mut self,
188        records: &'a [RecordRef<'a>],
189    ) -> tokio::io::Result<()> {
190        for record in records {
191            self.encode_record(record).await?;
192        }
193        self.writer.flush().await?;
194        Ok(())
195    }
196    pub async fn write_to_file(
197        file_path: &Path,
198        append: bool,
199        buffer: &[u8],
200    ) -> tokio::io::Result<()> {
201        let mut options = tokio::fs::OpenOptions::new();
202        options.create(true);
203
204        if append {
205            options.append(true);
206        } else {
207            options.write(true).truncate(true);
208        }
209
210        let mut file = options.open(file_path).await?;
211
212        file.write_all(buffer).await?;
213        file.flush().await?;
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use serial_test::serial;
221
222    use super::*;
223    use crate::decode::AsyncDecoder;
224    use crate::enums::Dataset;
225    use crate::enums::Schema;
226    use crate::record_enum::RecordEnum;
227    use crate::records::BidAskPair;
228    use crate::records::Mbp1Msg;
229    use crate::records::OhlcvMsg;
230    use crate::records::RecordHeader;
231    use crate::symbols::SymbolMap;
232    use std::io::Cursor;
233    use std::path::PathBuf;
234
235    #[tokio::test]
236    async fn test_async_encode_record() -> anyhow::Result<()> {
237        let ohlcv_msg = OhlcvMsg {
238            hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
239            open: 100,
240            high: 200,
241            low: 50,
242            close: 150,
243            volume: 1000,
244        };
245        let record_ref: RecordRef = (&ohlcv_msg).into();
246
247        // Test
248        let mut buffer = Vec::new();
249        let mut encoder = AsyncRecordEncoder::new(&mut buffer);
250        encoder
251            .encode_record(&record_ref)
252            .await
253            .expect("Encoding failed");
254
255        // Validate
256        let cursor = Cursor::new(buffer);
257        let mut decoder = AsyncDecoder::new(cursor).await?;
258        let record_ref = decoder.decode_ref().await?.unwrap();
259        let decoded_record: &OhlcvMsg = record_ref.get().unwrap();
260        assert_eq!(decoded_record, &ohlcv_msg);
261
262        Ok(())
263    }
264
265    #[tokio::test]
266    async fn test_async_encode_records() -> anyhow::Result<()> {
267        let ohlcv_msg1 = OhlcvMsg {
268            hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
269            open: 100000000000,
270            high: 200000000000,
271            low: 50000000000,
272            close: 150000000000,
273            volume: 1000,
274        };
275
276        let ohlcv_msg2 = OhlcvMsg {
277            hd: RecordHeader::new::<OhlcvMsg>(2, 1622471125, 0),
278            open: 110000000000,
279            high: 210000000000,
280            low: 55000000000,
281            close: 155000000000,
282            volume: 1100,
283        };
284
285        let record_ref1: RecordRef = (&ohlcv_msg1).into();
286        let record_ref2: RecordRef = (&ohlcv_msg2).into();
287
288        // Test
289        let mut buffer = Vec::new();
290        let mut encoder = AsyncRecordEncoder::new(&mut buffer);
291        encoder
292            .encode_records(&[record_ref1, record_ref2])
293            .await
294            .expect("Encoding failed");
295        // println!("{:?}", buffer);
296
297        // Validate
298        let cursor = Cursor::new(buffer);
299        let mut decoder = AsyncDecoder::new(cursor).await?;
300        let decoded_records = decoder.decode().await?;
301
302        assert_eq!(decoded_records.len(), 2);
303        assert_eq!(decoded_records[0], RecordEnum::Ohlcv(ohlcv_msg1));
304        assert_eq!(decoded_records[1], RecordEnum::Ohlcv(ohlcv_msg2));
305
306        Ok(())
307    }
308
309    #[tokio::test]
310    async fn test_encode_record() -> anyhow::Result<()> {
311        let ohlcv_msg = OhlcvMsg {
312            hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
313            open: 100,
314            high: 200,
315            low: 50,
316            close: 150,
317            volume: 1000,
318        };
319        let record_ref: RecordRef = (&ohlcv_msg).into();
320
321        // Test
322        let mut buffer = Vec::new();
323        let mut encoder = RecordEncoder::new(&mut buffer);
324        encoder.encode_record(&record_ref).expect("Encoding failed");
325
326        // Validate
327        let cursor = Cursor::new(buffer);
328        let mut decoder = AsyncDecoder::new(cursor).await?;
329        let record_ref = decoder.decode_ref().await?.unwrap();
330        let decoded_record: &OhlcvMsg = record_ref.get().unwrap();
331        assert_eq!(decoded_record, &ohlcv_msg);
332
333        Ok(())
334    }
335
336    #[tokio::test]
337    async fn test_encode_decode_records() -> anyhow::Result<()> {
338        let ohlcv_msg1 = OhlcvMsg {
339            hd: RecordHeader::new::<OhlcvMsg>(1, 1622471124, 0),
340            open: 100000000000,
341            high: 200000000000,
342            low: 50000000000,
343            close: 150000000000,
344            volume: 1000,
345        };
346
347        let ohlcv_msg2 = OhlcvMsg {
348            hd: RecordHeader::new::<OhlcvMsg>(2, 1622471125, 0),
349            open: 110000000000,
350            high: 210000000000,
351            low: 55000000000,
352            close: 155000000000,
353            volume: 1100,
354        };
355
356        let record_ref1: RecordRef = (&ohlcv_msg1).into();
357        let record_ref2: RecordRef = (&ohlcv_msg2).into();
358
359        // Test
360        let mut buffer = Vec::new();
361        let mut encoder = RecordEncoder::new(&mut buffer);
362        encoder
363            .encode_records(&[record_ref1, record_ref2])
364            .expect("Encoding failed");
365        // println!("{:?}", buffer);
366
367        // Validate
368        let cursor = Cursor::new(buffer);
369        let mut decoder = AsyncDecoder::new(cursor).await?;
370        let decoded_records = decoder.decode().await?;
371
372        assert_eq!(decoded_records.len(), 2);
373        assert_eq!(decoded_records[0], RecordEnum::Ohlcv(ohlcv_msg1));
374        assert_eq!(decoded_records[1], RecordEnum::Ohlcv(ohlcv_msg2));
375
376        Ok(())
377    }
378
379    #[tokio::test]
380    async fn test_encode_metadata() -> anyhow::Result<()> {
381        let mut symbol_map = SymbolMap::new();
382        symbol_map.add_instrument("AAPL", 1);
383        symbol_map.add_instrument("TSLA", 2);
384
385        let metadata = Metadata::new(
386            Schema::Ohlcv1S,
387            Dataset::Equities,
388            1234567898765,
389            123456765432,
390            symbol_map,
391        );
392
393        // Test
394        let mut buffer = Vec::new();
395        let mut encoder = MetadataEncoder::new(&mut buffer);
396        encoder
397            .encode_metadata(&metadata)
398            .expect("Error metadata encoding.");
399
400        // Validate
401        let length_buffer: [u8; 2] = buffer[..2].try_into()?;
402        let metadata_length = u16::from_le_bytes(length_buffer) as usize;
403        let bytes = &buffer[2..2 + metadata_length];
404        let decoded = Metadata::deserialize(&bytes)?;
405        assert_eq!(decoded.schema, metadata.schema);
406        assert_eq!(decoded.start, metadata.start);
407        assert_eq!(decoded.end, metadata.end);
408        assert_eq!(decoded.mappings, metadata.mappings);
409        Ok(())
410    }
411
412    #[test]
413    fn test_encode() {
414        // Metadata
415        let mut symbol_map = SymbolMap::new();
416        symbol_map.add_instrument("AAPL", 1);
417        symbol_map.add_instrument("TSLA", 2);
418
419        let metadata = Metadata::new(
420            Schema::Ohlcv1S,
421            Dataset::Equities,
422            1234567898765,
423            123456765432,
424            symbol_map,
425        );
426
427        // Record
428        let ohlcv_msg1 = OhlcvMsg {
429            hd: RecordHeader::new::<OhlcvMsg>(1, 1724287878000000000, 0),
430            open: 100000000000,
431            high: 200000000000,
432            low: 50000000000,
433            close: 150000000000,
434            volume: 1000000000000,
435        };
436
437        let ohlcv_msg2 = OhlcvMsg {
438            hd: RecordHeader::new::<OhlcvMsg>(2, 1724289878000000000, 0),
439            open: 110000000000,
440            high: 210000000000,
441            low: 55000000000,
442            close: 155000000000,
443            volume: 1100000000000,
444        };
445
446        let record_ref1: RecordRef = (&ohlcv_msg1).into();
447        let record_ref2: RecordRef = (&ohlcv_msg2).into();
448        let records = &[record_ref1, record_ref2];
449
450        // Test
451        let mut buffer = Vec::new();
452        let mut encoder = CombinedEncoder::new(&mut buffer);
453        encoder
454            .encode(&metadata, records)
455            .expect("Error on encoding");
456
457        // Validate
458        assert!(buffer.len() > 0);
459    }
460
461    #[tokio::test]
462    #[serial]
463    async fn test_encode_metadata_and_records_seperate_to_same_file() -> anyhow::Result<()> {
464        // Metadata
465        let mut symbol_map = SymbolMap::new();
466        symbol_map.add_instrument("AAPL", 1);
467        symbol_map.add_instrument("TSLA", 2);
468
469        let metadata = Metadata::new(
470            Schema::Ohlcv1S,
471            Dataset::Equities,
472            1234567898765,
473            123456765432,
474            symbol_map,
475        );
476
477        // Record
478        let ohlcv_msg1 = OhlcvMsg {
479            hd: RecordHeader::new::<OhlcvMsg>(1, 1724287878000000000, 0),
480            open: 100000000000,
481            high: 200000000000,
482            low: 50000000000,
483            close: 150000000000,
484            volume: 1000000000000,
485        };
486
487        let ohlcv_msg2 = OhlcvMsg {
488            hd: RecordHeader::new::<OhlcvMsg>(2, 1724289878000000000, 0),
489            open: 110000000000,
490            high: 210000000000,
491            low: 55000000000,
492            close: 155000000000,
493            volume: 1100000000000,
494        };
495
496        let record_ref1: RecordRef = (&ohlcv_msg1).into();
497        let record_ref2: RecordRef = (&ohlcv_msg2).into();
498        let records = &[record_ref1, record_ref2];
499
500        // Test
501        let file = PathBuf::from("tests/mbp_encoded_seperatly.bin");
502        let mut buffer = Vec::new();
503        let mut m_encoder = MetadataEncoder::new(&mut buffer);
504        m_encoder.encode_metadata(&metadata)?;
505        let _ = m_encoder.write_to_file(&file, true);
506
507        let mut buffer = Vec::new();
508        let mut r_encoder = RecordEncoder::new(&mut buffer);
509        r_encoder.encode_records(records)?;
510        let _ = r_encoder.write_to_file(&file, true);
511
512        // Validate
513        let mut decoder =
514            <AsyncDecoder<tokio::io::BufReader<tokio::fs::File>>>::from_file(file.clone()).await?;
515        let metadata_decoded = decoder.metadata().unwrap();
516        let records = decoder.decode().await?;
517        let expected = vec![
518            RecordEnum::from_ref(record_ref1)?,
519            RecordEnum::from_ref(record_ref2)?,
520        ];
521        assert!(metadata == metadata_decoded);
522        assert!(expected == records);
523
524        // Cleanup
525        if file.exists() {
526            std::fs::remove_file(&file).expect("Failed to delete the test file.");
527        }
528        Ok(())
529    }
530
531    #[tokio::test]
532    #[serial]
533    async fn test_encode_to_file_w_metadata() -> anyhow::Result<()> {
534        // Metadata
535        let mut symbol_map = SymbolMap::new();
536        symbol_map.add_instrument("AAPL", 1);
537        symbol_map.add_instrument("TSLA", 2);
538
539        let metadata = Metadata::new(
540            Schema::Mbp1,
541            Dataset::Futures,
542            1234567898765,
543            123456765432,
544            symbol_map,
545        );
546
547        // Record
548        let msg1 = Mbp1Msg {
549            hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
550            price: 12345676543,
551            size: 1234543,
552            action: 0,
553            side: 0,
554            depth: 0,
555            flags: 0,
556            ts_recv: 1231,
557            ts_in_delta: 123432,
558            sequence: 23432,
559            discriminator: 0,
560            levels: [BidAskPair {
561                bid_px: 10000000,
562                ask_px: 200000,
563                bid_sz: 3000000,
564                ask_sz: 400000000,
565                bid_ct: 50000000,
566                ask_ct: 60000000,
567            }],
568        };
569        let msg2 = Mbp1Msg {
570            hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
571            price: 12345676543,
572            size: 1234543,
573            action: 0,
574            side: 0,
575            depth: 0,
576            flags: 0,
577            ts_recv: 1231,
578            ts_in_delta: 123432,
579            sequence: 23432,
580            discriminator: 0,
581            levels: [BidAskPair {
582                bid_px: 10000000,
583                ask_px: 200000,
584                bid_sz: 3000000,
585                ask_sz: 400000000,
586                bid_ct: 50000000,
587                ask_ct: 60000000,
588            }],
589        };
590
591        let record_ref1: RecordRef = (&msg1).into();
592        let record_ref2: RecordRef = (&msg2).into();
593        let records = &[record_ref1, record_ref2];
594
595        let mut buffer = Vec::new();
596        let mut encoder = CombinedEncoder::new(&mut buffer);
597        encoder
598            .encode(&metadata, records)
599            .expect("Error on encoding");
600
601        // Test
602        let file = PathBuf::from("tests/mbp_w_metadata.bin");
603        let _ = encoder.write_to_file(&file, false);
604
605        // Validate
606        let mut decoder =
607            <AsyncDecoder<tokio::io::BufReader<tokio::fs::File>>>::from_file(file.clone()).await?;
608        let records = decoder.decode().await?;
609        let expected = vec![
610            RecordEnum::from_ref(record_ref1)?,
611            RecordEnum::from_ref(record_ref2)?,
612        ];
613
614        assert!(expected == records);
615
616        // Cleanup
617        if file.exists() {
618            std::fs::remove_file(&file).expect("Failed to delete the test file.");
619        }
620        Ok(())
621    }
622
623    #[tokio::test]
624    #[serial]
625    // #[ignore]
626    async fn test_encode_to_file_wout_metadata() -> anyhow::Result<()> {
627        // Record
628        let msg1 = Mbp1Msg {
629            hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
630            price: 12345676543,
631            size: 1234543,
632            action: 0,
633            side: 0,
634            depth: 0,
635            flags: 0,
636            ts_recv: 1231,
637            ts_in_delta: 123432,
638            sequence: 23432,
639            discriminator: 0,
640            levels: [BidAskPair {
641                bid_px: 10000000,
642                ask_px: 200000,
643                bid_sz: 3000000,
644                ask_sz: 400000000,
645                bid_ct: 50000000,
646                ask_ct: 60000000,
647            }],
648        };
649        let msg2 = Mbp1Msg {
650            hd: RecordHeader::new::<Mbp1Msg>(1, 1622471124, 0),
651            price: 12345676543,
652            size: 1234543,
653            action: 0,
654            side: 0,
655            depth: 0,
656            flags: 0,
657            ts_recv: 1231,
658            ts_in_delta: 123432,
659            sequence: 23432,
660            discriminator: 0,
661            levels: [BidAskPair {
662                bid_px: 10000000,
663                ask_px: 200000,
664                bid_sz: 3000000,
665                ask_sz: 400000000,
666                bid_ct: 50000000,
667                ask_ct: 60000000,
668            }],
669        };
670
671        let record_ref1: RecordRef = (&msg1).into();
672        let record_ref2: RecordRef = (&msg2).into();
673
674        let mut buffer = Vec::new();
675        let mut encoder = RecordEncoder::new(&mut buffer);
676        encoder
677            .encode_records(&[record_ref1, record_ref2])
678            .expect("Encoding failed");
679
680        // Test
681        let file = PathBuf::from("tests/mbp_wout_metadata.bin");
682        let _ = encoder.write_to_file(&file, false);
683
684        // Validate
685        let mut decoder =
686            <AsyncDecoder<tokio::io::BufReader<tokio::fs::File>>>::from_file(file.clone()).await?;
687        let records = decoder.decode().await?;
688        let expected = vec![
689            RecordEnum::from_ref(record_ref1)?,
690            RecordEnum::from_ref(record_ref2)?,
691        ];
692
693        assert!(expected == records);
694
695        // Cleanup
696        if file.exists() {
697            std::fs::remove_file(&file).expect("Failed to delete the test file.");
698        }
699        Ok(())
700    }
701}