use std::{fmt::Debug, io::IoSlice, marker::PhantomData, mem, ptr::NonNull};
use crate::{
record::{HasRType, Record, RecordHeader},
rtype_dispatch, RecordEnum, RecordRefEnum,
};
#[derive(Copy, Clone)]
pub struct RecordRef<'a> {
ptr: NonNull<RecordHeader>,
_marker: PhantomData<&'a RecordHeader>,
}
unsafe impl Send for RecordRef<'_> {}
unsafe impl Sync for RecordRef<'_> {}
impl<'a> RecordRef<'a> {
pub unsafe fn new(buffer: &'a [u8]) -> Self {
debug_assert!(
buffer.len() >= mem::size_of::<RecordHeader>(),
"buffer of length {} is too short",
buffer.len()
);
let raw_ptr = buffer.as_ptr() as *mut RecordHeader;
debug_assert_eq!(
raw_ptr.align_offset(std::mem::align_of::<RecordHeader>()),
0,
"unaligned buffer passed to `RecordRef::new`"
);
let ptr = NonNull::new_unchecked(raw_ptr.cast::<RecordHeader>());
Self {
ptr,
_marker: PhantomData,
}
}
pub unsafe fn unchecked_from_header(header: *const RecordHeader) -> Self {
Self {
ptr: NonNull::new_unchecked(header.cast_mut()),
_marker: PhantomData,
}
}
pub fn has<T: HasRType>(&self) -> bool {
T::has_rtype(self.header().rtype)
}
pub fn get<T: HasRType>(&self) -> Option<&'a T> {
if self.has::<T>() {
assert!(
self.record_size() >= mem::size_of::<T>(),
"Malformed `{}` record: expected length of at least {} bytes, found {} bytes. \
Confirm the DBN version in the Metadata header and the version upgrade policy",
std::any::type_name::<T>(),
mem::size_of::<T>(),
self.record_size()
);
Some(unsafe { self.ptr.cast::<T>().as_ref() })
} else {
None
}
}
pub fn try_get<T: HasRType>(&self) -> crate::Result<&'a T> {
if self.has::<T>() {
if self.record_size() >= mem::size_of::<T>() {
Ok(unsafe { self.ptr.cast::<T>().as_ref() })
} else {
Err(crate::Error::conversion::<T>(format!(
"{self:?} has insufficient length, may be an earlier version of this record"
)))
}
} else {
Err(crate::Error::conversion::<T>(format!(
"{self:?} has incorrect rtype"
)))
}
}
pub fn as_enum(&self) -> crate::Result<RecordRefEnum<'_>> {
RecordRefEnum::try_from(*self)
}
pub unsafe fn get_unchecked<T: HasRType>(&self) -> &'a T {
debug_assert!(self.record_size() >= mem::size_of::<T>());
self.ptr.cast::<T>().as_ref()
}
}
impl<'a, R> From<&'a R> for RecordRef<'a>
where
R: HasRType,
{
fn from(rec: &'a R) -> Self {
Self {
ptr: unsafe {
NonNull::new_unchecked((rec.header() as *const RecordHeader).cast_mut())
},
_marker: PhantomData,
}
}
}
impl<'a> AsRef<[u8]> for RecordRef<'a> {
fn as_ref(&self) -> &'a [u8] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr() as *const u8, self.record_size()) }
}
}
impl<'a> Record for RecordRef<'a> {
fn header(&self) -> &'a RecordHeader {
unsafe { self.ptr.as_ref() }
}
fn raw_index_ts(&self) -> u64 {
fn raw_index_ts<T: HasRType>(t: &T) -> u64 {
t.raw_index_ts()
}
rtype_dispatch!(self, raw_index_ts()).unwrap_or_else(|_| self.header().ts_event)
}
}
impl<'a> From<&'a RecordEnum> for RecordRef<'a> {
fn from(rec_enum: &'a RecordEnum) -> Self {
match rec_enum {
RecordEnum::Mbo(rec) => Self::from(rec),
RecordEnum::Trade(rec) => Self::from(rec),
RecordEnum::Mbp1(rec) => Self::from(rec),
RecordEnum::Mbp10(rec) => Self::from(rec),
RecordEnum::Ohlcv(rec) => Self::from(rec),
RecordEnum::Status(rec) => Self::from(rec),
RecordEnum::InstrumentDef(rec) => Self::from(rec),
RecordEnum::Imbalance(rec) => Self::from(rec),
RecordEnum::Stat(rec) => Self::from(rec),
RecordEnum::Error(rec) => Self::from(rec),
RecordEnum::SymbolMapping(rec) => Self::from(rec),
RecordEnum::System(rec) => Self::from(rec),
RecordEnum::Cmbp1(rec) => Self::from(rec),
RecordEnum::Bbo(rec) => Self::from(rec),
RecordEnum::Cbbo(rec) => Self::from(rec),
}
}
}
impl<'a> From<RecordRefEnum<'a>> for RecordRef<'a> {
fn from(rec_enum: RecordRefEnum<'a>) -> Self {
match rec_enum {
RecordRefEnum::Mbo(rec) => Self::from(rec),
RecordRefEnum::Trade(rec) => Self::from(rec),
RecordRefEnum::Mbp1(rec) => Self::from(rec),
RecordRefEnum::Mbp10(rec) => Self::from(rec),
RecordRefEnum::Ohlcv(rec) => Self::from(rec),
RecordRefEnum::Status(rec) => Self::from(rec),
RecordRefEnum::InstrumentDef(rec) => Self::from(rec),
RecordRefEnum::Imbalance(rec) => Self::from(rec),
RecordRefEnum::Stat(rec) => Self::from(rec),
RecordRefEnum::Error(rec) => Self::from(rec),
RecordRefEnum::SymbolMapping(rec) => Self::from(rec),
RecordRefEnum::System(rec) => Self::from(rec),
RecordRefEnum::Cmbp1(rec) => Self::from(rec),
RecordRefEnum::Bbo(rec) => Self::from(rec),
RecordRefEnum::Cbbo(rec) => Self::from(rec),
}
}
}
impl<'a> From<RecordRef<'a>> for IoSlice<'a> {
fn from(rec: RecordRef<'a>) -> Self {
Self::new(unsafe {
std::slice::from_raw_parts(rec.ptr.as_ptr() as *const u8, rec.record_size())
})
}
}
impl Debug for RecordRef<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RecordRef")
.field(
"ptr",
&format_args!("{:?} --> {:?}", self.ptr, self.header()),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::ffi::c_char;
use crate::{
enums::rtype, v1, v3, ErrorMsg, FlagSet, InstrumentDefMsg, MboMsg, Mbp10Msg, Mbp1Msg,
OhlcvMsg, TradeMsg,
};
use super::*;
const SOURCE_RECORD: MboMsg = MboMsg {
hd: RecordHeader::new::<MboMsg>(rtype::MBO, 1, 1, 0),
order_id: 17,
price: 0,
size: 32,
flags: FlagSet::empty(),
channel_id: 1,
action: 'A' as c_char,
side: 'B' as c_char,
ts_recv: 0,
ts_in_delta: 160,
sequence: 1067,
};
#[test]
fn test_header() {
let target = RecordRef::from(&SOURCE_RECORD);
assert_eq!(*target.header(), SOURCE_RECORD.hd);
}
#[test]
fn test_fmt_debug() {
let target = RecordRef::from(&SOURCE_RECORD);
let string = format!("{target:?}");
dbg!(&string);
assert!(string.starts_with("RecordRef { ptr: 0x"));
assert!(string.ends_with("--> RecordHeader { length: 14, rtype: Mbo, publisher_id: GlbxMdp3Glbx, instrument_id: 1, ts_event: 0 } }"));
}
#[test]
fn test_has_and_get() {
let target = RecordRef::from(&SOURCE_RECORD);
assert!(!target.has::<Mbp1Msg>());
assert!(!target.has::<Mbp10Msg>());
assert!(!target.has::<TradeMsg>());
assert!(!target.has::<ErrorMsg>());
assert!(!target.has::<OhlcvMsg>());
assert!(!target.has::<InstrumentDefMsg>());
assert!(target.has::<MboMsg>());
assert_eq!(*target.get::<MboMsg>().unwrap(), SOURCE_RECORD);
}
#[test]
fn test_as_ref() {
let target = RecordRef::from(&SOURCE_RECORD);
let byte_slice = target.as_ref();
assert_eq!(SOURCE_RECORD.record_size(), byte_slice.len());
assert_eq!(target.record_size(), byte_slice.len());
}
#[should_panic]
#[test]
fn test_get_too_short() {
let mut src = SOURCE_RECORD;
src.hd.length -= 1;
let target = RecordRef::from(&src);
target.get::<MboMsg>();
}
#[should_panic]
#[test]
fn test_get_previous_ver() {
let src = v1::InstrumentDefMsg::default();
let target = RecordRef::from(&src);
target.get::<v3::InstrumentDefMsg>();
}
#[test]
fn test_try_get_previous_ver() {
let src = v1::InstrumentDefMsg::default();
let target = RecordRef::from(&src);
assert!(
matches!(target.try_get::<v3::InstrumentDefMsg>(), Err(e) if e.to_string().contains("has insufficient length"))
);
}
#[test]
fn niche() {
assert_eq!(
std::mem::size_of::<RecordRef>(),
std::mem::size_of::<Option<RecordRef>>()
);
assert_eq!(
std::mem::size_of::<RecordRef>(),
std::mem::size_of::<usize>()
);
}
}