1use std::io::Read;
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_schema::{ArrowError, DataType, Schema, SchemaRef};
13use geoarrow_array::GeoArrowArray;
14use geoarrow_array::array::{LargeWktArray, WktArray, WktViewArray};
15use geoarrow_array::cast::from_wkt;
16use geoarrow_schema::error::GeoArrowResult;
17use geoarrow_schema::{GeoArrowType, WktType};
18
19#[derive(Debug, Clone)]
21pub struct CsvReaderOptions {
22 pub geometry_column_name: Option<String>,
26
27 pub to_type: GeoArrowType,
29}
30
31pub struct CsvReader<R> {
33 reader: arrow_csv::Reader<R>,
34 output_schema: SchemaRef,
35 geometry_column_index: usize,
36 to_type: GeoArrowType,
37}
38
39impl<R> CsvReader<R> {
40 pub fn schema(&self) -> SchemaRef {
42 self.output_schema.clone()
43 }
44}
45
46impl<R: Read> CsvReader<R> {
47 pub fn try_new(
49 reader: arrow_csv::Reader<R>,
50 options: CsvReaderOptions,
51 ) -> GeoArrowResult<Self> {
52 let schema = reader.schema();
53 let geometry_column_name =
54 find_geometry_column(&schema, options.geometry_column_name.as_deref())?;
55 let geometry_column_index = schema.index_of(&geometry_column_name)?;
56
57 let mut output_fields = schema.fields().to_vec();
59 output_fields[geometry_column_index] =
60 options.to_type.to_field(geometry_column_name, true).into();
61
62 let output_schema = Arc::new(Schema::new_with_metadata(
63 output_fields,
64 schema.metadata().clone(),
65 ));
66
67 Ok(Self {
68 reader,
69 output_schema,
70 geometry_column_index,
71 to_type: options.to_type,
72 })
73 }
74}
75
76impl<R: Read> Iterator for CsvReader<R> {
77 type Item = Result<RecordBatch, ArrowError>;
78
79 fn next(&mut self) -> Option<Self::Item> {
80 let reader = &mut self.reader;
81 reader.next().map(move |batch| {
82 parse_batch(
83 batch,
84 self.output_schema.clone(),
85 self.geometry_column_index,
86 self.to_type.clone(),
87 )
88 })
89 }
90}
91
92impl<R: Read> arrow_array::RecordBatchReader for CsvReader<R> {
93 fn schema(&self) -> SchemaRef {
94 self.schema()
95 }
96}
97
98fn parse_batch(
99 batch: Result<RecordBatch, ArrowError>,
100 output_schema: SchemaRef,
101 geometry_column_index: usize,
102 to_type: GeoArrowType,
103) -> Result<RecordBatch, ArrowError> {
104 let batch = batch?;
105 let column = batch.column(geometry_column_index);
106
107 let parsed_arr = match column.data_type() {
108 DataType::Utf8 => {
109 let arr = WktArray::try_from((column.as_ref(), WktType::default()))?;
110 from_wkt(&arr, to_type)
111 }
112 DataType::LargeUtf8 => {
113 let arr = LargeWktArray::try_from((column.as_ref(), WktType::default()))?;
114 from_wkt(&arr, to_type)
115 }
116 DataType::Utf8View => {
117 let arr = WktViewArray::try_from((column.as_ref(), WktType::default()))?;
118 from_wkt(&arr, to_type)
119 }
120 _ => unreachable!(),
121 }?;
122
123 let mut columns = batch.columns().to_vec();
125 columns[geometry_column_index] = parsed_arr.into_array_ref();
126
127 RecordBatch::try_new(output_schema, columns)
128}
129
130fn find_geometry_column(
131 schema: &Schema,
132 geometry_column_name: Option<&str>,
133) -> GeoArrowResult<String> {
134 if let Some(geometry_col_name) = geometry_column_name {
135 if schema
136 .fields()
137 .iter()
138 .any(|field| field.name() == geometry_col_name)
139 {
140 Ok(geometry_col_name.to_string())
141 } else {
142 Err(ArrowError::CsvError(format!(
143 "CSV geometry column specified to have name '{}' but no such column found",
144 geometry_col_name
145 ))
146 .into())
147 }
148 } else {
149 let mut field_name: Option<String> = None;
150 for field in schema.fields().iter() {
151 if field.name().to_lowercase().as_str() == "geometry" {
152 field_name = Some(field.name().clone());
153 }
154 }
155 field_name.ok_or(
156 ArrowError::CsvError(
157 "No CSV geometry column name specified and no geometry column found.".to_string(),
158 )
159 .into(),
160 )
161 }
162}
163
164#[cfg(test)]
165mod tests {
166
167 use std::io::Cursor;
168
169 use arrow_csv::ReaderBuilder;
170 use arrow_csv::reader::Format;
171 use geo_traits::{CoordTrait, PointTrait};
172 use geoarrow_array::GeoArrowArrayAccessor;
173 use geoarrow_array::array::PointArray;
174 use geoarrow_schema::{Dimension, PointType};
175
176 use super::*;
177
178 #[test]
179 fn read_csv() {
180 let s = r#"
181address,type,datetime,report location,incident number
182904 7th Av,Car Fire,05/22/2019 12:55:00 PM,POINT (-122.329051 47.6069),F190051945
1839610 53rd Av S,Aid Response,05/22/2019 12:55:00 PM,POINT (-122.266529 47.515984),F190051946"#;
184
185 let format = Format::default().with_header(true);
186 let (schema, _num_read_records) = format.infer_schema(Cursor::new(s), None).unwrap();
187 let reader = ReaderBuilder::new(schema.into())
188 .with_format(format)
189 .build(Cursor::new(s))
190 .unwrap();
191
192 let point_type = PointType::new(Dimension::XY, Default::default());
193 let to_type = GeoArrowType::Point(point_type.clone());
194 let geo_options = CsvReaderOptions {
195 geometry_column_name: Some("report location".to_string()),
196 to_type: to_type.clone(),
197 };
198 let geo_reader = CsvReader::try_new(reader, geo_options).unwrap();
199
200 let batches: Vec<_> = geo_reader.collect::<Result<Vec<_>, _>>().unwrap();
201 let batch = batches.into_iter().next().unwrap();
202 let schema = batch.schema();
203 assert_eq!(schema.fields().len(), 5);
204
205 let geom_field = schema.field(3);
206 let actual = GeoArrowType::from_extension_field(geom_field).unwrap();
207 assert_eq!(actual, to_type);
208
209 let geom_array = batch.column(3);
210 let point_arr = PointArray::try_from((geom_array.as_ref(), point_type)).unwrap();
211 assert_eq!(point_arr.len(), 2);
212 let point1 = point_arr.value(0).unwrap();
213 assert_eq!(point1.coord().unwrap().x(), -122.329051);
214 assert_eq!(point1.coord().unwrap().y(), 47.6069);
215
216 let point2 = point_arr.value(1).unwrap();
217 assert_eq!(point2.coord().unwrap().x(), -122.266529);
218 assert_eq!(point2.coord().unwrap().y(), 47.515984);
219
220 }
223}