dbz_lib/write/
dbz.rs

1use std::{
2    io::{self, SeekFrom, Write},
3    mem,
4    ops::Range,
5    slice,
6};
7
8use anyhow::{anyhow, Context};
9use databento_defs::record::ConstTypeId;
10use streaming_iterator::StreamingIterator;
11use zstd::{stream::AutoFinishEncoder, Encoder};
12
13use crate::{read::SymbolMapping, Metadata};
14
15pub(crate) const SCHEMA_VERSION: u8 = 1;
16
17/// Create a new Zstd encoder with default settings
18fn new_encoder<'a, W: io::Write>(writer: W) -> anyhow::Result<AutoFinishEncoder<'a, W>> {
19    pub(crate) const ZSTD_COMPRESSION_LEVEL: i32 = 0;
20
21    let mut encoder = Encoder::new(writer, ZSTD_COMPRESSION_LEVEL)?;
22    encoder.include_checksum(true)?;
23    Ok(encoder.auto_finish())
24}
25
26impl Metadata {
27    pub(crate) const ZSTD_MAGIC_RANGE: Range<u32> = 0x184D2A50..0x184D2A60;
28    pub(crate) const VERSION_CSTR_LEN: usize = 4;
29    pub(crate) const DATASET_CSTR_LEN: usize = 16;
30    pub(crate) const RESERVED_LEN: usize = 39;
31    pub(crate) const FIXED_METADATA_LEN: usize = 96;
32    pub(crate) const SYMBOL_CSTR_LEN: usize = 22;
33
34    pub fn encode(&self, mut writer: impl io::Write + io::Seek) -> anyhow::Result<()> {
35        writer.write_all(Self::ZSTD_MAGIC_RANGE.start.to_le_bytes().as_slice())?;
36        // write placeholder frame size to filled in at the end
37        writer.write_all(b"0000")?;
38        writer.write_all(b"DBZ")?;
39        writer.write_all(&[self.version])?;
40        Self::encode_fixed_len_cstr::<_, { Self::DATASET_CSTR_LEN }>(&mut writer, &self.dataset)?;
41        writer.write_all((self.schema as u16).to_le_bytes().as_slice())?;
42        Self::encode_range_and_counts(
43            &mut writer,
44            self.start,
45            self.end,
46            self.limit,
47            self.record_count,
48        )?;
49        writer.write_all(&[self.compression as u8])?;
50        writer.write_all(&[self.stype_in as u8])?;
51        writer.write_all(&[self.stype_out as u8])?;
52        // padding
53        writer.write_all(&[0; Self::RESERVED_LEN])?;
54        {
55            // remaining metadata is compressed
56            let mut zstd_encoder = new_encoder(&mut writer)?;
57            // schema_definition_length
58            zstd_encoder.write_all(0u32.to_le_bytes().as_slice())?;
59
60            Self::encode_repeated_symbol_cstr(&mut zstd_encoder, self.symbols.as_slice())
61                .with_context(|| "Failed to encode symbols")?;
62            Self::encode_repeated_symbol_cstr(&mut zstd_encoder, self.partial.as_slice())
63                .with_context(|| "Failed to encode partial")?;
64            Self::encode_repeated_symbol_cstr(&mut zstd_encoder, self.not_found.as_slice())
65                .with_context(|| "Failed to encode not_found")?;
66            Self::encode_symbol_mappings(&mut zstd_encoder, self.mappings.as_slice())?;
67        }
68
69        let raw_size = writer.stream_position()?;
70        // go back and update the size now that we know it
71        writer.seek(SeekFrom::Start(4))?;
72        // magic number and size aren't included in the metadata size
73        let frame_size = (raw_size - 8) as u32;
74        writer.write_all(frame_size.to_le_bytes().as_slice())?;
75        // go back to end to leave `writer` in a place for more data to be written
76        writer.seek(SeekFrom::End(0))?;
77
78        Ok(())
79    }
80
81    pub fn update_encoded(
82        mut writer: impl io::Write + io::Seek,
83        start: u64,
84        end: u64,
85        limit: u64,
86        record_count: u64,
87    ) -> anyhow::Result<()> {
88        /// Byte position of the field `start`
89        const START_SEEK_FROM: SeekFrom =
90            SeekFrom::Start((8 + 4 + Metadata::DATASET_CSTR_LEN + 2) as u64);
91
92        writer
93            .seek(START_SEEK_FROM)
94            .with_context(|| "Failed to seek to write position".to_owned())?;
95        Self::encode_range_and_counts(&mut writer, start, end, limit, record_count)?;
96        writer
97            .seek(SeekFrom::End(0))
98            .with_context(|| "Failed to seek back to end".to_owned())?;
99        Ok(())
100    }
101
102    fn encode_range_and_counts(
103        writer: &mut impl io::Write,
104        start: u64,
105        end: u64,
106        limit: u64,
107        record_count: u64,
108    ) -> anyhow::Result<()> {
109        writer.write_all(start.to_le_bytes().as_slice())?;
110        writer.write_all(end.to_le_bytes().as_slice())?;
111        writer.write_all(limit.to_le_bytes().as_slice())?;
112        writer.write_all(record_count.to_le_bytes().as_slice())?;
113        Ok(())
114    }
115
116    fn encode_repeated_symbol_cstr(
117        writer: &mut impl io::Write,
118        symbols: &[String],
119    ) -> anyhow::Result<()> {
120        writer.write_all((symbols.len() as u32).to_le_bytes().as_slice())?;
121        for symbol in symbols {
122            Self::encode_fixed_len_cstr::<_, { Self::SYMBOL_CSTR_LEN }>(writer, symbol)?;
123        }
124
125        Ok(())
126    }
127
128    fn encode_symbol_mappings(
129        writer: &mut impl io::Write,
130        symbol_mappings: &[SymbolMapping],
131    ) -> anyhow::Result<()> {
132        // encode mappings_count
133        writer.write_all((symbol_mappings.len() as u32).to_le_bytes().as_slice())?;
134        for symbol_mapping in symbol_mappings {
135            Self::encode_symbol_mapping(writer, symbol_mapping)?;
136        }
137        Ok(())
138    }
139
140    fn encode_symbol_mapping(
141        writer: &mut impl io::Write,
142        symbol_mapping: &SymbolMapping,
143    ) -> anyhow::Result<()> {
144        Self::encode_fixed_len_cstr::<_, { Self::SYMBOL_CSTR_LEN }>(
145            writer,
146            &symbol_mapping.native,
147        )?;
148        // encode interval_count
149        writer.write_all(
150            (symbol_mapping.intervals.len() as u32)
151                .to_le_bytes()
152                .as_slice(),
153        )?;
154        for interval in symbol_mapping.intervals.iter() {
155            Self::encode_date(writer, interval.start_date)?;
156            Self::encode_date(writer, interval.end_date)?;
157            Self::encode_fixed_len_cstr::<_, { Self::SYMBOL_CSTR_LEN }>(writer, &interval.symbol)?;
158        }
159        Ok(())
160    }
161
162    // Can't specify const generic with impl trait until Rust 1.63, see
163    // https://github.com/rust-lang/rust/issues/83701
164    fn encode_fixed_len_cstr<W: io::Write, const LEN: usize>(
165        writer: &mut W,
166        string: &str,
167    ) -> anyhow::Result<()> {
168        if !string.is_ascii() {
169            return Err(anyhow!(
170                "'{string}' can't be encoded in DBZ because it contains non-ASCII characters"
171            ));
172        }
173        if string.len() > LEN {
174            return Err(anyhow!(
175                "'{string}' is too long to be encoded in DBZ; it cannot be longer {LEN} characters"
176            ));
177        }
178        writer.write_all(string.as_bytes())?;
179        // pad remaining space with null bytes
180        for _ in string.len()..LEN {
181            writer.write_all(&[0])?;
182        }
183        Ok(())
184    }
185
186    fn encode_date(writer: &mut impl io::Write, date: time::Date) -> anyhow::Result<()> {
187        let mut date_int = date.year() as u32 * 10_000;
188        date_int += date.month() as u32 * 100;
189        date_int += date.day() as u32;
190        writer.write_all(date_int.to_le_bytes().as_slice())?;
191        Ok(())
192    }
193}
194
195unsafe fn as_u8_slice<T: Sized>(data: &T) -> &[u8] {
196    slice::from_raw_parts(data as *const T as *const u8, mem::size_of::<T>())
197}
198
199/// Incrementally serializes the records in `iter` in the DBZ format to `writer`.
200pub fn write_dbz_stream<T>(
201    writer: impl io::Write,
202    mut stream: impl StreamingIterator<Item = T>,
203) -> anyhow::Result<()>
204where
205    T: ConstTypeId + Sized,
206{
207    let mut encoder = new_encoder(writer)
208        .with_context(|| "Failed to create Zstd encoder for writing DBZ".to_owned())?;
209    while let Some(record) = stream.next() {
210        let bytes = unsafe {
211            // Safety: all records, types implementing `ConstTypeId` are POD
212            as_u8_slice(record)
213        };
214        match encoder.write_all(bytes) {
215            // closed pipe, should stop writing output
216            Err(e) if e.kind() == io::ErrorKind::BrokenPipe => return Ok(()),
217            r => r,
218        }
219        .with_context(|| "Failed to serialize {record:#?}")?;
220    }
221    encoder.flush()?;
222    Ok(())
223}
224
225/// Incrementally serializes the records in `iter` in the DBZ format to `writer`.
226pub fn write_dbz<'a, T>(
227    writer: impl io::Write,
228    iter: impl Iterator<Item = &'a T>,
229) -> anyhow::Result<()>
230where
231    T: 'a + ConstTypeId + Sized,
232{
233    let mut encoder = new_encoder(writer)
234        .with_context(|| "Failed to create Zstd encoder for writing DBZ".to_owned())?;
235    for record in iter {
236        let bytes = unsafe {
237            // Safety: all records, types implementing `ConstTypeId` are POD
238            as_u8_slice(record)
239        };
240        match encoder.write_all(bytes) {
241            // closed pipe, should stop writing output
242            Err(e) if e.kind() == io::ErrorKind::BrokenPipe => return Ok(()),
243            r => r,
244        }
245        .with_context(|| "Failed to serialize {record:#?}")?;
246    }
247    encoder.flush()?;
248    Ok(())
249}
250
251#[cfg(test)]
252mod tests {
253    use std::{
254        ffi::c_char,
255        fmt,
256        io::{BufWriter, Seek},
257        mem,
258    };
259
260    use databento_defs::{
261        enums::{Compression, SType, Schema},
262        record::{Mbp1Msg, OhlcvMsg, RecordHeader, StatusMsg, TickMsg, TradeMsg},
263    };
264
265    use crate::{
266        read::{FromLittleEndianSlice, MappingInterval},
267        write::test_data::{VecStream, BID_ASK, RECORD_HEADER},
268        DbzStreamIter,
269    };
270
271    use super::*;
272
273    #[test]
274    fn test_encode_decode_metadata_identity() {
275        let mut extra = serde_json::Map::default();
276        extra.insert(
277            "Key".to_owned(),
278            serde_json::Value::Number(serde_json::Number::from_f64(4.0).unwrap()),
279        );
280        let metadata = Metadata {
281            version: 1,
282            dataset: "GLBX.MDP3".to_owned(),
283            schema: Schema::Mbp10,
284            stype_in: SType::Native,
285            stype_out: SType::ProductId,
286            start: 1657230820000000000,
287            end: 1658960170000000000,
288            limit: 0,
289            compression: Compression::ZStd,
290            record_count: 14,
291            symbols: vec!["ES".to_owned(), "NG".to_owned()],
292            partial: vec!["ESM2".to_owned()],
293            not_found: vec!["QQQQQ".to_owned()],
294            mappings: vec![
295                SymbolMapping {
296                    native: "ES.0".to_owned(),
297                    intervals: vec![MappingInterval {
298                        start_date: time::Date::from_calendar_date(2022, time::Month::July, 26)
299                            .unwrap(),
300                        end_date: time::Date::from_calendar_date(2022, time::Month::September, 1)
301                            .unwrap(),
302                        symbol: "ESU2".to_owned(),
303                    }],
304                },
305                SymbolMapping {
306                    native: "NG.0".to_owned(),
307                    intervals: vec![
308                        MappingInterval {
309                            start_date: time::Date::from_calendar_date(2022, time::Month::July, 26)
310                                .unwrap(),
311                            end_date: time::Date::from_calendar_date(2022, time::Month::August, 29)
312                                .unwrap(),
313                            symbol: "NGU2".to_owned(),
314                        },
315                        MappingInterval {
316                            start_date: time::Date::from_calendar_date(
317                                2022,
318                                time::Month::August,
319                                29,
320                            )
321                            .unwrap(),
322                            end_date: time::Date::from_calendar_date(
323                                2022,
324                                time::Month::September,
325                                1,
326                            )
327                            .unwrap(),
328                            symbol: "NGV2".to_owned(),
329                        },
330                    ],
331                },
332            ],
333        };
334        let mut buffer = Vec::new();
335        let cursor = io::Cursor::new(&mut buffer);
336        metadata.encode(cursor).unwrap();
337        dbg!(&buffer);
338        let res = Metadata::read(&mut &buffer[..]).unwrap();
339        dbg!(&res, &metadata);
340        assert_eq!(res, metadata);
341    }
342
343    #[test]
344    fn test_encode_repeated_symbol_cstr() {
345        let mut buffer = Vec::new();
346        let symbols = vec![
347            "NG".to_owned(),
348            "HP".to_owned(),
349            "HPQ".to_owned(),
350            "LNQ".to_owned(),
351        ];
352        Metadata::encode_repeated_symbol_cstr(&mut buffer, symbols.as_slice()).unwrap();
353        assert_eq!(
354            buffer.len(),
355            mem::size_of::<u32>() + symbols.len() * Metadata::SYMBOL_CSTR_LEN
356        );
357        assert_eq!(u32::from_le_slice(&buffer[..4]), 4);
358        for (i, symbol) in symbols.iter().enumerate() {
359            let offset = i * Metadata::SYMBOL_CSTR_LEN;
360            assert_eq!(
361                &buffer[4 + offset..4 + offset + symbol.len()],
362                symbol.as_bytes()
363            );
364        }
365    }
366
367    #[test]
368    fn test_encode_fixed_len_cstr() {
369        let mut buffer = Vec::new();
370        Metadata::encode_fixed_len_cstr::<_, { Metadata::SYMBOL_CSTR_LEN }>(&mut buffer, "NG")
371            .unwrap();
372        assert_eq!(buffer.len(), Metadata::SYMBOL_CSTR_LEN);
373        assert_eq!(&buffer[..2], b"NG");
374        for b in buffer[2..].iter() {
375            assert_eq!(*b, 0);
376        }
377    }
378
379    #[test]
380    fn test_encode_date() {
381        let date = time::Date::from_calendar_date(2020, time::Month::May, 17).unwrap();
382        let mut buffer = Vec::new();
383        Metadata::encode_date(&mut buffer, date).unwrap();
384        assert_eq!(buffer.len(), mem::size_of::<u32>());
385        assert_eq!(buffer.as_slice(), 20200517u32.to_le_bytes().as_slice());
386    }
387
388    #[test]
389    fn test_update_encoded() {
390        let orig_metadata = Metadata {
391            version: 1,
392            dataset: "GLBX.MDP3".to_owned(),
393            schema: Schema::Mbo,
394            stype_in: SType::Smart,
395            stype_out: SType::Native,
396            start: 1657230820000000000,
397            end: 1658960170000000000,
398            limit: 0,
399            record_count: 1_450_000,
400            compression: Compression::ZStd,
401            symbols: vec![],
402            partial: vec![],
403            not_found: vec![],
404            mappings: vec![],
405        };
406        let mut buffer = Vec::new();
407        let cursor = io::Cursor::new(&mut buffer);
408        orig_metadata.encode(cursor).unwrap();
409        let orig_res = Metadata::read(&mut &buffer[..]).unwrap();
410        assert_eq!(orig_metadata, orig_res);
411        let mut cursor = io::Cursor::new(&mut buffer);
412        assert_eq!(cursor.position(), 0);
413        cursor.seek(SeekFrom::End(0)).unwrap();
414        let before_pos = cursor.position();
415        assert!(before_pos != 0);
416        let new_start = 1697240529000000000;
417        let new_end = 17058980170000000000;
418        let new_limit = 10;
419        let new_record_count = 100_678;
420        Metadata::update_encoded(&mut cursor, new_start, new_end, new_limit, new_record_count)
421            .unwrap();
422        assert_eq!(before_pos, cursor.position());
423        let res = Metadata::read(&mut &buffer[..]).unwrap();
424        assert!(res != orig_res);
425        assert_eq!(res.start, new_start);
426        assert_eq!(res.end, new_end);
427        assert_eq!(res.limit, new_limit);
428        assert_eq!(res.record_count, new_record_count);
429    }
430
431    fn encode_records_and_stub_metadata<T>(schema: Schema, records: Vec<T>) -> (Vec<u8>, Metadata)
432    where
433        T: ConstTypeId + Clone,
434    {
435        let mut buffer = Vec::new();
436        let writer = BufWriter::new(&mut buffer);
437        write_dbz_stream(writer, VecStream::new(records.clone())).unwrap();
438        dbg!(&buffer);
439        let metadata = Metadata {
440            version: 1,
441            dataset: "GLBX.MDP3".to_owned(),
442            schema,
443            start: 0,
444            end: 0,
445            limit: 0,
446            record_count: records.len() as u64,
447            compression: Compression::None,
448            stype_in: SType::Native,
449            stype_out: SType::ProductId,
450            symbols: vec![],
451            partial: vec![],
452            not_found: vec![],
453            mappings: vec![],
454        };
455        (buffer, metadata)
456    }
457
458    fn assert_encode_decode_record_identity<T>(schema: Schema, records: Vec<T>)
459    where
460        T: ConstTypeId + Clone + fmt::Debug + PartialEq,
461    {
462        let (buffer, metadata) = encode_records_and_stub_metadata(schema, records.clone());
463        let mut iter: DbzStreamIter<&[u8], T> =
464            DbzStreamIter::new(buffer.as_slice(), metadata).unwrap();
465        let mut res = Vec::new();
466        while let Some(rec) = iter.next() {
467            res.push(rec.to_owned());
468        }
469        dbg!(&res, &records);
470        assert_eq!(res, records);
471    }
472
473    #[test]
474    fn test_encode_decode_mbo_identity() {
475        let records = vec![
476            TickMsg {
477                hd: RecordHeader {
478                    rtype: TickMsg::TYPE_ID,
479                    ..RECORD_HEADER
480                },
481                order_id: 2,
482                price: 9250000000,
483                size: 25,
484                flags: -128,
485                channel_id: 1,
486                action: 'B' as i8,
487                side: 67,
488                ts_recv: 1658441891000000000,
489                ts_in_delta: 1000,
490                sequence: 98,
491            },
492            TickMsg {
493                hd: RecordHeader {
494                    rtype: TickMsg::TYPE_ID,
495                    ..RECORD_HEADER
496                },
497                order_id: 3,
498                price: 9350000000,
499                size: 800,
500                flags: 0,
501                channel_id: 1,
502                action: 'C' as i8,
503                side: 67,
504                ts_recv: 1658441991000000000,
505                ts_in_delta: 750,
506                sequence: 101,
507            },
508        ];
509        assert_encode_decode_record_identity(Schema::Mbo, records);
510    }
511
512    #[test]
513    fn test_encode_decode_mbp1_identity() {
514        let records = vec![
515            Mbp1Msg {
516                hd: RecordHeader {
517                    rtype: Mbp1Msg::TYPE_ID,
518                    ..RECORD_HEADER
519                },
520                price: 925000000000,
521                size: 300,
522                action: 'S' as i8,
523                side: 67,
524                flags: -128,
525                depth: 1,
526                ts_recv: 1658442001000000000,
527                ts_in_delta: 750,
528                sequence: 100,
529                booklevel: [BID_ASK; 1],
530            },
531            Mbp1Msg {
532                hd: RecordHeader {
533                    rtype: Mbp1Msg::TYPE_ID,
534                    ..RECORD_HEADER
535                },
536                price: 925000000000,
537                size: 50,
538                action: 'B' as i8,
539                side: 67,
540                flags: -128,
541                depth: 1,
542                ts_recv: 1658542001000000000,
543                ts_in_delta: 787,
544                sequence: 101,
545                booklevel: [BID_ASK; 1],
546            },
547        ];
548        assert_encode_decode_record_identity(Schema::Mbp1, records);
549    }
550
551    #[test]
552    fn test_encode_decode_trade_identity() {
553        let records = vec![
554            TradeMsg {
555                hd: RecordHeader {
556                    rtype: TradeMsg::TYPE_ID,
557                    ..RECORD_HEADER
558                },
559                price: 925000000000,
560                size: 1,
561                action: 'T' as i8,
562                side: 'B' as i8,
563                flags: 0,
564                depth: 4,
565                ts_recv: 1658441891000000000,
566                ts_in_delta: 234,
567                sequence: 1005,
568                booklevel: [],
569            },
570            TradeMsg {
571                hd: RecordHeader {
572                    rtype: TradeMsg::TYPE_ID,
573                    ..RECORD_HEADER
574                },
575                price: 925000000000,
576                size: 10,
577                action: 'T' as i8,
578                side: 'S' as i8,
579                flags: 0,
580                depth: 1,
581                ts_recv: 1659441891000000000,
582                ts_in_delta: 10358,
583                sequence: 1010,
584                booklevel: [],
585            },
586        ];
587        assert_encode_decode_record_identity(Schema::Trades, records);
588    }
589
590    #[test]
591    fn test_encode_decode_ohlcv_identity() {
592        let records = vec![
593            OhlcvMsg {
594                hd: RecordHeader {
595                    rtype: OhlcvMsg::TYPE_ID,
596                    ..RECORD_HEADER
597                },
598                open: 92500000000,
599                high: 95200000000,
600                low: 91200000000,
601                close: 91600000000,
602                volume: 6785,
603            },
604            OhlcvMsg {
605                hd: RecordHeader {
606                    rtype: OhlcvMsg::TYPE_ID,
607                    ..RECORD_HEADER
608                },
609                open: 91600000000,
610                high: 95100000000,
611                low: 91600000000,
612                close: 92300000000,
613                volume: 7685,
614            },
615        ];
616        assert_encode_decode_record_identity(Schema::Ohlcv1D, records);
617    }
618
619    #[test]
620    fn test_encode_decode_status_identity() {
621        let mut group = [0; 21];
622        for (i, c) in "group".chars().enumerate() {
623            group[i] = c as c_char;
624        }
625        let records = vec![
626            StatusMsg {
627                hd: RecordHeader {
628                    rtype: StatusMsg::TYPE_ID,
629                    ..RECORD_HEADER
630                },
631                ts_recv: 1658441891000000000,
632                group,
633                trading_status: 3,
634                halt_reason: 4,
635                trading_event: 5,
636            },
637            StatusMsg {
638                hd: RecordHeader {
639                    rtype: StatusMsg::TYPE_ID,
640                    ..RECORD_HEADER
641                },
642                ts_recv: 1658541891000000000,
643                group,
644                trading_status: 4,
645                halt_reason: 5,
646                trading_event: 6,
647            },
648        ];
649        assert_encode_decode_record_identity(Schema::Status, records);
650    }
651
652    #[test]
653    fn test_decode_malformed_encoded_dbz() {
654        let records = vec![
655            OhlcvMsg {
656                hd: RecordHeader {
657                    rtype: OhlcvMsg::TYPE_ID,
658                    ..RECORD_HEADER
659                },
660                open: 92500000000,
661                high: 95200000000,
662                low: 91200000000,
663                close: 91600000000,
664                volume: 6785,
665            },
666            OhlcvMsg {
667                hd: RecordHeader {
668                    rtype: OhlcvMsg::TYPE_ID,
669                    ..RECORD_HEADER
670                },
671                open: 91600000000,
672                high: 95100000000,
673                low: 91600000000,
674                close: 92300000000,
675                volume: 7685,
676            },
677        ];
678        let wrong_schema = Schema::Mbo;
679        let (buffer, metadata) = encode_records_and_stub_metadata(wrong_schema, records);
680        type WrongRecord = TickMsg;
681        let mut iter: DbzStreamIter<&[u8], WrongRecord> =
682            DbzStreamIter::new(buffer.as_slice(), metadata).unwrap();
683        // check doesn't panic
684        assert!(iter.next().is_none());
685        assert!(iter.next().is_none());
686    }
687}