Skip to main content

geodatafusion_geojson/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![cfg_attr(not(test), warn(unused_crate_dependencies))]
3#![doc(
4    html_logo_url = "https://github.com/geoarrow.png",
5    html_favicon_url = "https://github.com/geoarrow.png?size=32"
6)]
7
8pub mod file_format;
9
10pub use file_format::{GeoJsonFileFactory, GeoJsonFormat, GeoJsonFormatFactory};
11
12#[cfg(test)]
13mod tests {
14    use std::env::temp_dir;
15    use std::fs;
16    use std::sync::Arc;
17
18    use arrow_array::{Int32Array, RecordBatch, StringArray};
19    use arrow_schema::{DataType, Field, Schema};
20    use datafusion::catalog::MemTable;
21    use datafusion::execution::SessionStateBuilder;
22    use datafusion::prelude::SessionContext;
23    use geoarrow_array::GeoArrowArray;
24    use geoarrow_array::builder::PointBuilder;
25    use geoarrow_schema::{Dimension, PointType};
26    use wkt::wkt;
27
28    use super::*;
29
30    fn sample_table() -> (Vec<RecordBatch>, Arc<Schema>) {
31        let mut builder = PointBuilder::new(PointType::new(Dimension::XY, Default::default()));
32        builder.push_point(Some(&wkt!(POINT(1.0 2.0))));
33        builder.push_point(Some(&wkt!(POINT(3.0 4.0))));
34        let geometry = builder.finish();
35
36        let fields = vec![
37            Arc::new(Field::new("id", DataType::Int32, false)),
38            Arc::new(Field::new("name", DataType::Utf8, false)),
39            Arc::new(geometry.data_type().to_field("geometry", true)),
40        ];
41        let schema = Arc::new(Schema::new(fields));
42        let batch = RecordBatch::try_new(
43            schema.clone(),
44            vec![
45                Arc::new(Int32Array::from(vec![1, 2])) as _,
46                Arc::new(StringArray::from(vec!["Point A", "Point B"])) as _,
47                geometry.into_array_ref(),
48            ],
49        )
50        .unwrap();
51
52        (vec![batch], schema)
53    }
54
55    #[tokio::test]
56    async fn test_write_geojsonlines_sink() {
57        let file_format = Arc::new(GeoJsonFileFactory::new()); // Lines format
58        let state = SessionStateBuilder::new()
59            .with_file_formats(vec![file_format])
60            .build();
61        let ctx = SessionContext::new_with_state(state);
62
63        let (batches, schema) = sample_table();
64        let mem_table = Arc::new(MemTable::try_new(schema.clone(), vec![batches]).unwrap());
65        ctx.register_table("mem_table", mem_table).unwrap();
66
67        let file_path = temp_dir().join("test_geojsonlines_sink.geojsonl");
68
69        ctx.sql(&format!("COPY mem_table TO '{}';", file_path.display()))
70            .await
71            .unwrap()
72            .collect()
73            .await
74            .unwrap();
75
76        // Read the file and verify it's valid GeoJSON Lines
77        let contents = fs::read_to_string(&file_path).unwrap();
78        let lines: Vec<&str> = contents.lines().collect();
79
80        // Should have 2 lines for 2 features
81        assert_eq!(lines.len(), 2);
82
83        // Parse each line as a GeoJSON Feature
84        for (i, line) in lines.iter().enumerate() {
85            let feature: geojson::Feature = line.parse().unwrap();
86            let expected_id = i + 1;
87            let expected_name = format!("Point {}", if i == 0 { "A" } else { "B" });
88
89            assert_eq!(
90                feature.id,
91                Some(geojson::feature::Id::Number(serde_json::Number::from(
92                    expected_id
93                )))
94            );
95            assert_eq!(
96                feature.properties.as_ref().unwrap().get("name").unwrap(),
97                &serde_json::Value::String(expected_name)
98            );
99
100            // Check geometry
101            if let Some(ref geometry) = feature.geometry {
102                if let geojson::Value::Point(coords) = &geometry.value {
103                    let expected_x = if i == 0 { 1.0 } else { 3.0 };
104                    let expected_y = if i == 0 { 2.0 } else { 4.0 };
105                    assert_eq!(coords[0], expected_x);
106                    assert_eq!(coords[1], expected_y);
107                } else {
108                    panic!("Expected Point geometry");
109                }
110            } else {
111                panic!("Expected geometry");
112            }
113        }
114
115        fs::remove_file(&file_path).unwrap();
116    }
117
118    #[tokio::test]
119    async fn test_write_geojson_with_id_column() {
120        // Test with explicit ID column
121        let mut builder = PointBuilder::new(PointType::new(Dimension::XY, Default::default()));
122        builder.push_point(Some(&wkt!(POINT(5.0 6.0))));
123        let geometry = builder.finish();
124
125        let fields = vec![
126            Arc::new(Field::new("id", DataType::Int32, false)),
127            Arc::new(geometry.data_type().to_field("geometry", true)),
128            Arc::new(Field::new("value", DataType::Int32, false)),
129        ];
130        let schema = Arc::new(Schema::new(fields));
131        let batch = RecordBatch::try_new(
132            schema.clone(),
133            vec![
134                Arc::new(Int32Array::from(vec![42])) as _,
135                geometry.into_array_ref(),
136                Arc::new(Int32Array::from(vec![100])) as _,
137            ],
138        )
139        .unwrap();
140
141        let file_format = Arc::new(GeoJsonFileFactory::new());
142        let state = SessionStateBuilder::new()
143            .with_file_formats(vec![file_format])
144            .build();
145        let ctx = SessionContext::new_with_state(state);
146
147        let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap());
148        ctx.register_table("mem_table", mem_table).unwrap();
149
150        let file_path = temp_dir().join("test_geojson_with_id.geojsonl");
151
152        ctx.sql(&format!("COPY mem_table TO '{}';", file_path.display()))
153            .await
154            .unwrap()
155            .collect()
156            .await
157            .unwrap();
158
159        // Read and validate - should be a single line since it's one feature
160        let contents = fs::read_to_string(&file_path).unwrap();
161        let lines: Vec<&str> = contents.lines().collect();
162        assert_eq!(lines.len(), 1);
163
164        // Parse the single line as a GeoJSON Feature
165        let feature: geojson::Feature = lines[0].parse().unwrap();
166
167        // Check the id field is at feature level
168        assert_eq!(
169            feature.id,
170            Some(geojson::feature::Id::Number(serde_json::Number::from(42)))
171        );
172
173        // Check properties only contains "value", not "id"
174        let props = feature.properties.as_ref().unwrap();
175        assert_eq!(
176            props.get("value").unwrap(),
177            &serde_json::Value::Number(serde_json::Number::from(100))
178        );
179        assert!(props.get("id").is_none()); // id should not be in properties
180
181        fs::remove_file(&file_path).unwrap();
182    }
183}