use crate::column::Column;
use crate::dataframe::{DataFrame, DataError};
#[derive(Debug, Clone)]
pub struct CsvConfig {
pub delimiter: u8,
pub has_header: bool,
pub max_rows: Option<usize>,
pub trim_whitespace: bool,
}
impl Default for CsvConfig {
fn default() -> Self {
CsvConfig {
delimiter: b',',
has_header: true,
max_rows: None,
trim_whitespace: true,
}
}
}
pub struct CsvReader {
config: CsvConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum InferredType {
Int,
Float,
Bool,
Str,
}
fn infer_type(s: &str) -> InferredType {
let t = s.trim();
if t == "true" || t == "false" || t == "1" || t == "0" {
return InferredType::Bool;
}
let digits = t.strip_prefix('-').unwrap_or(t);
if !digits.is_empty() && digits.bytes().all(|b| b.is_ascii_digit()) {
return InferredType::Int;
}
let no_sign = t.strip_prefix('-').unwrap_or(t);
let dot_count = no_sign.chars().filter(|&c| c == '.').count();
if dot_count == 1 {
let without_dot: String = no_sign.chars().filter(|&c| c != '.').collect();
if !without_dot.is_empty() && without_dot.bytes().all(|b| b.is_ascii_digit()) {
return InferredType::Float;
}
}
if t.parse::<f64>().is_ok() {
return InferredType::Float;
}
InferredType::Str
}
fn split_fields<'a>(row: &'a [u8], delimiter: u8) -> Vec<&'a str> {
let mut fields = Vec::new();
let mut start = 0usize;
for i in 0..row.len() {
if row[i] == delimiter {
let field = std::str::from_utf8(&row[start..i]).unwrap_or("");
fields.push(field);
start = i + 1;
}
}
let tail = &row[start..];
let tail = tail.strip_suffix(b"\r").unwrap_or(tail);
let field = std::str::from_utf8(tail).unwrap_or("");
fields.push(field);
fields
}
impl CsvReader {
pub fn new(config: CsvConfig) -> Self {
CsvReader { config }
}
pub fn parse(&self, input: &[u8]) -> Result<DataFrame, DataError> {
if input.is_empty() {
return Ok(DataFrame::new());
}
let rows: Vec<&[u8]> = input
.split(|&b| b == b'\n')
.filter(|r| !r.is_empty() && *r != b"\r")
.collect();
if rows.is_empty() {
return Ok(DataFrame::new());
}
let delim = self.config.delimiter;
let (header_names, data_rows) = if self.config.has_header {
let names: Vec<String> = split_fields(rows[0], delim)
.into_iter()
.map(|s| {
if self.config.trim_whitespace {
s.trim().to_string()
} else {
s.to_string()
}
})
.collect();
(names, &rows[1..])
} else {
let ncols = split_fields(rows[0], delim).len();
let names: Vec<String> = (0..ncols).map(|i| format!("col_{}", i)).collect();
(names, &rows[..])
};
let ncols = header_names.len();
if ncols == 0 {
return Ok(DataFrame::new());
}
let data_rows = if let Some(max) = self.config.max_rows {
&data_rows[..data_rows.len().min(max)]
} else {
data_rows
};
if data_rows.is_empty() {
let columns: Vec<(String, Column)> = header_names
.into_iter()
.map(|name| (name, Column::Str(Vec::new())))
.collect();
return DataFrame::from_columns(columns);
}
let first_fields = split_fields(data_rows[0], delim);
let mut col_types: Vec<InferredType> = first_fields
.iter()
.map(|s| {
let s = if self.config.trim_whitespace {
s.trim()
} else {
*s
};
infer_type(s)
})
.collect();
while col_types.len() < ncols {
col_types.push(InferredType::Str);
}
let nrows = data_rows.len();
let mut int_bufs: Vec<Option<Vec<i64>>> = vec![None; ncols];
let mut float_bufs: Vec<Option<Vec<f64>>> = vec![None; ncols];
let mut bool_bufs: Vec<Option<Vec<bool>>> = vec![None; ncols];
let mut str_bufs: Vec<Option<Vec<String>>> = vec![None; ncols];
for (i, &t) in col_types.iter().enumerate() {
match t {
InferredType::Int => int_bufs[i] = Some(Vec::with_capacity(nrows)),
InferredType::Float => float_bufs[i] = Some(Vec::with_capacity(nrows)),
InferredType::Bool => bool_bufs[i] = Some(Vec::with_capacity(nrows)),
InferredType::Str => str_bufs[i] = Some(Vec::with_capacity(nrows)),
}
}
for &row_bytes in data_rows.iter() {
let fields = split_fields(row_bytes, delim);
for col_idx in 0..ncols {
let raw = if col_idx < fields.len() {
fields[col_idx]
} else {
""
};
let s = if self.config.trim_whitespace {
raw.trim()
} else {
raw
};
match col_types[col_idx] {
InferredType::Int => {
let v = s.parse::<i64>().unwrap_or(0);
int_bufs[col_idx].as_mut().unwrap().push(v);
}
InferredType::Float => {
let v = s.parse::<f64>().unwrap_or(0.0);
float_bufs[col_idx].as_mut().unwrap().push(v);
}
InferredType::Bool => {
let v = matches!(s, "true" | "1");
bool_bufs[col_idx].as_mut().unwrap().push(v);
}
InferredType::Str => {
str_bufs[col_idx].as_mut().unwrap().push(s.to_string());
}
}
}
}
let mut columns: Vec<(String, Column)> = Vec::with_capacity(ncols);
for (i, name) in header_names.into_iter().enumerate() {
let col = match col_types[i] {
InferredType::Int => Column::Int(int_bufs[i].take().unwrap()),
InferredType::Float => Column::Float(float_bufs[i].take().unwrap()),
InferredType::Bool => Column::Bool(bool_bufs[i].take().unwrap()),
InferredType::Str => Column::Str(str_bufs[i].take().unwrap()),
};
columns.push((name, col));
}
DataFrame::from_columns(columns)
}
}
pub struct StreamingCsvProcessor {
config: CsvConfig,
}
impl StreamingCsvProcessor {
pub fn new(config: CsvConfig) -> Self {
StreamingCsvProcessor { config }
}
pub fn sum_columns(&self, input: &[u8]) -> Result<(Vec<String>, Vec<f64>, usize), DataError> {
if input.is_empty() {
return Ok((vec![], vec![], 0));
}
let rows: Vec<&[u8]> = input
.split(|&b| b == b'\n')
.filter(|r| !r.is_empty() && *r != b"\r")
.collect();
if rows.is_empty() {
return Ok((vec![], vec![], 0));
}
let delim = self.config.delimiter;
let (header_names, data_rows) = if self.config.has_header {
let names: Vec<String> = split_fields(rows[0], delim)
.into_iter()
.map(|s| s.trim().to_string())
.collect();
(names, &rows[1..])
} else {
let ncols = split_fields(rows[0], delim).len();
let names: Vec<String> = (0..ncols).map(|i| format!("col_{}", i)).collect();
(names, &rows[..])
};
let ncols = header_names.len();
let mut sums: Vec<f64> = vec![0.0; ncols];
let mut comp: Vec<f64> = vec![0.0; ncols];
let mut row_count = 0usize;
let data_rows = if let Some(max) = self.config.max_rows {
&data_rows[..data_rows.len().min(max)]
} else {
data_rows
};
for &row_bytes in data_rows {
let fields = split_fields(row_bytes, delim);
for col_idx in 0..ncols {
let s = if col_idx < fields.len() {
if self.config.trim_whitespace {
fields[col_idx].trim()
} else {
fields[col_idx]
}
} else {
""
};
let v: f64 = s.parse().unwrap_or(0.0);
let y = v - comp[col_idx];
let t = sums[col_idx] + y;
comp[col_idx] = (t - sums[col_idx]) - y;
sums[col_idx] = t;
}
row_count += 1;
}
Ok((header_names, sums, row_count))
}
pub fn minmax_columns(
&self,
input: &[u8],
) -> Result<(Vec<String>, Vec<f64>, Vec<f64>, usize), DataError> {
if input.is_empty() {
return Ok((vec![], vec![], vec![], 0));
}
let rows: Vec<&[u8]> = input
.split(|&b| b == b'\n')
.filter(|r| !r.is_empty() && *r != b"\r")
.collect();
if rows.is_empty() {
return Ok((vec![], vec![], vec![], 0));
}
let delim = self.config.delimiter;
let (header_names, data_rows) = if self.config.has_header {
let names: Vec<String> = split_fields(rows[0], delim)
.into_iter()
.map(|s| s.trim().to_string())
.collect();
(names, &rows[1..])
} else {
let ncols = split_fields(rows[0], delim).len();
let names = (0..ncols).map(|i| format!("col_{}", i)).collect();
(names, &rows[..])
};
let ncols = header_names.len();
let mut mins: Vec<f64> = vec![f64::INFINITY; ncols];
let mut maxs: Vec<f64> = vec![f64::NEG_INFINITY; ncols];
let mut row_count = 0usize;
let data_rows = if let Some(max) = self.config.max_rows {
&data_rows[..data_rows.len().min(max)]
} else {
data_rows
};
for &row_bytes in data_rows {
let fields = split_fields(row_bytes, delim);
for col_idx in 0..ncols {
let s = if col_idx < fields.len() {
fields[col_idx].trim()
} else {
""
};
if let Ok(v) = s.parse::<f64>() {
if v < mins[col_idx] {
mins[col_idx] = v;
}
if v > maxs[col_idx] {
maxs[col_idx] = v;
}
}
}
row_count += 1;
}
Ok((header_names, mins, maxs, row_count))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_basic_csv() {
let csv = b"name,age,score\nAlice,30,9.5\nBob,25,8.1";
let df = CsvReader::new(CsvConfig::default()).parse(csv).unwrap();
assert_eq!(df.nrows(), 2);
assert_eq!(df.ncols(), 3);
}
#[test]
fn test_parse_empty() {
let df = CsvReader::new(CsvConfig::default()).parse(b"").unwrap();
assert_eq!(df.nrows(), 0);
}
#[test]
fn test_parse_header_only() {
let csv = b"x,y,z\n";
let df = CsvReader::new(CsvConfig::default()).parse(csv).unwrap();
assert_eq!(df.nrows(), 0);
assert_eq!(df.ncols(), 3);
}
#[test]
fn test_parse_type_inference() {
let csv = b"a,b,c,d\n42,3.14,true,hello\n10,2.71,false,world";
let df = CsvReader::new(CsvConfig::default()).parse(csv).unwrap();
assert_eq!(df.nrows(), 2);
assert!(matches!(df.get_column("a"), Some(Column::Int(_))));
assert!(matches!(df.get_column("b"), Some(Column::Float(_))));
assert!(matches!(df.get_column("c"), Some(Column::Bool(_))));
assert!(matches!(df.get_column("d"), Some(Column::Str(_))));
}
#[test]
fn test_parse_tsv() {
let csv = b"x\ty\n1\t2\n3\t4";
let config = CsvConfig {
delimiter: b'\t',
..Default::default()
};
let df = CsvReader::new(config).parse(csv).unwrap();
assert_eq!(df.nrows(), 2);
assert_eq!(df.ncols(), 2);
}
#[test]
fn test_parse_max_rows() {
let csv = b"x\n1\n2\n3\n4\n5";
let config = CsvConfig {
max_rows: Some(3),
..Default::default()
};
let df = CsvReader::new(config).parse(csv).unwrap();
assert_eq!(df.nrows(), 3);
}
#[test]
fn test_streaming_sum() {
let csv = b"x,y\n1.0,2.0\n3.0,4.0\n5.0,6.0";
let proc = StreamingCsvProcessor::new(CsvConfig::default());
let (headers, sums, count) = proc.sum_columns(csv).unwrap();
assert_eq!(headers, vec!["x", "y"]);
assert_eq!(count, 3);
assert!((sums[0] - 9.0).abs() < 1e-10);
assert!((sums[1] - 12.0).abs() < 1e-10);
}
#[test]
fn test_streaming_minmax() {
let csv = b"x,y\n1.0,6.0\n3.0,2.0\n5.0,4.0";
let proc = StreamingCsvProcessor::new(CsvConfig::default());
let (headers, mins, maxs, count) = proc.minmax_columns(csv).unwrap();
assert_eq!(headers, vec!["x", "y"]);
assert_eq!(count, 3);
assert!((mins[0] - 1.0).abs() < 1e-10);
assert!((maxs[0] - 5.0).abs() < 1e-10);
assert!((mins[1] - 2.0).abs() < 1e-10);
assert!((maxs[1] - 6.0).abs() < 1e-10);
}
#[test]
fn test_parse_windows_line_endings() {
let csv = b"a,b\r\n1,2\r\n3,4\r\n";
let df = CsvReader::new(CsvConfig::default()).parse(csv).unwrap();
assert_eq!(df.nrows(), 2);
}
#[test]
fn test_no_header() {
let csv = b"1,2,3\n4,5,6";
let config = CsvConfig {
has_header: false,
..Default::default()
};
let df = CsvReader::new(config).parse(csv).unwrap();
assert_eq!(df.nrows(), 2);
assert_eq!(df.column_names(), vec!["col_0", "col_1", "col_2"]);
}
}