use std::{fmt, future::Future, net::SocketAddr, time::Duration as StdDuration};
use dbn::{
decode::{
dbn::{
async_decode_metadata_with_fsm, async_decode_record_ref_with_fsm,
fsm::{DbnFsm, ProcessResult},
},
AsyncDynReader,
},
Compression, Metadata, RecordRef, VersionUpgradePolicy,
};
use time::Duration;
use tokio::{
io::{AsyncReadExt, BufReader, ReadHalf, WriteHalf},
net::{TcpStream, ToSocketAddrs},
};
use tracing::{info, info_span, instrument, warn, Span};
use crate::ApiKey;
use super::{
protocol::{self, Protocol, SessionOptions},
ClientBuilder, SlowReaderBehavior, Subscription, TimeoutConf, Unset,
};
pub struct Client {
key: ApiKey,
dataset: String,
send_ts_out: bool,
upgrade_policy: VersionUpgradePolicy,
heartbeat_interval: Option<Duration>,
user_agent_ext: Option<String>,
compression: Compression,
slow_reader_behavior: Option<SlowReaderBehavior>,
timeout_conf: TimeoutConf,
protocol: Protocol<WriteHalf<TcpStream>>,
peer_addr: SocketAddr,
sub_counter: u32,
subscriptions: Vec<Subscription>,
reader: AsyncDynReader<BufReader<ReadHalf<TcpStream>>>,
fsm: DbnFsm,
session_id: String,
is_closed: bool,
span: Span,
}
impl Client {
#[deprecated(since = "0.27.0", note = "Use the builder instead")]
pub async fn connect(
key: String,
dataset: String,
send_ts_out: bool,
upgrade_policy: VersionUpgradePolicy,
heartbeat_interval: Option<Duration>,
) -> crate::Result<Self> {
let builder = Self::builder()
.key(key)?
.dataset(dataset)
.send_ts_out(send_ts_out)
.upgrade_policy(upgrade_policy);
if let Some(heartbeat_interval) = heartbeat_interval {
builder.heartbeat_interval(heartbeat_interval).build().await
} else {
builder.build().await
}
}
#[deprecated(since = "0.27.0", note = "Use the builder instead")]
pub async fn connect_with_addr(
addr: impl ToSocketAddrs,
key: String,
dataset: String,
send_ts_out: bool,
upgrade_policy: VersionUpgradePolicy,
heartbeat_interval: Option<Duration>,
) -> crate::Result<Self> {
let builder = Self::builder()
.addr(addr)
.await?
.key(key)?
.dataset(dataset)
.send_ts_out(send_ts_out)
.upgrade_policy(upgrade_policy);
if let Some(heartbeat_interval) = heartbeat_interval {
builder.heartbeat_interval(heartbeat_interval).build().await
} else {
builder.build().await
}
}
pub(crate) async fn new(
ClientBuilder {
addr,
key,
dataset,
send_ts_out,
upgrade_policy,
heartbeat_interval,
buf_size,
user_agent_ext,
compression,
slow_reader_behavior,
timeout_conf,
}: ClientBuilder<ApiKey, String>,
) -> crate::Result<Self> {
let connect_fut = async {
if let Some(addr) = addr {
TcpStream::connect(addr.as_slice()).await
} else {
TcpStream::connect(protocol::determine_gateway(&dataset)).await
}
};
let stream = apply_connect_timeout(connect_fut, timeout_conf.connect).await?;
let peer_addr = stream.peer_addr()?;
let (recver, sender) = tokio::io::split(stream);
let mut recver = BufReader::new(recver);
let mut protocol = Protocol::new(sender);
let options = SessionOptions {
compression,
send_ts_out,
heartbeat_interval_s: heartbeat_interval.map(|i| i.whole_seconds()),
user_agent_ext: user_agent_ext.as_deref(),
slow_reader_behavior,
};
let session_id = apply_auth_timeout(
protocol.authenticate(&mut recver, &key, &dataset, options),
timeout_conf.auth,
)
.await?;
let reader = AsyncDynReader::with_buffer(recver, compression);
let span = info_span!("LiveClient", %dataset, session_id);
Ok(Self {
key,
dataset,
send_ts_out,
upgrade_policy,
heartbeat_interval,
user_agent_ext,
compression,
slow_reader_behavior,
timeout_conf,
protocol,
peer_addr,
reader,
fsm: DbnFsm::builder()
.upgrade_policy(upgrade_policy)
.buffer_size(buf_size.unwrap_or(DbnFsm::DEFAULT_BUF_SIZE))
.build()
.unwrap(),
session_id,
is_closed: false,
span,
sub_counter: 0,
subscriptions: Vec::new(),
})
}
pub fn builder() -> ClientBuilder<Unset, Unset> {
ClientBuilder::default()
}
pub fn key(&self) -> &str {
&self.key.0
}
pub fn dataset(&self) -> &str {
&self.dataset
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn send_ts_out(&self) -> bool {
self.send_ts_out
}
pub fn upgrade_policy(&self) -> VersionUpgradePolicy {
self.upgrade_policy
}
pub fn heartbeat_interval(&self) -> Option<Duration> {
self.heartbeat_interval
}
pub fn compression(&self) -> Compression {
self.compression
}
pub fn slow_reader_behavior(&self) -> Option<SlowReaderBehavior> {
self.slow_reader_behavior
}
pub fn timeout_conf(&self) -> &TimeoutConf {
&self.timeout_conf
}
pub fn subscriptions(&self) -> &Vec<Subscription> {
&self.subscriptions
}
pub fn subscriptions_mut(&mut self) -> &mut Vec<Subscription> {
&mut self.subscriptions
}
pub async fn close(&mut self) -> crate::Result<()> {
self.protocol.shutdown().await
}
#[instrument(parent = &self.span, skip_all)]
pub async fn subscribe(&mut self, mut sub: Subscription) -> crate::Result<()> {
if sub.id.is_none() {
if self.sub_counter == u32::MAX {
warn!("Exhausted all subscription IDs");
} else {
self.sub_counter += 1;
}
sub.id = Some(self.sub_counter);
}
self.protocol.subscribe(&sub).await?;
self.subscriptions.push(sub);
Ok(())
}
#[instrument(parent = &self.span, skip_all)]
pub async fn start(&mut self) -> crate::Result<Metadata> {
if self.fsm.has_decoded_metadata() {
return Err(crate::Error::BadArgument {
param_name: "self",
desc: "ignored request to start session that has already been started".to_owned(),
});
};
info!("Starting session");
self.protocol.start_session().await?;
Ok(async_decode_metadata_with_fsm(&mut self.reader, &mut self.fsm).await?)
}
#[instrument(parent = &self.span, level = "debug", skip_all)]
pub async fn next_record(&mut self) -> crate::Result<Option<RecordRef<'_>>> {
if !self.fsm.has_decoded_metadata() {
return Err(crate::Error::BadArgument {
param_name: "self",
desc: "Can't call LiveClient::next_record before starting session".to_owned(),
});
};
let timeout = self.heartbeat_timeout();
let record = tokio::time::timeout(
timeout,
async_decode_record_ref_with_fsm(&mut self.reader, &mut self.fsm),
)
.await
.map_err(|_elapsed| {
self.is_closed = true;
crate::Error::HeartbeatTimeout(
time::Duration::try_from(timeout).unwrap(),
)
})??;
if record.is_none() {
self.is_closed = true;
}
Ok(record)
}
#[instrument(parent = &self.span, level = "debug", skip_all)]
pub async fn fill_buf(&mut self) -> crate::Result<usize> {
if !self.fsm.has_decoded_metadata() {
return Err(crate::Error::BadArgument {
param_name: "self",
desc: "Can't call LiveClient::fill_buf before starting session".to_owned(),
});
};
let timeout = self.heartbeat_timeout();
let read_result = tokio::time::timeout(timeout, self.reader.read(self.fsm.space()))
.await
.map_err(|_elapsed| {
self.is_closed = true;
crate::Error::HeartbeatTimeout(
time::Duration::try_from(timeout).unwrap(),
)
})?;
match read_result {
Ok(nbytes) => {
if nbytes == 0 {
self.is_closed = true;
} else {
self.fsm.fill(nbytes);
}
Ok(nbytes)
}
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
self.is_closed = true;
Ok(0)
}
Err(err) => Err(crate::Error::Io(err)),
}
}
pub fn try_next_record(&mut self) -> crate::Result<Option<RecordRef<'_>>> {
if !self.fsm.has_decoded_metadata() {
return Err(crate::Error::BadArgument {
param_name: "self",
desc: "Can't call LiveClient::try_next_record before starting session".to_owned(),
});
};
match self.fsm.process() {
ProcessResult::Record(_) => Ok(self.fsm.last_record()),
ProcessResult::ReadMore(_) => Ok(None),
ProcessResult::Err(err) => Err(err.into()),
ProcessResult::Metadata(_) => unreachable!("metadata already decoded"),
}
}
pub fn is_closed(&self) -> bool {
self.is_closed
}
fn heartbeat_timeout(&self) -> StdDuration {
self.heartbeat_interval
.map(|i| StdDuration::from_secs(i.whole_seconds().unsigned_abs() + 5))
.unwrap_or_else(|| StdDuration::from_secs(35))
}
pub async fn reconnect(&mut self) -> crate::Result<()> {
info!("Reconnecting");
if let Err(err) = self.close().await {
warn!(
?err,
"Failed to close connection before reconnect. Proceeding"
);
}
let connect_fut = TcpStream::connect(self.peer_addr);
let stream = apply_connect_timeout(connect_fut, self.timeout_conf.connect).await?;
let (recver, sender) = tokio::io::split(stream);
let mut recver = BufReader::new(recver);
self.protocol = Protocol::new(sender);
self.sub_counter = 0;
let options = SessionOptions {
compression: self.compression,
send_ts_out: self.send_ts_out,
heartbeat_interval_s: self.heartbeat_interval.map(|i| i.whole_seconds()),
user_agent_ext: self.user_agent_ext.as_deref(),
slow_reader_behavior: self.slow_reader_behavior,
};
self.session_id = apply_auth_timeout(
self.protocol
.authenticate(&mut recver, &self.key, &self.dataset, options),
self.timeout_conf.auth,
)
.await?;
self.reader = AsyncDynReader::with_buffer(recver, self.compression);
self.fsm.reset();
self.is_closed = false;
self.span = info_span!("LiveClient", dataset = %self.dataset, session_id = self.session_id);
Ok(())
}
pub async fn resubscribe(&mut self) -> crate::Result<()> {
for sub in self.subscriptions.iter_mut() {
sub.start = None;
self.sub_counter = self.sub_counter.max(sub.id.unwrap_or(0));
self.protocol.subscribe(sub).await?;
}
Ok(())
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LiveClient")
.field("key", &self.key)
.field("dataset", &self.dataset)
.field("send_ts_out", &self.send_ts_out)
.field("upgrade_policy", &self.upgrade_policy)
.field("heartbeat_interval", &self.heartbeat_interval)
.field("compression", &self.compression)
.field("slow_reader_warning", &self.slow_reader_behavior)
.field("timeout_conf", &self.timeout_conf)
.field("peer_addr", &self.peer_addr)
.field("sub_counter", &self.sub_counter)
.field("subscriptions", &self.subscriptions)
.field("session_id", &self.session_id)
.finish_non_exhaustive()
}
}
fn to_std_duration(d: Duration) -> StdDuration {
StdDuration::from_secs(d.whole_seconds().unsigned_abs())
}
async fn apply_connect_timeout(
connect_fut: impl Future<Output = std::io::Result<TcpStream>>,
timeout: Option<Duration>,
) -> crate::Result<TcpStream> {
if let Some(t) = timeout {
Ok(tokio::time::timeout(to_std_duration(t), connect_fut)
.await
.map_err(|_| crate::Error::ConnectTimeout(t))??)
} else {
Ok(connect_fut.await?)
}
}
async fn apply_auth_timeout(
auth_fut: impl Future<Output = crate::Result<String>>,
timeout: Option<Duration>,
) -> crate::Result<String> {
if let Some(t) = timeout {
tokio::time::timeout(to_std_duration(t), auth_fut)
.await
.map_err(|_| crate::Error::AuthTimeout(t))?
} else {
auth_fut.await
}
}
#[cfg(test)]
mod tests {
use std::{ffi::c_char, fmt};
use async_compression::tokio::write::ZstdEncoder;
use dbn::{
encode::AsyncDbnMetadataEncoder,
enums::rtype,
publishers::Dataset,
record::{HasRType, OhlcvMsg, RecordHeader, TradeMsg, WithTsOut},
FlagSet, Mbp10Msg, MetadataBuilder, Record, SType, Schema,
};
use time::{Duration, OffsetDateTime};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
join,
net::{TcpListener, TcpStream},
select,
sync::mpsc::UnboundedSender,
task::JoinHandle,
};
use tracing::level_filters::LevelFilter;
use super::*;
use crate::USER_AGENT;
const TEST_KEY: &str = "32-character-with-lots-of-filler";
struct MockGateway {
dataset: String,
send_ts_out: bool,
slow_reader_behavior: Option<SlowReaderBehavior>,
listener: TcpListener,
stream: Option<BufReader<TcpStream>>,
}
impl MockGateway {
async fn new(dataset: String, send_ts_out: bool) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
Self {
dataset,
send_ts_out,
slow_reader_behavior: None,
listener,
stream: None,
}
}
async fn accept(&mut self) {
let stream = self.listener.accept().await.unwrap().0;
stream.set_nodelay(true).unwrap();
self.stream = Some(BufReader::new(stream));
}
async fn authenticate(&mut self, heartbeat_interval: Option<Duration>) {
self.accept().await;
self.send("lsg-test\n").await;
self.send("cram=t7kNhwj4xqR0QYjzFKtBEG2ec2pXJ4FK\n").await;
let auth_line = self.read_line().await;
let auth_start = auth_line.find("auth=").unwrap() + 5;
let auth_end = auth_line[auth_start..].find('|').unwrap();
let auth = &auth_line[auth_start..auth_start + auth_end];
let (auth, bucket) = auth.split_once('-').unwrap();
assert!(
auth.chars().all(|c| c.is_ascii_hexdigit()),
"Expected '{auth}' to be composed of only hex characters"
);
assert_eq!(bucket, "iller");
assert!(auth_line.contains(&format!("dataset={}", self.dataset)));
assert!(auth_line.contains("encoding=dbn"));
assert!(auth_line.contains(&format!("ts_out={}", if self.send_ts_out { 1 } else { 0 })));
assert!(auth_line.contains(&format!("client={}", *USER_AGENT)));
if let Some(heartbeat_interval) = heartbeat_interval {
assert!(auth_line.contains(&format!(
"heartbeat_interval_s={}",
heartbeat_interval.whole_seconds()
)));
} else {
assert!(!auth_line.contains("heartbeat_interval_s="));
}
if let Some(slow_reader_behavior) = self.slow_reader_behavior {
assert!(auth_line.contains(&format!("slow_reader_behavior={slow_reader_behavior}")));
} else {
assert!(!auth_line.contains("slow_reader_behavior="));
}
self.send("success=1|session_id=5\n").await;
}
async fn subscribe(&mut self, subscription: Subscription, is_last: bool) {
let sub_line = self.read_line().await;
assert!(sub_line.contains(&format!("symbols={}", subscription.symbols.to_api_string())));
assert!(sub_line.contains(&format!("schema={}", subscription.schema)));
assert!(sub_line.contains(&format!("stype_in={}", subscription.stype_in)));
assert!(sub_line.contains("id="));
if let Some(start) = subscription.start {
assert!(sub_line.contains(&format!("start={}", start.unix_timestamp_nanos())))
}
assert!(sub_line.contains(&format!("snapshot={}", subscription.use_snapshot as u8)));
assert!(sub_line.contains(&format!("is_last={}", is_last as u8)));
}
async fn start(&mut self) {
let start_line = self.read_line().await;
assert_eq!(start_line, "start_session\n");
let dataset = self.dataset.clone();
let stream = self.stream();
let mut encoder = AsyncDbnMetadataEncoder::new(stream);
encoder
.encode(
&MetadataBuilder::new()
.dataset(dataset)
.start(time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)
.schema(None)
.stype_in(None)
.stype_out(SType::InstrumentId)
.build(),
)
.await
.unwrap();
}
async fn send(&mut self, bytes: &str) {
self.stream().write_all(bytes.as_bytes()).await.unwrap();
info!("Sent: {}", &bytes[..bytes.len() - 1])
}
async fn send_record(&mut self, record: Box<dyn AsRef<[u8]> + Send>) {
let bytes = (*record).as_ref();
let half = bytes.len() / 2;
self.stream().write_all(&bytes[..half]).await.unwrap();
self.stream().flush().await.unwrap();
self.stream().write_all(&bytes[half..]).await.unwrap();
}
async fn start_compressed(&mut self) -> ZstdEncoder<&mut BufReader<TcpStream>> {
let start_line = self.read_line().await;
assert_eq!(start_line, "start_session\n");
let dataset = self.dataset.clone();
let stream = self.stream.as_mut().unwrap();
let mut encoder = ZstdEncoder::new(stream);
let mut meta_encoder = AsyncDbnMetadataEncoder::new(&mut encoder);
meta_encoder
.encode(
&MetadataBuilder::new()
.dataset(dataset)
.start(time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64)
.schema(None)
.stype_in(None)
.stype_out(SType::InstrumentId)
.build(),
)
.await
.unwrap();
encoder.flush().await.unwrap();
encoder
}
async fn read_line(&mut self) -> String {
let mut res = String::new();
self.stream().read_line(&mut res).await.unwrap();
info!("Read: {}", &res[..res.len() - 1]);
res
}
fn stream(&mut self) -> &mut BufReader<TcpStream> {
self.stream.as_mut().unwrap()
}
fn addr(&self) -> String {
format!("127.0.0.1:{}", self.listener.local_addr().unwrap().port())
}
async fn close(&mut self) {
if let Some(stream) = self.stream.as_mut() {
stream.shutdown().await.unwrap();
}
self.stream = None;
}
}
struct Fixture {
send: UnboundedSender<Event>,
port: u16,
task: JoinHandle<()>,
}
enum Event {
Exit,
Accept,
Authenticate(Option<Duration>),
Send(String),
Subscribe(Subscription, bool),
Start,
SendRecord(Box<dyn AsRef<[u8]> + Send>),
Disconnect,
}
impl fmt::Debug for Event {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Event::Exit => write!(f, "Exit"),
Event::Accept => write!(f, "Accept"),
Event::Authenticate(hb_int) => write!(f, "Authenticate({hb_int:?})"),
Event::Send(msg) => write!(f, "Send({msg:?})"),
Event::Subscribe(sub, is_last) => write!(f, "Subscribe({sub:?}, {is_last:?})"),
Event::Start => write!(f, "Start"),
Event::SendRecord(_) => write!(f, "SendRecord"),
Event::Disconnect => write!(f, "Disconnect"),
}
}
}
impl Fixture {
pub async fn new(dataset: String, send_ts_out: bool) -> Self {
let (send, mut recv) = tokio::sync::mpsc::unbounded_channel();
let mut mock = MockGateway::new(dataset, send_ts_out).await;
let port = mock.listener.local_addr().unwrap().port();
let task = tokio::task::spawn(async move {
loop {
match recv.recv().await {
Some(Event::Authenticate(hb_int)) => mock.authenticate(hb_int).await,
Some(Event::Accept) => mock.accept().await,
Some(Event::Send(msg)) => mock.send(&msg).await,
Some(Event::Subscribe(sub, is_last)) => mock.subscribe(sub, is_last).await,
Some(Event::Start) => mock.start().await,
Some(Event::SendRecord(rec)) => mock.send_record(rec).await,
Some(Event::Disconnect) => mock.close().await,
Some(Event::Exit) | None => break,
}
}
});
Self { task, port, send }
}
pub fn accept(&mut self) {
self.send.send(Event::Accept).unwrap();
}
pub fn authenticate(&mut self, heartbeat_interval: Option<Duration>) {
self.send
.send(Event::Authenticate(heartbeat_interval))
.unwrap();
}
pub fn expect_subscribe(&mut self, subscription: Subscription, is_last: bool) {
self.send
.send(Event::Subscribe(subscription, is_last))
.unwrap();
}
pub fn start(&mut self) {
self.send.send(Event::Start).unwrap();
}
pub fn send(&mut self, msg: String) {
self.send.send(Event::Send(msg)).unwrap();
}
pub fn send_record<R>(&mut self, record: R)
where
R: HasRType + AsRef<[u8]> + Clone + Send + 'static,
{
self.send
.send(Event::SendRecord(Box::new(record.clone())))
.unwrap();
}
pub fn addr(&self) -> String {
format!("127.0.0.1:{}", self.port)
}
pub fn disconnect(&mut self) {
self.send.send(Event::Disconnect).unwrap()
}
pub async fn stop(self) {
self.send.send(Event::Exit).unwrap();
self.task.await.unwrap()
}
}
async fn setup(
dataset: Dataset,
send_ts_out: bool,
heartbeat_interval: Option<Duration>,
) -> (Fixture, Client) {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(LevelFilter::DEBUG)
.with_test_writer()
.try_init();
let mut fixture = Fixture::new(dataset.to_string(), send_ts_out).await;
fixture.authenticate(heartbeat_interval);
let builder = Client::builder()
.addr(fixture.addr())
.await
.unwrap()
.key(TEST_KEY.to_owned())
.unwrap()
.dataset(dataset.to_string())
.send_ts_out(send_ts_out)
.timeout_conf(TimeoutConf {
connect: None,
auth: None,
});
let target = if let Some(heartbeat_interval) = heartbeat_interval {
builder.heartbeat_interval(heartbeat_interval)
} else {
builder
}
.build()
.await
.unwrap();
(fixture, target)
}
#[tokio::test]
async fn test_subscribe() {
let (mut fixture, mut client) = setup(Dataset::XnasItch, false, None).await;
let subscription = Subscription::builder()
.symbols(vec!["MSFT", "TSLA", "QQQ"])
.schema(Schema::Ohlcv1M)
.stype_in(SType::RawSymbol)
.build();
fixture.expect_subscribe(subscription.clone(), true);
client.subscribe(subscription).await.unwrap();
fixture.stop().await;
}
#[tokio::test]
async fn test_subscribe_snapshot() {
let (mut fixture, mut client) =
setup(Dataset::XnasItch, false, Some(Duration::MINUTE)).await;
let subscription = Subscription::builder()
.symbols(vec!["MSFT", "TSLA", "QQQ"])
.schema(Schema::Ohlcv1M)
.stype_in(SType::RawSymbol)
.use_snapshot()
.build();
fixture.expect_subscribe(subscription.clone(), true);
client.subscribe(subscription).await.unwrap();
fixture.stop().await;
}
#[tokio::test]
async fn test_subscribe_snapshot_failed() {
let (fixture, mut client) =
setup(Dataset::XnasItch, false, Some(Duration::seconds(5))).await;
let err = client
.subscribe(
Subscription::builder()
.symbols(vec!["MSFT", "TSLA", "QQQ"])
.schema(Schema::Ohlcv1M)
.stype_in(SType::RawSymbol)
.start(time::OffsetDateTime::now_utc())
.use_snapshot()
.build(),
)
.await
.unwrap_err();
assert!(err
.to_string()
.contains("cannot request snapshot with start time"));
fixture.stop().await;
}
#[tokio::test]
async fn test_subscription_chunking() {
const SYMBOL: &str = "TEST";
const SYMBOL_COUNT: usize = 1001;
let (mut fixture, mut client) = setup(Dataset::XnasItch, false, None).await;
let sub_base = Subscription::builder()
.schema(Schema::Ohlcv1M)
.stype_in(SType::RawSymbol);
let subscription = sub_base.clone().symbols(vec![SYMBOL; SYMBOL_COUNT]).build();
client.subscribe(subscription).await.unwrap();
let mut i = 0;
while i < SYMBOL_COUNT {
let chunk_size = 500.min(SYMBOL_COUNT - i);
fixture.expect_subscribe(
sub_base.clone().symbols(vec![SYMBOL; chunk_size]).build(),
i + chunk_size == SYMBOL_COUNT,
);
i += chunk_size;
}
fixture.stop().await;
}
#[tokio::test]
async fn test_next_record() {
const REC: OhlcvMsg = OhlcvMsg {
hd: RecordHeader::new::<OhlcvMsg>(rtype::OHLCV_1M, 1, 2, 3),
open: 1,
high: 2,
low: 3,
close: 4,
volume: 5,
};
let (mut fixture, mut client) =
setup(Dataset::GlbxMdp3, false, Some(Duration::minutes(5))).await;
fixture.start();
let metadata = client.start().await.unwrap();
assert_eq!(metadata.version, dbn::DBN_VERSION);
assert!(metadata.schema.is_none());
assert_eq!(metadata.dataset, Dataset::GlbxMdp3.as_str());
fixture.send_record(REC);
let rec = client.next_record().await.unwrap().unwrap();
assert_eq!(*rec.get::<OhlcvMsg>().unwrap(), REC);
fixture.stop().await;
}
#[tokio::test]
async fn test_next_record_with_ts_out() {
let expected = WithTsOut::new(
TradeMsg {
hd: RecordHeader::new::<TradeMsg>(rtype::MBP_0, 1, 2, 3),
price: 1,
size: 2,
action: b'A' as c_char,
side: b'A' as c_char,
flags: FlagSet::default(),
depth: 1,
ts_recv: 0,
ts_in_delta: 0,
sequence: 2,
},
time::OffsetDateTime::now_utc().unix_timestamp_nanos() as u64,
);
let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, true, None).await;
fixture.start();
let metadata = client.start().await.unwrap();
assert_eq!(metadata.version, dbn::DBN_VERSION);
assert!(metadata.schema.is_none());
assert_eq!(metadata.dataset, Dataset::GlbxMdp3.as_str());
fixture.send_record(expected.clone());
let rec = client.next_record().await.unwrap().unwrap();
assert_eq!(*rec.get::<WithTsOut<TradeMsg>>().unwrap(), expected);
fixture.stop().await;
}
#[tokio::test]
async fn test_close() {
let (mut fixture, mut client) =
setup(Dataset::GlbxMdp3, true, Some(Duration::seconds(45))).await;
fixture.start();
client.start().await.unwrap();
client.close().await.unwrap();
fixture.stop().await;
}
#[tokio::test]
async fn test_error_without_success() {
const DATASET: Dataset = Dataset::OpraPillar;
let mut fixture = Fixture::new(DATASET.to_string(), false).await;
let addr = fixture.addr();
let client_task = tokio::spawn(async move {
let res = Client::builder()
.addr(addr)
.await
.unwrap()
.key(TEST_KEY.to_owned())
.unwrap()
.dataset(DATASET.to_string())
.build()
.await;
if let Err(e) = &res {
dbg!(e);
}
assert!(matches!(res, Err(e) if e.to_string().contains("Unknown failure")));
});
let fixture_task = tokio::spawn(async move {
fixture.accept();
fixture.send("lsg-test\n".to_owned());
fixture.send("cram=t7kNhwj4xqR0QYjzFKtBEG2ec2pXJ4FK\n".to_owned());
fixture.send("Unknown failure\n".to_owned());
});
let (r1, r2) = join!(client_task, fixture_task);
r1.unwrap();
r2.unwrap();
}
#[tokio::test]
async fn test_cancellation_safety() {
let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, true, None).await;
fixture.start();
let metadata = client.start().await.unwrap();
assert_eq!(metadata.version, dbn::DBN_VERSION);
assert!(metadata.schema.is_none());
assert_eq!(metadata.dataset, Dataset::GlbxMdp3.as_str());
fixture.send_record(Mbp10Msg::default());
let mut int_1 = tokio::time::interval(std::time::Duration::from_millis(1));
let mut int_2 = tokio::time::interval(std::time::Duration::from_millis(1));
let mut int_3 = tokio::time::interval(std::time::Duration::from_millis(1));
let mut int_4 = tokio::time::interval(std::time::Duration::from_millis(1));
let mut int_5 = tokio::time::interval(std::time::Duration::from_millis(1));
let mut int_6 = tokio::time::interval(std::time::Duration::from_millis(1));
for _ in 0..5_000 {
select! {
_ = int_1.tick() => {
fixture.send_record(Mbp10Msg::default());
}
_ = int_2.tick() => {
fixture.send_record(Mbp10Msg::default());
}
_ = int_3.tick() => {
fixture.send_record(Mbp10Msg::default());
}
_ = int_4.tick() => {
fixture.send_record(Mbp10Msg::default());
}
_ = int_5.tick() => {
fixture.send_record(Mbp10Msg::default());
}
_ = int_6.tick() => {
fixture.send_record(Mbp10Msg::default());
}
res = client.next_record() => {
let rec = res.unwrap().unwrap();
dbg!(rec.header());
assert_eq!(*rec.get::<Mbp10Msg>().unwrap(), Mbp10Msg::default());
}
}
}
fixture.stop().await;
}
#[tokio::test]
async fn test_reconnect() {
let (mut fixture, mut client) = setup(Dataset::EqusMini, true, None).await;
let sub = Subscription::builder()
.symbols(["SPY", "QQQ"])
.schema(Schema::Trades)
.start(OffsetDateTime::UNIX_EPOCH)
.build();
fixture.expect_subscribe(sub.clone(), true);
client.subscribe(sub.clone()).await.unwrap();
fixture.start();
let metadata = client.start().await.unwrap();
assert_eq!(metadata.version, dbn::DBN_VERSION);
assert!(metadata.schema.is_none());
assert_eq!(metadata.dataset, Dataset::EqusMini.as_str());
let trade = TradeMsg {
hd: RecordHeader::default::<TradeMsg>(rtype::MBP_0),
price: 1,
size: 2,
action: 'T' as c_char,
side: 'B' as c_char,
flags: FlagSet::default(),
depth: 0,
ts_recv: 3,
ts_in_delta: 4,
sequence: 5,
};
fixture.send_record(trade.clone());
assert_eq!(
*client
.next_record()
.await
.unwrap()
.unwrap()
.get::<TradeMsg>()
.unwrap(),
trade
);
fixture.disconnect();
assert!(client.next_record().await.unwrap().is_none());
fixture.authenticate(None);
client.reconnect().await.unwrap();
let mut resub = sub.clone();
resub.start = None;
fixture.expect_subscribe(resub, true);
client.resubscribe().await.unwrap();
fixture.start();
client.start().await.unwrap();
fixture.send_record(trade.clone());
assert_eq!(
*client
.next_record()
.await
.unwrap()
.unwrap()
.get::<TradeMsg>()
.unwrap(),
trade
);
fixture.stop().await;
}
#[tokio::test]
async fn test_next_record_with_zstd_compression() {
const REC: OhlcvMsg = OhlcvMsg {
hd: RecordHeader::new::<OhlcvMsg>(rtype::OHLCV_1M, 1, 2, 3),
open: 1,
high: 2,
low: 3,
close: 4,
volume: 5,
};
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(LevelFilter::DEBUG)
.with_test_writer()
.try_init();
let mut mock = MockGateway::new(Dataset::GlbxMdp3.to_string(), false).await;
let addr = mock.addr();
let mock_task = tokio::spawn(async move {
mock.authenticate(None).await;
let mut encoder = mock.start_compressed().await;
encoder.write_all(REC.as_ref()).await.unwrap();
encoder.write_all(REC.as_ref()).await.unwrap();
encoder.flush().await.unwrap();
encoder.shutdown().await.unwrap();
});
let mut client = Client::builder()
.addr(addr)
.await
.unwrap()
.key(TEST_KEY.to_owned())
.unwrap()
.dataset(Dataset::GlbxMdp3.to_string())
.compression(Compression::Zstd)
.build()
.await
.unwrap();
let metadata = client.start().await.unwrap();
assert_eq!(metadata.dataset, Dataset::GlbxMdp3.to_string());
let rec1 = client.next_record().await.unwrap().unwrap();
assert_eq!(*rec1.get::<OhlcvMsg>().unwrap(), REC);
let rec2 = client.next_record().await.unwrap().unwrap();
assert_eq!(*rec2.get::<OhlcvMsg>().unwrap(), REC);
mock_task.await.unwrap();
}
#[tokio::test]
async fn test_try_next_record_and_fill_buf() {
const REC1: OhlcvMsg = OhlcvMsg {
hd: RecordHeader::new::<OhlcvMsg>(rtype::OHLCV_1M, 1, 2, 3),
open: 1,
high: 2,
low: 3,
close: 4,
volume: 5,
};
const REC2: OhlcvMsg = OhlcvMsg {
hd: RecordHeader::new::<OhlcvMsg>(rtype::OHLCV_1M, 4, 5, 6),
open: 10,
high: 20,
low: 30,
close: 40,
volume: 50,
};
let (mut fixture, mut client) = setup(Dataset::GlbxMdp3, false, None).await;
fixture.start();
client.start().await.unwrap();
assert!(!client.is_closed());
assert!(client.try_next_record().unwrap().is_none());
fixture.send_record(REC1);
fixture.send_record(REC2);
let rec1 = loop {
if let Some(rec) = client.try_next_record().unwrap() {
break rec;
}
let nbytes = client.fill_buf().await.unwrap();
assert!(nbytes > 0);
assert!(!client.is_closed());
};
assert_eq!(*rec1.get::<OhlcvMsg>().unwrap(), REC1);
let rec2 = loop {
if let Some(rec) = client.try_next_record().unwrap() {
break rec;
}
let nbytes = client.fill_buf().await.unwrap();
assert!(nbytes > 0);
};
assert_eq!(*rec2.get::<OhlcvMsg>().unwrap(), REC2);
assert!(client.try_next_record().unwrap().is_none());
fixture.disconnect();
let nbytes = client.fill_buf().await.unwrap();
assert_eq!(nbytes, 0);
assert!(client.is_closed());
fixture.stop().await;
}
#[tokio::test]
async fn test_slow_reader_behavior_warn() {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(LevelFilter::DEBUG)
.with_test_writer()
.try_init();
let mut mock = MockGateway::new(Dataset::XnasItch.to_string(), false).await;
mock.slow_reader_behavior = Some(SlowReaderBehavior::Warn);
let addr = mock.addr();
let mock_task = tokio::spawn(async move {
mock.authenticate(None).await;
});
let _client = Client::builder()
.addr(addr)
.await
.unwrap()
.key(TEST_KEY.to_owned())
.unwrap()
.dataset(Dataset::XnasItch.to_string())
.slow_reader_behavior(SlowReaderBehavior::Warn)
.build()
.await
.unwrap();
mock_task.await.unwrap();
}
#[tokio::test]
async fn test_slow_reader_behavior_skip() {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(LevelFilter::DEBUG)
.with_test_writer()
.try_init();
let mut mock = MockGateway::new(Dataset::XnasItch.to_string(), false).await;
mock.slow_reader_behavior = Some(SlowReaderBehavior::Skip);
let addr = mock.addr();
let mock_task = tokio::spawn(async move {
mock.authenticate(None).await;
});
let _client = Client::builder()
.addr(addr)
.await
.unwrap()
.key(TEST_KEY.to_owned())
.unwrap()
.dataset(Dataset::XnasItch.to_string())
.slow_reader_behavior(SlowReaderBehavior::Skip)
.build()
.await
.unwrap();
mock_task.await.unwrap();
}
#[tokio::test(start_paused = true)]
async fn test_heartbeat_timeout_next_record() {
let (mut fixture, mut client) =
setup(Dataset::GlbxMdp3, false, Some(Duration::seconds(1))).await;
fixture.start();
client.start().await.unwrap();
let err = client.next_record().await.unwrap_err();
assert!(
matches!(&err, crate::Error::HeartbeatTimeout(_)),
"Expected HeartbeatTimeout, got: {err}"
);
assert!(client.is_closed());
fixture.stop().await;
}
#[tokio::test(start_paused = true)]
async fn test_heartbeat_timeout_fill_buf() {
let (mut fixture, mut client) =
setup(Dataset::GlbxMdp3, false, Some(Duration::seconds(1))).await;
fixture.start();
client.start().await.unwrap();
let err = client.fill_buf().await.unwrap_err();
assert!(
matches!(&err, crate::Error::HeartbeatTimeout(_)),
"Expected HeartbeatTimeout, got: {err}"
);
assert!(client.is_closed());
fixture.stop().await;
}
#[tokio::test(start_paused = true)]
async fn test_reconnect_resets_heartbeat_timer() {
let (mut fixture, mut client) =
setup(Dataset::GlbxMdp3, false, Some(Duration::seconds(1))).await;
fixture.start();
client.start().await.unwrap();
tokio::time::resume();
fixture.disconnect();
let rec = client.next_record().await.unwrap();
assert!(rec.is_none());
assert!(client.is_closed());
tokio::time::pause();
fixture.authenticate(Some(Duration::seconds(1)));
client.reconnect().await.unwrap();
assert!(!client.is_closed());
fixture.start();
client.start().await.unwrap();
tokio::time::resume();
fixture.disconnect();
let rec = client.next_record().await.unwrap();
assert!(rec.is_none());
fixture.stop().await;
}
#[tokio::test(start_paused = true)]
async fn test_connect_timeout() {
let result = Client::builder()
.addr("192.0.2.1:13000")
.await
.unwrap()
.key(TEST_KEY.to_owned())
.unwrap()
.dataset("GLBX.MDP3")
.timeout_conf(TimeoutConf {
connect: Some(Duration::seconds(1)),
auth: None,
})
.build()
.await;
assert!(
matches!(result, Err(crate::Error::ConnectTimeout(_))),
"Expected ConnectTimeout, got {result:?}"
);
}
#[tokio::test(start_paused = true)]
async fn test_auth_timeout() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (_stream, _addr) = listener.accept().await.unwrap();
tokio::time::sleep(StdDuration::from_secs(60)).await;
});
let result = Client::builder()
.addr(addr)
.await
.unwrap()
.key(TEST_KEY.to_owned())
.unwrap()
.dataset("GLBX.MDP3")
.timeout_conf(TimeoutConf {
connect: None,
auth: Some(Duration::seconds(1)),
})
.build()
.await;
assert!(
matches!(result, Err(crate::Error::AuthTimeout(_))),
"Expected AuthTimeout, got {result:?}"
);
server.abort();
}
#[test]
fn test_default_timeout_conf() {
let builder = Client::builder();
let timeout_conf = builder.timeout_conf;
assert_eq!(timeout_conf.connect, Some(Duration::seconds(10)));
assert_eq!(timeout_conf.auth, Some(Duration::seconds(30)));
}
}