use crate::{
de::{
val_reader::{AsciiValReader, BinValReader, ScalarReader},
RowDeserializer,
},
DeserializeError, ElementDef, PlyFormat, PlyHeader,
};
use byteorder::{BigEndian, LittleEndian};
use serde::{
de::{DeserializeSeed, Error, SeqAccess, Visitor},
Deserialize, Deserializer,
};
use std::{io::Cursor, marker::PhantomData};
pub struct PlyChunkedReader {
header: Option<PlyHeader>,
current_element_index: usize,
rows_parsed: usize,
data_buffer: Vec<u8>,
}
impl PlyChunkedReader {
pub fn new() -> Self {
Self {
header: None,
current_element_index: 0,
rows_parsed: 0,
data_buffer: Vec::new(),
}
}
pub fn buffer_mut(&mut self) -> &mut Vec<u8> {
&mut self.data_buffer
}
pub fn header(&mut self) -> Option<&PlyHeader> {
if self.header.is_none() {
let available_data = &self.data_buffer;
let mut cursor = Cursor::new(available_data);
let header = PlyHeader::parse(&mut cursor);
if let Ok(header) = header {
self.header = Some(header);
self.data_buffer.drain(..cursor.position() as usize);
}
}
self.header.as_ref()
}
pub fn next_chunk<T>(&mut self) -> Result<T, DeserializeError>
where
T: for<'de> Deserialize<'de>,
{
T::deserialize(self)
}
pub fn current_element(&mut self) -> Option<&ElementDef> {
let ind = self.current_element_index;
self.header().and_then(|e| e.elem_defs.get(ind))
}
pub fn rows_done(&self) -> usize {
self.rows_parsed
}
}
impl<'de> Deserializer<'de> for &'_ mut PlyChunkedReader {
type Error = DeserializeError;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
let _ = self.header();
let Some(header) = &self.header else {
return visitor.visit_seq(EmptySeq);
};
if self.current_element_index >= header.elem_defs.len() {
return Err(DeserializeError::custom("Ran out of elements"));
}
let elem_def = &header.elem_defs[self.current_element_index];
let mut cursor = Cursor::new(&self.data_buffer);
let remaining = elem_def.count - self.rows_parsed;
let (res, rows_remaining) = match header.format {
PlyFormat::Ascii => {
let mut seq = ChunkPlyReaderSeqVisitor {
remaining,
row: RowDeserializer::<_, AsciiValReader>::new(
&mut cursor,
&elem_def.properties,
),
};
let res = visitor.visit_seq(&mut seq)?;
(res, seq.remaining)
}
PlyFormat::BinaryLittleEndian => {
let mut seq = ChunkPlyReaderSeqVisitor {
remaining,
row: RowDeserializer::<_, BinValReader<LittleEndian>>::new(
&mut cursor,
&elem_def.properties,
),
};
let res = visitor.visit_seq(&mut seq)?;
(res, seq.remaining)
}
PlyFormat::BinaryBigEndian => {
let mut seq = ChunkPlyReaderSeqVisitor {
remaining,
row: RowDeserializer::<_, BinValReader<BigEndian>>::new(
&mut cursor,
&elem_def.properties,
),
};
let res = visitor.visit_seq(&mut seq)?;
(res, seq.remaining)
}
};
self.rows_parsed = elem_def.count - rows_remaining;
self.data_buffer.drain(..cursor.position() as usize);
if self.rows_parsed >= elem_def.count {
self.rows_parsed = 0;
self.current_element_index += 1;
}
Ok(res)
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
serde::forward_to_deserialize_any! {
bool i8 u8 i16 u16 i32 u32 f32 f64 i128 i64 u128 u64 char str string
bytes byte_buf unit unit_struct tuple
tuple_struct map struct enum identifier ignored_any option
}
}
struct EmptySeq;
impl<'de> SeqAccess<'de> for EmptySeq {
type Error = DeserializeError;
fn next_element_seed<T>(&mut self, _seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: serde::de::DeserializeSeed<'de>,
{
Ok(None)
}
}
struct ChunkPlyReaderSeqVisitor<'a, D: AsRef<[u8]>, S: ScalarReader> {
remaining: usize,
row: RowDeserializer<'a, Cursor<D>, S>,
}
impl<'de, D: AsRef<[u8]>, S: ScalarReader> SeqAccess<'de>
for &mut ChunkPlyReaderSeqVisitor<'_, D, S>
{
type Error = DeserializeError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: serde::de::DeserializeSeed<'de>,
{
if self.remaining == 0 {
return Ok(None);
}
let last_pos = self.row.reader.position();
match seed.deserialize(&mut self.row) {
Ok(element) => {
self.remaining -= 1;
Ok(Some(element))
}
Err(e) if e.0.kind() == std::io::ErrorKind::UnexpectedEof => {
self.row.reader.set_position(last_pos);
Ok(None)
}
Err(e) => Err(e)?,
}
}
fn size_hint(&self) -> Option<usize> {
Some(self.remaining)
}
}
impl Default for PlyChunkedReader {
fn default() -> Self {
Self::new()
}
}
pub struct RowVisitor<T, F: FnMut(T)> {
row_callback: F,
_row: PhantomData<T>,
}
impl<T, F: FnMut(T)> RowVisitor<T, F> {
#[must_use = "Please call deserialize(&mut file) to actually deserialize data"]
pub fn new(row_callback: F) -> Self {
Self {
row_callback,
_row: PhantomData,
}
}
}
impl<'de, T: Deserialize<'de>, F: FnMut(T)> DeserializeSeed<'de> for &mut RowVisitor<T, F> {
type Value = ();
fn deserialize<D: Deserializer<'de>>(self, deserializer: D) -> Result<(), D::Error> {
deserializer.deserialize_seq(self)
}
}
impl<'de, T: Deserialize<'de>, F: FnMut(T)> Visitor<'de> for &mut RowVisitor<T, F> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sequence of rows")
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<(), A::Error> {
while let Some(row) = seq.next_element()? {
(self.row_callback)(row);
}
Ok(())
}
}