1use std::{
2 io::{self, SeekFrom, Write},
3 mem,
4 ops::Range,
5 slice,
6};
7
8use anyhow::{anyhow, Context};
9use databento_defs::record::ConstTypeId;
10use streaming_iterator::StreamingIterator;
11use zstd::{stream::AutoFinishEncoder, Encoder};
12
13use crate::{read::SymbolMapping, Metadata};
14
15pub(crate) const SCHEMA_VERSION: u8 = 1;
16
17fn new_encoder<'a, W: io::Write>(writer: W) -> anyhow::Result<AutoFinishEncoder<'a, W>> {
19 pub(crate) const ZSTD_COMPRESSION_LEVEL: i32 = 0;
20
21 let mut encoder = Encoder::new(writer, ZSTD_COMPRESSION_LEVEL)?;
22 encoder.include_checksum(true)?;
23 Ok(encoder.auto_finish())
24}
25
26impl Metadata {
27 pub(crate) const ZSTD_MAGIC_RANGE: Range<u32> = 0x184D2A50..0x184D2A60;
28 pub(crate) const VERSION_CSTR_LEN: usize = 4;
29 pub(crate) const DATASET_CSTR_LEN: usize = 16;
30 pub(crate) const RESERVED_LEN: usize = 39;
31 pub(crate) const FIXED_METADATA_LEN: usize = 96;
32 pub(crate) const SYMBOL_CSTR_LEN: usize = 22;
33
34 pub fn encode(&self, mut writer: impl io::Write + io::Seek) -> anyhow::Result<()> {
35 writer.write_all(Self::ZSTD_MAGIC_RANGE.start.to_le_bytes().as_slice())?;
36 writer.write_all(b"0000")?;
38 writer.write_all(b"DBZ")?;
39 writer.write_all(&[self.version])?;
40 Self::encode_fixed_len_cstr::<_, { Self::DATASET_CSTR_LEN }>(&mut writer, &self.dataset)?;
41 writer.write_all((self.schema as u16).to_le_bytes().as_slice())?;
42 Self::encode_range_and_counts(
43 &mut writer,
44 self.start,
45 self.end,
46 self.limit,
47 self.record_count,
48 )?;
49 writer.write_all(&[self.compression as u8])?;
50 writer.write_all(&[self.stype_in as u8])?;
51 writer.write_all(&[self.stype_out as u8])?;
52 writer.write_all(&[0; Self::RESERVED_LEN])?;
54 {
55 let mut zstd_encoder = new_encoder(&mut writer)?;
57 zstd_encoder.write_all(0u32.to_le_bytes().as_slice())?;
59
60 Self::encode_repeated_symbol_cstr(&mut zstd_encoder, self.symbols.as_slice())
61 .with_context(|| "Failed to encode symbols")?;
62 Self::encode_repeated_symbol_cstr(&mut zstd_encoder, self.partial.as_slice())
63 .with_context(|| "Failed to encode partial")?;
64 Self::encode_repeated_symbol_cstr(&mut zstd_encoder, self.not_found.as_slice())
65 .with_context(|| "Failed to encode not_found")?;
66 Self::encode_symbol_mappings(&mut zstd_encoder, self.mappings.as_slice())?;
67 }
68
69 let raw_size = writer.stream_position()?;
70 writer.seek(SeekFrom::Start(4))?;
72 let frame_size = (raw_size - 8) as u32;
74 writer.write_all(frame_size.to_le_bytes().as_slice())?;
75 writer.seek(SeekFrom::End(0))?;
77
78 Ok(())
79 }
80
81 pub fn update_encoded(
82 mut writer: impl io::Write + io::Seek,
83 start: u64,
84 end: u64,
85 limit: u64,
86 record_count: u64,
87 ) -> anyhow::Result<()> {
88 const START_SEEK_FROM: SeekFrom =
90 SeekFrom::Start((8 + 4 + Metadata::DATASET_CSTR_LEN + 2) as u64);
91
92 writer
93 .seek(START_SEEK_FROM)
94 .with_context(|| "Failed to seek to write position".to_owned())?;
95 Self::encode_range_and_counts(&mut writer, start, end, limit, record_count)?;
96 writer
97 .seek(SeekFrom::End(0))
98 .with_context(|| "Failed to seek back to end".to_owned())?;
99 Ok(())
100 }
101
102 fn encode_range_and_counts(
103 writer: &mut impl io::Write,
104 start: u64,
105 end: u64,
106 limit: u64,
107 record_count: u64,
108 ) -> anyhow::Result<()> {
109 writer.write_all(start.to_le_bytes().as_slice())?;
110 writer.write_all(end.to_le_bytes().as_slice())?;
111 writer.write_all(limit.to_le_bytes().as_slice())?;
112 writer.write_all(record_count.to_le_bytes().as_slice())?;
113 Ok(())
114 }
115
116 fn encode_repeated_symbol_cstr(
117 writer: &mut impl io::Write,
118 symbols: &[String],
119 ) -> anyhow::Result<()> {
120 writer.write_all((symbols.len() as u32).to_le_bytes().as_slice())?;
121 for symbol in symbols {
122 Self::encode_fixed_len_cstr::<_, { Self::SYMBOL_CSTR_LEN }>(writer, symbol)?;
123 }
124
125 Ok(())
126 }
127
128 fn encode_symbol_mappings(
129 writer: &mut impl io::Write,
130 symbol_mappings: &[SymbolMapping],
131 ) -> anyhow::Result<()> {
132 writer.write_all((symbol_mappings.len() as u32).to_le_bytes().as_slice())?;
134 for symbol_mapping in symbol_mappings {
135 Self::encode_symbol_mapping(writer, symbol_mapping)?;
136 }
137 Ok(())
138 }
139
140 fn encode_symbol_mapping(
141 writer: &mut impl io::Write,
142 symbol_mapping: &SymbolMapping,
143 ) -> anyhow::Result<()> {
144 Self::encode_fixed_len_cstr::<_, { Self::SYMBOL_CSTR_LEN }>(
145 writer,
146 &symbol_mapping.native,
147 )?;
148 writer.write_all(
150 (symbol_mapping.intervals.len() as u32)
151 .to_le_bytes()
152 .as_slice(),
153 )?;
154 for interval in symbol_mapping.intervals.iter() {
155 Self::encode_date(writer, interval.start_date)?;
156 Self::encode_date(writer, interval.end_date)?;
157 Self::encode_fixed_len_cstr::<_, { Self::SYMBOL_CSTR_LEN }>(writer, &interval.symbol)?;
158 }
159 Ok(())
160 }
161
162 fn encode_fixed_len_cstr<W: io::Write, const LEN: usize>(
165 writer: &mut W,
166 string: &str,
167 ) -> anyhow::Result<()> {
168 if !string.is_ascii() {
169 return Err(anyhow!(
170 "'{string}' can't be encoded in DBZ because it contains non-ASCII characters"
171 ));
172 }
173 if string.len() > LEN {
174 return Err(anyhow!(
175 "'{string}' is too long to be encoded in DBZ; it cannot be longer {LEN} characters"
176 ));
177 }
178 writer.write_all(string.as_bytes())?;
179 for _ in string.len()..LEN {
181 writer.write_all(&[0])?;
182 }
183 Ok(())
184 }
185
186 fn encode_date(writer: &mut impl io::Write, date: time::Date) -> anyhow::Result<()> {
187 let mut date_int = date.year() as u32 * 10_000;
188 date_int += date.month() as u32 * 100;
189 date_int += date.day() as u32;
190 writer.write_all(date_int.to_le_bytes().as_slice())?;
191 Ok(())
192 }
193}
194
195unsafe fn as_u8_slice<T: Sized>(data: &T) -> &[u8] {
196 slice::from_raw_parts(data as *const T as *const u8, mem::size_of::<T>())
197}
198
199pub fn write_dbz_stream<T>(
201 writer: impl io::Write,
202 mut stream: impl StreamingIterator<Item = T>,
203) -> anyhow::Result<()>
204where
205 T: ConstTypeId + Sized,
206{
207 let mut encoder = new_encoder(writer)
208 .with_context(|| "Failed to create Zstd encoder for writing DBZ".to_owned())?;
209 while let Some(record) = stream.next() {
210 let bytes = unsafe {
211 as_u8_slice(record)
213 };
214 match encoder.write_all(bytes) {
215 Err(e) if e.kind() == io::ErrorKind::BrokenPipe => return Ok(()),
217 r => r,
218 }
219 .with_context(|| "Failed to serialize {record:#?}")?;
220 }
221 encoder.flush()?;
222 Ok(())
223}
224
225pub fn write_dbz<'a, T>(
227 writer: impl io::Write,
228 iter: impl Iterator<Item = &'a T>,
229) -> anyhow::Result<()>
230where
231 T: 'a + ConstTypeId + Sized,
232{
233 let mut encoder = new_encoder(writer)
234 .with_context(|| "Failed to create Zstd encoder for writing DBZ".to_owned())?;
235 for record in iter {
236 let bytes = unsafe {
237 as_u8_slice(record)
239 };
240 match encoder.write_all(bytes) {
241 Err(e) if e.kind() == io::ErrorKind::BrokenPipe => return Ok(()),
243 r => r,
244 }
245 .with_context(|| "Failed to serialize {record:#?}")?;
246 }
247 encoder.flush()?;
248 Ok(())
249}
250
251#[cfg(test)]
252mod tests {
253 use std::{
254 ffi::c_char,
255 fmt,
256 io::{BufWriter, Seek},
257 mem,
258 };
259
260 use databento_defs::{
261 enums::{Compression, SType, Schema},
262 record::{Mbp1Msg, OhlcvMsg, RecordHeader, StatusMsg, TickMsg, TradeMsg},
263 };
264
265 use crate::{
266 read::{FromLittleEndianSlice, MappingInterval},
267 write::test_data::{VecStream, BID_ASK, RECORD_HEADER},
268 DbzStreamIter,
269 };
270
271 use super::*;
272
273 #[test]
274 fn test_encode_decode_metadata_identity() {
275 let mut extra = serde_json::Map::default();
276 extra.insert(
277 "Key".to_owned(),
278 serde_json::Value::Number(serde_json::Number::from_f64(4.0).unwrap()),
279 );
280 let metadata = Metadata {
281 version: 1,
282 dataset: "GLBX.MDP3".to_owned(),
283 schema: Schema::Mbp10,
284 stype_in: SType::Native,
285 stype_out: SType::ProductId,
286 start: 1657230820000000000,
287 end: 1658960170000000000,
288 limit: 0,
289 compression: Compression::ZStd,
290 record_count: 14,
291 symbols: vec!["ES".to_owned(), "NG".to_owned()],
292 partial: vec!["ESM2".to_owned()],
293 not_found: vec!["QQQQQ".to_owned()],
294 mappings: vec![
295 SymbolMapping {
296 native: "ES.0".to_owned(),
297 intervals: vec![MappingInterval {
298 start_date: time::Date::from_calendar_date(2022, time::Month::July, 26)
299 .unwrap(),
300 end_date: time::Date::from_calendar_date(2022, time::Month::September, 1)
301 .unwrap(),
302 symbol: "ESU2".to_owned(),
303 }],
304 },
305 SymbolMapping {
306 native: "NG.0".to_owned(),
307 intervals: vec![
308 MappingInterval {
309 start_date: time::Date::from_calendar_date(2022, time::Month::July, 26)
310 .unwrap(),
311 end_date: time::Date::from_calendar_date(2022, time::Month::August, 29)
312 .unwrap(),
313 symbol: "NGU2".to_owned(),
314 },
315 MappingInterval {
316 start_date: time::Date::from_calendar_date(
317 2022,
318 time::Month::August,
319 29,
320 )
321 .unwrap(),
322 end_date: time::Date::from_calendar_date(
323 2022,
324 time::Month::September,
325 1,
326 )
327 .unwrap(),
328 symbol: "NGV2".to_owned(),
329 },
330 ],
331 },
332 ],
333 };
334 let mut buffer = Vec::new();
335 let cursor = io::Cursor::new(&mut buffer);
336 metadata.encode(cursor).unwrap();
337 dbg!(&buffer);
338 let res = Metadata::read(&mut &buffer[..]).unwrap();
339 dbg!(&res, &metadata);
340 assert_eq!(res, metadata);
341 }
342
343 #[test]
344 fn test_encode_repeated_symbol_cstr() {
345 let mut buffer = Vec::new();
346 let symbols = vec![
347 "NG".to_owned(),
348 "HP".to_owned(),
349 "HPQ".to_owned(),
350 "LNQ".to_owned(),
351 ];
352 Metadata::encode_repeated_symbol_cstr(&mut buffer, symbols.as_slice()).unwrap();
353 assert_eq!(
354 buffer.len(),
355 mem::size_of::<u32>() + symbols.len() * Metadata::SYMBOL_CSTR_LEN
356 );
357 assert_eq!(u32::from_le_slice(&buffer[..4]), 4);
358 for (i, symbol) in symbols.iter().enumerate() {
359 let offset = i * Metadata::SYMBOL_CSTR_LEN;
360 assert_eq!(
361 &buffer[4 + offset..4 + offset + symbol.len()],
362 symbol.as_bytes()
363 );
364 }
365 }
366
367 #[test]
368 fn test_encode_fixed_len_cstr() {
369 let mut buffer = Vec::new();
370 Metadata::encode_fixed_len_cstr::<_, { Metadata::SYMBOL_CSTR_LEN }>(&mut buffer, "NG")
371 .unwrap();
372 assert_eq!(buffer.len(), Metadata::SYMBOL_CSTR_LEN);
373 assert_eq!(&buffer[..2], b"NG");
374 for b in buffer[2..].iter() {
375 assert_eq!(*b, 0);
376 }
377 }
378
379 #[test]
380 fn test_encode_date() {
381 let date = time::Date::from_calendar_date(2020, time::Month::May, 17).unwrap();
382 let mut buffer = Vec::new();
383 Metadata::encode_date(&mut buffer, date).unwrap();
384 assert_eq!(buffer.len(), mem::size_of::<u32>());
385 assert_eq!(buffer.as_slice(), 20200517u32.to_le_bytes().as_slice());
386 }
387
388 #[test]
389 fn test_update_encoded() {
390 let orig_metadata = Metadata {
391 version: 1,
392 dataset: "GLBX.MDP3".to_owned(),
393 schema: Schema::Mbo,
394 stype_in: SType::Smart,
395 stype_out: SType::Native,
396 start: 1657230820000000000,
397 end: 1658960170000000000,
398 limit: 0,
399 record_count: 1_450_000,
400 compression: Compression::ZStd,
401 symbols: vec![],
402 partial: vec![],
403 not_found: vec![],
404 mappings: vec![],
405 };
406 let mut buffer = Vec::new();
407 let cursor = io::Cursor::new(&mut buffer);
408 orig_metadata.encode(cursor).unwrap();
409 let orig_res = Metadata::read(&mut &buffer[..]).unwrap();
410 assert_eq!(orig_metadata, orig_res);
411 let mut cursor = io::Cursor::new(&mut buffer);
412 assert_eq!(cursor.position(), 0);
413 cursor.seek(SeekFrom::End(0)).unwrap();
414 let before_pos = cursor.position();
415 assert!(before_pos != 0);
416 let new_start = 1697240529000000000;
417 let new_end = 17058980170000000000;
418 let new_limit = 10;
419 let new_record_count = 100_678;
420 Metadata::update_encoded(&mut cursor, new_start, new_end, new_limit, new_record_count)
421 .unwrap();
422 assert_eq!(before_pos, cursor.position());
423 let res = Metadata::read(&mut &buffer[..]).unwrap();
424 assert!(res != orig_res);
425 assert_eq!(res.start, new_start);
426 assert_eq!(res.end, new_end);
427 assert_eq!(res.limit, new_limit);
428 assert_eq!(res.record_count, new_record_count);
429 }
430
431 fn encode_records_and_stub_metadata<T>(schema: Schema, records: Vec<T>) -> (Vec<u8>, Metadata)
432 where
433 T: ConstTypeId + Clone,
434 {
435 let mut buffer = Vec::new();
436 let writer = BufWriter::new(&mut buffer);
437 write_dbz_stream(writer, VecStream::new(records.clone())).unwrap();
438 dbg!(&buffer);
439 let metadata = Metadata {
440 version: 1,
441 dataset: "GLBX.MDP3".to_owned(),
442 schema,
443 start: 0,
444 end: 0,
445 limit: 0,
446 record_count: records.len() as u64,
447 compression: Compression::None,
448 stype_in: SType::Native,
449 stype_out: SType::ProductId,
450 symbols: vec![],
451 partial: vec![],
452 not_found: vec![],
453 mappings: vec![],
454 };
455 (buffer, metadata)
456 }
457
458 fn assert_encode_decode_record_identity<T>(schema: Schema, records: Vec<T>)
459 where
460 T: ConstTypeId + Clone + fmt::Debug + PartialEq,
461 {
462 let (buffer, metadata) = encode_records_and_stub_metadata(schema, records.clone());
463 let mut iter: DbzStreamIter<&[u8], T> =
464 DbzStreamIter::new(buffer.as_slice(), metadata).unwrap();
465 let mut res = Vec::new();
466 while let Some(rec) = iter.next() {
467 res.push(rec.to_owned());
468 }
469 dbg!(&res, &records);
470 assert_eq!(res, records);
471 }
472
473 #[test]
474 fn test_encode_decode_mbo_identity() {
475 let records = vec![
476 TickMsg {
477 hd: RecordHeader {
478 rtype: TickMsg::TYPE_ID,
479 ..RECORD_HEADER
480 },
481 order_id: 2,
482 price: 9250000000,
483 size: 25,
484 flags: -128,
485 channel_id: 1,
486 action: 'B' as i8,
487 side: 67,
488 ts_recv: 1658441891000000000,
489 ts_in_delta: 1000,
490 sequence: 98,
491 },
492 TickMsg {
493 hd: RecordHeader {
494 rtype: TickMsg::TYPE_ID,
495 ..RECORD_HEADER
496 },
497 order_id: 3,
498 price: 9350000000,
499 size: 800,
500 flags: 0,
501 channel_id: 1,
502 action: 'C' as i8,
503 side: 67,
504 ts_recv: 1658441991000000000,
505 ts_in_delta: 750,
506 sequence: 101,
507 },
508 ];
509 assert_encode_decode_record_identity(Schema::Mbo, records);
510 }
511
512 #[test]
513 fn test_encode_decode_mbp1_identity() {
514 let records = vec![
515 Mbp1Msg {
516 hd: RecordHeader {
517 rtype: Mbp1Msg::TYPE_ID,
518 ..RECORD_HEADER
519 },
520 price: 925000000000,
521 size: 300,
522 action: 'S' as i8,
523 side: 67,
524 flags: -128,
525 depth: 1,
526 ts_recv: 1658442001000000000,
527 ts_in_delta: 750,
528 sequence: 100,
529 booklevel: [BID_ASK; 1],
530 },
531 Mbp1Msg {
532 hd: RecordHeader {
533 rtype: Mbp1Msg::TYPE_ID,
534 ..RECORD_HEADER
535 },
536 price: 925000000000,
537 size: 50,
538 action: 'B' as i8,
539 side: 67,
540 flags: -128,
541 depth: 1,
542 ts_recv: 1658542001000000000,
543 ts_in_delta: 787,
544 sequence: 101,
545 booklevel: [BID_ASK; 1],
546 },
547 ];
548 assert_encode_decode_record_identity(Schema::Mbp1, records);
549 }
550
551 #[test]
552 fn test_encode_decode_trade_identity() {
553 let records = vec![
554 TradeMsg {
555 hd: RecordHeader {
556 rtype: TradeMsg::TYPE_ID,
557 ..RECORD_HEADER
558 },
559 price: 925000000000,
560 size: 1,
561 action: 'T' as i8,
562 side: 'B' as i8,
563 flags: 0,
564 depth: 4,
565 ts_recv: 1658441891000000000,
566 ts_in_delta: 234,
567 sequence: 1005,
568 booklevel: [],
569 },
570 TradeMsg {
571 hd: RecordHeader {
572 rtype: TradeMsg::TYPE_ID,
573 ..RECORD_HEADER
574 },
575 price: 925000000000,
576 size: 10,
577 action: 'T' as i8,
578 side: 'S' as i8,
579 flags: 0,
580 depth: 1,
581 ts_recv: 1659441891000000000,
582 ts_in_delta: 10358,
583 sequence: 1010,
584 booklevel: [],
585 },
586 ];
587 assert_encode_decode_record_identity(Schema::Trades, records);
588 }
589
590 #[test]
591 fn test_encode_decode_ohlcv_identity() {
592 let records = vec![
593 OhlcvMsg {
594 hd: RecordHeader {
595 rtype: OhlcvMsg::TYPE_ID,
596 ..RECORD_HEADER
597 },
598 open: 92500000000,
599 high: 95200000000,
600 low: 91200000000,
601 close: 91600000000,
602 volume: 6785,
603 },
604 OhlcvMsg {
605 hd: RecordHeader {
606 rtype: OhlcvMsg::TYPE_ID,
607 ..RECORD_HEADER
608 },
609 open: 91600000000,
610 high: 95100000000,
611 low: 91600000000,
612 close: 92300000000,
613 volume: 7685,
614 },
615 ];
616 assert_encode_decode_record_identity(Schema::Ohlcv1D, records);
617 }
618
619 #[test]
620 fn test_encode_decode_status_identity() {
621 let mut group = [0; 21];
622 for (i, c) in "group".chars().enumerate() {
623 group[i] = c as c_char;
624 }
625 let records = vec![
626 StatusMsg {
627 hd: RecordHeader {
628 rtype: StatusMsg::TYPE_ID,
629 ..RECORD_HEADER
630 },
631 ts_recv: 1658441891000000000,
632 group,
633 trading_status: 3,
634 halt_reason: 4,
635 trading_event: 5,
636 },
637 StatusMsg {
638 hd: RecordHeader {
639 rtype: StatusMsg::TYPE_ID,
640 ..RECORD_HEADER
641 },
642 ts_recv: 1658541891000000000,
643 group,
644 trading_status: 4,
645 halt_reason: 5,
646 trading_event: 6,
647 },
648 ];
649 assert_encode_decode_record_identity(Schema::Status, records);
650 }
651
652 #[test]
653 fn test_decode_malformed_encoded_dbz() {
654 let records = vec![
655 OhlcvMsg {
656 hd: RecordHeader {
657 rtype: OhlcvMsg::TYPE_ID,
658 ..RECORD_HEADER
659 },
660 open: 92500000000,
661 high: 95200000000,
662 low: 91200000000,
663 close: 91600000000,
664 volume: 6785,
665 },
666 OhlcvMsg {
667 hd: RecordHeader {
668 rtype: OhlcvMsg::TYPE_ID,
669 ..RECORD_HEADER
670 },
671 open: 91600000000,
672 high: 95100000000,
673 low: 91600000000,
674 close: 92300000000,
675 volume: 7685,
676 },
677 ];
678 let wrong_schema = Schema::Mbo;
679 let (buffer, metadata) = encode_records_and_stub_metadata(wrong_schema, records);
680 type WrongRecord = TickMsg;
681 let mut iter: DbzStreamIter<&[u8], WrongRecord> =
682 DbzStreamIter::new(buffer.as_slice(), metadata).unwrap();
683 assert!(iter.next().is_none());
685 assert!(iter.next().is_none());
686 }
687}