1use std::{io::Cursor, pin::Pin};
2
3use async_compression::tokio::bufread::*;
4use futures::{Stream, StreamExt};
5use quick_protobuf::{MessageWrite, Writer};
6use thiserror::Error;
7use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8
9use crate::{
10 protos::fileformat::{mod_Blob::OneOfdata as Data, Blob, BlobHeader},
11 FileBlock, BLOB_HEADER_MAX_LEN, BLOB_MAX_LEN, OSM_DATA_TYPE, OSM_HEADER_TYPE,
12};
13
14#[derive(Debug, PartialEq, Eq, Clone, Copy)]
18pub enum Encoder {
19 #[cfg(feature = "zlib")]
20 Zlib,
21 #[cfg(feature = "zstd")]
22 Zstd,
23 #[cfg(feature = "lzma")]
24 Lzma,
25}
26
27#[derive(Error, Debug)]
29pub enum SerializeError {
30 #[error(transparent)]
31 Io(#[from] tokio::io::Error),
32 #[error("BlobHeader is too long: {0} bytes")]
33 BlobHeaderExceedsMaxLength(usize),
34 #[error("Blob is too long: {0} bytes")]
35 BlobExceedsMaxLength(usize),
36 #[error("Failed to serialize protobuf message: {0}")]
37 Proto(#[from] quick_protobuf::Error),
38}
39
40fn encode<R: AsyncBufRead + Unpin + Send + 'static>(
42 reader: R,
43 encoder: Option<Encoder>,
44) -> Pin<Box<dyn AsyncRead + Send>> {
45 if let Some(encoder) = encoder {
46 match encoder {
47 #[cfg(feature = "zlib")]
48 Encoder::Zlib => Box::pin(ZlibEncoder::new(reader)),
49 #[cfg(feature = "zstd")]
50 Encoder::Zstd => Box::pin(ZstdEncoder::new(reader)),
51 #[cfg(feature = "lzma")]
52 Encoder::Lzma => Box::pin(LzmaEncoder::new(reader)),
53 }
54 } else {
55 Box::pin(reader)
56 }
57}
58
59pub async fn serialize_osm_pbf<W: AsyncWrite + Unpin + Send>(
61 mut blocks: impl Stream<Item = FileBlock> + Unpin + Send,
62 mut out: W,
63 encoder: Option<Encoder>,
64) -> Result<(), SerializeError> {
65 while let Some(block) = blocks.next().await {
66 let raw_size = match &block {
67 FileBlock::Header(header) => header.get_size(),
68 FileBlock::Primitive(primitive) => primitive.get_size(),
69 FileBlock::Other { bytes, .. } => bytes.len(),
70 };
71 if raw_size > BLOB_MAX_LEN {
72 return Err(SerializeError::BlobExceedsMaxLength(raw_size));
73 }
74
75 let (blob_bytes, type_pb) = match block {
76 FileBlock::Header(header) => {
77 (serialize_into_vec(&header)?, OSM_HEADER_TYPE.to_string())
78 }
79 FileBlock::Primitive(primitive) => {
80 (serialize_into_vec(&primitive)?, OSM_DATA_TYPE.to_string())
81 }
82 FileBlock::Other { r#type, bytes } => (bytes, r#type),
83 };
84
85 let blob_encoded = {
86 let cursor = Cursor::new(blob_bytes);
87 let mut encoded = encode(cursor, encoder);
88 let mut buf = vec![];
89 encoded.read_to_end(&mut buf).await?;
90 buf
91 };
92
93 let blob = Blob {
94 raw_size: if encoder.is_some() {
96 Some(raw_size as i32)
97 } else {
98 None
99 },
100 data: match encoder {
101 #[cfg(feature = "zlib")]
102 Some(Encoder::Zlib) => Data::zlib_data(blob_encoded),
103 #[cfg(feature = "zstd")]
104 Some(Encoder::Zstd) => Data::zstd_data(blob_encoded),
105 #[cfg(feature = "lzma")]
106 Some(Encoder::Lzma) => Data::lzma_data(blob_encoded),
107 None => Data::raw(blob_encoded),
108 },
109 };
110
111 let blob_header = BlobHeader {
112 type_pb,
113 indexdata: None,
114 datasize: blob.get_size() as i32,
115 };
116
117 let blob_header_size = blob_header.get_size();
118 if blob_header_size > BLOB_HEADER_MAX_LEN {
119 return Err(SerializeError::BlobHeaderExceedsMaxLength(blob_header_size));
120 }
121
122 out.write_i32(blob_header_size as i32).await?;
123 out.write_all(&serialize_into_vec(&blob_header)?).await?;
124 out.write_all(&serialize_into_vec(&blob)?).await?;
125 }
126 Ok(())
127}
128
129fn serialize_into_vec<M: MessageWrite>(message: &M) -> Result<Vec<u8>, quick_protobuf::Error> {
131 let len = message.get_size();
132 let mut v = Vec::with_capacity(len);
133
134 {
135 let mut writer = Writer::new(&mut v);
136 message.write_message(&mut writer)?;
137 }
138 Ok(v)
139}