geoarrow_csv/
reader.rs

1//! Read from CSV files with a geometry column encoded as Well-Known Text.
2//!
3//! The CSV reader implements [`RecordBatchReader`], so you can iterate over the batches of the CSV
4//! without materializing the entire file in memory.
5//!
6//! [`RecordBatchReader`]: arrow_array::RecordBatchReader
7
8use std::io::Read;
9use std::sync::Arc;
10
11use arrow_array::RecordBatch;
12use arrow_schema::{ArrowError, DataType, Schema, SchemaRef};
13use geoarrow_array::GeoArrowArray;
14use geoarrow_array::array::{LargeWktArray, WktArray, WktViewArray};
15use geoarrow_array::cast::from_wkt;
16use geoarrow_schema::error::GeoArrowResult;
17use geoarrow_schema::{GeoArrowType, WktType};
18
19/// Options for the CSV reader.
20#[derive(Debug, Clone)]
21pub struct CsvReaderOptions {
22    /// The name of the geometry column in the CSV
23    ///
24    /// Defaults to `"geometry"`
25    pub geometry_column_name: Option<String>,
26
27    /// The target geometry type to convert the WKT strings to.
28    pub to_type: GeoArrowType,
29}
30
31/// A CSV reader that parses a WKT-encoded geometry column
32pub struct CsvReader<R> {
33    reader: arrow_csv::Reader<R>,
34    output_schema: SchemaRef,
35    geometry_column_index: usize,
36    to_type: GeoArrowType,
37}
38
39impl<R> CsvReader<R> {
40    /// Access the schema of this reader
41    pub fn schema(&self) -> SchemaRef {
42        self.output_schema.clone()
43    }
44}
45
46impl<R: Read> CsvReader<R> {
47    /// Wrap an upstream `arrow_csv::Reader` in an iterator that parses WKT geometries.
48    pub fn try_new(
49        reader: arrow_csv::Reader<R>,
50        options: CsvReaderOptions,
51    ) -> GeoArrowResult<Self> {
52        let schema = reader.schema();
53        let geometry_column_name =
54            find_geometry_column(&schema, options.geometry_column_name.as_deref())?;
55        let geometry_column_index = schema.index_of(&geometry_column_name)?;
56
57        // Transform to output schema
58        let mut output_fields = schema.fields().to_vec();
59        output_fields[geometry_column_index] =
60            options.to_type.to_field(geometry_column_name, true).into();
61
62        let output_schema = Arc::new(Schema::new_with_metadata(
63            output_fields,
64            schema.metadata().clone(),
65        ));
66
67        Ok(Self {
68            reader,
69            output_schema,
70            geometry_column_index,
71            to_type: options.to_type,
72        })
73    }
74}
75
76impl<R: Read> Iterator for CsvReader<R> {
77    type Item = Result<RecordBatch, ArrowError>;
78
79    fn next(&mut self) -> Option<Self::Item> {
80        let reader = &mut self.reader;
81        reader.next().map(move |batch| {
82            parse_batch(
83                batch,
84                self.output_schema.clone(),
85                self.geometry_column_index,
86                self.to_type.clone(),
87            )
88        })
89    }
90}
91
92impl<R: Read> arrow_array::RecordBatchReader for CsvReader<R> {
93    fn schema(&self) -> SchemaRef {
94        self.schema()
95    }
96}
97
98fn parse_batch(
99    batch: Result<RecordBatch, ArrowError>,
100    output_schema: SchemaRef,
101    geometry_column_index: usize,
102    to_type: GeoArrowType,
103) -> Result<RecordBatch, ArrowError> {
104    let batch = batch?;
105    let column = batch.column(geometry_column_index);
106
107    let parsed_arr = match column.data_type() {
108        DataType::Utf8 => {
109            let arr = WktArray::try_from((column.as_ref(), WktType::default()))?;
110            from_wkt(&arr, to_type)
111        }
112        DataType::LargeUtf8 => {
113            let arr = LargeWktArray::try_from((column.as_ref(), WktType::default()))?;
114            from_wkt(&arr, to_type)
115        }
116        DataType::Utf8View => {
117            let arr = WktViewArray::try_from((column.as_ref(), WktType::default()))?;
118            from_wkt(&arr, to_type)
119        }
120        _ => unreachable!(),
121    }?;
122
123    // Replace column in record batch
124    let mut columns = batch.columns().to_vec();
125    columns[geometry_column_index] = parsed_arr.into_array_ref();
126
127    RecordBatch::try_new(output_schema, columns)
128}
129
130fn find_geometry_column(
131    schema: &Schema,
132    geometry_column_name: Option<&str>,
133) -> GeoArrowResult<String> {
134    if let Some(geometry_col_name) = geometry_column_name {
135        if schema
136            .fields()
137            .iter()
138            .any(|field| field.name() == geometry_col_name)
139        {
140            Ok(geometry_col_name.to_string())
141        } else {
142            Err(ArrowError::CsvError(format!(
143                "CSV geometry column specified to have name '{}' but no such column found",
144                geometry_col_name
145            ))
146            .into())
147        }
148    } else {
149        let mut field_name: Option<String> = None;
150        for field in schema.fields().iter() {
151            if field.name().to_lowercase().as_str() == "geometry" {
152                field_name = Some(field.name().clone());
153            }
154        }
155        field_name.ok_or(
156            ArrowError::CsvError(
157                "No CSV geometry column name specified and no geometry column found.".to_string(),
158            )
159            .into(),
160        )
161    }
162}
163
164#[cfg(test)]
165mod tests {
166
167    use std::io::Cursor;
168
169    use arrow_csv::ReaderBuilder;
170    use arrow_csv::reader::Format;
171    use geo_traits::{CoordTrait, PointTrait};
172    use geoarrow_array::GeoArrowArrayAccessor;
173    use geoarrow_array::array::PointArray;
174    use geoarrow_schema::{Dimension, PointType};
175
176    use super::*;
177
178    #[test]
179    fn read_csv() {
180        let s = r#"
181address,type,datetime,report location,incident number
182904 7th Av,Car Fire,05/22/2019 12:55:00 PM,POINT (-122.329051 47.6069),F190051945
1839610 53rd Av S,Aid Response,05/22/2019 12:55:00 PM,POINT (-122.266529 47.515984),F190051946"#;
184
185        let format = Format::default().with_header(true);
186        let (schema, _num_read_records) = format.infer_schema(Cursor::new(s), None).unwrap();
187        let reader = ReaderBuilder::new(schema.into())
188            .with_format(format)
189            .build(Cursor::new(s))
190            .unwrap();
191
192        let point_type = PointType::new(Dimension::XY, Default::default());
193        let to_type = GeoArrowType::Point(point_type.clone());
194        let geo_options = CsvReaderOptions {
195            geometry_column_name: Some("report location".to_string()),
196            to_type: to_type.clone(),
197        };
198        let geo_reader = CsvReader::try_new(reader, geo_options).unwrap();
199
200        let batches: Vec<_> = geo_reader.collect::<Result<Vec<_>, _>>().unwrap();
201        let batch = batches.into_iter().next().unwrap();
202        let schema = batch.schema();
203        assert_eq!(schema.fields().len(), 5);
204
205        let geom_field = schema.field(3);
206        let actual = GeoArrowType::from_extension_field(geom_field).unwrap();
207        assert_eq!(actual, to_type);
208
209        let geom_array = batch.column(3);
210        let point_arr = PointArray::try_from((geom_array.as_ref(), point_type)).unwrap();
211        assert_eq!(point_arr.len(), 2);
212        let point1 = point_arr.value(0).unwrap();
213        assert_eq!(point1.coord().unwrap().x(), -122.329051);
214        assert_eq!(point1.coord().unwrap().y(), 47.6069);
215
216        let point2 = point_arr.value(1).unwrap();
217        assert_eq!(point2.coord().unwrap().x(), -122.266529);
218        assert_eq!(point2.coord().unwrap().y(), 47.515984);
219
220        // arrow_csv::reader::infer_schema_from_files(files, delimiter, max_read_records, has_header)
221        //         infer_schema_from_files(files, delimiter, max_read_records, has_header)
222    }
223}