cityjson-arrow 0.6.1

Arrow IPC and Parquet transport for CityJSON 2.0 city models
Documentation
use crate::codec::WriteReport;
use crate::convert::emit_tables;
use crate::error::{Error, Result};
use crate::internal::{build_parts_from_tables, emit_part_tables};
use crate::schema::{CityArrowHeader, CityModelArrowParts, ProjectionLayout, canonical_schema_set};
use crate::transport::{
    CanonicalTable, CanonicalTableSink, schema_for_table, single_or_concat_batches, validate_schema,
};
use arrow::ipc::reader::StreamReader;
use arrow::ipc::writer::StreamWriter;
use arrow::record_batch::RecordBatch;
use cityjson::relational::ModelRelationalView;
use serde::{Deserialize, Serialize};
use std::io::{ErrorKind, Read, Write};

const STREAM_MAGIC: &[u8] = b"CITYJSON_ARROW_STREAM_V3\0";
const STREAM_END_TAG: u8 = u8::MAX;

#[derive(Debug, Clone, Serialize, Deserialize)]
struct StreamPrelude {
    header: CityArrowHeader,
    projection: ProjectionLayout,
}

type StreamFrames = Vec<(CanonicalTable, usize, RecordBatch)>;

pub(crate) fn write_model_stream<W: Write>(
    relational: &ModelRelationalView<'_>,
    writer: W,
) -> Result<WriteReport> {
    let mut sink = StreamSink::new(writer);
    emit_tables(relational, &mut sink)?;
    sink.finish()
}

pub(crate) fn write_parts_stream<W: Write>(parts: &CityModelArrowParts, writer: W) -> Result<()> {
    let mut sink = StreamSink::new(writer);
    emit_part_tables(parts, &mut sink)?;
    let _ = sink.finish()?;
    Ok(())
}

pub(crate) fn read_parts_stream<R: Read>(mut reader: R) -> Result<CityModelArrowParts> {
    let (prelude, tables) = read_stream_frames(&mut reader)?;
    let ordered_tables = tables
        .into_iter()
        .map(|(table, expected_rows, batch)| {
            if batch.num_rows() == expected_rows {
                Ok((table, batch))
            } else {
                Err(Error::Conversion(format!(
                    "{} frame declared {expected_rows} rows but decoded {} rows",
                    table.as_str(),
                    batch.num_rows()
                )))
            }
        })
        .collect::<Result<Vec<_>>>()?;
    build_parts_from_tables(&prelude.header, &prelude.projection, ordered_tables)
}

struct StreamSink<W> {
    writer: CountingWriter<W>,
    started: bool,
    batch_count: usize,
    row_count: usize,
}

impl<W> StreamSink<W> {
    fn new(writer: W) -> Self {
        Self {
            writer: CountingWriter::new(writer),
            started: false,
            batch_count: 0,
            row_count: 0,
        }
    }
}

impl<W: Write> StreamSink<W> {
    fn finish(&mut self) -> Result<WriteReport> {
        if self.started {
            self.writer.write_all(&[STREAM_END_TAG])?;
        }
        Ok(WriteReport {
            batch_count: self.batch_count,
            row_count: self.row_count,
            payload_bytes: self.writer.bytes_written(),
        })
    }
}

impl<W: Write> CanonicalTableSink for StreamSink<W> {
    fn start(&mut self, header: &CityArrowHeader, projection: &ProjectionLayout) -> Result<()> {
        let prelude_bytes = serde_json::to_vec(&StreamPrelude {
            header: header.clone(),
            projection: projection.clone(),
        })?;
        let prelude_len = u64::try_from(prelude_bytes.len())
            .map_err(|_| Error::Conversion("stream prelude length overflow".to_string()))?;

        self.writer.write_all(STREAM_MAGIC)?;
        self.writer.write_all(&prelude_len.to_le_bytes())?;
        self.writer.write_all(&prelude_bytes)?;
        self.started = true;
        Ok(())
    }

