use std::path::Path;
use serde::Deserialize;
use crate::error::{Error, Result};
use wickra_core::Candle;
const REQUIRED_COLUMNS: [&str; 6] = ["timestamp", "open", "high", "low", "close", "volume"];
#[derive(Debug, Clone, Deserialize)]
pub struct DefaultRow {
pub timestamp: i64,
pub open: f64,
pub high: f64,
pub low: f64,
pub close: f64,
pub volume: f64,
}
impl DefaultRow {
fn into_candle(self) -> Result<Candle> {
Candle::new(
self.open,
self.high,
self.low,
self.close,
self.volume,
self.timestamp,
)
.map_err(Error::from)
}
}
#[derive(Debug)]
pub struct BomStripReader<R> {
inner: R,
checked: bool,
leftover: Vec<u8>,
leftover_pos: usize,
}
impl<R: std::io::Read> BomStripReader<R> {
pub fn new(inner: R) -> Self {
Self {
inner,
checked: false,
leftover: Vec::new(),
leftover_pos: 0,
}
}
fn check_bom(&mut self) -> std::io::Result<()> {
if self.checked {
return Ok(());
}
self.checked = true;
let mut probe = [0u8; 3];
let mut filled = 0;
while filled < probe.len() {
let n = self.inner.read(&mut probe[filled..])?;
if n == 0 {
break; }
filled += n;
}
if probe[..filled] != [0xEF, 0xBB, 0xBF] {
self.leftover.extend_from_slice(&probe[..filled]);
}
Ok(())
}
}
impl<R: std::io::Read> std::io::Read for BomStripReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.check_bom()?;
if self.leftover_pos < self.leftover.len() {
let n = (self.leftover.len() - self.leftover_pos).min(buf.len());
buf[..n].copy_from_slice(&self.leftover[self.leftover_pos..self.leftover_pos + n]);
self.leftover_pos += n;
return Ok(n);
}
self.inner.read(buf)
}
}
fn validate_headers<R: std::io::Read>(reader: &mut csv::Reader<R>) -> Result<()> {
let headers = reader.headers()?;
let present: Vec<String> = headers.iter().map(|h| h.trim().to_string()).collect();
let missing: Vec<&str> = REQUIRED_COLUMNS
.iter()
.copied()
.filter(|col| !present.iter().any(|h| h == col))
.collect();
if !missing.is_empty() {
return Err(Error::Malformed(format!(
"CSV header is missing required column(s) [{}]; found [{}] — \
the first line must be a header naming {}",
missing.join(", "),
present.join(", "),
REQUIRED_COLUMNS.join(",")
)));
}
Ok(())
}
#[derive(Debug)]
pub struct CandleReader<R: std::io::Read> {
reader: csv::Reader<R>,
}
impl<R: std::io::Read> CandleReader<R> {
fn build(inner: R) -> Result<Self> {
let mut reader = csv::ReaderBuilder::new()
.has_headers(true)
.trim(csv::Trim::All)
.from_reader(inner);
validate_headers(&mut reader)?;
Ok(Self { reader })
}
}
impl CandleReader<BomStripReader<std::fs::File>> {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = std::fs::File::open(path)?;
Self::from_reader(file)
}
}
impl<R: std::io::Read> CandleReader<BomStripReader<R>> {
pub fn from_reader(inner: R) -> Result<Self> {
Self::build(BomStripReader::new(inner))
}
}
impl<R: std::io::Read> CandleReader<R> {
pub fn from_csv_reader(mut reader: csv::Reader<R>) -> Result<Self> {
validate_headers(&mut reader)?;
Ok(Self { reader })
}
pub fn candles(&mut self) -> impl Iterator<Item = Result<Candle>> + '_ {
self.reader.deserialize::<DefaultRow>().map(|row_res| {
let row = row_res?;
row.into_candle()
})
}
pub fn read_all(&mut self) -> Result<Vec<Candle>> {
self.candles().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn reads_well_formed_csv() {
let mut tmp = tempfile::NamedTempFile::new().unwrap();
writeln!(tmp, "timestamp,open,high,low,close,volume").unwrap();
writeln!(tmp, "1,10.0,11.0,9.0,10.5,100").unwrap();
writeln!(tmp, "2,10.5,11.5,10.0,11.0,150").unwrap();
writeln!(tmp, "3,11.0,12.0,10.5,11.5,200").unwrap();
tmp.flush().unwrap();
let mut r = CandleReader::open(tmp.path()).unwrap();
let candles = r.read_all().unwrap();
assert_eq!(candles.len(), 3);
assert_eq!(candles[0].open, 10.0);
assert_eq!(candles[2].close, 11.5);
assert_eq!(candles[1].timestamp, 2);
}
#[test]
fn rejects_invalid_ohlc() {
let mut tmp = tempfile::NamedTempFile::new().unwrap();
writeln!(tmp, "timestamp,open,high,low,close,volume").unwrap();
writeln!(tmp, "1,10.0,8.0,9.0,9.5,100").unwrap();
tmp.flush().unwrap();
let mut r = CandleReader::open(tmp.path()).unwrap();
let candles: Result<Vec<Candle>> = r.candles().collect();
assert!(candles.is_err());
}
#[test]
fn from_reader_works_on_in_memory_data() {
let data = "timestamp,open,high,low,close,volume\n1,1,2,0,1,10\n2,1,2,0,1,10\n";
let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
let v = r.read_all().unwrap();
assert_eq!(v.len(), 2);
}
#[test]
fn rejects_file_without_header() {
let data = "1,10.0,11.0,9.0,10.5,100\n2,10.5,11.5,10.0,11.0,150\n";
let err = CandleReader::from_reader(data.as_bytes()).unwrap_err();
assert!(matches!(err, Error::Malformed(_)));
}
#[test]
fn rejects_header_missing_a_column() {
let data = "timestamp,open,high,low,close\n1,10.0,11.0,9.0,10.5\n";
let err = CandleReader::from_reader(data.as_bytes()).unwrap_err();
match err {
Error::Malformed(msg) => assert!(msg.contains("volume"), "msg: {msg}"),
other => panic!("expected Malformed, got {other:?}"),
}
}
#[test]
fn strips_leading_utf8_bom() {
let data = "\u{feff}timestamp,open,high,low,close,volume\n1,10.0,11.0,9.0,10.5,100\n";
let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
let v = r.read_all().unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0].timestamp, 1);
assert_eq!(v[0].open, 10.0);
}
#[test]
fn tolerates_whitespace_around_fields() {
let data = " timestamp , open , high , low , close , volume \n\
1 , 10.0 , 11.0 , 9.0 , 10.5 , 100 \n";
let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
let v = r.read_all().unwrap();
assert_eq!(v.len(), 1);
assert_eq!(v[0].close, 10.5);
assert_eq!(v[0].volume, 100.0);
}
#[test]
fn bom_stripper_passes_through_non_bom_input() {
use std::io::Read;
let mut out = String::new();
BomStripReader::new("hello".as_bytes())
.read_to_string(&mut out)
.unwrap();
assert_eq!(out, "hello");
}
#[test]
fn bom_stripper_handles_short_input() {
use std::io::Read;
let mut out = Vec::new();
BomStripReader::new([0x41u8, 0x42u8].as_slice())
.read_to_end(&mut out)
.unwrap();
assert_eq!(out, vec![0x41, 0x42]);
}
}