use super::{KdbConnection, Sym, SymbolInterner};
use crate::nodes::produce_async;
use crate::types::*;
use anyhow::{Result, bail};
use kdb_plus_fixed::ipc::error::Error as KdbError;
use kdb_plus_fixed::ipc::{ConnectionMethod, K, QStream};
use kdb_plus_fixed::qtype;
use log::info;
use std::rc::Rc;
pub trait KdbExt {
fn column_names(&self) -> Result<Vec<String>>;
fn rows(&self) -> Result<Rows>;
fn element_at(&self, index: usize) -> Result<K, KdbError>;
}
pub struct Rows {
columns: Vec<K>,
n_rows: usize,
}
impl Rows {
pub fn len(&self) -> usize {
self.n_rows
}
pub fn is_empty(&self) -> bool {
self.n_rows == 0
}
pub fn get(&self, index: usize) -> Option<Row<'_>> {
if index < self.n_rows {
Some(Row {
columns: &self.columns,
index,
})
} else {
None
}
}
pub fn iter(&self) -> RowIter<'_> {
RowIter {
columns: &self.columns,
n_rows: self.n_rows,
current: 0,
}
}
}
impl<'a> IntoIterator for &'a Rows {
type Item = Row<'a>;
type IntoIter = RowIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct RowIter<'a> {
columns: &'a [K],
n_rows: usize,
current: usize,
}
impl<'a> Iterator for RowIter<'a> {
type Item = Row<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.current < self.n_rows {
let row = Row {
columns: self.columns,
index: self.current,
};
self.current += 1;
Some(row)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.n_rows - self.current;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for RowIter<'_> {}
#[derive(Clone, Copy)]
pub struct Row<'a> {
columns: &'a [K],
index: usize,
}
impl Row<'_> {
pub fn get(&self, col: usize) -> Result<K, KdbError> {
self.columns
.get(col)
.ok_or(KdbError::IndexOutOfBounds {
index: col,
length: self.columns.len(),
})?
.element_at(self.index)
}
pub fn get_sym(&self, col: usize, interner: &mut SymbolInterner) -> Result<Sym, KdbError> {
let column = self.columns.get(col).ok_or(KdbError::IndexOutOfBounds {
index: col,
length: self.columns.len(),
})?;
let strings = column
.as_vec::<String>()
.map_err(|_| KdbError::InvalidOperation {
operator: "get_sym",
operand_type: "K",
expected: Some("symbol list"),
})?;
let s = strings.get(self.index).ok_or(KdbError::IndexOutOfBounds {
index: self.index,
length: strings.len(),
})?;
Ok(interner.intern(s))
}
pub fn len(&self) -> usize {
self.columns.len()
}
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
}
impl KdbExt for K {
fn column_names(&self) -> Result<Vec<String>> {
if self.get_type() != qtype::TABLE {
bail!("expected table (qtype 98), got qtype {}", self.get_type());
}
let dict = self.get_dictionary()?;
let dict_parts = dict.as_vec::<K>()?;
let keys = dict_parts
.first()
.ok_or_else(|| anyhow::anyhow!("table dictionary has no keys"))?;
let symbols = keys.as_vec::<String>()?;
Ok(symbols.clone())
}
fn rows(&self) -> Result<Rows> {
if self.get_type() != qtype::TABLE {
bail!("expected table (qtype 98), got qtype {}", self.get_type());
}
let dict = self.get_dictionary()?;
let dict_parts = dict.as_vec::<K>()?;
if dict_parts.len() < 2 {
bail!("table dictionary missing values");
}
let values = &dict_parts[1];
let columns = values.as_vec::<K>()?.clone();
let n_rows = columns.first().map(|c| c.len()).unwrap_or(0);
Ok(Rows { columns, n_rows })
}
fn element_at(&self, index: usize) -> Result<K, KdbError> {
let ktype = self.get_type();
let len = self.len();
let result = match ktype {
qtype::LONG_LIST | qtype::TIMESTAMP_LIST | qtype::TIMESPAN_LIST => self
.as_vec::<i64>()
.ok()
.and_then(|v| v.get(index).map(|&x| K::new_long(x))),
qtype::FLOAT_LIST => self
.as_vec::<f64>()
.ok()
.and_then(|v| v.get(index).map(|&x| K::new_float(x))),
qtype::SYMBOL_LIST => self
.as_vec::<String>()
.ok()
.and_then(|v| v.get(index).map(|x| K::new_symbol(x.clone()))),
qtype::STRING => self
.as_vec::<u8>()
.ok()
.and_then(|v| v.get(index).map(|&x| K::new_byte(x))),
qtype::INT_LIST | qtype::DATE_LIST | qtype::TIME_LIST => self
.as_vec::<i32>()
.ok()
.and_then(|v| v.get(index).map(|&x| K::new_int(x))),
qtype::SHORT_LIST => self
.as_vec::<i16>()
.ok()
.and_then(|v| v.get(index).map(|&x| K::new_short(x))),
qtype::BOOL_LIST => self
.as_vec::<bool>()
.ok()
.and_then(|v| v.get(index).map(|&x| K::new_bool(x))),
qtype::REAL_LIST => self
.as_vec::<f32>()
.ok()
.and_then(|v| v.get(index).map(|&x| K::new_real(x))),
qtype::COMPOUND_LIST => self.as_vec::<K>().ok().and_then(|v| v.get(index).cloned()),
_ => {
return Err(KdbError::InvalidOperation {
operator: "element_at",
operand_type: "K",
expected: Some("list type"),
});
}
};
result.ok_or(KdbError::IndexOutOfBounds { index, length: len })
}
}
pub trait KdbDeserialize: Sized {
fn from_kdb_row(
row: Row<'_>,
columns: &[String],
interner: &mut SymbolInterner,
) -> Result<Self, KdbError>;
}
fn nano_to_kdb_date(t: NanoTime) -> i32 {
let unix_days = (u64::from(t) / 1_000_000_000 / 86400) as i64;
(unix_days - 10957) as i32
}
fn format_kdb_date(kdb_date: i32) -> String {
let z = kdb_date as i64 + 10957 + 719_468;
let era = if z >= 0 {
z / 146_097
} else {
(z - 146_096) / 146_097
};
let doe = z - era * 146_097;
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146_096) / 365;
let y = yoe + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y + 1 } else { y };
format!("{}.{:02}.{:02}", y, m, d)
}
fn format_kdb_timestamp(kdb_nanos: i64) -> String {
const DAY_NANOS: i64 = 86_400_000_000_000;
let kdb_day = kdb_nanos.div_euclid(DAY_NANOS) as i32;
let time_nanos = kdb_nanos.rem_euclid(DAY_NANOS);
let h = time_nanos / 3_600_000_000_000;
let m = (time_nanos % 3_600_000_000_000) / 60_000_000_000;
let s = (time_nanos % 60_000_000_000) / 1_000_000_000;
let ns = time_nanos % 1_000_000_000;
format!(
"{}D{:02}:{:02}:{:02}.{:09}",
format_kdb_date(kdb_day),
h,
m,
s,
ns
)
}
fn inject_time_filter(query: &str, time_col: &str, start: i64, end: Option<i64>) -> String {
let start_ts = format_kdb_timestamp(start);
let filter = match end {
Some(i64::MAX) => format!("{} within ({};0Wp)", time_col, start_ts),
Some(end_ts) => format!(
"{} within ({};{})",
time_col,
start_ts,
format_kdb_timestamp(end_ts)
),
None => format!("{} >= {}", time_col, start_ts),
};
let lower = query.to_lowercase();
if lower.contains(" where ") {
format!("{}, {}", query, filter)
} else {
format!("{} where {}", query, filter)
}
}
fn build_chunk_query(query: &str, offset: usize, rows_per_chunk: usize) -> String {
let trimmed = query.trim_start();
if trimmed
.get(..6)
.is_some_and(|s| s.eq_ignore_ascii_case("select"))
{
let after_select = trimmed[6..].trim_start();
if after_select.starts_with('[') {
let close = after_select
.find(']')
.map(|i| i + 1)
.unwrap_or(after_select.len());
format!(
"select[{},{}] {}",
offset,
rows_per_chunk,
after_select[close..].trim_start()
)
} else {
format!("select[{},{}] {}", offset, rows_per_chunk, after_select)
}
} else {
format!("({};{}) sublist {}", offset, rows_per_chunk, query)
}
}
fn inject_date_filter(query: &str, date_col: &str, start: i32, end: Option<i32>) -> String {
let filter = match end {
Some(end_date) => format!(
"{} within ({};{})",
date_col,
format_kdb_date(start),
format_kdb_date(end_date)
),
None => format!("{} >= {}", date_col, format_kdb_date(start)),
};
let lower = query.to_lowercase();
if lower.contains(" where ") {
format!("{}, {}", query, filter)
} else {
format!("{} where {}", query, filter)
}
}
fn chunk_stream<T>(
mut socket: QStream,
time_col: String,
mut query_fn: impl FnMut(Option<usize>) -> Option<String> + Send + 'static,
) -> impl futures::Stream<Item = anyhow::Result<(NanoTime, T)>> + Send + 'static
where
T: KdbDeserialize + Send + 'static,
{
async_stream::stream! {
let mut last_count: Option<usize> = None;
let mut interner = SymbolInterner::default();
while let Some(query) = query_fn(last_count) {
let result: K = match socket.send_sync_message(&query.as_str()).await {
Ok(r) => r,
Err(e) => {
yield Err(anyhow::Error::new(e).context(format!("KDB query failed: {}", query)));
break;
}
};
let (columns, rows) = match (result.column_names(), result.rows()) {
(Ok(cols), Ok(rows)) => (cols, rows),
(Err(e), _) | (_, Err(e)) => {
yield Err(anyhow::anyhow!("{}\nkdb query failed with\n{}", query, e));
break;
}
};
let row_count = rows.len();
info!("KDB query: {} ({} records)", query, row_count);
if row_count == 0 {
break;
}
let time_col_idx = match columns.iter().position(|c| c == &time_col) {
Some(idx) => idx,
None => {
yield Err(anyhow::anyhow!(
"time column '{}' not found in result columns: {:?}",
time_col, columns
));
break;
}
};
let mut prev_time: Option<i64> = None;
let mut row_error = false;
for row in &rows {
let time_kdb = match row.get(time_col_idx).and_then(|v| v.get_long()) {
Ok(t) => t,
Err(e) => {
yield Err(anyhow::Error::new(e).context(format!("failed to extract time from KDB row: {}", query)));
row_error = true;
break;
}
};
if let Some(prev) = prev_time
&& time_kdb < prev
{
yield Err(anyhow::anyhow!(
"KDB data is not sorted by time column '{}': got {} after {}. \
Add `{} xasc` to your query to sort the data.\nQuery: {}",
time_col, time_kdb, prev, time_col, query
));
row_error = true;
break;
}
prev_time = Some(time_kdb);
let time = NanoTime::from_kdb_timestamp(time_kdb);
let record = match T::from_kdb_row(row, &columns, &mut interner) {
Ok(r) => r,
Err(e) => {
yield Err(anyhow::Error::new(e).context(format!("KDB deserialization failed: {}", query)));
row_error = true;
break;
}
};
yield Ok((time, record));
}
if row_error {
break;
}
last_count = Some(row_count);
}
}
}
#[must_use]
pub fn kdb_read_chunks<T, F>(
connection: KdbConnection,
query_fn: F,
time_col: impl Into<String>,
) -> Rc<dyn Stream<Burst<T>>>
where
T: Element + Send + KdbDeserialize + 'static,
F: FnMut(Option<usize>) -> Option<String> + Send + 'static,
{
let time_col = time_col.into();
produce_async(move |_ctx| async move {
let creds = connection.credentials_string();
let socket = QStream::connect(
ConnectionMethod::TCP,
&connection.host,
connection.port,
&creds,
)
.await?;
Ok(chunk_stream::<T>(socket, time_col, query_fn))
})
}
#[must_use]
pub fn kdb_read<T>(
connection: KdbConnection,
query: impl Into<String>,
time_col: impl Into<String>,
date_col: Option<impl Into<String>>,
rows_per_chunk: usize,
) -> Rc<dyn Stream<Burst<T>>>
where
T: Element + Send + KdbDeserialize + 'static,
{
let query = query.into();
let time_col = time_col.into();
let date_col = date_col.map(|d| d.into());
produce_async(move |ctx| {
let start_time = ctx.start_time;
let end_time = ctx.end_time().ok();
async move {
let base_query = {
let q = match &date_col {
Some(dc) => {
let start_kdb_date = nano_to_kdb_date(start_time);
let end_kdb_date = end_time.map(nano_to_kdb_date);
inject_date_filter(&query, dc, start_kdb_date, end_kdb_date)
}
None => query.clone(),
};
let start_kdb_ts = start_time.to_kdb_timestamp();
let end_kdb_ts = end_time.map(|t| t.to_kdb_timestamp());
inject_time_filter(&q, &time_col, start_kdb_ts, end_kdb_ts)
};
let creds = connection.credentials_string();
let socket = QStream::connect(
ConnectionMethod::TCP,
&connection.host,
connection.port,
&creds,
)
.await?;
let mut offset = 0usize;
let query_fn = move |last_count: Option<usize>| -> Option<String> {
match last_count {
None => {}
Some(n) if n < rows_per_chunk => return None,
Some(n) => offset += n,
}
Some(build_chunk_query(&base_query, offset, rows_per_chunk))
};
Ok(chunk_stream::<T>(socket, time_col, query_fn))
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nanotime_from_kdb_timestamp() {
let kdb_time: i64 = 0;
let nano = NanoTime::from_kdb_timestamp(kdb_time);
assert_eq!(u64::from(nano), 946_684_800_000_000_000);
let kdb_time: i64 = 1_000_000_000; let nano = NanoTime::from_kdb_timestamp(kdb_time);
assert_eq!(u64::from(nano), 946_684_801_000_000_000);
}
#[test]
fn test_nanotime_kdb_timestamp_round_trip() {
let original = NanoTime::new(1_000_000_000_000_000_000); let kdb_ts = original.to_kdb_timestamp();
let restored = NanoTime::from_kdb_timestamp(kdb_ts);
assert_eq!(original, restored);
let kdb_epoch = NanoTime::new(946_684_800_000_000_000);
assert_eq!(kdb_epoch.to_kdb_timestamp(), 0);
let after_epoch = NanoTime::new(946_684_801_000_000_000); assert_eq!(after_epoch.to_kdb_timestamp(), 1_000_000_000);
}
#[test]
fn test_format_kdb_date() {
assert_eq!(format_kdb_date(0), "2000.01.01");
assert_eq!(format_kdb_date(1), "2000.01.02");
assert_eq!(format_kdb_date(31), "2000.02.01");
assert_eq!(format_kdb_date(366), "2001.01.01");
assert_eq!(format_kdb_date(-1), "1999.12.31");
assert_eq!(format_kdb_date(-365), "1999.01.01");
}
#[test]
fn test_nano_to_kdb_date() {
let t = NanoTime::from_kdb_timestamp(0);
assert_eq!(nano_to_kdb_date(t), 0);
let t2 = NanoTime::from_kdb_timestamp(86_400_000_000_000);
assert_eq!(nano_to_kdb_date(t2), 1);
assert_eq!(nano_to_kdb_date(NanoTime::ZERO), -10957);
}
#[test]
fn test_build_chunk_query_basic_select() {
let q = build_chunk_query("select from trades", 0, 10000);
assert_eq!(q, "select[0,10000] from trades");
}
#[test]
fn test_build_chunk_query_with_offset() {
let q = build_chunk_query("select from trades", 30000, 10000);
assert_eq!(q, "select[30000,10000] from trades");
}
#[test]
fn test_build_chunk_query_preserves_where_clause() {
let q = build_chunk_query("select from trades where sym=`AAPL", 0, 1000);
assert_eq!(q, "select[0,1000] from trades where sym=`AAPL");
}
#[test]
fn test_build_chunk_query_preserves_columns() {
let q = build_chunk_query("select price,qty from trades", 5000, 1000);
assert_eq!(q, "select[5000,1000] price,qty from trades");
}
#[test]
fn test_build_chunk_query_replaces_existing_bracket() {
let q = build_chunk_query("select[5] from trades", 0, 10000);
assert_eq!(q, "select[0,10000] from trades");
}
#[test]
fn test_build_chunk_query_case_insensitive() {
let q = build_chunk_query("SELECT FROM trades", 0, 100);
assert_eq!(q, "select[0,100] FROM trades");
}
#[test]
fn test_build_chunk_query_non_select_fallback() {
let q = build_chunk_query("exec price from trades", 0, 100);
assert_eq!(q, "(0;100) sublist exec price from trades");
}
#[test]
fn test_inject_date_filter_bounded_no_existing_where() {
let q = inject_date_filter("select from trades", "date", 0, Some(1));
assert_eq!(
q,
"select from trades where date within (2000.01.01;2000.01.02)"
);
}
#[test]
fn test_inject_date_filter_bounded_with_existing_where() {
let q = inject_date_filter("select from trades where sym=`AAPL", "date", 0, Some(1));
assert_eq!(
q,
"select from trades where sym=`AAPL, date within (2000.01.01;2000.01.02)"
);
}
#[test]
fn test_inject_date_filter_unbounded_no_existing_where() {
let q = inject_date_filter("select from trades", "date", 0, None);
assert_eq!(q, "select from trades where date >= 2000.01.01");
}
#[test]
fn test_inject_date_filter_unbounded_with_existing_where() {
let q = inject_date_filter("select from trades where sym=`AAPL", "date", 0, None);
assert_eq!(q, "select from trades where sym=`AAPL, date >= 2000.01.01");
}
#[test]
fn test_inject_date_filter_same_start_end() {
let q = inject_date_filter("select from trades", "date", 5, Some(5));
assert_eq!(
q,
"select from trades where date within (2000.01.06;2000.01.06)"
);
}
#[test]
fn test_format_kdb_timestamp_epoch() {
assert_eq!(format_kdb_timestamp(0), "2000.01.01D00:00:00.000000000");
}
#[test]
fn test_format_kdb_timestamp_one_second() {
assert_eq!(
format_kdb_timestamp(1_000_000_000),
"2000.01.01D00:00:01.000000000"
);
}
#[test]
fn test_format_kdb_timestamp_one_day() {
assert_eq!(
format_kdb_timestamp(86_400_000_000_000),
"2000.01.02D00:00:00.000000000"
);
}
#[test]
fn test_format_kdb_timestamp_before_epoch() {
assert_eq!(format_kdb_timestamp(-1), "1999.12.31D23:59:59.999999999");
}
#[test]
fn test_inject_time_filter_unbounded() {
let q = inject_time_filter("select from trades", "time", 0, None);
assert_eq!(
q,
"select from trades where time >= 2000.01.01D00:00:00.000000000"
);
}
#[test]
fn test_inject_time_filter_bounded() {
let q = inject_time_filter("select from trades", "time", 0, Some(86_400_000_000_000));
assert_eq!(
q,
"select from trades where time within (2000.01.01D00:00:00.000000000;2000.01.02D00:00:00.000000000)"
);
}
#[test]
fn test_inject_time_filter_max_uses_0wp() {
let q = inject_time_filter("select from trades", "time", 0, Some(i64::MAX));
assert_eq!(
q,
"select from trades where time within (2000.01.01D00:00:00.000000000;0Wp)"
);
}
#[test]
fn test_inject_time_filter_with_existing_where() {
let q = inject_time_filter("select from trades where sym=`AAPL", "time", 0, None);
assert_eq!(
q,
"select from trades where sym=`AAPL, time >= 2000.01.01D00:00:00.000000000"
);
}
}