use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use crate::Gpkg;
use crate::error::{GpkgError, Result};
use crate::gpkg::GpkgAttributeTable;
use crate::types::{ColumnSpec, ColumnType};
pub struct ArrowGpkgAttributeWriter<'a> {
gpkg: &'a Gpkg,
table_name: String,
state: Option<AttributeWriterState>,
}
struct AttributeWriterState {
insert_sql: String,
col_indices: Vec<usize>,
}
impl<'a> ArrowGpkgAttributeWriter<'a> {
pub fn new(gpkg: &'a Gpkg, table_name: &str) -> Result<Self> {
Ok(Self {
gpkg,
table_name: table_name.to_string(),
state: None,
})
}
pub fn write(&mut self, batch: &RecordBatch) -> Result<()> {
if self.state.is_none() {
self.initialize(batch.schema())?;
}
self.write_batch(batch)
}
fn initialize(&mut self, schema: SchemaRef) -> Result<()> {
let property_columns = build_property_columns(&schema)?;
self.gpkg
.create_attribute_table(&self.table_name, &property_columns)?;
let insert_sql = GpkgAttributeTable::build_insert_sql(&self.table_name, &property_columns);
let col_indices: Vec<usize> = (0..schema.fields().len()).collect();
self.state = Some(AttributeWriterState {
insert_sql,
col_indices,
});
Ok(())
}
fn write_batch(&self, batch: &RecordBatch) -> Result<()> {
let state = self
.state
.as_ref()
.expect("initialize must be called first");
let mut stmt = self.gpkg.conn.prepare_cached(&state.insert_sql)?;
let num_params = state.col_indices.len();
let mut params: Vec<rusqlite::types::Value> = Vec::with_capacity(num_params);
for row_idx in 0..batch.num_rows() {
params.clear();
for &col_idx in &state.col_indices {
let array = batch.column(col_idx);
params.push(extract_value(array, row_idx)?);
}
stmt.execute(rusqlite::params_from_iter(¶ms))?;
}
Ok(())
}
}
fn build_property_columns(schema: &SchemaRef) -> Result<Vec<ColumnSpec>> {
let mut columns = Vec::new();
for field in schema.fields().iter() {
let column_type = arrow_type_to_column_type(field.data_type())?;
columns.push(ColumnSpec {
name: field.name().clone(),
column_type,
});
}
Ok(columns)
}
fn arrow_type_to_column_type(dt: &arrow_schema::DataType) -> Result<ColumnType> {
use arrow_schema::DataType;
match dt {
DataType::Boolean => Ok(ColumnType::Boolean),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
Ok(ColumnType::Integer)
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
Ok(ColumnType::Integer)
}
DataType::Float32 | DataType::Float64 => Ok(ColumnType::Double),
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Ok(ColumnType::Varchar),
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Ok(ColumnType::Blob),
other => Err(GpkgError::GeoArrow(format!(
"Unsupported Arrow data type for GeoPackage column: {other:?}"
))),
}
}
fn extract_value(array: &dyn arrow_array::Array, row_idx: usize) -> Result<rusqlite::types::Value> {
if array.is_null(row_idx) {
return Ok(rusqlite::types::Value::Null);
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::BooleanArray>() {
return Ok(rusqlite::types::Value::Integer(if a.value(row_idx) {
1
} else {
0
}));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::Int8Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx) as i64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::Int16Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx) as i64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::Int32Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx) as i64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::Int64Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx)));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::UInt8Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx) as i64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::UInt16Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx) as i64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::UInt32Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx) as i64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::UInt64Array>() {
return Ok(rusqlite::types::Value::Integer(a.value(row_idx) as i64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::Float32Array>() {
return Ok(rusqlite::types::Value::Real(a.value(row_idx) as f64));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::Float64Array>() {
return Ok(rusqlite::types::Value::Real(a.value(row_idx)));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::StringArray>() {
return Ok(rusqlite::types::Value::Text(a.value(row_idx).to_string()));
}
if let Some(a) = array
.as_any()
.downcast_ref::<arrow_array::LargeStringArray>()
{
return Ok(rusqlite::types::Value::Text(a.value(row_idx).to_string()));
}
if let Some(a) = array
.as_any()
.downcast_ref::<arrow_array::StringViewArray>()
{
return Ok(rusqlite::types::Value::Text(a.value(row_idx).to_string()));
}
if let Some(a) = array.as_any().downcast_ref::<arrow_array::BinaryArray>() {
return Ok(rusqlite::types::Value::Blob(a.value(row_idx).to_vec()));
}
if let Some(a) = array
.as_any()
.downcast_ref::<arrow_array::LargeBinaryArray>()
{
return Ok(rusqlite::types::Value::Blob(a.value(row_idx).to_vec()));
}
if let Some(a) = array
.as_any()
.downcast_ref::<arrow_array::BinaryViewArray>()
{
return Ok(rusqlite::types::Value::Blob(a.value(row_idx).to_vec()));
}
Err(GpkgError::GeoArrow(format!(
"Unsupported Arrow array type: {:?}",
array.data_type()
)))
}
#[cfg(all(test, feature = "arrow"))]
mod tests {
use super::ArrowGpkgAttributeWriter;
use crate::Result;
use crate::arrow::attribute_reader::ArrowGpkgAttributeReader;
use crate::gpkg::Gpkg;
use arrow_array::{Int64Array, RecordBatch, StringArray};
use arrow_schema::{Field, Schema};
use std::sync::Arc;
#[test]
fn write_and_read_back_attribute_table() -> Result<()> {
let gpkg = Gpkg::open_in_memory()?;
let name_array = Arc::new(StringArray::from(vec!["alpha", "beta"]));
let value_array = Arc::new(Int64Array::from(vec![10, 20]));
let schema = Arc::new(Schema::new(vec![
Arc::new(Field::new("name", arrow_schema::DataType::Utf8, true)),
Arc::new(Field::new("value", arrow_schema::DataType::Int64, true)),
]));
let batch =
RecordBatch::try_new(schema, vec![name_array, value_array]).expect("valid batch");
let mut writer = ArrowGpkgAttributeWriter::new(&gpkg, "test_attrs")?;
writer.write(&batch)?;
let mut reader = ArrowGpkgAttributeReader::new(&gpkg, "test_attrs", 100)?;
let read_batch = reader.next().unwrap()?;
assert_eq!(read_batch.num_rows(), 2);
assert_eq!(read_batch.num_columns(), 2);
let names = read_batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(names.value(0), "alpha");
assert_eq!(names.value(1), "beta");
let values = read_batch
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(values.value(0), 10);
assert_eq!(values.value(1), 20);
Ok(())
}
#[test]
fn write_multiple_batches_attribute_table() -> Result<()> {
let gpkg = Gpkg::open_in_memory()?;
let schema = Arc::new(Schema::new(vec![Arc::new(Field::new(
"rank",
arrow_schema::DataType::Int64,
true,
))]));
let mut writer = ArrowGpkgAttributeWriter::new(&gpkg, "multi_batch_attrs")?;
for i in 0..3 {
let rank_array = Arc::new(Int64Array::from(vec![i as i64]));
let batch =
RecordBatch::try_new(schema.clone(), vec![rank_array]).expect("valid batch");
writer.write(&batch)?;
}
let table = gpkg.get_attribute_table("multi_batch_attrs")?;
let rows = table.rows()?;
assert_eq!(rows.len(), 3);
Ok(())
}
}