Skip to main content

alopex_dataframe/io/
csv.rs

1use std::fs::File;
2use std::io::{BufReader, BufWriter, Seek};
3use std::path::Path;
4use std::sync::Arc;
5
6use arrow::datatypes::SchemaRef;
7use arrow::record_batch::RecordBatch;
8use arrow_csv::reader::{Format, ReaderBuilder};
9use arrow_csv::WriterBuilder;
10use regex::Regex;
11
12use crate::io::options::CsvReadOptions;
13use crate::{col, DataFrame, DataFrameError, Result};
14
15/// Read a CSV file eagerly into a `DataFrame` using default `CsvReadOptions`.
16pub fn read_csv(_path: impl AsRef<Path>) -> Result<DataFrame> {
17    read_csv_with_options(_path, &CsvReadOptions::default())
18}
19
20/// Write a `DataFrame` to a CSV file (currently always includes a header row).
21pub fn write_csv(path: impl AsRef<Path>, df: &DataFrame) -> Result<()> {
22    let path = path.as_ref();
23    let file = File::create(path).map_err(|source| DataFrameError::io_with_path(source, path))?;
24    let mut writer = WriterBuilder::new()
25        .with_header(true)
26        .build(BufWriter::new(file));
27
28    for batch in df.to_arrow() {
29        writer
30            .write(&batch)
31            .map_err(|source| DataFrameError::Arrow { source })?;
32    }
33
34    Ok(())
35}
36
37/// Read a CSV file eagerly into a `DataFrame` using the provided options.
38pub fn read_csv_with_options(
39    path: impl AsRef<Path>,
40    options: &CsvReadOptions,
41) -> Result<DataFrame> {
42    validate_csv_read_options(options)?;
43
44    let path = path.as_ref();
45    let file = File::open(path).map_err(|source| DataFrameError::io_with_path(source, path))?;
46    let mut reader = BufReader::new(file);
47
48    let mut format = Format::default()
49        .with_header(options.has_header)
50        .with_delimiter(options.delimiter);
51
52    if let Some(quote_char) = options.quote_char {
53        format = format.with_quote(quote_char);
54    }
55
56    if !options.null_values.is_empty() {
57        let pattern = options
58            .null_values
59            .iter()
60            .map(|s| regex::escape(s))
61            .collect::<Vec<_>>()
62            .join("|");
63        let regex = Regex::new(&format!("^(?:{pattern})$")).map_err(|e| {
64            DataFrameError::configuration("null_values", format!("invalid regex: {e}"))
65        })?;
66        format = format.with_null_regex(regex);
67    }
68
69    let (schema, _) = format
70        .infer_schema(&mut reader, Some(options.infer_schema_length))
71        .map_err(|source| DataFrameError::Arrow { source })?;
72    let schema: SchemaRef = Arc::new(schema);
73
74    reader
75        .rewind()
76        .map_err(|source| DataFrameError::io_with_path(source, path))?;
77
78    let projection_indices =
79        projection_indices_from_schema(&schema, options.projection.as_deref())?;
80
81    let csv_reader = ReaderBuilder::new(schema.clone())
82        .with_format(format)
83        .build(reader)
84        .map_err(|source| DataFrameError::Arrow { source })?;
85
86    let mut batches = Vec::new();
87    for maybe_batch in csv_reader {
88        let batch = maybe_batch.map_err(|source| DataFrameError::Arrow { source })?;
89        let batch = if options.predicate.is_some() {
90            batch
91        } else {
92            project_batch(batch, projection_indices.as_deref())?
93        };
94        batches.push(batch);
95    }
96
97    let mut df = DataFrame::from_batches(batches)?;
98
99    if let Some(predicate) = options.predicate.clone() {
100        df = df.filter(predicate)?;
101        if let Some(projection) = options.projection.as_deref() {
102            df = df.select(projection.iter().map(|name| col(name)).collect())?;
103        }
104    }
105
106    Ok(df)
107}
108
109fn validate_csv_read_options(options: &CsvReadOptions) -> Result<()> {
110    if options.delimiter == b'\0' {
111        return Err(DataFrameError::configuration(
112            "delimiter",
113            "delimiter must not be NUL (0x00)",
114        ));
115    }
116    if options.quote_char == Some(b'\0') {
117        return Err(DataFrameError::configuration(
118            "quote_char",
119            "quote_char must not be NUL (0x00)",
120        ));
121    }
122    Ok(())
123}
124
125fn projection_indices_from_schema(
126    schema: &SchemaRef,
127    projection: Option<&[String]>,
128) -> Result<Option<Vec<usize>>> {
129    let Some(projection) = projection else {
130        return Ok(None);
131    };
132
133    let mut indices = Vec::with_capacity(projection.len());
134    for name in projection {
135        let idx = schema
136            .fields()
137            .iter()
138            .position(|f| f.name() == name)
139            .ok_or_else(|| DataFrameError::column_not_found(name.clone()))?;
140        indices.push(idx);
141    }
142
143    Ok(Some(indices))
144}
145
146fn project_batch(batch: RecordBatch, projection: Option<&[usize]>) -> Result<RecordBatch> {
147    let Some(projection) = projection else {
148        return Ok(batch);
149    };
150
151    batch.project(projection).map_err(|e| {
152        DataFrameError::schema_mismatch(format!("failed to project record batch: {e}"))
153    })
154}
155
156#[cfg(test)]
157mod tests {
158    use std::sync::Arc;
159
160    use arrow::array::{ArrayRef, Float64Array, Int64Array, StringArray};
161    use arrow::datatypes::{DataType, Field, Schema};
162    use arrow::record_batch::RecordBatch;
163
164    use super::{read_csv_with_options, write_csv};
165    use crate::io::CsvReadOptions;
166    use crate::{col, lit, DataFrame, DataFrameError};
167
168    #[test]
169    fn csv_roundtrip_basic() {
170        let schema = Arc::new(Schema::new(vec![
171            Field::new("a", DataType::Int64, true),
172            Field::new("b", DataType::Float64, true),
173            Field::new("c", DataType::Utf8, true),
174        ]));
175
176        let batch = RecordBatch::try_new(
177            schema,
178            vec![
179                Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])) as ArrayRef,
180                Arc::new(Float64Array::from(vec![Some(1.5), Some(2.0), None])) as ArrayRef,
181                Arc::new(StringArray::from(vec![Some("x"), None, Some("z")])) as ArrayRef,
182            ],
183        )
184        .unwrap();
185
186        let df = DataFrame::from_batches(vec![batch]).unwrap();
187        let dir = tempfile::tempdir().unwrap();
188        let path = dir.path().join("sample.csv");
189
190        write_csv(&path, &df).unwrap();
191        let df2 = read_csv_with_options(&path, &CsvReadOptions::default()).unwrap();
192
193        assert_eq!(df2.schema().as_ref(), df.schema().as_ref());
194        assert_eq!(df2.height(), df.height());
195    }
196
197    #[test]
198    fn csv_projection_unknown_column_is_error() {
199        let dir = tempfile::tempdir().unwrap();
200        let path = dir.path().join("sample.csv");
201
202        std::fs::write(&path, "a,b\n1,2\n").unwrap();
203
204        let options = CsvReadOptions::default().with_projection(["a", "x"]);
205        let err = read_csv_with_options(&path, &options).unwrap_err();
206        assert!(matches!(err, DataFrameError::ColumnNotFound { .. }));
207    }
208
209    #[test]
210    fn csv_predicate_is_applied() {
211        let dir = tempfile::tempdir().unwrap();
212        let path = dir.path().join("sample.csv");
213
214        std::fs::write(&path, "a,b\n1,x\n2,y\n3,z\n").unwrap();
215
216        let options = CsvReadOptions::default().with_predicate(col("a").gt(lit(1i64)));
217        let df = read_csv_with_options(&path, &options).unwrap();
218        assert_eq!(df.height(), 2);
219    }
220
221    #[test]
222    fn csv_invalid_delimiter_is_configuration_error() {
223        let dir = tempfile::tempdir().unwrap();
224        let path = dir.path().join("sample.csv");
225
226        std::fs::write(&path, "a,b\n1,2\n").unwrap();
227
228        let options = CsvReadOptions::default().with_delimiter(b'\0');
229        let err = read_csv_with_options(&path, &options).unwrap_err();
230        assert!(matches!(err, DataFrameError::Configuration { .. }));
231    }
232}