osm_pbf/
serialize.rs

1use std::{io::Cursor, pin::Pin};
2
3use async_compression::tokio::bufread::*;
4use futures::{Stream, StreamExt};
5use quick_protobuf::{MessageWrite, Writer};
6use thiserror::Error;
7use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8
9use crate::{
10    protos::fileformat::{mod_Blob::OneOfdata as Data, Blob, BlobHeader},
11    FileBlock, BLOB_HEADER_MAX_LEN, BLOB_MAX_LEN, OSM_DATA_TYPE, OSM_HEADER_TYPE,
12};
13
14/// Encoder to use when writing to a PBF file
15///
16/// This will only include encoders from enabled features.
17#[derive(Debug, PartialEq, Eq, Clone, Copy)]
18pub enum Encoder {
19    #[cfg(feature = "zlib")]
20    Zlib,
21    #[cfg(feature = "zstd")]
22    Zstd,
23    #[cfg(feature = "lzma")]
24    Lzma,
25}
26
27/// Any error encountered in [serialize_osm_pbf]
28#[derive(Error, Debug)]
29pub enum SerializeError {
30    #[error(transparent)]
31    Io(#[from] tokio::io::Error),
32    #[error("BlobHeader is too long: {0} bytes")]
33    BlobHeaderExceedsMaxLength(usize),
34    #[error("Blob is too long: {0} bytes")]
35    BlobExceedsMaxLength(usize),
36    #[error("Failed to serialize protobuf message: {0}")]
37    Proto(#[from] quick_protobuf::Error),
38}
39
40/// Applies the requested compression scheme, if any
41fn encode<R: AsyncBufRead + Unpin + Send + 'static>(
42    reader: R,
43    encoder: Option<Encoder>,
44) -> Pin<Box<dyn AsyncRead + Send>> {
45    if let Some(encoder) = encoder {
46        match encoder {
47            #[cfg(feature = "zlib")]
48            Encoder::Zlib => Box::pin(ZlibEncoder::new(reader)),
49            #[cfg(feature = "zstd")]
50            Encoder::Zstd => Box::pin(ZstdEncoder::new(reader)),
51            #[cfg(feature = "lzma")]
52            Encoder::Lzma => Box::pin(LzmaEncoder::new(reader)),
53        }
54    } else {
55        Box::pin(reader)
56    }
57}
58
59/// Serialize a stream of [FileBlock]s in the PBF format
60pub async fn serialize_osm_pbf<W: AsyncWrite + Unpin + Send>(
61    mut blocks: impl Stream<Item = FileBlock> + Unpin + Send,
62    mut out: W,
63    encoder: Option<Encoder>,
64) -> Result<(), SerializeError> {
65    while let Some(block) = blocks.next().await {
66        let raw_size = match &block {
67            FileBlock::Header(header) => header.get_size(),
68            FileBlock::Primitive(primitive) => primitive.get_size(),
69            FileBlock::Other { bytes, .. } => bytes.len(),
70        };
71        if raw_size > BLOB_MAX_LEN {
72            return Err(SerializeError::BlobExceedsMaxLength(raw_size));
73        }
74
75        let (blob_bytes, type_pb) = match block {
76            FileBlock::Header(header) => {
77                (serialize_into_vec(&header)?, OSM_HEADER_TYPE.to_string())
78            }
79            FileBlock::Primitive(primitive) => {
80                (serialize_into_vec(&primitive)?, OSM_DATA_TYPE.to_string())
81            }
82            FileBlock::Other { r#type, bytes } => (bytes, r#type),
83        };
84
85        let blob_encoded = {
86            let cursor = Cursor::new(blob_bytes);
87            let mut encoded = encode(cursor, encoder);
88            let mut buf = vec![];
89            encoded.read_to_end(&mut buf).await?;
90            buf
91        };
92
93        let blob = Blob {
94            // Only set if this is compressed
95            raw_size: if encoder.is_some() {
96                Some(raw_size as i32)
97            } else {
98                None
99            },
100            data: match encoder {
101                #[cfg(feature = "zlib")]
102                Some(Encoder::Zlib) => Data::zlib_data(blob_encoded),
103                #[cfg(feature = "zstd")]
104                Some(Encoder::Zstd) => Data::zstd_data(blob_encoded),
105                #[cfg(feature = "lzma")]
106                Some(Encoder::Lzma) => Data::lzma_data(blob_encoded),
107                None => Data::raw(blob_encoded),
108            },
109        };
110
111        let blob_header = BlobHeader {
112            type_pb,
113            indexdata: None,
114            datasize: blob.get_size() as i32,
115        };
116
117        let blob_header_size = blob_header.get_size();
118        if blob_header_size > BLOB_HEADER_MAX_LEN {
119            return Err(SerializeError::BlobHeaderExceedsMaxLength(blob_header_size));
120        }
121
122        out.write_i32(blob_header_size as i32).await?;
123        out.write_all(&serialize_into_vec(&blob_header)?).await?;
124        out.write_all(&serialize_into_vec(&blob)?).await?;
125    }
126    Ok(())
127}
128
129/// [quick_protobuf::writer::serialize_into_vec] but doesn't write length
130fn serialize_into_vec<M: MessageWrite>(message: &M) -> Result<Vec<u8>, quick_protobuf::Error> {
131    let len = message.get_size();
132    let mut v = Vec::with_capacity(len);
133
134    {
135        let mut writer = Writer::new(&mut v);
136        message.write_message(&mut writer)?;
137    }
138    Ok(v)
139}