use super::*;
use crate::{RunFor, RunMode, nodes::*, types::*};
use anyhow::{Context, Result};
use derive_new::new;
use kdb_plus_fixed::ipc::{ConnectionMethod, K, QStream};
use log::{Level, LevelFilter};
use tokio::runtime::Runtime;
pub(super) const TABLE_NAME: &str = "test_trades";
const WRITE_TABLE_NAME: &str = "test_trades_write";
#[derive(Debug, Clone, Default)]
pub struct TestTrade {
sym: Sym,
price: f64,
qty: i64,
}
impl KdbSerialize for TestTrade {
fn to_kdb_row(&self) -> K {
K::new_compound_list(vec![
K::new_symbol(self.sym.to_string()),
K::new_float(self.price),
K::new_long(self.qty),
])
}
}
impl KdbDeserialize for TestTrade {
fn from_kdb_row(
row: Row<'_>,
_columns: &[String],
interner: &mut SymbolInterner,
) -> Result<(NanoTime, Self), KdbError> {
let time = row.get_timestamp(1)?; Ok((
time,
TestTrade {
sym: row.get_sym(2, interner)?,
price: row.get(3)?.get_float()?,
qty: row.get(4)?.get_long()?,
},
))
}
}
#[derive(Debug, Clone, Default)]
pub struct TestTradeWrite {
sym: Sym,
price: f64,
qty: i64,
}
impl KdbDeserialize for TestTradeWrite {
fn from_kdb_row(
row: Row<'_>,
_columns: &[String],
interner: &mut SymbolInterner,
) -> Result<(NanoTime, Self), KdbError> {
let time = row.get_timestamp(0)?; Ok((
time,
TestTradeWrite {
sym: row.get_sym(1, interner)?,
price: row.get(2)?.get_float()?,
qty: row.get(3)?.get_long()?,
},
))
}
}
#[derive(new)]
struct TestDataBuilder {
connection: KdbConnection,
tokio: Runtime,
}
impl TestDataBuilder {
fn connection() -> KdbConnection {
let port = std::env::var("KDB_TEST_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(5000);
let host = std::env::var("KDB_TEST_HOST").unwrap_or_else(|_| "localhost".to_string());
KdbConnection::new(host, port)
}
async fn socket(&self) -> Result<QStream> {
let creds = self.connection.credentials_string();
QStream::connect(
ConnectionMethod::TCP,
&self.connection.host,
self.connection.port,
&creds,
)
.await
.context("Failed to connect to KDB+")
}
async fn execute(&self, query: &str) -> Result<()> {
let result = self
.socket()
.await?
.send_sync_message(&query)
.await
.context("Failed to send query to KDB+")?;
if result.get_type() == -128 {
anyhow::bail!("KDB+ query error: {:?}", result);
}
Ok(())
}
async fn create_table(&self) -> Result<()> {
self.execute(&format!(
"{}:([]date:`date$();time:`timestamp$();sym:`symbol$();price:`float$();qty:`long$())",
TABLE_NAME
))
.await?;
Ok(())
}
async fn write_rows(
&self,
records_per_day: usize,
num_days: usize,
sorted: bool,
) -> Result<()> {
let n = records_per_day * num_days;
let (date_expr, time_expr) = if sorted {
(
format!(
"raze {{{}#2000.01.01+x}} each til {}",
records_per_day, num_days
),
format!(
"raze {{(`timestamp$2000.01.01+x)+(86400000000000j div {}j)*til {}}} each til {}",
records_per_day, records_per_day, num_days
),
)
} else {
(
format!("{}#2000.01.01", n),
format!("2000.01.01D00:00:00.000000000+1000000000j*neg {}?{}", n, n),
)
};
let query = format!(
"insert[`{table};({dates};{times};{n}?`AAPL`GOOG`MSFT;{n}?100.0;{n}?1000j)]",
table = TABLE_NAME,
dates = date_expr,
times = time_expr,
n = n,
);
self.execute(&query).await?;
Ok(())
}
async fn drop_table(&self) -> Result<()> {
self.execute(&format!("delete {} from `.", TABLE_NAME))
.await?;
Ok(())
}
async fn create_write_table(&self) -> Result<()> {
self.execute(&format!(
"{}:([]time:`timestamp$();sym:`symbol$();price:`float$();qty:`long$())",
WRITE_TABLE_NAME
))
.await?;
Ok(())
}
async fn write_rows_to_write_table(&self, n: usize) -> Result<()> {
let query = format!(
"insert[`{table};(2000.01.01D00:00:00.000000000+1000000000j*til {n};{n}?`AAPL`GOOG`MSFT;{n}?100.0;{n}?1000j)]",
table = WRITE_TABLE_NAME,
n = n,
);
self.execute(&query).await?;
Ok(())
}
async fn drop_write_table(&self) -> Result<()> {
self.execute(&format!("delete {} from `.", WRITE_TABLE_NAME))
.await?;
Ok(())
}
fn setup(&self, records_per_day: usize, num_days: usize, sorted: bool) -> Result<()> {
self.tokio.block_on(async {
self.create_table().await?;
self.write_rows(records_per_day, num_days, sorted).await?;
Ok(())
})
}
fn teardown(&self) -> Result<()> {
self.tokio.block_on(async { self.drop_table().await })
}
}
pub(super) fn with_test_data<F>(
records_per_day: usize,
num_days: usize,
sorted: bool,
test: F,
) -> anyhow::Result<()>
where
F: FnOnce(usize, KdbConnection) -> anyhow::Result<()>,
{
let conn = TestDataBuilder::connection();
let rt = tokio::runtime::Runtime::new()?;
let builder = TestDataBuilder::new(conn.clone(), rt);
builder.setup(records_per_day, num_days, sorted)?;
let test_result = test(records_per_day * num_days, conn);
let teardown_result = builder.teardown();
test_result?;
teardown_result?;
Ok(())
}
fn with_empty_table<F>(test: F) -> Result<()>
where
F: FnOnce(KdbConnection) -> Result<()>,
{
let conn = TestDataBuilder::connection();
let rt = tokio::runtime::Runtime::new()?;
let builder = TestDataBuilder::new(conn.clone(), rt);
builder.tokio.block_on(builder.create_table())?;
let test_result = test(conn);
let teardown_result = builder.teardown();
test_result?;
teardown_result?;
Ok(())
}
fn with_empty_write_table<F>(test: F) -> Result<()>
where
F: FnOnce(KdbConnection) -> Result<()>,
{
let conn = TestDataBuilder::connection();
let rt = tokio::runtime::Runtime::new()?;
let builder = TestDataBuilder::new(conn.clone(), rt);
builder.tokio.block_on(builder.create_write_table())?;
let test_result = test(conn);
let teardown_result = builder.tokio.block_on(builder.drop_write_table());
test_result?;
teardown_result?;
Ok(())
}
pub(super) fn slice_query(date: i32, t0: NanoTime, t1: NanoTime) -> String {
format!(
"select from {} where date=2000.01.01+{}, time >= (`timestamp$){}j, time < (`timestamp$){}j",
TABLE_NAME,
date,
t0.to_kdb_timestamp(),
t1.to_kdb_timestamp(),
)
}
#[test]
fn test_kdb_sorted_data() -> Result<()> {
let _ = env_logger::try_init();
with_test_data(3, 2, true, |_n, conn| {
let stream = kdb_read::<TestTrade, _>(
conn,
std::time::Duration::from_secs(24 * 3600),
|within, date, _| slice_query(date, within.0, within.1),
);
let collected = stream.collapse().collect();
collected.clone().run(
RunMode::HistoricalFrom(NanoTime::from_kdb_timestamp(0)),
RunFor::Duration(std::time::Duration::from_secs(2 * 86400)),
)?;
assert_eq!(
collected.peek_value().len(),
6,
"Should read all 6 rows (3 per day × 2 days)"
);
Ok(())
})
}
#[derive(Debug, Clone, Default)]
#[allow(dead_code)]
struct BadTrade {
sym: i64, }
impl KdbDeserialize for BadTrade {
fn from_kdb_row(
row: Row<'_>,
_columns: &[String],
_interner: &mut SymbolInterner,
) -> Result<(NanoTime, Self), KdbError> {
let time = row.get_timestamp(1)?; Ok((
time,
BadTrade {
sym: row.get(2)?.get_long()?,
},
))
}
}
#[test]
fn test_kdb_bad_query() -> Result<()> {
let _ = env_logger::try_init();
let conn = TestDataBuilder::connection();
let stream = kdb_read::<TestTrade, _>(
conn,
std::time::Duration::from_secs(24 * 3600),
|_, _, _| "select from nonexistent_table_xyz".to_string(),
);
let collected = stream.collapse().collect();
let result = collected.run(
RunMode::HistoricalFrom(NanoTime::from_kdb_timestamp(0)),
RunFor::Duration(std::time::Duration::from_secs(86400)),
);
assert!(result.is_err(), "Bad query should return an error");
Ok(())
}
#[test]
fn test_kdb_deserialization_error() -> Result<()> {
let _ = env_logger::try_init();
let result = with_test_data(3, 1, true, |_n, conn| {
let stream = kdb_read::<BadTrade, _>(
conn,
std::time::Duration::from_secs(24 * 3600),
|within, date, _| slice_query(date, within.0, within.1),
);
let collected = stream.collapse().collect();
collected.run(
RunMode::HistoricalFrom(NanoTime::from_kdb_timestamp(0)),
RunFor::Duration(std::time::Duration::from_secs(86400)),
)?;
Ok(())
});
assert!(
result.is_err(),
"Type mismatch should return a deserialization error"
);
Ok(())
}
#[test]
fn test_read_read_perf() -> Result<()> {
log::set_max_level(LevelFilter::Off);
let records_per_day = 100_000;
let num_days = 10;
with_test_data(records_per_day, num_days, true, |n, conn| {
let periods = [
std::time::Duration::from_secs(3600), std::time::Duration::from_secs(6 * 3600), std::time::Duration::from_secs(24 * 3600), ];
println!("\n{:<15} {:>12}", "Period (secs)", "Time");
println!("{}", "-".repeat(30));
for &period in &periods {
let start = std::time::Instant::now();
let stream = kdb_read::<TestTrade, _>(conn.clone(), period, |within, date, _| {
slice_query(date, within.0, within.1)
});
let counter = stream.collapse().count();
counter.clone().run(
RunMode::HistoricalFrom(NanoTime::from_kdb_timestamp(0)),
RunFor::Duration(std::time::Duration::from_secs(num_days as u64 * 86400)),
)?;
assert_eq!(counter.peek_value() as usize, n);
println!("{:<15} {:?}", period.as_secs(), start.elapsed());
}
Ok(())
})
}
#[test]
fn test_kdb_connection_refused() -> Result<()> {
let _ = env_logger::try_init();
let conn = KdbConnection::new("localhost", 59999);
let stream = kdb_read::<TestTrade, _>(
conn,
std::time::Duration::from_secs(24 * 3600),
|_, _, _| format!("select from {}", TABLE_NAME),
);
let collected = stream.collapse().collect();
let result = collected.run(
RunMode::HistoricalFrom(NanoTime::from_kdb_timestamp(0)),
RunFor::Duration(std::time::Duration::from_secs(86400)),
);
assert!(result.is_err(), "Connection refused should return an error");
Ok(())
}
#[test]
fn test_kdb_empty_table_returns_zero_rows() -> Result<()> {
let _ = env_logger::try_init();
with_empty_table(|conn| {
let stream = kdb_read::<TestTrade, _>(
conn,
std::time::Duration::from_secs(24 * 3600),
|within, date, _| slice_query(date, within.0, within.1),
);
let collected = stream.collapse().collect();
collected.clone().run(
RunMode::HistoricalFrom(NanoTime::from_kdb_timestamp(0)),
RunFor::Duration(std::time::Duration::from_secs(86400)),
)?;
assert_eq!(
collected.peek_value().len(),
0,
"Empty table should return 0 rows"
);
Ok(())
})
}
#[test]
fn test_kdb_read_works() -> Result<()> {
let _ = env_logger::try_init();
with_test_data(3, 2, true, |_n, conn| {
let start = NanoTime::from_kdb_timestamp(0);
let stream = kdb_read::<TestTrade, _>(
conn,
std::time::Duration::from_secs(12 * 3600), move |(slice_start, slice_end), date, _iteration| {
slice_query(date, slice_start, slice_end)
},
);
let collected = stream.collapse().collect();
collected.clone().run(
RunMode::HistoricalFrom(start),
RunFor::Duration(std::time::Duration::from_secs(2 * 86400)),
)?;
let rows = collected.peek_value();
assert_eq!(
rows.len(),
6,
"Should read all 6 rows (3 per day × 2 days) across 4 time slices, got {}",
rows.len()
);
Ok(())
})
}
fn write_and_verify(conn: KdbConnection, trades: Vec<TestTrade>) -> Result<usize> {
let n = trades.len();
let write_conn = conn.clone();
let stream = produce_async(move |_ctx| {
let trades = trades;
async move {
Ok(async_stream::stream! {
for (i, trade) in trades.into_iter().enumerate() {
let time = NanoTime::from_kdb_timestamp(i as i64 * 1_000_000_000);
yield Ok((time, trade));
}
})
}
});
let writer = kdb_write(write_conn, WRITE_TABLE_NAME, &stream);
writer.run(RunMode::HistoricalFrom(NanoTime::ZERO), RunFor::Forever)?;
let rt = tokio::runtime::Runtime::new()?;
let verify_conn = conn;
let count = rt.block_on(async {
let creds = verify_conn.credentials_string();
let mut socket = QStream::connect(
ConnectionMethod::TCP,
&verify_conn.host,
verify_conn.port,
&creds,
)
.await?;
let query = format!("count {}", WRITE_TABLE_NAME);
let result = socket.send_sync_message(&query.as_str()).await?;
let count = result.get_long()?;
Ok::<i64, anyhow::Error>(count)
})?;
println!("Wrote {} trades, verified {} in KDB", n, count);
Ok(count as usize)
}
fn make_test_trades(n: usize) -> Vec<TestTrade> {
let syms = ["AAPL", "GOOG", "MSFT"];
let mut interner = SymbolInterner::default();
(0..n)
.map(|i| TestTrade {
sym: interner.intern(syms[i % syms.len()]),
price: 100.0 + i as f64,
qty: (i * 10 + 1) as i64,
})
.collect()
}
#[test]
fn test_kdb_write_round_trip() -> Result<()> {
let _ = env_logger::try_init();
let trades = make_test_trades(5);
with_empty_write_table(|conn| {
let count = write_and_verify(conn.clone(), trades)?;
assert_eq!(count, 5, "Should have written 5 trades");
let read_stream = kdb_read::<TestTradeWrite, _>(
conn,
std::time::Duration::from_secs(24 * 3600),
move |(t0, t1), _, _| {
format!(
"select from {} where time >= (`timestamp$){}j, time < (`timestamp$){}j",
WRITE_TABLE_NAME,
t0.to_kdb_timestamp(),
t1.to_kdb_timestamp(),
)
},
);
let collected = read_stream
.collapse()
.logged("readback", Level::Info)
.collect();
collected.clone().run(
RunMode::HistoricalFrom(NanoTime::from_kdb_timestamp(0)),
RunFor::Duration(std::time::Duration::from_secs(86400)),
)?;
let rows = collected.peek_value();
assert_eq!(rows.len(), 5, "Should read back 5 rows");
let first = &rows[0].value;
assert_eq!(first.sym.to_string(), "AAPL");
assert!((first.price - 100.0).abs() < 0.001);
assert_eq!(first.qty, 1);
Ok(())
})
}
#[test]
fn test_kdb_write_append() -> Result<()> {
let _ = env_logger::try_init();
let conn = TestDataBuilder::connection();
let rt = tokio::runtime::Runtime::new()?;
let builder = TestDataBuilder::new(conn.clone(), rt);
builder.tokio.block_on(async {
builder.create_write_table().await?;
builder.write_rows_to_write_table(3).await
})?;
let test_result: anyhow::Result<()> = (|| {
let new_trades = make_test_trades(2);
let write_conn = conn.clone();
let stream = produce_async(move |_ctx| {
let trades = new_trades;
async move {
Ok(async_stream::stream! {
for (i, trade) in trades.into_iter().enumerate() {
let time = NanoTime::from_kdb_timestamp((10 + i as i64) * 1_000_000_000);
yield Ok((time, trade));
}
})
}
});
let writer = kdb_write(write_conn, WRITE_TABLE_NAME, &stream);
writer.run(RunMode::HistoricalFrom(NanoTime::ZERO), RunFor::Forever)?;
let rt = tokio::runtime::Runtime::new()?;
let count = rt.block_on(async {
let creds = conn.credentials_string();
let mut socket =
QStream::connect(ConnectionMethod::TCP, &conn.host, conn.port, &creds).await?;
let query = format!("count {}", WRITE_TABLE_NAME);
let result = socket.send_sync_message(&query.as_str()).await?;
result.get_long().map_err(anyhow::Error::new)
})?;
assert_eq!(count, 5, "Should have 3 original + 2 appended = 5 rows");
println!("Append test: 3 + 2 = {} rows", count);
Ok(())
})();
let teardown_result = builder.tokio.block_on(builder.drop_write_table());
test_result?;
teardown_result?;
Ok(())
}