use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use csv as csv_crate;
use crate::array::{ArrayRef, BinaryArray};
use crate::builder::*;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;
use self::csv_crate::{StringRecord, StringRecordsIntoIter};
pub struct Reader {
schema: Arc<Schema>,
projection: Option<Vec<usize>>,
record_iter: StringRecordsIntoIter<BufReader<File>>,
batch_size: usize,
}
impl Reader {
pub fn new(
file: File,
schema: Arc<Schema>,
has_headers: bool,
batch_size: usize,
projection: Option<Vec<usize>>,
) -> Self {
let csv_reader = csv::ReaderBuilder::new()
.has_headers(has_headers)
.from_reader(BufReader::new(file));
let record_iter = csv_reader.into_records();
Reader {
schema: schema.clone(),
projection,
record_iter,
batch_size,
}
}
}
fn build_primitive_array<T: ArrowPrimitiveType>(
rows: &[StringRecord],
col_idx: &usize,
) -> Result<ArrayRef> {
let mut builder = PrimitiveArrayBuilder::<T>::new(rows.len());
for row_index in 0..rows.len() {
match rows[row_index].get(*col_idx) {
Some(s) if s.len() > 0 => match s.parse::<T::Native>() {
Ok(v) => builder.push(v)?,
Err(_) => {
return Err(ArrowError::ParseError(format!(
"Error while parsing value {}",
s
)));
}
},
_ => builder.push_null().unwrap(),
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
impl Reader {
pub fn next(&mut self) -> Result<Option<RecordBatch>> {
let mut rows: Vec<StringRecord> = Vec::with_capacity(self.batch_size);
for _ in 0..self.batch_size {
match self.record_iter.next() {
Some(Ok(r)) => {
rows.push(r);
}
Some(Err(_)) => {
return Err(ArrowError::ParseError("Error reading CSV file".to_string()));
}
None => break,
}
}
if rows.is_empty() {
return Ok(None);
}
let projection: Vec<usize> = match self.projection {
Some(ref v) => v.clone(),
None => self
.schema
.fields()
.iter()
.enumerate()
.map(|(i, _)| i)
.collect(),
};
let rows = &rows[..];
let arrays: Result<Vec<ArrayRef>> = projection
.iter()
.map(|i| {
let field = self.schema.field(*i);
match field.data_type() {
&DataType::Boolean => build_primitive_array::<BooleanType>(rows, i),
&DataType::Int8 => build_primitive_array::<Int8Type>(rows, i),
&DataType::Int16 => build_primitive_array::<Int16Type>(rows, i),
&DataType::Int32 => build_primitive_array::<Int32Type>(rows, i),
&DataType::Int64 => build_primitive_array::<Int64Type>(rows, i),
&DataType::UInt8 => build_primitive_array::<UInt8Type>(rows, i),
&DataType::UInt16 => build_primitive_array::<UInt16Type>(rows, i),
&DataType::UInt32 => build_primitive_array::<UInt32Type>(rows, i),
&DataType::UInt64 => build_primitive_array::<UInt64Type>(rows, i),
&DataType::Float32 => build_primitive_array::<Float32Type>(rows, i),
&DataType::Float64 => build_primitive_array::<Float64Type>(rows, i),
&DataType::Utf8 => {
let values_builder: UInt8Builder = UInt8Builder::new(rows.len());
let mut list_builder = ListArrayBuilder::new(values_builder);
for row_index in 0..rows.len() {
match rows[row_index].get(*i) {
Some(s) => {
list_builder.values().push_slice(s.as_bytes()).unwrap();
list_builder.append(true).unwrap();
}
_ => {
list_builder.append(false).unwrap();
}
}
}
Ok(Arc::new(BinaryArray::from(list_builder.finish())) as ArrayRef)
}
other => Err(ArrowError::ParseError(format!(
"Unsupported data type {:?}",
other
))),
}
})
.collect();
match arrays {
Ok(arr) => Ok(Some(RecordBatch::new(self.schema.clone(), arr))),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array::*;
use crate::datatypes::Field;
#[test]
fn test_csv() {
let schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]);
let file = File::open("test/data/uk_cities.csv").unwrap();
let mut csv = Reader::new(file, Arc::new(schema), false, 1024, None);
let batch = csv.next().unwrap().unwrap();
assert_eq!(37, batch.num_rows());
assert_eq!(3, batch.num_columns());
let lat = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(57.653484, lat.value(0));
let city = batch
.column(0)
.as_any()
.downcast_ref::<BinaryArray>()
.unwrap();
let city_name: String = String::from_utf8(city.get_value(13).to_vec()).unwrap();
assert_eq!("Aberdeen, Aberdeen City, UK", city_name);
}
#[test]
fn test_csv_with_projection() {
let schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]);
let file = File::open("test/data/uk_cities.csv").unwrap();
let mut csv = Reader::new(file, Arc::new(schema), false, 1024, Some(vec![0, 1]));
let batch = csv.next().unwrap().unwrap();
assert_eq!(37, batch.num_rows());
assert_eq!(2, batch.num_columns());
}
#[test]
fn test_nulls() {
let schema = Schema::new(vec![
Field::new("c_int", DataType::UInt64, false),
Field::new("c_float", DataType::Float32, false),
Field::new("c_string", DataType::Utf8, false),
]);
let file = File::open("test/data/null_test.csv").unwrap();
let mut csv = Reader::new(file, Arc::new(schema), true, 1024, None);
let batch = csv.next().unwrap().unwrap();
assert_eq!(false, batch.column(1).is_null(0));
assert_eq!(false, batch.column(1).is_null(1));
assert_eq!(true, batch.column(1).is_null(2));
assert_eq!(false, batch.column(1).is_null(3));
assert_eq!(false, batch.column(1).is_null(4));
}
}