use crate::error::{CsvError, Result};
use crate::from_csv::config::FromCsvConfig;
use crate::from_csv::parsing::{parse_csv_value, parse_csv_value_with_type};
use crate::from_csv::schema_inference::{infer_column_types, ColumnType};
use crate::from_csv::validation::{validate_cell, validate_headers, CsvSizeTracker};
use hedl_core::{Document, Item, MatrixList, Node};
use std::io::Read;
pub fn from_csv(csv: &str, type_name: &str, schema: &[&str]) -> Result<Document> {
from_csv_with_config(csv, type_name, schema, FromCsvConfig::default())
}
pub fn from_csv_with_config(
csv: &str,
type_name: &str,
schema: &[&str],
config: FromCsvConfig,
) -> Result<Document> {
from_csv_reader_with_config(csv.as_bytes(), type_name, schema, config)
}
pub fn from_csv_reader<R: Read>(reader: R, type_name: &str, schema: &[&str]) -> Result<Document> {
from_csv_reader_with_config(reader, type_name, schema, FromCsvConfig::default())
}
pub fn from_csv_reader_with_config<R: Read>(
reader: R,
type_name: &str,
schema: &[&str],
config: FromCsvConfig,
) -> Result<Document> {
let mut csv_reader = csv::ReaderBuilder::new()
.delimiter(config.delimiter)
.has_headers(config.has_headers)
.trim(if config.trim {
csv::Trim::All
} else {
csv::Trim::None
})
.from_reader(reader);
let mut doc = Document::new((2, 0));
let mut full_schema = vec!["id".to_string()];
full_schema.extend(schema.iter().map(|s| (*s).to_string()));
doc.structs
.insert(type_name.to_string(), full_schema.clone());
let mut matrix_list = MatrixList::new(type_name, full_schema.clone());
let headers = csv_reader.headers().map_err(|e| CsvError::ParseError {
line: 0,
message: e.to_string(),
})?;
validate_headers(headers, &config)?;
let mut size_tracker = CsvSizeTracker::new(config.max_total_size);
let header_size: usize = headers.iter().map(str::len).sum();
size_tracker.bytes_read += header_size;
let _inferred_types = if config.infer_schema {
let mut all_records = Vec::new();
for (record_idx, result) in csv_reader.records().enumerate() {
if record_idx >= config.max_rows {
return Err(CsvError::SecurityLimit {
limit: config.max_rows,
actual: record_idx + 1,
});
}
let record = result.map_err(|e| CsvError::ParseError {
line: record_idx + 1,
message: e.to_string(),
})?;
if record.is_empty() {
continue;
}
size_tracker.track_record(&record)?;
for (col_idx, cell) in record.iter().enumerate() {
validate_cell(cell, record_idx + 1, col_idx, &config)?;
}
let row: Vec<String> = record
.iter()
.map(std::string::ToString::to_string)
.collect();
all_records.push(row);
}
let types = infer_column_types(&all_records, config.sample_rows);
for (record_idx, row) in all_records.iter().enumerate() {
let id = row
.first()
.ok_or_else(|| CsvError::MissingColumn("id".to_string()))?;
if id.is_empty() {
return Err(CsvError::EmptyId {
row: record_idx + 1,
});
}
let mut fields = Vec::new();
for (field_idx, field) in row.iter().enumerate() {
let col_type = types.get(field_idx).copied().unwrap_or(ColumnType::String);
let value = parse_csv_value_with_type(field, col_type).map_err(|e| {
e.with_context(format!(
"in column '{}' at line {}",
full_schema.get(field_idx).unwrap_or(&"unknown".to_string()),
record_idx + 1
))
})?;
fields.push(value);
}
if fields.len() != full_schema.len() {
return Err(CsvError::WidthMismatch {
expected: full_schema.len(),
actual: fields.len(),
row: record_idx + 1,
});
}
let node = Node::new(type_name, id, fields);
matrix_list.add_row(node);
}
types
} else {
for (record_idx, result) in csv_reader.records().enumerate() {
if record_idx >= config.max_rows {
return Err(CsvError::SecurityLimit {
limit: config.max_rows,
actual: record_idx + 1,
});
}
let record = result.map_err(|e| CsvError::ParseError {
line: record_idx + 1,
message: e.to_string(),
})?;
if record.is_empty() {
continue;
}
size_tracker.track_record(&record)?;
for (col_idx, cell) in record.iter().enumerate() {
validate_cell(cell, record_idx + 1, col_idx, &config)?;
}
let id = record
.get(0)
.ok_or_else(|| CsvError::MissingColumn("id".to_string()))?;
if id.is_empty() {
return Err(CsvError::EmptyId {
row: record_idx + 1,
});
}
let mut fields = Vec::new();
for (field_idx, field) in record.iter().enumerate() {
let value = parse_csv_value(field).map_err(|e| {
e.with_context(format!(
"in column '{}' at line {}",
full_schema.get(field_idx).unwrap_or(&"unknown".to_string()),
record_idx + 1
))
})?;
fields.push(value);
}
if fields.len() != full_schema.len() {
return Err(CsvError::WidthMismatch {
expected: full_schema.len(),
actual: fields.len(),
row: record_idx + 1,
});
}
let node = Node::new(type_name, id, fields);
matrix_list.add_row(node);
}
Vec::new()
};
let list_key = config
.list_key
.unwrap_or_else(|| format!("{}s", type_name.to_lowercase()));
doc.root.insert(list_key, Item::List(matrix_list));
Ok(doc)
}