use smol_str::SmolStr;
use std::io::Write;
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct Bar {
pub timestamp_nanos: i64,
pub open: i64,
pub high: i64,
pub low: i64,
pub close: i64,
pub volume: i64,
pub tick_count: u32,
pub vwap: i64,
}
pub struct BarSeries {
pub(crate) bars: Vec<Bar>,
symbol: SmolStr,
interval_nanos: i64,
_price_scale: u8,
_volume_scale: u8,
timezone_offset: i32,
}
impl BarSeries {
pub fn new(symbol: impl Into<SmolStr>, interval_nanos: i64) -> Self {
BarSeries {
bars: Vec::new(),
symbol: symbol.into(),
interval_nanos,
_price_scale: 8,
_volume_scale: 0,
timezone_offset: 0,
}
}
pub fn as_slice(&self) -> &[Bar] {
&self.bars
}
pub fn bars_mut(&mut self) -> &mut Vec<Bar> {
&mut self.bars
}
pub fn into_inner(self) -> Vec<Bar> {
self.bars
}
pub fn push(&mut self, bar: Bar) {
self.bars.push(bar);
}
pub fn symbol(&self) -> &SmolStr {
&self.symbol
}
pub fn interval_nanos(&self) -> i64 {
self.interval_nanos
}
pub fn with_timezone_offset(mut self, offset: i32) -> Self {
self.timezone_offset = offset;
self
}
#[cfg(feature = "arrow-export")]
pub fn to_arrow(&self) -> Result<arrow::record_batch::RecordBatch, arrow::error::ArrowError> {
use arrow::array::{Int64Array, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("timestamp_nanos", DataType::Int64, false),
Field::new("open", DataType::Int64, false),
Field::new("high", DataType::Int64, false),
Field::new("low", DataType::Int64, false),
Field::new("close", DataType::Int64, false),
Field::new("volume", DataType::Int64, false),
Field::new("tick_count", DataType::UInt32, false),
Field::new("vwap", DataType::Int64, false),
]));
let n = self.bars.len();
let cap = n.max(1);
let mut ts = Vec::with_capacity(cap);
let mut open = Vec::with_capacity(cap);
let mut high = Vec::with_capacity(cap);
let mut low = Vec::with_capacity(cap);
let mut close = Vec::with_capacity(cap);
let mut volume = Vec::with_capacity(cap);
let mut tick_count = Vec::with_capacity(cap);
let mut vwap = Vec::with_capacity(cap);
for bar in &self.bars {
ts.push(bar.timestamp_nanos);
open.push(bar.open);
high.push(bar.high);
low.push(bar.low);
close.push(bar.close);
volume.push(bar.volume);
tick_count.push(bar.tick_count);
vwap.push(bar.vwap);
}
let columns: Vec<arrow::array::ArrayRef> = vec![
Arc::new(Int64Array::from(ts)),
Arc::new(Int64Array::from(open)),
Arc::new(Int64Array::from(high)),
Arc::new(Int64Array::from(low)),
Arc::new(Int64Array::from(close)),
Arc::new(Int64Array::from(volume)),
Arc::new(UInt32Array::from(tick_count)),
Arc::new(Int64Array::from(vwap)),
];
let batch = arrow::record_batch::RecordBatch::try_new(schema, columns)?;
Ok(batch)
}
#[cfg(feature = "polars-export")]
pub fn to_polars(&self) -> polars::prelude::PolarsResult<polars::frame::DataFrame> {
use polars::prelude::*;
let ts: Vec<i64> = self.bars.iter().map(|b| b.timestamp_nanos).collect();
let open: Vec<i64> = self.bars.iter().map(|b| b.open).collect();
let high: Vec<i64> = self.bars.iter().map(|b| b.high).collect();
let low: Vec<i64> = self.bars.iter().map(|b| b.low).collect();
let close: Vec<i64> = self.bars.iter().map(|b| b.close).collect();
let volume: Vec<i64> = self.bars.iter().map(|b| b.volume).collect();
let tick_count: Vec<u32> = self.bars.iter().map(|b| b.tick_count).collect();
let vwap: Vec<i64> = self.bars.iter().map(|b| b.vwap).collect();
let cols = vec![
Column::new("timestamp_nanos".into(), ts),
Column::new("open".into(), open),
Column::new("high".into(), high),
Column::new("low".into(), low),
Column::new("close".into(), close),
Column::new("volume".into(), volume),
Column::new("tick_count".into(), tick_count),
Column::new("vwap".into(), vwap),
];
DataFrame::new(self.bars.len(), cols)
}
pub fn to_csv<W: Write>(&self, writer: &mut csv::Writer<W>) -> Result<(), csv::Error> {
for bar in &self.bars {
writer.serialize((
bar.timestamp_nanos,
bar.open,
bar.high,
bar.low,
bar.close,
bar.volume,
bar.tick_count,
bar.vwap,
))?;
}
writer.flush()?;
Ok(())
}
pub fn resample(&self, new_interval_nanos: i64) -> Result<BarSeries, crate::Error> {
if new_interval_nanos % self.interval_nanos != 0 {
return Err(crate::Error::InvalidConfiguration(
"new interval must be a multiple of the current interval".into(),
));
}
let factor = (new_interval_nanos / self.interval_nanos) as usize;
let mut out = BarSeries::new(self.symbol.clone(), new_interval_nanos);
let mut i = 0;
while i < self.bars.len() {
let chunk_end = (i + factor).min(self.bars.len());
let first = &self.bars[i];
let mut high = first.high;
let mut low = first.low;
let mut volume = first.volume;
let mut tick_count = first.tick_count;
let mut vwap_num = first.vwap * first.volume;
let last_idx = chunk_end - 1;
for j in (i + 1)..chunk_end {
let b = &self.bars[j];
if b.high > high {
high = b.high;
}
if b.low < low {
low = b.low;
}
volume += b.volume;
tick_count += b.tick_count;
vwap_num += b.vwap * b.volume;
}
let last = &self.bars[last_idx];
let vwap = if volume > 0 {
vwap_num / volume
} else {
last.close
};
out.push(Bar {
timestamp_nanos: first.timestamp_nanos,
open: first.open,
high,
low,
close: last.close,
volume,
tick_count,
vwap,
});
i = chunk_end;
}
Ok(out)
}
}
#[derive(Debug)]
pub struct BarBuilder {
pub start_time: i64,
pub end_time: i64,
pub open: Option<i64>,
pub high: i64,
pub low: i64,
pub close: i64,
pub volume_sum: i64,
pub vwap_numerator: i64,
pub tick_count: u32,
}
impl BarBuilder {
pub fn new(start_time: i64, end_time: i64) -> Self {
BarBuilder {
start_time,
end_time,
open: None,
high: i64::MIN,
low: i64::MAX,
close: 0,
volume_sum: 0,
vwap_numerator: 0,
tick_count: 0,
}
}
#[inline(always)]
pub fn update(&mut self, price: i64, volume: i64) {
if self.open.is_none() {
self.open = Some(price);
}
self.high = self.high.max(price);
self.low = self.low.min(price);
self.close = price;
self.volume_sum += volume;
self.vwap_numerator += price * volume;
self.tick_count += 1;
}
pub fn build(&self) -> Bar {
let open = self.open.unwrap_or(self.close);
Bar {
timestamp_nanos: self.start_time,
open,
high: self.high,
low: self.low,
close: self.close,
volume: self.volume_sum,
tick_count: self.tick_count,
vwap: if self.volume_sum > 0 {
self.vwap_numerator / self.volume_sum
} else {
self.close
},
}
}
pub fn is_empty(&self) -> bool {
self.tick_count == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bar_builder_basic() {
let mut b = BarBuilder::new(0, 60_000_000_000);
b.update(100, 1000);
b.update(200, 500);
let bar = b.build();
assert_eq!(bar.open, 100);
assert_eq!(bar.high, 200);
assert_eq!(bar.low, 100);
assert_eq!(bar.close, 200);
assert_eq!(bar.volume, 1500);
assert_eq!(bar.tick_count, 2);
assert_eq!(bar.vwap, (100 * 1000 + 200 * 500) / 1500);
}
#[test]
fn test_bar_builder_empty() {
let b = BarBuilder::new(0, 60_000_000_000);
assert!(b.is_empty());
}
#[test]
fn test_bar_builder_single_tick() {
let mut b = BarBuilder::new(0, 60_000_000_000);
b.update(150, 2000);
let bar = b.build();
assert_eq!(bar.open, 150);
assert_eq!(bar.close, 150);
assert_eq!(bar.high, 150);
assert_eq!(bar.low, 150);
assert_eq!(bar.vwap, 150);
}
#[test]
fn test_bar_series_push_and_slice() {
let mut s = BarSeries::new("AAPL", 60_000_000_000);
assert_eq!(s.as_slice().len(), 0);
let bar = Bar {
timestamp_nanos: 0,
open: 100,
high: 110,
low: 90,
close: 105,
volume: 5000,
tick_count: 10,
vwap: 102,
};
s.push(bar);
assert_eq!(s.as_slice().len(), 1);
assert_eq!(s.symbol(), "AAPL");
assert_eq!(s.interval_nanos(), 60_000_000_000);
}
#[test]
fn test_bar_series_resample() {
let mut s = BarSeries::new("TEST", 60_000_000_000);
for i in 0..4 {
s.push(Bar {
timestamp_nanos: i * 60_000_000_000,
open: 100 + i,
high: 110 + i,
low: 90 + i,
close: 105 + i,
volume: 1000,
tick_count: 5,
vwap: 102 + i,
});
}
let resampled = s.resample(120_000_000_000).unwrap();
assert_eq!(resampled.as_slice().len(), 2);
assert_eq!(resampled.as_slice()[0].open, 100);
assert_eq!(resampled.as_slice()[0].close, 106);
assert_eq!(resampled.as_slice()[0].volume, 2000);
assert_eq!(resampled.as_slice()[0].tick_count, 10);
}
#[test]
fn test_resample_invalid_interval() {
let s = BarSeries::new("TEST", 60_000_000_000);
let result = s.resample(90_000_000_000);
assert!(result.is_err());
}
#[test]
fn test_to_csv() {
let mut s = BarSeries::new("TEST", 60_000_000_000);
s.push(Bar {
timestamp_nanos: 0,
open: 100,
high: 110,
low: 90,
close: 105,
volume: 1000,
tick_count: 5,
vwap: 102,
});
let mut buf = Vec::new();
let mut w = csv::Writer::from_writer(&mut buf);
s.to_csv(&mut w).unwrap();
drop(w);
let output = String::from_utf8(buf).unwrap();
assert!(output.contains("100,110,90,105,1000,5,102"));
}
}