alopex_dataframe/io/
csv.rs1use std::fs::File;
2use std::io::{BufReader, BufWriter, Seek};
3use std::path::Path;
4use std::sync::Arc;
5
6use arrow::datatypes::SchemaRef;
7use arrow::record_batch::RecordBatch;
8use arrow_csv::reader::{Format, ReaderBuilder};
9use arrow_csv::WriterBuilder;
10use regex::Regex;
11
12use crate::io::options::CsvReadOptions;
13use crate::{col, DataFrame, DataFrameError, Result};
14
15pub fn read_csv(_path: impl AsRef<Path>) -> Result<DataFrame> {
17 read_csv_with_options(_path, &CsvReadOptions::default())
18}
19
20pub fn write_csv(path: impl AsRef<Path>, df: &DataFrame) -> Result<()> {
22 let path = path.as_ref();
23 let file = File::create(path).map_err(|source| DataFrameError::io_with_path(source, path))?;
24 let mut writer = WriterBuilder::new()
25 .with_header(true)
26 .build(BufWriter::new(file));
27
28 for batch in df.to_arrow() {
29 writer
30 .write(&batch)
31 .map_err(|source| DataFrameError::Arrow { source })?;
32 }
33
34 Ok(())
35}
36
37pub fn read_csv_with_options(
39 path: impl AsRef<Path>,
40 options: &CsvReadOptions,
41) -> Result<DataFrame> {
42 validate_csv_read_options(options)?;
43
44 let path = path.as_ref();
45 let file = File::open(path).map_err(|source| DataFrameError::io_with_path(source, path))?;
46 let mut reader = BufReader::new(file);
47
48 let mut format = Format::default()
49 .with_header(options.has_header)
50 .with_delimiter(options.delimiter);
51
52 if let Some(quote_char) = options.quote_char {
53 format = format.with_quote(quote_char);
54 }
55
56 if !options.null_values.is_empty() {
57 let pattern = options
58 .null_values
59 .iter()
60 .map(|s| regex::escape(s))
61 .collect::<Vec<_>>()
62 .join("|");
63 let regex = Regex::new(&format!("^(?:{pattern})$")).map_err(|e| {
64 DataFrameError::configuration("null_values", format!("invalid regex: {e}"))
65 })?;
66 format = format.with_null_regex(regex);
67 }
68
69 let (schema, _) = format
70 .infer_schema(&mut reader, Some(options.infer_schema_length))
71 .map_err(|source| DataFrameError::Arrow { source })?;
72 let schema: SchemaRef = Arc::new(schema);
73
74 reader
75 .rewind()
76 .map_err(|source| DataFrameError::io_with_path(source, path))?;
77
78 let projection_indices =
79 projection_indices_from_schema(&schema, options.projection.as_deref())?;
80
81 let csv_reader = ReaderBuilder::new(schema.clone())
82 .with_format(format)
83 .build(reader)
84 .map_err(|source| DataFrameError::Arrow { source })?;
85
86 let mut batches = Vec::new();
87 for maybe_batch in csv_reader {
88 let batch = maybe_batch.map_err(|source| DataFrameError::Arrow { source })?;
89 let batch = if options.predicate.is_some() {
90 batch
91 } else {
92 project_batch(batch, projection_indices.as_deref())?
93 };
94 batches.push(batch);
95 }
96
97 let mut df = DataFrame::from_batches(batches)?;
98
99 if let Some(predicate) = options.predicate.clone() {
100 df = df.filter(predicate)?;
101 if let Some(projection) = options.projection.as_deref() {
102 df = df.select(projection.iter().map(|name| col(name)).collect())?;
103 }
104 }
105
106 Ok(df)
107}
108
109fn validate_csv_read_options(options: &CsvReadOptions) -> Result<()> {
110 if options.delimiter == b'\0' {
111 return Err(DataFrameError::configuration(
112 "delimiter",
113 "delimiter must not be NUL (0x00)",
114 ));
115 }
116 if options.quote_char == Some(b'\0') {
117 return Err(DataFrameError::configuration(
118 "quote_char",
119 "quote_char must not be NUL (0x00)",
120 ));
121 }
122 Ok(())
123}
124
125fn projection_indices_from_schema(
126 schema: &SchemaRef,
127 projection: Option<&[String]>,
128) -> Result<Option<Vec<usize>>> {
129 let Some(projection) = projection else {
130 return Ok(None);
131 };
132
133 let mut indices = Vec::with_capacity(projection.len());
134 for name in projection {
135 let idx = schema
136 .fields()
137 .iter()
138 .position(|f| f.name() == name)
139 .ok_or_else(|| DataFrameError::column_not_found(name.clone()))?;
140 indices.push(idx);
141 }
142
143 Ok(Some(indices))
144}
145
146fn project_batch(batch: RecordBatch, projection: Option<&[usize]>) -> Result<RecordBatch> {
147 let Some(projection) = projection else {
148 return Ok(batch);
149 };
150
151 batch.project(projection).map_err(|e| {
152 DataFrameError::schema_mismatch(format!("failed to project record batch: {e}"))
153 })
154}
155
156#[cfg(test)]
157mod tests {
158 use std::sync::Arc;
159
160 use arrow::array::{ArrayRef, Float64Array, Int64Array, StringArray};
161 use arrow::datatypes::{DataType, Field, Schema};
162 use arrow::record_batch::RecordBatch;
163
164 use super::{read_csv_with_options, write_csv};
165 use crate::io::CsvReadOptions;
166 use crate::{col, lit, DataFrame, DataFrameError};
167
168 #[test]
169 fn csv_roundtrip_basic() {
170 let schema = Arc::new(Schema::new(vec![
171 Field::new("a", DataType::Int64, true),
172 Field::new("b", DataType::Float64, true),
173 Field::new("c", DataType::Utf8, true),
174 ]));
175
176 let batch = RecordBatch::try_new(
177 schema,
178 vec![
179 Arc::new(Int64Array::from(vec![Some(1), None, Some(3)])) as ArrayRef,
180 Arc::new(Float64Array::from(vec![Some(1.5), Some(2.0), None])) as ArrayRef,
181 Arc::new(StringArray::from(vec![Some("x"), None, Some("z")])) as ArrayRef,
182 ],
183 )
184 .unwrap();
185
186 let df = DataFrame::from_batches(vec![batch]).unwrap();
187 let dir = tempfile::tempdir().unwrap();
188 let path = dir.path().join("sample.csv");
189
190 write_csv(&path, &df).unwrap();
191 let df2 = read_csv_with_options(&path, &CsvReadOptions::default()).unwrap();
192
193 assert_eq!(df2.schema().as_ref(), df.schema().as_ref());
194 assert_eq!(df2.height(), df.height());
195 }
196
197 #[test]
198 fn csv_projection_unknown_column_is_error() {
199 let dir = tempfile::tempdir().unwrap();
200 let path = dir.path().join("sample.csv");
201
202 std::fs::write(&path, "a,b\n1,2\n").unwrap();
203
204 let options = CsvReadOptions::default().with_projection(["a", "x"]);
205 let err = read_csv_with_options(&path, &options).unwrap_err();
206 assert!(matches!(err, DataFrameError::ColumnNotFound { .. }));
207 }
208
209 #[test]
210 fn csv_predicate_is_applied() {
211 let dir = tempfile::tempdir().unwrap();
212 let path = dir.path().join("sample.csv");
213
214 std::fs::write(&path, "a,b\n1,x\n2,y\n3,z\n").unwrap();
215
216 let options = CsvReadOptions::default().with_predicate(col("a").gt(lit(1i64)));
217 let df = read_csv_with_options(&path, &options).unwrap();
218 assert_eq!(df.height(), 2);
219 }
220
221 #[test]
222 fn csv_invalid_delimiter_is_configuration_error() {
223 let dir = tempfile::tempdir().unwrap();
224 let path = dir.path().join("sample.csv");
225
226 std::fs::write(&path, "a,b\n1,2\n").unwrap();
227
228 let options = CsvReadOptions::default().with_delimiter(b'\0');
229 let err = read_csv_with_options(&path, &options).unwrap_err();
230 assert!(matches!(err, DataFrameError::Configuration { .. }));
231 }
232}