use std::io::Cursor;
use polars::prelude::*;
use crate::delimiter::{detect_delimiter, Delimiter};
use crate::error::{DataLoadError, Result};
use crate::options::LoadOptions;
pub fn load_csv(content: &[u8], options: &LoadOptions) -> Result<DataFrame> {
let csv_string = std::str::from_utf8(content)?;
if csv_string.trim().is_empty() {
return Err(DataLoadError::EmptyFile);
}
let delimiter = match options.delimiter {
Delimiter::Auto => detect_delimiter(csv_string),
other => other.as_byte().unwrap_or(b','),
};
let cursor = Cursor::new(csv_string.as_bytes());
let mut csv_options = CsvReadOptions::default()
.with_has_header(options.has_header)
.map_parse_options(|parse_options| parse_options.with_separator(delimiter));
if let Some(max_rows) = options.max_rows {
csv_options = csv_options.with_n_rows(Some(max_rows));
}
if options.skip_rows > 0 {
csv_options = csv_options.with_skip_rows(options.skip_rows);
}
if let Some(infer_len) = options.infer_schema_length {
csv_options = csv_options.with_infer_schema_length(Some(infer_len));
}
let df = csv_options
.into_reader_with_file_handle(cursor)
.finish()?;
Ok(df)
}
pub fn load_csv_with_fallback(
content: &[u8],
options: &LoadOptions,
delimiters: &[Delimiter],
) -> Result<DataFrame> {
let csv_string = std::str::from_utf8(content)?;
if csv_string.trim().is_empty() {
return Err(DataLoadError::EmptyFile);
}
for &delimiter in delimiters {
let opts = LoadOptions {
delimiter,
..options.clone()
};
if let Ok(df) = load_csv(content, &opts) {
if df.width() > 1 {
return Ok(df);
}
}
}
let opts = LoadOptions {
delimiter: delimiters.first().copied().unwrap_or(Delimiter::Comma),
..options.clone()
};
load_csv(content, &opts)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_csv_comma() {
let content = b"a,b,c\n1,2,3\n4,5,6";
let df = load_csv(content, &LoadOptions::default()).unwrap();
assert_eq!(df.shape(), (2, 3));
assert_eq!(df.get_column_names(), &["a", "b", "c"]);
}
#[test]
fn test_load_csv_tab() {
let content = b"a\tb\tc\n1\t2\t3\n4\t5\t6";
let df = load_csv(content, &LoadOptions::default()).unwrap();
assert_eq!(df.shape(), (2, 3));
}
#[test]
fn test_load_csv_semicolon() {
let content = b"a;b;c\n1;2;3";
let df = load_csv(content, &LoadOptions::default()).unwrap();
assert_eq!(df.shape(), (1, 3));
}
#[test]
fn test_load_csv_explicit_delimiter() {
let content = b"a|b|c\n1|2|3";
let opts = LoadOptions::new().with_delimiter(Delimiter::Pipe);
let df = load_csv(content, &opts).unwrap();
assert_eq!(df.shape(), (1, 3));
}
#[test]
fn test_load_csv_no_header() {
let content = b"1,2,3\n4,5,6";
let opts = LoadOptions::new().with_header(false);
let df = load_csv(content, &opts).unwrap();
assert_eq!(df.shape(), (2, 3));
}
#[test]
fn test_load_csv_max_rows() {
let content = b"a,b\n1,2\n3,4\n5,6\n7,8";
let opts = LoadOptions::new().with_max_rows(Some(2));
let df = load_csv(content, &opts).unwrap();
assert_eq!(df.shape(), (2, 2));
}
#[test]
fn test_load_csv_empty() {
let content = b"";
let result = load_csv(content, &LoadOptions::default());
assert!(matches!(result, Err(DataLoadError::EmptyFile)));
}
#[test]
fn test_load_csv_with_fallback() {
let content = b"a;b;c\n1;2;3";
let delimiters = [Delimiter::Comma, Delimiter::Tab, Delimiter::Semicolon];
let df = load_csv_with_fallback(content, &LoadOptions::default(), &delimiters).unwrap();
assert_eq!(df.shape(), (1, 3));
assert_eq!(df.get_column_names(), &["a", "b", "c"]);
}
}