    fn push_batch(&mut self, table: CanonicalTable, batch: RecordBatch) -> Result<()> {
        if !self.started {
            return Err(Error::Unsupported(
                "stream sink must be started before writing table batches".to_string(),
            ));
        }
        self.writer.write_all(&[table.stream_tag()])?;
        self.writer.write_all(
            &u64::try_from(batch.num_rows())
                .map_err(|_| Error::Conversion("stream row count overflow".to_string()))?
                .to_le_bytes(),
        )?;
        write_stream_batch(&mut self.writer, &batch)?;
        self.batch_count += 1;
        self.row_count += batch.num_rows();
        Ok(())
    }
}

fn read_stream_prelude<R: Read>(reader: &mut R) -> Result<StreamPrelude> {
    let mut magic = vec![0_u8; STREAM_MAGIC.len()];
    reader.read_exact(&mut magic)?;
    if magic != STREAM_MAGIC {
        return Err(Error::Unsupported(
            "stream header magic is invalid".to_string(),
        ));
    }

    let prelude_len = usize::try_from(read_u64(reader)?).map_err(|_| {
        Error::Conversion("stream prelude length does not fit in memory".to_string())
    })?;
    let mut prelude_bytes = vec![0_u8; prelude_len];
    reader.read_exact(&mut prelude_bytes)?;
    serde_json::from_slice(&prelude_bytes).map_err(Error::from)
}

pub(crate) fn read_stream_batches<R: Read>(
    mut reader: R,
) -> Result<(CityArrowHeader, ProjectionLayout, StreamFrames)> {
    let (prelude, tables) = read_stream_frames(&mut reader)?;
    Ok((prelude.header, prelude.projection, tables))
}

fn read_stream_frames<R: Read>(reader: &mut R) -> Result<(StreamPrelude, StreamFrames)> {
    let prelude = read_stream_prelude(reader)?;
    let schemas = canonical_schema_set(&prelude.projection);
    let mut tables = Vec::new();
    loop {
        let tag = read_u8(reader)?;
        if tag == STREAM_END_TAG {
            break;
        }
        let table = CanonicalTable::from_stream_tag(tag)?;
        let expected_rows = usize::try_from(read_u64(reader)?).map_err(|_| {
            Error::Conversion("stream row count does not fit in memory".to_string())
        })?;
        let batch = deserialize_stream_batch(reader, schema_for_table(&schemas, table), table)?;
        tables.push((table, expected_rows, batch));
    }
    Ok((prelude, tables))
}

fn read_u8<R: Read>(reader: &mut R) -> Result<u8> {
    let mut byte = [0_u8; 1];
    reader.read_exact(&mut byte).map_err(|error| {
        if error.kind() == ErrorKind::UnexpectedEof {
            Error::Unsupported("stream ended before the final frame marker".to_string())
        } else {
            Error::from(error)
        }
    })?;
    Ok(byte[0])
}

fn read_u64<R: Read>(reader: &mut R) -> Result<u64> {
    let mut bytes = [0_u8; 8];
    reader.read_exact(&mut bytes)?;
    Ok(u64::from_le_bytes(bytes))
}

fn write_stream_batch<W: Write>(writer: &mut W, batch: &RecordBatch) -> Result<()> {
    let mut stream = StreamWriter::try_new(writer, &batch.schema())?;
    stream.write(batch)?;
    stream.finish()?;
    Ok(())
}

fn deserialize_stream_batch<R: Read>(
    reader: &mut R,
    expected_schema: &arrow::datatypes::SchemaRef,
    table: CanonicalTable,
) -> Result<RecordBatch> {
    let mut stream = StreamReader::try_new(reader.by_ref(), None)?;
    let schema = stream.schema();
    validate_schema(expected_schema, &schema, table)?;
    let batch = single_or_concat_batches(expected_schema, &mut stream)?;
    Ok(batch)
}

struct CountingWriter<W> {
    inner: W,
    bytes_written: u64,
}

impl<W> CountingWriter<W> {
    const fn new(inner: W) -> Self {
        Self {
            inner,
            bytes_written: 0,
        }
    }

    const fn bytes_written(&self) -> u64 {
        self.bytes_written
    }
}

impl<W: Write> Write for CountingWriter<W> {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        let written = self.inner.write(buf)?;
        self.bytes_written += u64::try_from(written).expect("write count fits into u64");
        Ok(written)
    }

    fn flush(&mut self) -> std::io::Result<()> {
        self.inner.flush()
    }
}