1use std::fs::File;
7use std::path::Path;
8use std::sync::Arc;
9
10use arrow::csv::reader::{Format, Reader, ReaderBuilder};
11use arrow::datatypes::{DataType, SchemaRef};
12use arrow::record_batch::RecordBatch;
13
14use crate::CsvResult;
15use crate::inference;
16
17#[derive(Debug, Clone)]
19pub struct CsvReadOptions {
20 pub has_header: bool,
21 pub delimiter: u8,
22 pub max_read_records: Option<usize>,
23 pub batch_size: Option<usize>,
24 pub null_token: Option<String>,
25}
26
27impl Default for CsvReadOptions {
28 fn default() -> Self {
29 Self {
30 has_header: true,
31 delimiter: b',',
32 max_read_records: None,
33 batch_size: None,
34 null_token: None,
35 }
36 }
37}
38
39impl CsvReadOptions {
40 pub(crate) fn to_format(&self) -> Format {
41 let mut format = Format::default().with_header(self.has_header);
42 if self.delimiter != b',' {
43 format = format.with_delimiter(self.delimiter);
44 }
45 format
46 }
47}
48
49#[derive(Debug, Clone, Default)]
51pub struct CsvReader {
52 options: CsvReadOptions,
53}
54
55impl CsvReader {
56 pub fn new(options: CsvReadOptions) -> Self {
57 Self { options }
58 }
59
60 pub fn with_options(options: CsvReadOptions) -> Self {
61 Self::new(options)
62 }
63
64 pub fn options(&self) -> &CsvReadOptions {
65 &self.options
66 }
67
68 pub fn options_mut(&mut self) -> &mut CsvReadOptions {
69 &mut self.options
70 }
71
72 pub fn into_options(self) -> CsvReadOptions {
73 self.options
74 }
75
76 pub fn infer_schema(&self, path: &Path) -> CsvResult<SchemaRef> {
78 let outcome = inference::infer(path, &self.options)?;
79 Ok(outcome.target_schema)
80 }
81
82 pub fn open(&self, path: &Path) -> CsvResult<CsvReadSession> {
84 let outcome = inference::infer(path, &self.options)?;
85 let file = File::open(path)?;
86
87 let mut builder = ReaderBuilder::new(Arc::clone(&outcome.raw_schema))
88 .with_format(self.options.to_format());
89 if let Some(batch_size) = self.options.batch_size {
90 builder = builder.with_batch_size(batch_size);
91 }
92
93 let reader = builder.build(file)?;
94 Ok(CsvReadSession {
95 schema: outcome.target_schema,
96 type_overrides: outcome.type_overrides,
97 reader,
98 })
99 }
100}
101
102pub struct CsvReadSession {
104 schema: SchemaRef,
105 type_overrides: Vec<Option<DataType>>,
106 reader: Reader<File>,
107}
108
109impl CsvReadSession {
110 pub fn schema(&self) -> SchemaRef {
112 Arc::clone(&self.schema)
113 }
114
115 pub fn type_overrides(&self) -> &[Option<DataType>] {
117 &self.type_overrides
118 }
119
120 pub fn into_parts(self) -> (SchemaRef, Reader<File>, Vec<Option<DataType>>) {
122 (self.schema, self.reader, self.type_overrides)
123 }
124
125 pub fn reader(&mut self) -> &mut Reader<File> {
127 &mut self.reader
128 }
129}
130
131impl Iterator for CsvReadSession {
132 type Item = arrow::error::Result<RecordBatch>;
133
134 fn next(&mut self) -> Option<Self::Item> {
135 self.reader.next()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use arrow::datatypes::DataType;
143 use std::io::Write;
144 use tempfile::NamedTempFile;
145
146 fn write_sample_csv() -> NamedTempFile {
147 let mut tmp = NamedTempFile::new().expect("create tmp");
148 writeln!(tmp, "id,price,flag,timestamp,text").unwrap();
149 writeln!(tmp, "1,3.14,true,2024-01-01T12:34:56Z,hello").unwrap();
150 writeln!(tmp, "2,2.71,false,2024-01-02T01:02:03Z,world").unwrap();
151 tmp
152 }
153
154 #[test]
155 fn infer_schema_detects_types() {
156 let tmp = write_sample_csv();
157 let reader = CsvReader::default();
158 let schema = reader.infer_schema(tmp.path()).expect("infer");
159
160 assert_eq!(schema.field(0).data_type(), &DataType::Int64);
161 assert_eq!(schema.field(1).data_type(), &DataType::Float64);
162 assert_eq!(schema.field(2).data_type(), &DataType::Boolean);
163 assert!(matches!(
164 schema.field(3).data_type(),
165 DataType::Timestamp(_, _)
166 ));
167 assert_eq!(schema.field(4).data_type(), &DataType::Utf8);
168 }
169
170 #[test]
171 fn reader_streams_batches() {
172 let tmp = write_sample_csv();
173 let options = CsvReadOptions {
174 batch_size: Some(1),
175 ..Default::default()
176 };
177 let reader = CsvReader::new(options);
178
179 let mut session = reader.open(tmp.path()).expect("open reader");
180 assert_eq!(session.schema().field(0).data_type(), &DataType::Int64);
181
182 let first = session.next().expect("first batch").expect("batch ok");
183 assert_eq!(first.num_rows(), 1);
184 assert_eq!(first.column(0).data_type(), &DataType::Int64);
185 }
186}