1use std::path::Path;
13
14use serde::Deserialize;
15
16use crate::error::{Error, Result};
17use wickra_core::Candle;
18
19const REQUIRED_COLUMNS: [&str; 6] = ["timestamp", "open", "high", "low", "close", "volume"];
24
25#[derive(Debug, Clone, Deserialize)]
31pub struct DefaultRow {
32 pub timestamp: i64,
33 pub open: f64,
34 pub high: f64,
35 pub low: f64,
36 pub close: f64,
37 pub volume: f64,
38}
39
40impl DefaultRow {
41 fn into_candle(self) -> Result<Candle> {
42 Candle::new(
43 self.open,
44 self.high,
45 self.low,
46 self.close,
47 self.volume,
48 self.timestamp,
49 )
50 .map_err(Error::from)
51 }
52}
53
54#[derive(Debug)]
63pub struct BomStripReader<R> {
64 inner: R,
65 checked: bool,
67 leftover: Vec<u8>,
70 leftover_pos: usize,
71}
72
73impl<R: std::io::Read> BomStripReader<R> {
74 pub fn new(inner: R) -> Self {
76 Self {
77 inner,
78 checked: false,
79 leftover: Vec::new(),
80 leftover_pos: 0,
81 }
82 }
83
84 fn check_bom(&mut self) -> std::io::Result<()> {
87 if self.checked {
88 return Ok(());
89 }
90 self.checked = true;
91
92 let mut probe = [0u8; 3];
93 let mut filled = 0;
94 while filled < probe.len() {
95 let n = self.inner.read(&mut probe[filled..])?;
96 if n == 0 {
97 break; }
99 filled += n;
100 }
101
102 if probe[..filled] != [0xEF, 0xBB, 0xBF] {
103 self.leftover.extend_from_slice(&probe[..filled]);
105 }
106 Ok(())
107 }
108}
109
110impl<R: std::io::Read> std::io::Read for BomStripReader<R> {
111 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
112 self.check_bom()?;
113 if self.leftover_pos < self.leftover.len() {
114 let n = (self.leftover.len() - self.leftover_pos).min(buf.len());
115 buf[..n].copy_from_slice(&self.leftover[self.leftover_pos..self.leftover_pos + n]);
116 self.leftover_pos += n;
117 return Ok(n);
118 }
119 self.inner.read(buf)
120 }
121}
122
123fn validate_headers<R: std::io::Read>(reader: &mut csv::Reader<R>) -> Result<()> {
125 let headers = reader.headers()?;
126 let present: Vec<String> = headers.iter().map(|h| h.trim().to_string()).collect();
127 let missing: Vec<&str> = REQUIRED_COLUMNS
128 .iter()
129 .copied()
130 .filter(|col| !present.iter().any(|h| h == col))
131 .collect();
132 if !missing.is_empty() {
133 return Err(Error::Malformed(format!(
134 "CSV header is missing required column(s) [{}]; found [{}] — \
135 the first line must be a header naming {}",
136 missing.join(", "),
137 present.join(", "),
138 REQUIRED_COLUMNS.join(",")
139 )));
140 }
141 Ok(())
142}
143
144#[derive(Debug)]
146pub struct CandleReader<R: std::io::Read> {
147 reader: csv::Reader<R>,
148}
149
150impl<R: std::io::Read> CandleReader<R> {
151 fn build(inner: R) -> Result<Self> {
153 let mut reader = csv::ReaderBuilder::new()
154 .has_headers(true)
155 .trim(csv::Trim::All)
156 .from_reader(inner);
157 validate_headers(&mut reader)?;
158 Ok(Self { reader })
159 }
160}
161
162impl CandleReader<BomStripReader<std::fs::File>> {
163 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
173 let file = std::fs::File::open(path)?;
174 Self::from_reader(file)
175 }
176}
177
178impl<R: std::io::Read> CandleReader<BomStripReader<R>> {
179 pub fn from_reader(inner: R) -> Result<Self> {
187 Self::build(BomStripReader::new(inner))
188 }
189}
190
191impl<R: std::io::Read> CandleReader<R> {
192 pub fn from_csv_reader(mut reader: csv::Reader<R>) -> Result<Self> {
203 validate_headers(&mut reader)?;
204 Ok(Self { reader })
205 }
206
207 pub fn candles(&mut self) -> impl Iterator<Item = Result<Candle>> + '_ {
209 self.reader.deserialize::<DefaultRow>().map(|row_res| {
210 let row = row_res?;
211 row.into_candle()
212 })
213 }
214
215 pub fn read_all(&mut self) -> Result<Vec<Candle>> {
217 self.candles().collect()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use std::io::Write;
225
226 #[test]
227 fn reads_well_formed_csv() {
228 let mut tmp = tempfile::NamedTempFile::new().unwrap();
229 writeln!(tmp, "timestamp,open,high,low,close,volume").unwrap();
230 writeln!(tmp, "1,10.0,11.0,9.0,10.5,100").unwrap();
231 writeln!(tmp, "2,10.5,11.5,10.0,11.0,150").unwrap();
232 writeln!(tmp, "3,11.0,12.0,10.5,11.5,200").unwrap();
233 tmp.flush().unwrap();
234
235 let mut r = CandleReader::open(tmp.path()).unwrap();
236 let candles = r.read_all().unwrap();
237 assert_eq!(candles.len(), 3);
238 assert_eq!(candles[0].open, 10.0);
239 assert_eq!(candles[2].close, 11.5);
240 assert_eq!(candles[1].timestamp, 2);
241 }
242
243 #[test]
244 fn rejects_invalid_ohlc() {
245 let mut tmp = tempfile::NamedTempFile::new().unwrap();
246 writeln!(tmp, "timestamp,open,high,low,close,volume").unwrap();
247 writeln!(tmp, "1,10.0,8.0,9.0,9.5,100").unwrap();
249 tmp.flush().unwrap();
250
251 let mut r = CandleReader::open(tmp.path()).unwrap();
252 let candles: Result<Vec<Candle>> = r.candles().collect();
253 assert!(candles.is_err());
254 }
255
256 #[test]
257 fn from_reader_works_on_in_memory_data() {
258 let data = "timestamp,open,high,low,close,volume\n1,1,2,0,1,10\n2,1,2,0,1,10\n";
259 let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
260 let v = r.read_all().unwrap();
261 assert_eq!(v.len(), 2);
262 }
263
264 #[test]
265 fn rejects_file_without_header() {
266 let data = "1,10.0,11.0,9.0,10.5,100\n2,10.5,11.5,10.0,11.0,150\n";
269 let err = CandleReader::from_reader(data.as_bytes()).unwrap_err();
270 assert!(matches!(err, Error::Malformed(_)));
271 }
272
273 #[test]
274 fn rejects_header_missing_a_column() {
275 let data = "timestamp,open,high,low,close\n1,10.0,11.0,9.0,10.5\n";
277 let err = CandleReader::from_reader(data.as_bytes()).unwrap_err();
278 assert!(
283 matches!(&err, Error::Malformed(msg) if msg.contains("volume")),
284 "expected Malformed mentioning 'volume', got {err:?}"
285 );
286 }
287
288 #[test]
293 fn from_csv_reader_accepts_a_prebuilt_reader() {
294 let data = "timestamp;open;high;low;close;volume\n1;10.0;11.0;9.0;10.5;100\n";
295 let inner = csv::ReaderBuilder::new()
296 .delimiter(b';')
297 .from_reader(data.as_bytes());
298 let mut r = CandleReader::from_csv_reader(inner).unwrap();
299 let candles = r.read_all().unwrap();
300 assert_eq!(candles.len(), 1);
301 assert_eq!(candles[0].close, 10.5);
302 }
303
304 #[test]
305 fn strips_leading_utf8_bom() {
306 let data = "\u{feff}timestamp,open,high,low,close,volume\n1,10.0,11.0,9.0,10.5,100\n";
308 let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
309 let v = r.read_all().unwrap();
310 assert_eq!(v.len(), 1);
311 assert_eq!(v[0].timestamp, 1);
312 assert_eq!(v[0].open, 10.0);
313 }
314
315 #[test]
316 fn tolerates_whitespace_around_fields() {
317 let data = " timestamp , open , high , low , close , volume \n\
318 1 , 10.0 , 11.0 , 9.0 , 10.5 , 100 \n";
319 let mut r = CandleReader::from_reader(data.as_bytes()).unwrap();
320 let v = r.read_all().unwrap();
321 assert_eq!(v.len(), 1);
322 assert_eq!(v[0].close, 10.5);
323 assert_eq!(v[0].volume, 100.0);
324 }
325
326 #[test]
327 fn bom_stripper_passes_through_non_bom_input() {
328 use std::io::Read;
329 let mut out = String::new();
330 BomStripReader::new("hello".as_bytes())
331 .read_to_string(&mut out)
332 .unwrap();
333 assert_eq!(out, "hello");
334 }
335
336 #[test]
337 fn bom_stripper_handles_short_input() {
338 use std::io::Read;
339 let mut out = Vec::new();
340 BomStripReader::new([0x41u8, 0x42u8].as_slice())
342 .read_to_end(&mut out)
343 .unwrap();
344 assert_eq!(out, vec![0x41, 0x42]);
345 }
346}