use std::os::raw::c_char;
use crate::{
macros::{dbn_record, CsvSerialize, JsonSerialize},
record::{transmute_header_bytes, transmute_record_bytes},
rtype, HasRType, RecordHeader, RecordRef, SecurityUpdateAction, UserDefinedInstrument,
VersionUpgradePolicy, WithTsOut,
};
#[cfg(not(feature = "python"))]
use dbn_macros::MockPyo3;
pub const SYMBOL_CSTR_LEN_V1: usize = 22;
pub const SYMBOL_CSTR_LEN_V2: usize = 71;
pub(crate) const METADATA_RESERVED_LEN_V1: usize = 47;
pub const fn version_symbol_cstr_len(version: u8) -> usize {
if version < 2 {
SYMBOL_CSTR_LEN_V1
} else {
SYMBOL_CSTR_LEN_V2
}
}
pub use crate::record::ErrorMsg as ErrorMsgV2;
pub use crate::record::InstrumentDefMsg as InstrumentDefMsgV2;
pub use crate::record::SymbolMappingMsg as SymbolMappingMsgV2;
pub use crate::record::SystemMsg as SystemMsgV2;
pub unsafe fn decode_record_ref<'a>(
version: u8,
upgrade_policy: VersionUpgradePolicy,
ts_out: bool,
compat_buffer: &'a mut [u8; crate::MAX_RECORD_LEN],
input: &'a [u8],
) -> RecordRef<'a> {
if version == 1 && upgrade_policy == VersionUpgradePolicy::Upgrade {
let header = transmute_header_bytes(input).unwrap();
match header.rtype {
rtype::INSTRUMENT_DEF => {
return upgrade_record::<InstrumentDefMsgV1, InstrumentDefMsgV2>(
ts_out,
compat_buffer,
input,
);
}
rtype::SYMBOL_MAPPING => {
return upgrade_record::<SymbolMappingMsgV1, SymbolMappingMsgV2>(
ts_out,
compat_buffer,
input,
);
}
rtype::ERROR => {
return upgrade_record::<ErrorMsgV1, ErrorMsgV2>(ts_out, compat_buffer, input);
}
rtype::SYSTEM => {
return upgrade_record::<SystemMsgV1, SystemMsgV2>(ts_out, compat_buffer, input);
}
_ => (),
}
}
RecordRef::new(input)
}
unsafe fn upgrade_record<'a, T, U>(
ts_out: bool,
compat_buffer: &'a mut [u8; crate::MAX_RECORD_LEN],
input: &'a [u8],
) -> RecordRef<'a>
where
T: HasRType,
U: HasRType + for<'b> From<&'b T>,
{
if ts_out {
let rec = transmute_record_bytes::<WithTsOut<T>>(input).unwrap();
let upgraded = WithTsOut::new(U::from(&rec.rec), rec.ts_out);
std::ptr::copy_nonoverlapping(&upgraded, compat_buffer.as_mut_ptr().cast(), 1);
} else {
let upgraded = U::from(transmute_record_bytes::<T>(input).unwrap());
std::ptr::copy_nonoverlapping(&upgraded, compat_buffer.as_mut_ptr().cast(), 1);
}
RecordRef::new(compat_buffer)
}
#[repr(C)]
#[derive(Clone, CsvSerialize, JsonSerialize, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "trivial_copy", derive(Copy))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "python",
pyo3::pyclass(dict, module = "databento_dbn"),
derive(crate::macros::PyFieldDesc)
)]
#[cfg_attr(not(feature = "python"), derive(MockPyo3))] #[cfg_attr(test, derive(type_layout::TypeLayout))]
#[dbn_record(rtype::INSTRUMENT_DEF)]
pub struct InstrumentDefMsgV1 {
#[pyo3(get, set)]
pub hd: RecordHeader,
#[dbn(encode_order(0), index_ts, unix_nanos)]
#[pyo3(get, set)]
pub ts_recv: u64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub min_price_increment: i64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub display_factor: i64,
#[dbn(unix_nanos)]
#[pyo3(get, set)]
pub expiration: u64,
#[dbn(unix_nanos)]
#[pyo3(get, set)]
pub activation: u64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub high_limit_price: i64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub low_limit_price: i64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub max_price_variation: i64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub trading_reference_price: i64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub unit_of_measure_qty: i64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub min_price_increment_amount: i64,
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub price_ratio: i64,
#[pyo3(get, set)]
pub inst_attrib_value: i32,
#[pyo3(get, set)]
pub underlying_id: u32,
#[pyo3(get, set)]
pub raw_instrument_id: u32,
#[pyo3(get, set)]
pub market_depth_implied: i32,
#[pyo3(get, set)]
pub market_depth: i32,
#[pyo3(get, set)]
pub market_segment_id: u32,
#[pyo3(get, set)]
pub max_trade_vol: u32,
#[pyo3(get, set)]
pub min_lot_size: i32,
#[pyo3(get, set)]
pub min_lot_size_block: i32,
#[pyo3(get, set)]
pub min_lot_size_round_lot: i32,
#[pyo3(get, set)]
pub min_trade_vol: u32,
#[doc(hidden)]
pub _reserved2: [u8; 4],
#[pyo3(get, set)]
pub contract_multiplier: i32,
#[pyo3(get, set)]
pub decay_quantity: i32,
#[pyo3(get, set)]
pub original_contract_size: i32,
#[doc(hidden)]
pub _reserved3: [u8; 4],
#[pyo3(get, set)]
pub trading_reference_date: u16,
#[pyo3(get, set)]
pub appl_id: i16,
#[pyo3(get, set)]
pub maturity_year: u16,
#[pyo3(get, set)]
pub decay_start_date: u16,
#[pyo3(get, set)]
pub channel_id: u16,
#[dbn(fmt_method)]
#[cfg_attr(feature = "serde", serde(with = "crate::record::cstr_serde"))]
pub currency: [c_char; 4],
#[dbn(fmt_method)]
#[cfg_attr(feature = "serde", serde(with = "crate::record::cstr_serde"))]
pub settl_currency: [c_char; 4],
#[dbn(fmt_method)]
#[cfg_attr(feature = "serde", serde(with = "crate::record::cstr_serde"))]
pub secsubtype: [c_char; 6],
#[dbn(encode_order(2), fmt_method)]
pub raw_symbol: [c_char; SYMBOL_CSTR_LEN_V1],
#[dbn(fmt_method)]
pub group: [c_char; 21],
#[dbn(fmt_method)]
pub exchange: [c_char; 5],
#[dbn(fmt_method)]
pub asset: [c_char; 7],
#[dbn(fmt_method)]
pub cfi: [c_char; 7],
#[dbn(fmt_method)]
pub security_type: [c_char; 7],
#[dbn(fmt_method)]
pub unit_of_measure: [c_char; 31],
#[dbn(fmt_method)]
pub underlying: [c_char; 21],
#[dbn(fmt_method)]
pub strike_price_currency: [c_char; 4],
#[dbn(c_char, encode_order(4))]
#[pyo3(set)]
pub instrument_class: c_char,
#[doc(hidden)]
pub _reserved4: [u8; 2],
#[dbn(fixed_price)]
#[pyo3(get, set)]
pub strike_price: i64,
#[doc(hidden)]
pub _reserved5: [u8; 6],
#[dbn(c_char)]
#[pyo3(set)]
pub match_algorithm: c_char,
#[pyo3(get, set)]
pub md_security_trading_status: u8,
#[pyo3(get, set)]
pub main_fraction: u8,
#[pyo3(get, set)]
pub price_display_format: u8,
#[pyo3(get, set)]
pub settl_price_type: u8,
#[pyo3(get, set)]
pub sub_fraction: u8,
#[pyo3(get, set)]
pub underlying_product: u8,
#[dbn(encode_order(3))]
#[pyo3(set)]
pub security_update_action: SecurityUpdateAction,
#[pyo3(get, set)]
pub maturity_month: u8,
#[pyo3(get, set)]
pub maturity_day: u8,
#[pyo3(get, set)]
pub maturity_week: u8,
#[pyo3(set)]
pub user_defined_instrument: UserDefinedInstrument,
#[pyo3(get, set)]
pub contract_multiplier_unit: i8,
#[pyo3(get, set)]
pub flow_schedule_type: i8,
#[pyo3(get, set)]
pub tick_rule: u8,
#[doc(hidden)]
pub _dummy: [u8; 3],
}
#[repr(C)]
#[derive(Clone, CsvSerialize, JsonSerialize, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "trivial_copy", derive(Copy))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "python",
pyo3::pyclass(dict, module = "databento_dbn"),
derive(crate::macros::PyFieldDesc)
)]
#[cfg_attr(not(feature = "python"), derive(MockPyo3))] #[cfg_attr(test, derive(type_layout::TypeLayout))]
#[dbn_record(rtype::ERROR)]
pub struct ErrorMsgV1 {
#[pyo3(get, set)]
pub hd: RecordHeader,
#[dbn(fmt_method)]
#[cfg_attr(feature = "serde", serde(with = "crate::record::cstr_serde"))]
pub err: [c_char; 64],
}
#[repr(C)]
#[derive(Clone, CsvSerialize, JsonSerialize, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "trivial_copy", derive(Copy))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "python",
pyo3::pyclass(dict, module = "databento_dbn"),
derive(crate::macros::PyFieldDesc)
)]
#[cfg_attr(not(feature = "python"), derive(MockPyo3))] #[cfg_attr(test, derive(type_layout::TypeLayout))]
#[dbn_record(rtype::SYMBOL_MAPPING)]
pub struct SymbolMappingMsgV1 {
#[pyo3(get, set)]
pub hd: RecordHeader,
#[dbn(fmt_method)]
pub stype_in_symbol: [c_char; SYMBOL_CSTR_LEN_V1],
#[dbn(fmt_method)]
pub stype_out_symbol: [c_char; SYMBOL_CSTR_LEN_V1],
#[doc(hidden)]
pub _dummy: [u8; 4],
#[dbn(unix_nanos)]
#[pyo3(get, set)]
pub start_ts: u64,
#[dbn(unix_nanos)]
#[pyo3(get, set)]
pub end_ts: u64,
}
#[repr(C)]
#[derive(Clone, CsvSerialize, JsonSerialize, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "trivial_copy", derive(Copy))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "python",
pyo3::pyclass(dict, module = "databento_dbn"),
derive(crate::macros::PyFieldDesc)
)]
#[cfg_attr(not(feature = "python"), derive(MockPyo3))] #[cfg_attr(test, derive(type_layout::TypeLayout))]
#[dbn_record(rtype::SYSTEM)]
pub struct SystemMsgV1 {
#[pyo3(get, set)]
pub hd: RecordHeader,
#[dbn(fmt_method)]
#[cfg_attr(feature = "serde", serde(with = "crate::record::cstr_serde"))]
pub msg: [c_char; 64],
}
impl From<&InstrumentDefMsgV1> for InstrumentDefMsgV2 {
fn from(old: &InstrumentDefMsgV1) -> Self {
let mut res = Self {
hd: RecordHeader::new::<Self>(
rtype::INSTRUMENT_DEF,
old.hd.publisher_id,
old.hd.instrument_id,
old.hd.ts_event,
),
ts_recv: old.ts_recv,
min_price_increment: old.min_price_increment,
display_factor: old.display_factor,
expiration: old.expiration,
activation: old.activation,
high_limit_price: old.high_limit_price,
low_limit_price: old.low_limit_price,
max_price_variation: old.max_price_variation,
trading_reference_price: old.trading_reference_price,
unit_of_measure_qty: old.unit_of_measure_qty,
min_price_increment_amount: old.min_price_increment_amount,
price_ratio: old.price_ratio,
inst_attrib_value: old.inst_attrib_value,
underlying_id: old.underlying_id,
raw_instrument_id: old.raw_instrument_id,
market_depth_implied: old.market_depth_implied,
market_depth: old.market_depth,
market_segment_id: old.market_segment_id,
max_trade_vol: old.max_trade_vol,
min_lot_size: old.min_lot_size,
min_lot_size_block: old.min_lot_size_block,
min_lot_size_round_lot: old.min_lot_size_round_lot,
min_trade_vol: old.min_trade_vol,
contract_multiplier: old.contract_multiplier,
decay_quantity: old.decay_quantity,
original_contract_size: old.original_contract_size,
trading_reference_date: old.trading_reference_date,
appl_id: old.appl_id,
maturity_year: old.maturity_year,
decay_start_date: old.decay_start_date,
channel_id: old.channel_id,
currency: old.currency,
settl_currency: old.settl_currency,
secsubtype: old.secsubtype,
group: old.group,
exchange: old.exchange,
asset: old.asset,
cfi: old.cfi,
security_type: old.security_type,
unit_of_measure: old.unit_of_measure,
underlying: old.underlying,
strike_price_currency: old.strike_price_currency,
instrument_class: old.instrument_class,
strike_price: old.strike_price,
match_algorithm: old.match_algorithm,
md_security_trading_status: old.md_security_trading_status,
main_fraction: old.main_fraction,
price_display_format: old.price_display_format,
settl_price_type: old.settl_price_type,
sub_fraction: old.sub_fraction,
underlying_product: old.underlying_product,
security_update_action: old.security_update_action as c_char,
maturity_month: old.maturity_month,
maturity_day: old.maturity_day,
maturity_week: old.maturity_week,
user_defined_instrument: old.user_defined_instrument,
contract_multiplier_unit: old.contract_multiplier_unit,
flow_schedule_type: old.flow_schedule_type,
tick_rule: old.tick_rule,
..Default::default()
};
unsafe {
std::ptr::copy_nonoverlapping(
old.raw_symbol.as_ptr(),
res.raw_symbol.as_mut_ptr(),
SYMBOL_CSTR_LEN_V1,
);
}
res
}
}
impl From<&ErrorMsgV1> for ErrorMsgV2 {
fn from(old: &ErrorMsgV1) -> Self {
let mut new = Self {
hd: RecordHeader::new::<Self>(
rtype::ERROR,
old.hd.publisher_id,
old.hd.instrument_id,
old.hd.ts_event,
),
..Default::default()
};
unsafe {
std::ptr::copy_nonoverlapping(old.err.as_ptr(), new.err.as_mut_ptr(), new.err.len());
}
new
}
}
impl From<&SymbolMappingMsgV1> for SymbolMappingMsgV2 {
fn from(old: &SymbolMappingMsgV1) -> Self {
let mut res = Self {
hd: RecordHeader::new::<Self>(
rtype::SYMBOL_MAPPING,
old.hd.publisher_id,
old.hd.instrument_id,
old.hd.ts_event,
),
start_ts: old.start_ts,
end_ts: old.end_ts,
..Default::default()
};
unsafe {
std::ptr::copy_nonoverlapping(
old.stype_in_symbol.as_ptr(),
res.stype_in_symbol.as_mut_ptr(),
SYMBOL_CSTR_LEN_V1,
);
std::ptr::copy_nonoverlapping(
old.stype_out_symbol.as_ptr(),
res.stype_out_symbol.as_mut_ptr(),
SYMBOL_CSTR_LEN_V1,
);
}
res
}
}
impl From<&SystemMsgV1> for SystemMsgV2 {
fn from(old: &SystemMsgV1) -> Self {
let mut new = Self {
hd: RecordHeader::new::<Self>(
rtype::SYSTEM,
old.hd.publisher_id,
old.hd.instrument_id,
old.hd.ts_event,
),
..Default::default()
};
unsafe {
std::ptr::copy_nonoverlapping(old.msg.as_ptr(), new.msg.as_mut_ptr(), new.msg.len());
}
new
}
}
pub trait SymbolMappingRec: HasRType {
fn stype_in_symbol(&self) -> crate::Result<&str>;
fn stype_out_symbol(&self) -> crate::Result<&str>;
fn start_ts(&self) -> Option<time::OffsetDateTime>;
fn end_ts(&self) -> Option<time::OffsetDateTime>;
}
impl SymbolMappingRec for SymbolMappingMsgV1 {
fn stype_in_symbol(&self) -> crate::Result<&str> {
Self::stype_in_symbol(self)
}
fn stype_out_symbol(&self) -> crate::Result<&str> {
Self::stype_out_symbol(self)
}
fn start_ts(&self) -> Option<time::OffsetDateTime> {
Self::start_ts(self)
}
fn end_ts(&self) -> Option<time::OffsetDateTime> {
Self::end_ts(self)
}
}
impl SymbolMappingRec for SymbolMappingMsgV2 {
fn stype_in_symbol(&self) -> crate::Result<&str> {
Self::stype_in_symbol(self)
}
fn stype_out_symbol(&self) -> crate::Result<&str> {
Self::stype_out_symbol(self)
}
fn start_ts(&self) -> Option<time::OffsetDateTime> {
Self::start_ts(self)
}
fn end_ts(&self) -> Option<time::OffsetDateTime> {
Self::end_ts(self)
}
}
#[cfg(test)]
mod tests {
use std::{ffi::c_char, mem};
use time::OffsetDateTime;
use type_layout::{Field, TypeLayout};
use crate::{Mbp1Msg, Record, Schema, MAX_RECORD_LEN};
use super::*;
#[cfg(feature = "python")]
#[test]
fn test_strike_price_order_didnt_change() {
use crate::python::PyFieldDesc;
assert_eq!(
InstrumentDefMsgV1::ordered_fields(""),
InstrumentDefMsgV2::ordered_fields("")
);
}
#[test]
fn test_default_equivalency() {
assert_eq!(
InstrumentDefMsgV2::from(&InstrumentDefMsgV1::default()),
InstrumentDefMsgV2::default()
);
}
#[test]
fn test_definition_size_alignment_and_padding() {
assert_eq!(mem::size_of::<InstrumentDefMsgV1>(), 360);
let layout = InstrumentDefMsgV1::type_layout();
assert_eq!(layout.alignment, 8);
for field in layout.fields.iter() {
assert!(
matches!(field, Field::Field { .. }),
"Detected padding: {layout}"
);
}
}
#[test]
fn test_symbol_mapping_size_alignment_and_padding() {
assert_eq!(mem::size_of::<SymbolMappingMsgV1>(), 80);
let layout = SymbolMappingMsgV1::type_layout();
assert_eq!(layout.alignment, 8);
for field in layout.fields.iter() {
assert!(
matches!(field, Field::Field { .. }),
"Detected padding: {layout}"
);
}
}
#[test]
fn upgrade_symbol_mapping_ts_out() -> crate::Result<()> {
let orig = WithTsOut::new(
SymbolMappingMsgV1::new(1, 2, "ES.c.0", "ESH4", 0, 0)?,
OffsetDateTime::now_utc().unix_timestamp_nanos() as u64,
);
let mut compat_buffer = [0; MAX_RECORD_LEN];
let res = unsafe {
decode_record_ref(
1,
VersionUpgradePolicy::Upgrade,
true,
&mut compat_buffer,
orig.as_ref(),
)
};
let upgraded = res.get::<WithTsOut<SymbolMappingMsgV2>>().unwrap();
assert_eq!(orig.ts_out, upgraded.ts_out);
assert_eq!(orig.rec.stype_in_symbol()?, upgraded.rec.stype_in_symbol()?);
assert_eq!(
orig.rec.stype_out_symbol()?,
upgraded.rec.stype_out_symbol()?
);
assert_eq!(upgraded.record_size(), std::mem::size_of_val(upgraded));
assert!(std::ptr::addr_eq(upgraded.header(), compat_buffer.as_ptr()));
Ok(())
}
#[test]
fn upgrade_mbp1_ts_out() -> crate::Result<()> {
let rec = Mbp1Msg {
price: 1_250_000_000,
side: b'A' as c_char,
..Mbp1Msg::default_for_schema(Schema::Mbp1)
};
let orig = WithTsOut::new(rec, OffsetDateTime::now_utc().unix_timestamp_nanos() as u64);
let mut compat_buffer = [0; MAX_RECORD_LEN];
let res = unsafe {
decode_record_ref(
1,
VersionUpgradePolicy::Upgrade,
true,
&mut compat_buffer,
orig.as_ref(),
)
};
let upgraded = res.get::<WithTsOut<Mbp1Msg>>().unwrap();
assert!(std::ptr::eq(orig.header(), upgraded.header()));
Ok(())
}
}